hatecomputers.club/api/dns.go

138 lines
4.1 KiB
Go

package api
import (
"log"
"net/http"
"strconv"
"strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
)
const MAX_USER_RECORDS = 20
type FormError struct {
Errors []string
}
func userCanFuckWithDNSRecord(user *database.User, record *database.DNSRecord) bool {
return user.ID == record.UserID && (record.Name == user.Username || strings.HasSuffix(record.Name, "."+user.Username))
}
func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
(*context.TemplateData)["DNSRecords"] = dnsRecords
return success(context, req, resp)
}
}
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
formErrors := FormError{
Errors: []string{},
}
name := req.FormValue("name")
recordType := req.FormValue("type")
recordContent := req.FormValue("content")
ttl := req.FormValue("ttl")
ttlNum, err := strconv.Atoi(ttl)
if err != nil {
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
}
dnsRecord := &database.DNSRecord{
UserID: context.User.ID,
Name: name,
Type: recordType,
Content: recordContent,
TTL: ttlNum,
}
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if len(dnsRecords) >= MAX_USER_RECORDS {
formErrors.Errors = append(formErrors.Errors, "max records reached")
}
if !userCanFuckWithDNSRecord(context.User, dnsRecord) {
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username)
}
if len(formErrors.Errors) == 0 {
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, err.Error())
}
dnsRecord.ID = cloudflareRecordId
}
if len(formErrors.Errors) == 0 {
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, "error saving record")
}
}
if len(formErrors.Errors) == 0 {
http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
(*context.TemplateData)["DNSRecords"] = dnsRecords
(*context.TemplateData)["FormError"] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
}
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
recordId := req.FormValue("id")
record, err := database.GetDNSRecord(context.DBConn, recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if !userCanFuckWithDNSRecord(context.User, record) {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
err = database.DeleteDNSRecord(context.DBConn, recordId)
if err != nil {
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
}