Compare commits
6 Commits
c32ca84e8a
...
e398cf0540
Author | SHA1 | Date |
---|---|---|
Elizabeth | e398cf0540 | |
Elizabeth | b74a955dcb | |
Elizabeth | 8c7d9b3762 | |
Elizabeth | 47cc8feefa | |
Elizabeth | da6b6011fc | |
Elizabeth | cc33a90bfd |
|
@ -14,15 +14,20 @@ type CloudflareDNSResponse struct {
|
||||||
Result database.DNSRecord `json:"result"`
|
Result database.DNSRecord `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) {
|
type CloudflareExternalDNSAdapter struct {
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId)
|
ZoneId string
|
||||||
|
APIToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||||
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
||||||
payload := strings.NewReader(reqBody)
|
payload := strings.NewReader(reqBody)
|
||||||
|
|
||||||
req, _ := http.NewRequest("POST", url, payload)
|
req, _ := http.NewRequest("POST", url, payload)
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+apiToken)
|
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
|
||||||
return result.ID, nil
|
return result.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteDNSRecord(zoneId string, apiToken string, id string) error {
|
func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id)
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
|
||||||
|
|
||||||
req, _ := http.NewRequest("DELETE", url, nil)
|
req, _ := http.NewRequest("DELETE", url, nil)
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+apiToken)
|
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
package external_dns
|
||||||
|
|
||||||
|
import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
|
||||||
|
type ExternalDNSAdapter interface {
|
||||||
|
CreateDNSRecord(record *database.DNSRecord) (string, error)
|
||||||
|
DeleteDNSRecord(id string) error
|
||||||
|
}
|
|
@ -50,7 +50,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func InterceptOauthCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
state := req.URL.Query().Get("state")
|
state := req.URL.Query().Get("state")
|
||||||
code := req.URL.Query().Get("code")
|
code := req.URL.Query().Get("code")
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
package api_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *api.RequestContext, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
context := &api.RequestContext{
|
||||||
|
DBConn: testDb,
|
||||||
|
Args: &args.Arguments{},
|
||||||
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return testDb, context, func() {
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
todo: test api key creation
|
||||||
|
+ api key attached to user
|
||||||
|
+ user session is unique
|
||||||
|
+ goLogin goes to page in cookie
|
||||||
|
*/
|
214
api/dns.go
214
api/dns.go
|
@ -8,30 +8,25 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MAX_USER_RECORDS = 65
|
const MAX_USER_RECORDS = 65
|
||||||
|
|
||||||
type FormError struct {
|
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
||||||
Errors []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool {
|
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
|
||||||
ownedByUser := (user.ID == record.UserID)
|
ownedByUser := (user.ID == record.UserID)
|
||||||
if !ownedByUser {
|
if !ownedByUser {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !record.Internal {
|
if !record.Internal {
|
||||||
userOwnedDomains := []string{
|
for _, format := range ownedInternalDomainFormats {
|
||||||
fmt.Sprintf("%s", user.Username),
|
domain := fmt.Sprintf(format, user.Username)
|
||||||
fmt.Sprintf("%s.endpoints", user.Username),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range userOwnedDomains {
|
|
||||||
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
||||||
if domain == record.Name || isInSubDomain {
|
if domain == record.Name || isInSubDomain {
|
||||||
return true
|
return true
|
||||||
|
@ -64,116 +59,119 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
formErrors := FormError{
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
Errors: []string{},
|
formErrors := FormError{
|
||||||
}
|
Errors: []string{},
|
||||||
|
|
||||||
internal := req.FormValue("internal") == "on"
|
|
||||||
name := req.FormValue("name")
|
|
||||||
if internal && !strings.HasSuffix(name, ".") {
|
|
||||||
name += "."
|
|
||||||
}
|
|
||||||
|
|
||||||
recordType := req.FormValue("type")
|
|
||||||
recordType = strings.ToUpper(recordType)
|
|
||||||
|
|
||||||
recordContent := req.FormValue("content")
|
|
||||||
ttl := req.FormValue("ttl")
|
|
||||||
ttlNum, err := strconv.Atoi(ttl)
|
|
||||||
if err != nil {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
if dnsRecordCount >= MAX_USER_RECORDS {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecord := &database.DNSRecord{
|
|
||||||
UserID: context.User.ID,
|
|
||||||
Name: name,
|
|
||||||
Type: recordType,
|
|
||||||
Content: recordContent,
|
|
||||||
TTL: ttlNum,
|
|
||||||
Internal: internal,
|
|
||||||
}
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
if dnsRecord.Internal {
|
|
||||||
dnsRecord.ID = utils.RandomId()
|
|
||||||
} else {
|
|
||||||
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecord.ID = cloudflareRecordId
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
internal := req.FormValue("internal") == "on"
|
||||||
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
name := req.FormValue("name")
|
||||||
|
if internal && !strings.HasSuffix(name, ".") {
|
||||||
|
name += "."
|
||||||
|
}
|
||||||
|
|
||||||
|
recordType := req.FormValue("type")
|
||||||
|
recordType = strings.ToUpper(recordType)
|
||||||
|
|
||||||
|
recordContent := req.FormValue("content")
|
||||||
|
ttl := req.FormValue("ttl")
|
||||||
|
ttlNum, err := strconv.Atoi(ttl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||||
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
||||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
(*context.TemplateData)["FormError"] = &formErrors
|
|
||||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
|
||||||
|
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
recordId := req.FormValue("id")
|
|
||||||
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
|
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !record.Internal {
|
|
||||||
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
if dnsRecordCount >= MAX_USER_RECORDS {
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
||||||
|
}
|
||||||
|
|
||||||
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
dnsRecord := &database.DNSRecord{
|
||||||
if err != nil {
|
UserID: context.User.ID,
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
Name: name,
|
||||||
|
Type: recordType,
|
||||||
|
Content: recordContent,
|
||||||
|
TTL: ttlNum,
|
||||||
|
Internal: internal,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) {
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
if dnsRecord.Internal {
|
||||||
|
dnsRecord.ID = utils.RandomId()
|
||||||
|
} else {
|
||||||
|
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
http.Redirect(resp, req, "/dns", http.StatusFound)
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
(*context.TemplateData)["FormError"] = &formErrors
|
||||||
|
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||||
|
|
||||||
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
}
|
||||||
return success(context, req, resp)
|
|
||||||
|
func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
recordId := req.FormValue("id")
|
||||||
|
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) {
|
||||||
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !record.Internal {
|
||||||
|
err = dnsAdapter.DeleteDNSRecord(recordId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
||||||
|
if err != nil {
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(resp, req, "/dns", http.StatusFound)
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
package api_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *api.RequestContext, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
context := &api.RequestContext{
|
||||||
|
DBConn: testDb,
|
||||||
|
Args: &args.Arguments{},
|
||||||
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return testDb, context, func() {
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
testUser := &database.User{
|
||||||
|
ID: "test",
|
||||||
|
Username: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
records, err := database.GetUserDNSRecords(db, context.User.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(records) > 0 {
|
||||||
|
t.Errorf("expected no records, got records")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
api.PutDNSRecordContinuation(context, r, w)(api.IdContinuation, api.IdContinuation)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
}
|
|
@ -1,8 +1,6 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -43,16 +41,11 @@ func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
name := req.FormValue("name")
|
name := req.FormValue("name")
|
||||||
message := req.FormValue("message")
|
message := req.FormValue("message")
|
||||||
hCaptchaResponse := req.FormValue("h-captcha-response")
|
|
||||||
|
|
||||||
formErrors := FormError{
|
formErrors := FormError{
|
||||||
Errors: []string{},
|
Errors: []string{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if hCaptchaResponse == "" {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "hCaptcha is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &database.GuestbookEntry{
|
entry := &database.GuestbookEntry{
|
||||||
ID: utils.RandomId(),
|
ID: utils.RandomId(),
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -60,22 +53,19 @@ func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp
|
||||||
}
|
}
|
||||||
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
||||||
|
|
||||||
err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse)
|
if len(formErrors.Errors) == 0 {
|
||||||
if err != nil {
|
_, err := database.SaveGuestbookEntry(context.DBConn, entry)
|
||||||
log.Println(err)
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed")
|
formErrors.Errors = append(formErrors.Errors, "failed to save entry")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(formErrors.Errors) > 0 {
|
if len(formErrors.Errors) > 0 {
|
||||||
(*context.TemplateData)["FormError"] = formErrors
|
(*context.TemplateData)["FormError"] = formErrors
|
||||||
(*context.TemplateData)["EntryForm"] = entry
|
(*context.TemplateData)["EntryForm"] = entry
|
||||||
return failure(context, req, resp)
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
}
|
|
||||||
|
|
||||||
_, err = database.SaveGuestbookEntry(context.DBConn, entry)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,46 +86,3 @@ func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
|
||||||
SiteKey: context.Args.HcaptchaSiteKey,
|
|
||||||
}
|
|
||||||
log.Println(context.Args.HcaptchaSiteKey)
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyHCaptcha(secret, response string) error {
|
|
||||||
verifyURL := "https://hcaptcha.com/siteverify"
|
|
||||||
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", verifyURL, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse := struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
}{}
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !jsonResponse.Success {
|
|
||||||
return fmt.Errorf("hcaptcha verification failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
package api_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *api.RequestContext, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
context := &api.RequestContext{
|
||||||
|
DBConn: testDb,
|
||||||
|
Args: &args.Arguments{},
|
||||||
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return testDb, context, func() {
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidGuestbookPutsInDatabase(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
entries, err := database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(entries) > 0 {
|
||||||
|
t.Errorf("expected no entries, got entries")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", ts.URL, nil)
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"name": {"test"},
|
||||||
|
"message": {"test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ts.Config.Handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status code 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err = database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Errorf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
if entries[0].Name != req.FormValue("name") {
|
||||||
|
t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
entries, err := database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(entries) > 0 {
|
||||||
|
t.Errorf("expected no entries, got entries")
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n"
|
||||||
|
invalidRequests := []struct {
|
||||||
|
name string
|
||||||
|
message string
|
||||||
|
}{
|
||||||
|
{"", "test"},
|
||||||
|
{"test", ""},
|
||||||
|
{"", ""},
|
||||||
|
{"test", reallyLongStringThatWouldTakeTooMuchSpace},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, form := range invalidRequests {
|
||||||
|
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"name": {form.name},
|
||||||
|
"message": {form.message},
|
||||||
|
}
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
||||||
|
|
||||||
|
if responseRecorder.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status code 400, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err = database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Errorf("expected 0 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyCaptcha(secret, response string) error {
|
||||||
|
verifyURL := "https://hcaptcha.com/siteverify"
|
||||||
|
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", verifyURL, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse := struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}{}
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !jsonResponse.Success {
|
||||||
|
return fmt.Errorf("hcaptcha verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
||||||
|
SiteKey: context.Args.HcaptchaSiteKey,
|
||||||
|
}
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CaptchaVerificationContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
hCaptchaResponse := req.FormValue("h-captcha-response")
|
||||||
|
secretKey := context.Args.HcaptchaSecret
|
||||||
|
|
||||||
|
err := verifyCaptcha(secretKey, hCaptchaResponse)
|
||||||
|
if err != nil {
|
||||||
|
(*context.TemplateData)["FormError"] = FormError{
|
||||||
|
Errors: []string{"hCaptcha verification failed"},
|
||||||
|
}
|
||||||
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
|
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
32
api/serve.go
32
api/serve.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
@ -23,6 +24,10 @@ type RequestContext struct {
|
||||||
User *database.User
|
User *database.User
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FormError struct {
|
||||||
|
Errors []string
|
||||||
|
}
|
||||||
|
|
||||||
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
||||||
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
||||||
|
|
||||||
|
@ -80,11 +85,15 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
fileServer := http.FileServer(http.Dir(argv.StaticPath))
|
fileServer := http.FileServer(http.Dir(argv.StaticPath))
|
||||||
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
|
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
|
||||||
|
|
||||||
|
cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
|
||||||
|
APIToken: argv.CloudflareToken,
|
||||||
|
ZoneId: argv.CloudflareZone,
|
||||||
|
}
|
||||||
|
|
||||||
makeRequestContext := func() *RequestContext {
|
makeRequestContext := func() *RequestContext {
|
||||||
return &RequestContext{
|
return &RequestContext{
|
||||||
DBConn: dbConn,
|
DBConn: dbConn,
|
||||||
Args: argv,
|
Args: argv,
|
||||||
|
|
||||||
TemplateData: &map[string]interface{}{},
|
TemplateData: &map[string]interface{}{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,7 +103,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
@ -106,12 +115,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
|
|
||||||
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
|
||||||
|
|
||||||
mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
requestContext := makeRequestContext()
|
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -126,12 +130,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
|
|
||||||
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation(cloudflareAdapter), GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -151,12 +155,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
|
|
||||||
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaVerificationContinuation, CaptchaVerificationContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -66,7 +66,6 @@ func TemplateContinuation(path string, showBase bool) Continuation {
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.WriteHeader(200)
|
|
||||||
resp.Header().Set("Content-Type", "text/html")
|
resp.Header().Set("Content-Type", "text/html")
|
||||||
resp.Write(html.Bytes())
|
resp.Write(html.Bytes())
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package hcdns
|
package hcdns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
@ -16,7 +16,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func randomPort() int {
|
func randomPort() int {
|
||||||
return rand.Intn(3000) + 1024
|
return rand.Intn(3000) + 5192
|
||||||
}
|
}
|
||||||
|
|
||||||
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
|
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
|
||||||
|
@ -60,73 +60,74 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
|
||||||
cname := &database.DNSRecord{
|
records := []*database.DNSRecord{
|
||||||
ID: "1",
|
{
|
||||||
UserID: "test",
|
ID: "0",
|
||||||
Name: "cname.internal.example.com.",
|
UserID: "test",
|
||||||
Type: "CNAME",
|
Name: "cname.internal.example.com.",
|
||||||
Content: "res.example.com.",
|
Type: "CNAME",
|
||||||
TTL: 300,
|
Content: "next.internal.example.com.",
|
||||||
Internal: true,
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
}, {
|
||||||
|
ID: "1",
|
||||||
|
UserID: "test",
|
||||||
|
Name: "next.internal.example.com.",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "res.example.com.",
|
||||||
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
UserID: "test",
|
||||||
|
Name: "res.example.com.",
|
||||||
|
Type: "A",
|
||||||
|
Content: "1.2.3.2",
|
||||||
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
a := &database.DNSRecord{
|
|
||||||
ID: "2",
|
for _, record := range records {
|
||||||
UserID: "test",
|
database.SaveDNSRecord(testDb, record)
|
||||||
Name: "res.example.com.",
|
|
||||||
Type: "A",
|
|
||||||
Content: "127.0.0.1",
|
|
||||||
TTL: 300,
|
|
||||||
Internal: true,
|
|
||||||
}
|
}
|
||||||
database.SaveDNSRecord(testDb, cname)
|
|
||||||
database.SaveDNSRecord(testDb, a)
|
|
||||||
|
|
||||||
qtype := dns.TypeA
|
qtype := dns.TypeA
|
||||||
domain := dns.Fqdn(cname.Name)
|
domain := dns.Fqdn("cname.internal.example.com.")
|
||||||
client := &dns.Client{}
|
client := &dns.Client{}
|
||||||
message := &dns.Msg{}
|
message := &dns.Msg{}
|
||||||
message.SetQuestion(domain, qtype)
|
message.SetQuestion(domain, qtype)
|
||||||
|
|
||||||
in, _, err := client.Exchange(message, *addr)
|
in, _, err := client.Exchange(message, *addr)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(in.Answer) != 2 {
|
if len(in.Answer) != 3 {
|
||||||
t.Fatalf("expected 2 answers, got %d", len(in.Answer))
|
t.Fatalf("expected 3 answers, got %d", len(in.Answer))
|
||||||
}
|
}
|
||||||
|
|
||||||
if in.Answer[0].Header().Name != cname.Name {
|
for i, record := range records {
|
||||||
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
|
if in.Answer[i].Header().Name != record.Name {
|
||||||
|
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
|
||||||
|
t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(in.Answer[i].Header().Ttl) != record.TTL {
|
||||||
|
t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !in.Authoritative {
|
||||||
|
t.Fatalf("expected authoritative response")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if in.Answer[1].Header().Name != a.Name {
|
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
|
||||||
t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name)
|
t.Fatalf("expected final record to be the A record with correct IP")
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[0].(*dns.CNAME).Target != a.Name {
|
|
||||||
t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[1].(*dns.A).A.String() != a.Content {
|
|
||||||
t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
|
|
||||||
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[1].Header().Rrtype != dns.TypeA {
|
|
||||||
t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(in.Answer[0].Header().Ttl) != cname.TTL {
|
|
||||||
t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !in.Authoritative {
|
|
||||||
t.Fatalf("expected authoritative response")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue