testing | dont be recursive for external domains | finalize oauth #5

Merged
simponic merged 24 commits from dont-be-authoritative into main 2024-04-06 15:43:19 -04:00
4 changed files with 129 additions and 107 deletions
Showing only changes of commit cc33a90bfd - Show all commits

View File

@ -14,15 +14,20 @@ type CloudflareDNSResponse struct {
Result database.DNSRecord `json:"result"` Result database.DNSRecord `json:"result"`
} }
func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) { type CloudflareExternalDNSAdapter struct {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) ZoneId string
APIToken string
}
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
payload := strings.NewReader(reqBody) payload := strings.NewReader(reqBody)
req, _ := http.NewRequest("POST", url, payload) req, _ := http.NewRequest("POST", url, payload)
req.Header.Add("Authorization", "Bearer "+apiToken) req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
return result.ID, nil return result.ID, nil
} }
func DeleteDNSRecord(zoneId string, apiToken string, id string) error { func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id) url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
req, _ := http.NewRequest("DELETE", url, nil) req, _ := http.NewRequest("DELETE", url, nil)
req.Header.Add("Authorization", "Bearer "+apiToken) req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {

8
adapters/external_dns.go Normal file
View File

@ -0,0 +1,8 @@
package external_dns
import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
type ExternalDNSAdapter interface {
CreateDNSRecord(record *database.DNSRecord) (string, error)
DeleteDNSRecord(id string) error
}

View File

@ -8,7 +8,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
) )
@ -64,116 +64,119 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp
} }
} }
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
formErrors := FormError{ return func(success Continuation, failure Continuation) ContinuationChain {
Errors: []string{}, formErrors := FormError{
} Errors: []string{},
internal := req.FormValue("internal") == "on"
name := req.FormValue("name")
if internal && !strings.HasSuffix(name, ".") {
name += "."
}
recordType := req.FormValue("type")
recordType = strings.ToUpper(recordType)
recordContent := req.FormValue("content")
ttl := req.FormValue("ttl")
ttlNum, err := strconv.Atoi(ttl)
if err != nil {
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
}
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if dnsRecordCount >= MAX_USER_RECORDS {
formErrors.Errors = append(formErrors.Errors, "max records reached")
}
dnsRecord := &database.DNSRecord{
UserID: context.User.ID,
Name: name,
Type: recordType,
Content: recordContent,
TTL: ttlNum,
Internal: internal,
}
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
}
if len(formErrors.Errors) == 0 {
if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId()
} else {
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, err.Error())
}
dnsRecord.ID = cloudflareRecordId
} }
}
if len(formErrors.Errors) == 0 { internal := req.FormValue("internal") == "on"
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord) name := req.FormValue("name")
if internal && !strings.HasSuffix(name, ".") {
name += "."
}
recordType := req.FormValue("type")
recordType = strings.ToUpper(recordType)
recordContent := req.FormValue("content")
ttl := req.FormValue("ttl")
ttlNum, err := strconv.Atoi(ttl)
if err != nil { if err != nil {
log.Println(err) formErrors.Errors = append(formErrors.Errors, "invalid ttl")
formErrors.Errors = append(formErrors.Errors, "error saving record")
} }
}
if len(formErrors.Errors) == 0 { dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
(*context.TemplateData)["FormError"] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
}
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
recordId := req.FormValue("id")
record, err := database.GetDNSRecord(context.DBConn, recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
if !record.Internal {
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp) return failure(context, req, resp)
} }
} if dnsRecordCount >= MAX_USER_RECORDS {
formErrors.Errors = append(formErrors.Errors, "max records reached")
}
err = database.DeleteDNSRecord(context.DBConn, recordId) dnsRecord := &database.DNSRecord{
if err != nil { UserID: context.User.ID,
resp.WriteHeader(http.StatusInternalServerError) Name: name,
Type: recordType,
Content: recordContent,
TTL: ttlNum,
Internal: internal,
}
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
}
if len(formErrors.Errors) == 0 {
if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId()
} else {
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, err.Error())
}
}
}
if len(formErrors.Errors) == 0 {
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, "error saving record")
}
}
if len(formErrors.Errors) == 0 {
http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
(*context.TemplateData)["FormError"] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp) return failure(context, req, resp)
} }
}
http.Redirect(resp, req, "/dns", http.StatusFound) }
return success(context, req, resp)
func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
recordId := req.FormValue("id")
record, err := database.GetDNSRecord(context.DBConn, recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
if !record.Internal {
err = dnsAdapter.DeleteDNSRecord(recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
}
err = database.DeleteDNSRecord(context.DBConn, recordId)
if err != nil {
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
} }
} }

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"time" "time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
@ -80,6 +81,11 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
fileServer := http.FileServer(http.Dir(argv.StaticPath)) fileServer := http.FileServer(http.Dir(argv.StaticPath))
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600))) mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
APIToken: argv.CloudflareToken,
ZoneId: argv.CloudflareZone,
}
makeRequestContext := func() *RequestContext { makeRequestContext := func() *RequestContext {
return &RequestContext{ return &RequestContext{
DBConn: dbConn, DBConn: dbConn,
@ -126,12 +132,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation(cloudflareAdapter), GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {