package dns import ( "database/sql" "fmt" "log" "net/http" "strconv" "strings" "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/external_dns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) const MaxUserRecords = 100 var UserOwnedInternalFmtDomains = []string{"%s", "%s.endpoints"} func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { if context.User == nil { return failure(context, req, resp) } dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } (*context.TemplateData)["DNSRecords"] = dnsRecords return success(context, req, resp) } } func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { formErrors := types.BannerMessages{ Messages: []string{}, } dnsRecord := &database.DNSRecord{} id := req.FormValue("id") isNewRecord := id == "" if !isNewRecord { retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) formErrors.Messages = append(formErrors.Messages, "error getting record from id") } else { dnsRecord = retrievedDnsRecord } } else { dnsRecord.UserID = context.User.ID } dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true" dnsRecord.Name = req.FormValue("name") if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") { dnsRecord.Name += "." } recordType := req.FormValue("type") dnsRecord.Type = strings.ToUpper(recordType) dnsRecord.Content = req.FormValue("content") ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) if err != nil { resp.WriteHeader(http.StatusBadRequest) formErrors.Messages = append(formErrors.Messages, "invalid ttl") } dnsRecord.TTL = ttlNum dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } if dnsRecordCount >= maxUserRecords { resp.WriteHeader(http.StatusTooManyRequests) formErrors.Messages = append(formErrors.Messages, "max records reached") } if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { resp.WriteHeader(http.StatusUnauthorized) formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } if isNewRecord && len(formErrors.Messages) == 0 { if dnsRecord.Internal { dnsRecord.ID = utils.RandomId() } else { dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord) if err != nil { log.Println("error creating external dns record", err) resp.WriteHeader(http.StatusInternalServerError) formErrors.Messages = append(formErrors.Messages, err.Error()) } } } if !isNewRecord && len(formErrors.Messages) == 0 { if !dnsRecord.Internal { err = externalDnsAdapter.UpdateDNSRecord(dnsRecord) if err != nil { log.Println("error updating external dns record", err) resp.WriteHeader(http.StatusInternalServerError) formErrors.Messages = append(formErrors.Messages, err.Error()) } } } if len(formErrors.Messages) == 0 { _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) if err != nil { log.Println(err) formErrors.Messages = append(formErrors.Messages, "error saving record") } } if len(formErrors.Messages) == 0 { formSuccess := types.BannerMessages{ Messages: []string{"record saved."}, } (*context.TemplateData)["Success"] = formSuccess return success(context, req, resp) } log.Println(formErrors.Messages) (*context.TemplateData)["Error"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord return failure(context, req, resp) } } } func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { recordId := req.FormValue("id") record, err := database.GetDNSRecord(context.DBConn, recordId) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } if !(record.UserID == context.User.ID) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } if !record.Internal { err = externalDnsAdapter.DeleteDNSRecord(recordId) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } } err = database.DeleteDNSRecord(context.DBConn, recordId) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } formSuccess := types.BannerMessages{ Messages: []string{"record deleted."}, } (*context.TemplateData)["Success"] = formSuccess return success(context, req, resp) } } } func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { ownedByUser := (user.ID == record.UserID) if !ownedByUser { return false } if !record.Internal { for _, format := range ownedInternalDomainFormats { domain := fmt.Sprintf(format, user.Username) isInSubDomain := strings.HasSuffix(record.Name, "."+domain) if domain == record.Name || isInSubDomain { return true } } return false } owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) if err != nil { log.Println(err) return false } userIsOwnerOfDomain := owner == user.ID return ownedByUser && userIsOwnerOfDomain }