diff --git a/adapters/cloudflare/cloudflare.go b/adapters/cloudflare/cloudflare.go index 40b04a5..c302037 100644 --- a/adapters/cloudflare/cloudflare.go +++ b/adapters/cloudflare/cloudflare.go @@ -14,15 +14,20 @@ type CloudflareDNSResponse struct { Result database.DNSRecord `json:"result"` } -func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) { - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) +type CloudflareExternalDNSAdapter struct { + ZoneId string + APIToken string +} + +func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId) reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) payload := strings.NewReader(reqBody) req, _ := http.NewRequest("POST", url, payload) - req.Header.Add("Authorization", "Bearer "+apiToken) + req.Header.Add("Authorization", "Bearer "+adapter.APIToken) req.Header.Add("Content-Type", "application/json") res, err := http.DefaultClient.Do(req) @@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) return result.ID, nil } -func DeleteDNSRecord(zoneId string, apiToken string, id string) error { - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id) +func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id) req, _ := http.NewRequest("DELETE", url, nil) - req.Header.Add("Authorization", "Bearer "+apiToken) + req.Header.Add("Authorization", "Bearer "+adapter.APIToken) res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/adapters/external_dns.go b/adapters/external_dns.go new file mode 100644 index 0000000..c861283 --- /dev/null +++ b/adapters/external_dns.go @@ -0,0 +1,8 @@ +package external_dns + +import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + +type ExternalDNSAdapter interface { + CreateDNSRecord(record *database.DNSRecord) (string, error) + DeleteDNSRecord(id string) error +} diff --git a/api/dns.go b/api/dns.go index ad41103..6f0e1fd 100644 --- a/api/dns.go +++ b/api/dns.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) @@ -64,116 +64,119 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp } } -func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - formErrors := FormError{ - Errors: []string{}, - } - - internal := req.FormValue("internal") == "on" - name := req.FormValue("name") - if internal && !strings.HasSuffix(name, ".") { - name += "." - } - - recordType := req.FormValue("type") - recordType = strings.ToUpper(recordType) - - recordContent := req.FormValue("content") - ttl := req.FormValue("ttl") - ttlNum, err := strconv.Atoi(ttl) - if err != nil { - formErrors.Errors = append(formErrors.Errors, "invalid ttl") - } - - dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - if dnsRecordCount >= MAX_USER_RECORDS { - formErrors.Errors = append(formErrors.Errors, "max records reached") - } - - dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, - Internal: internal, - } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) { - formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") - } - - if len(formErrors.Errors) == 0 { - if dnsRecord.Internal { - dnsRecord.ID = utils.RandomId() - } else { - cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, err.Error()) - } - - dnsRecord.ID = cloudflareRecordId +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + formErrors := FormError{ + Errors: []string{}, } - } - if len(formErrors.Errors) == 0 { - _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) + internal := req.FormValue("internal") == "on" + name := req.FormValue("name") + if internal && !strings.HasSuffix(name, ".") { + name += "." + } + + recordType := req.FormValue("type") + recordType = strings.ToUpper(recordType) + + recordContent := req.FormValue("content") + ttl := req.FormValue("ttl") + ttlNum, err := strconv.Atoi(ttl) if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, "error saving record") + formErrors.Errors = append(formErrors.Errors, "invalid ttl") } - } - if len(formErrors.Errors) == 0 { - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) - } - - (*context.TemplateData)["FormError"] = &formErrors - (*context.TemplateData)["RecordForm"] = dnsRecord - - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } -} - -func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - recordId := req.FormValue("id") - record, err := database.GetDNSRecord(context.DBConn, recordId) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - if !record.Internal { - err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) + dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - } + if dnsRecordCount >= MAX_USER_RECORDS { + formErrors.Errors = append(formErrors.Errors, "max records reached") + } - err = database.DeleteDNSRecord(context.DBConn, recordId) - if err != nil { - resp.WriteHeader(http.StatusInternalServerError) + dnsRecord := &database.DNSRecord{ + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + Internal: internal, + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) { + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + } + + if len(formErrors.Errors) == 0 { + if dnsRecord.Internal { + dnsRecord.ID = utils.RandomId() + } else { + dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + } + } + + if len(formErrors.Errors) == 0 { + _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "error saving record") + } + } + + if len(formErrors.Errors) == 0 { + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } + + (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)["RecordForm"] = dnsRecord + + resp.WriteHeader(http.StatusBadRequest) return failure(context, req, resp) } - - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) + } +} + +func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + recordId := req.FormValue("id") + record, err := database.GetDNSRecord(context.DBConn, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + if !record.Internal { + err = dnsAdapter.DeleteDNSRecord(recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + } + + err = database.DeleteDNSRecord(context.DBConn, recordId) + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } } } diff --git a/api/serve.go b/api/serve.go index f71001d..1b632a1 100644 --- a/api/serve.go +++ b/api/serve.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" @@ -80,6 +81,11 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { fileServer := http.FileServer(http.Dir(argv.StaticPath)) mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600))) + cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{ + APIToken: argv.CloudflareToken, + ZoneId: argv.CloudflareZone, + } + makeRequestContext := func() *RequestContext { return &RequestContext{ DBConn: dbConn, @@ -126,12 +132,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation(cloudflareAdapter), GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {