diff --git a/adapters/cloudflare/cloudflare.go b/adapters/cloudflare/cloudflare.go new file mode 100644 index 0000000..bfcbea6 --- /dev/null +++ b/adapters/cloudflare/cloudflare.go @@ -0,0 +1,73 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" +) + +type CloudflareDNSResponse struct { + Result database.DNSRecord `json:"result"` +} + +func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) + + reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) + log.Println(reqBody) + payload := strings.NewReader(reqBody) + + req, _ := http.NewRequest("POST", url, payload) + + req.Header.Add("Authorization", "Bearer "+apiToken) + req.Header.Add("Content-Type", "application/json") + + res, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + + if res.StatusCode != 200 { + return "", fmt.Errorf("error creating dns record: %s", body) + } + + var response CloudflareDNSResponse + err = json.Unmarshal(body, &response) + if err != nil { + return "", err + } + + result := &response.Result + + return result.ID, nil +} + +func DeleteDNSRecord(zoneId string, apiToken string, id string) error { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id) + + req, _ := http.NewRequest("DELETE", url, nil) + + req.Header.Add("Authorization", "Bearer "+apiToken) + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + + if res.StatusCode != 200 { + return fmt.Errorf("error deleting dns record: %s", body) + } + + return nil +} diff --git a/api/dns.go b/api/dns.go index 3105f91..0822fbc 100644 --- a/api/dns.go +++ b/api/dns.go @@ -3,10 +3,21 @@ package api import ( "log" "net/http" + "strconv" + "strings" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" ) +type FormError struct { + Errors []string +} + +func userCanFuckWithDNSRecord(user *database.User, record *database.DNSRecord) bool { + return user.ID == record.UserID && (record.Name == user.Username || strings.HasSuffix(record.Name, "."+user.Username)) +} + func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) @@ -17,7 +28,106 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp } (*context.TemplateData)["DNSRecords"] = dnsRecords - + return success(context, req, resp) + } +} + +func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + formErrors := FormError{ + Errors: []string{}, + } + + name := req.FormValue("name") + recordType := req.FormValue("type") + recordContent := req.FormValue("content") + ttl := req.FormValue("ttl") + ttlNum, err := strconv.Atoi(ttl) + + dnsRecord := &database.DNSRecord{ + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + } + + if err != nil { + formErrors.Errors = append(formErrors.Errors, "invalid ttl") + } + + if !userCanFuckWithDNSRecord(context.User, dnsRecord) { + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username) + } + + if len(formErrors.Errors) == 0 { + cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + + dnsRecord.ID = cloudflareRecordId + } + + if len(formErrors.Errors) == 0 { + _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "error saving record") + } + } + + if len(formErrors.Errors) == 0 { + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } + + dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["DNSRecords"] = dnsRecords + (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)["RecordForm"] = dnsRecord + + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } +} + +func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + recordId := req.FormValue("id") + record, err := database.GetDNSRecord(context.DBConn, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if !userCanFuckWithDNSRecord(context.User, record) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + err = database.DeleteDNSRecord(context.DBConn, recordId) + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } } diff --git a/api/serve.go b/api/serve.go index 38b65b2..09e2072 100644 --- a/api/serve.go +++ b/api/serve.go @@ -70,7 +70,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux := http.NewServeMux() fileServer := http.FileServer(http.Dir(argv.StaticPath)) - mux.Handle("/static/", http.StripPrefix("/static/", fileServer)) + mux.Handle("GET /static/", http.StripPrefix("/static/", fileServer)) makeRequestContext := func() *RequestContext { return &RequestContext{ @@ -81,7 +81,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { } } - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) @@ -116,6 +116,16 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) + mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateDNSRecordContinuation, GoLoginContinuation)(IdContinuation, TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() name := r.PathValue("name") diff --git a/database/dns.go b/database/dns.go index 17487b7..bb5c1ef 100644 --- a/database/dns.go +++ b/database/dns.go @@ -8,13 +8,13 @@ import ( ) type DNSRecord struct { - ID string - UserID string - Name string - Type string - Content string - TTL int - CreatedAt time.Time + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Type string `json:"type"` + Content string `json:"content"` + TTL int `json:"ttl"` + CreatedAt time.Time `json:"created_at"` } func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) { @@ -38,3 +38,37 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) { return records, nil } + +func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) { + log.Println("saving dns record", record) + + record.CreatedAt = time.Now() + _, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.CreatedAt) + + if err != nil { + return nil, err + } + return record, nil +} + +func GetDNSRecord(db *sql.DB, recordID string) (*DNSRecord, error) { + log.Println("getting dns record", recordID) + + row := db.QueryRow("SELECT * FROM dns_records WHERE id = ?", recordID) + var record DNSRecord + err := row.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.CreatedAt) + if err != nil { + return nil, err + } + return &record, nil +} + +func DeleteDNSRecord(db *sql.DB, recordID string) error { + log.Println("deleting dns record", recordID) + + _, err := db.Exec("DELETE FROM dns_records WHERE id = ?", recordID) + if err != nil { + return err + } + return nil +} diff --git a/static/css/colors.css b/static/css/colors.css index 69e3e4b..c68bf8e 100644 --- a/static/css/colors.css +++ b/static/css/colors.css @@ -5,6 +5,7 @@ --link-color-light: #d291bc; --container-bg-light: #fff7f87a; --border-color-light: #692fcc; + --error-color-light: #a83254; --background-color-dark: #333; --background-color-dark-2: #2c2c2c; @@ -12,6 +13,7 @@ --link-color-dark: #b86b77; --container-bg-dark: #424242ea; --border-color-dark: #956ade; + --error-color-dark: #851736; } [data-theme="DARK"] { @@ -21,6 +23,7 @@ --link-color: var(--link-color-dark); --container-bg: var(--container-bg-dark); --border-color: var(--border-color-dark); + --error-color: var(--error-color-dark); } [data-theme="LIGHT"] { @@ -30,4 +33,10 @@ --link-color: var(--link-color-light); --container-bg: var(--container-bg-light); --border-color: var(--border-color-light); + --error-color: var(--error-color-light); +} + +.error { + background-color: var(--error-color); + padding: 1rem; } diff --git a/static/css/form.css b/static/css/form.css index 4e14b68..1378d75 100644 --- a/static/css/form.css +++ b/static/css/form.css @@ -1,4 +1,4 @@ -form { +.form { max-width: 600px; padding: 1em; background: var(--background-color-2); diff --git a/static/css/table.css b/static/css/table.css index 640ad83..75a961d 100644 --- a/static/css/table.css +++ b/static/css/table.css @@ -11,8 +11,13 @@ td { border-bottom: 1px solid var(--border-color); } +th, +thead { + background-color: var(--background-color-2); +} + tbody tr:nth-child(odd) { - background-color: var(--link-color); + background-color: var(--background-color); color: var(--text-color); } diff --git a/templates/dns.html b/templates/dns.html index 0a40cab..e317d05 100644 --- a/templates/dns.html +++ b/templates/dns.html @@ -5,10 +5,11 @@