diff --git a/adapters/external_dns/cloudflare/cloudflare.go b/adapters/external_dns/cloudflare/cloudflare.go index c302037..4f9e208 100644 --- a/adapters/external_dns/cloudflare/cloudflare.go +++ b/adapters/external_dns/cloudflare/cloudflare.go @@ -21,15 +21,11 @@ type CloudflareExternalDNSAdapter struct { func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId) - - reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) - payload := strings.NewReader(reqBody) + payload := strings.NewReader(encodeDNSRecord(record)) req, _ := http.NewRequest("POST", url, payload) - req.Header.Add("Authorization", "Bearer "+adapter.APIToken) req.Header.Add("Content-Type", "application/json") - res, err := http.DefaultClient.Do(req) if err != nil { return "", err @@ -53,6 +49,27 @@ func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DN return result.ID, nil } +func (adapter *CloudflareExternalDNSAdapter) UpdateDNSRecord(record *database.DNSRecord) error { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, record.ID) + payload := strings.NewReader(encodeDNSRecord(record)) + + req, _ := http.NewRequest("PUT", url, payload) + req.Header.Add("Authorization", "Bearer "+adapter.APIToken) + req.Header.Add("Content-Type", "application/json") + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + return fmt.Errorf("error updating dns record: %s", body) + } + + return nil +} + func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error { url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id) @@ -74,3 +91,7 @@ func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error { return nil } + +func encodeDNSRecord(record *database.DNSRecord) string { + return fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) +} diff --git a/adapters/external_dns/external_dns.go b/adapters/external_dns/external_dns.go index c861283..3894f6c 100644 --- a/adapters/external_dns/external_dns.go +++ b/adapters/external_dns/external_dns.go @@ -4,5 +4,6 @@ import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" type ExternalDNSAdapter interface { CreateDNSRecord(record *database.DNSRecord) (string, error) + UpdateDNSRecord(record *database.DNSRecord) error DeleteDNSRecord(id string) error } diff --git a/api/dns/dns.go b/api/dns/dns.go index 7e9c7c7..c24fa4a 100644 --- a/api/dns/dns.go +++ b/api/dns/dns.go @@ -36,29 +36,48 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request } } -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { formErrors := types.BannerMessages{ Messages: []string{}, } - internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" - name := req.FormValue("name") - if internal && !strings.HasSuffix(name, ".") { - name += "." + dnsRecord := &database.DNSRecord{} + id := req.FormValue("id") + isNewRecord := id == "" + if !isNewRecord { + retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + formErrors.Messages = append(formErrors.Messages, "error getting record from id") + } else { + dnsRecord = retrievedDnsRecord + } + } else { + dnsRecord.UserID = context.User.ID + } + + dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true" + + dnsRecord.Name = req.FormValue("name") + if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") { + dnsRecord.Name += "." } recordType := req.FormValue("type") - recordType = strings.ToUpper(recordType) + dnsRecord.Type = strings.ToUpper(recordType) + + dnsRecord.Content = req.FormValue("content") - recordContent := req.FormValue("content") ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) if err != nil { resp.WriteHeader(http.StatusBadRequest) formErrors.Messages = append(formErrors.Messages, "invalid ttl") } + dnsRecord.TTL = ttlNum dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) if err != nil { @@ -71,27 +90,29 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max formErrors.Messages = append(formErrors.Messages, "max records reached") } - dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, - Internal: internal, - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { + if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { resp.WriteHeader(http.StatusUnauthorized) - formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } - if len(formErrors.Messages) == 0 { + if isNewRecord && len(formErrors.Messages) == 0 { if dnsRecord.Internal { dnsRecord.ID = utils.RandomId() } else { - dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord) if err != nil { - log.Println(err) + log.Println("error creating external dns record", err) + resp.WriteHeader(http.StatusInternalServerError) + formErrors.Messages = append(formErrors.Messages, err.Error()) + } + } + } + + if !isNewRecord && len(formErrors.Messages) == 0 { + if !dnsRecord.Internal { + err = externalDnsAdapter.UpdateDNSRecord(dnsRecord) + if err != nil { + log.Println("error updating external dns record", err) resp.WriteHeader(http.StatusInternalServerError) formErrors.Messages = append(formErrors.Messages, err.Error()) } @@ -108,20 +129,21 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max if len(formErrors.Messages) == 0 { formSuccess := types.BannerMessages{ - Messages: []string{"record added."}, + Messages: []string{"record saved."}, } (*context.TemplateData)["Success"] = formSuccess return success(context, req, resp) } - (*context.TemplateData)[""] = &formErrors + log.Println(formErrors.Messages) + (*context.TemplateData)["Error"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord return failure(context, req, resp) } } } -func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { recordId := req.FormValue("id") @@ -138,7 +160,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun } if !record.Internal { - err = dnsAdapter.DeleteDNSRecord(recordId) + err = externalDnsAdapter.DeleteDNSRecord(recordId) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go index 30baedf..c4c581b 100644 --- a/api/dns/dns_test.go +++ b/api/dns/dns_test.go @@ -57,6 +57,7 @@ func setup() (*sql.DB, *types.RequestContext, func()) { type SignallingExternalDnsAdapter struct { AddChannel chan *database.DNSRecord RmChannel chan string + UpdateChan chan *database.DNSRecord } func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { @@ -72,6 +73,12 @@ func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { return nil } +func (adapter *SignallingExternalDnsAdapter) UpdateDNSRecord(record *database.DNSRecord) error { + go func() { adapter.UpdateChan <- record }() + + return nil +} + func TestThatOwnerCanPutRecordInDomain(t *testing.T) { db, context, cleanup := setup() defer cleanup() @@ -172,6 +179,73 @@ func TestThatUserCanAddToPublicEndpoints(t *testing.T) { } } +func TestThatUserCanUpdateExistingRecord(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + updateChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + UpdateChan: updateChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + responseRecorder := httptest.NewRecorder() + nonexistantRecord := httptest.NewRequest("POST", testServer.URL, nil) + + id := "1" + name := "test." + context.User.Username + nonexistantRecord.Form = map[string][]string{ + "id": {id}, + "internal": {"off"}, + "name": {name}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"new.domain."}, + } + + testServer.Config.Handler.ServeHTTP(responseRecorder, nonexistantRecord) + if responseRecorder.Code != http.StatusInternalServerError { + t.Errorf("expected internal server error return, got %d", responseRecorder.Code) + } + + record := &database.DNSRecord{ + ID: id, + Internal: false, + Name: name, + Type: "CNAME", + Content: "test.domain.", + TTL: 43000, + UserID: context.User.ID, + } + _, err := database.SaveDNSRecord(db, record) + if err != nil { + t.Error(err) + } + + existantRecord := nonexistantRecord + existantRecordRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(existantRecordRecorder, existantRecord) + if existantRecordRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", existantRecordRecorder.Code) + } + select { + case req := <-updateChannel: + newRecord, err := database.GetDNSRecord(db, req.ID) + if err != nil { + t.Error(err) + } + if newRecord.Content != "new.domain." { + t.Errorf("expected updated record, got %s", newRecord.Content) + } + case <-time.After(100 * time.Millisecond): + t.Errorf("expected updated record channel") + } +} + func TestThatExternalDnsSaves(t *testing.T) { db, context, cleanup := setup() defer cleanup() diff --git a/api/template/template.go b/api/template/template.go index 2875649..ad6a573 100644 --- a/api/template/template.go +++ b/api/template/template.go @@ -46,7 +46,7 @@ func renderTemplate(context *types.RequestContext, templateName string, showBase func TemplateContinuation(path string, showBase bool) types.Continuation { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { - html, err := renderTemplate(context, path, true) + html, err := renderTemplate(context, path, showBase) if errors.Is(err, os.ErrNotExist) { resp.WriteHeader(404) html, err = renderTemplate(context, "404.html", true)