From d9d39a01f24922b6de6ad65ceebcb3da501d2790 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Thu, 4 Apr 2024 15:08:50 -0600 Subject: [PATCH] dns api tests --- api/dns/dns.go | 22 +-- api/dns/dns_test.go | 395 +++++++++++++++++++++++++++++++++++++++++++- api/serve.go | 6 +- database/dns.go | 23 ++- 4 files changed, 422 insertions(+), 24 deletions(-) diff --git a/api/dns/dns.go b/api/dns/dns.go index 4805146..aa2f356 100644 --- a/api/dns/dns.go +++ b/api/dns/dns.go @@ -14,10 +14,6 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) -const MAX_USER_RECORDS = 65 - -var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} - func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { ownedByUser := (user.ID == record.UserID) if !ownedByUser { @@ -60,14 +56,14 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request } } -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { formErrors := types.FormError{ Errors: []string{}, } - internal := req.FormValue("internal") == "on" + internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" name := req.FormValue("name") if internal && !strings.HasSuffix(name, ".") { name += "." @@ -80,6 +76,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) if err != nil { + resp.WriteHeader(http.StatusBadRequest) formErrors.Errors = append(formErrors.Errors, "invalid ttl") } @@ -89,7 +86,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - if dnsRecordCount >= MAX_USER_RECORDS { + if dnsRecordCount >= maxUserRecords { + resp.WriteHeader(http.StatusTooManyRequests) formErrors.Errors = append(formErrors.Errors, "max records reached") } @@ -102,7 +100,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun Internal: internal, } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { + resp.WriteHeader(http.StatusUnauthorized) formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } @@ -113,6 +112,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) if err != nil { log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) formErrors.Errors = append(formErrors.Errors, err.Error()) } } @@ -127,14 +127,11 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun } if len(formErrors.Errors) == 0 { - http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } (*context.TemplateData)["FormError"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord - - resp.WriteHeader(http.StatusBadRequest) return failure(context, req, resp) } } @@ -151,7 +148,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { + if !(record.UserID == context.User.ID) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -171,7 +168,6 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } - http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } } diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go index cc56120..43dc680 100644 --- a/api/dns/dns_test.go +++ b/api/dns/dns_test.go @@ -2,18 +2,25 @@ package dns_test import ( "database/sql" + "fmt" "net/http" "net/http/httptest" "os" + "strconv" "testing" + "time" - // "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) +const MAX_USER_RECORDS = 10 + +var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} + func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { return success(context, req, resp) @@ -26,10 +33,19 @@ func setup() (*sql.DB, *types.RequestContext, func()) { testDb := database.MakeConn(&randomDb) database.Migrate(testDb) + user := &database.User{ + ID: "test", + Username: "test", + Mail: "test@test.com", + DisplayName: "test", + } + database.FindOrSaveUser(testDb, user) + context := &types.RequestContext{ DBConn: testDb, Args: &args.Arguments{}, TemplateData: &(map[string]interface{}{}), + User: user, } return testDb, context, func() { @@ -38,14 +54,33 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } +type SignallingExternalDnsAdapter struct { + AddChannel chan *database.DNSRecord + RmChannel chan string +} + +func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { + id := utils.RandomId() + go func() { adapter.AddChannel <- record }() + + return id, nil +} + +func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { + go func() { adapter.RmChannel <- id }() + + return nil +} + func TestThatOwnerCanPutRecordInDomain(t *testing.T) { db, context, cleanup := setup() defer cleanup() - _ = &database.User{ - ID: "test", - Username: "test", + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.domain.", } + domainOwner, _ = database.SaveDomainOwner(db, domainOwner) records, err := database.GetUserDNSRecords(db, context.User.ID) if err != nil { @@ -55,9 +90,353 @@ func TestThatOwnerCanPutRecordInDomain(t *testing.T) { t.Errorf("expected no records, got records") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // dns.CreateDNSRecordContinuation(context, r, w)(IdContinuation, IdContinuation) - })) - defer ts.Close() + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + validOwner := httptest.NewRequest("POST", testServer.URL, nil) + validOwner.Form = map[string][]string{ + "internal": {"on"}, + "name": {"new.test.domain."}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"test.domain."}, + } + + validOwnerRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) + if validOwnerRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) + } + + validOwnerNonInternalRecorder := httptest.NewRecorder() + validOwner.Form["internal"] = []string{"off"} + testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner) + if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized { + t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code) + } + + invalidOwnerRecorder := httptest.NewRecorder() + invalidOwner := validOwner + invalidOwner.Form["internal"] = []string{"on"} + invalidOwner.Form["name"] = []string{"new.invalid.domain."} + testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner) + if invalidOwnerRecorder.Code != http.StatusUnauthorized { + t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code) + } +} + +func TestThatUserCanAddToPublicEndpoints(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + responseRecorder := httptest.NewRecorder() + req := httptest.NewRequest("POST", testServer.URL, nil) + fmts := USER_OWNED_INTERNAL_FMT_DOMAINS + for _, format := range fmts { + name := fmt.Sprintf(format, context.User.Username) + + req.Form = map[string][]string{ + "internal": {"off"}, + "name": {name}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"test.domain."}, + } + + testServer.Config.Handler.ServeHTTP(responseRecorder, req) + if responseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", responseRecorder.Code) + } + + namedRecords, _ := database.FindDNSRecords(db, name, "CNAME") + if len(namedRecords) == 0 { + t.Errorf("saved record not found") + } + } +} + +func TestThatExternalDnsSaves(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + responseRecorder := httptest.NewRecorder() + externalRequest := httptest.NewRequest("POST", testServer.URL, nil) + + name := "test." + context.User.Username + externalRequest.Form = map[string][]string{ + "internal": {"off"}, + "name": {name}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"test.domain."}, + } + + testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) + if responseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", responseRecorder.Code) + } + select { + case res := <-addChannel: + if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." { + t.Errorf("received the wrong external record") + } + case <-time.After(100 * time.Millisecond): + t.Errorf("timed out in waiting for external addition") + } + + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.domain.", + } + domainOwner, _ = database.SaveDomainOwner(db, domainOwner) + internalRequest := externalRequest + internalRequest.Form["internal"] = []string{"on"} + internalRequest.Form["name"] = []string{"test.domain."} + + testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) + if responseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", responseRecorder.Code) + } + select { + case _ = <-addChannel: + t.Errorf("expected nothing in the add channel") + case <-time.After(100 * time.Millisecond): + } +} + +func TestThatUserMustOwnRecordToRemove(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + rmChannel := make(chan string) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + RmChannel: rmChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"} + _, err := database.FindOrSaveUser(db, nonOwnerUser) + if err != nil { + t.Error(err) + } + + record := &database.DNSRecord{ + ID: "1", + Internal: false, + Name: "test", + Type: "CNAME", + Content: "asdf", + TTL: 1000, + UserID: nonOwnerUser.ID, + } + _, err = database.SaveDNSRecord(db, record) + if err != nil { + t.Error(err) + } + + nonOwnerRecorder := httptest.NewRecorder() + nonOwner := httptest.NewRequest("POST", testServer.URL, nil) + nonOwner.Form = map[string][]string{ + "id": {record.ID}, + } + + testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner) + if nonOwnerRecorder.Code != http.StatusUnauthorized { + t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code) + } + + record.UserID = context.User.ID + record.ID = "2" + database.SaveDNSRecord(db, record) + + owner := nonOwner + owner.Form["id"] = []string{"2"} + ownerRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(ownerRecorder, owner) + if ownerRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", ownerRecorder.Code) + } +} + +func TestThatExternalDnsRemoves(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + record := &database.DNSRecord{ + ID: "1", + Internal: false, + Name: "test", + Type: "CNAME", + Content: "asdf", + TTL: 1000, + UserID: context.User.ID, + } + database.SaveDNSRecord(db, record) + + rmChannel := make(chan string) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + RmChannel: rmChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + externalResponseRecorder := httptest.NewRecorder() + deleteRequest := httptest.NewRequest("POST", testServer.URL, nil) + + deleteRequest.Form = map[string][]string{ + "id": {record.ID}, + } + + testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest) + if externalResponseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", externalResponseRecorder.Code) + } + select { + case res := <-rmChannel: + if res != record.ID { + t.Errorf("received the wrong external record") + } + case <-time.After(100 * time.Millisecond): + t.Errorf("timed out in waiting for external addition") + } + + record.Internal = true + record.Name = "test.domain." + database.SaveDNSRecord(db, record) + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.domain.", + } + database.SaveDomainOwner(db, domainOwner) + + internalResponseRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest) + if internalResponseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", internalResponseRecorder.Code) + } + select { + case _ = <-rmChannel: + t.Errorf("expected nothing in the rmchannel") + case <-time.After(100 * time.Millisecond): + } +} + +func TestRecordCountCannotExceed(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + record := &database.DNSRecord{ + Internal: false, + Name: context.User.Username, + Type: "CNAME", + Content: "asdf", + TTL: 1000, + UserID: context.User.ID, + } + + for i := 1; i <= MAX_USER_RECORDS; i++ { + record.ID = strconv.Itoa(i) + record.Name = record.ID + "." + record.Name + database.SaveDNSRecord(db, record) + } + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + req := httptest.NewRequest("POST", testServer.URL, nil) + req.Form = map[string][]string{ + "internal": {"off"}, + "name": {record.Name}, + "type": {record.Type}, + "ttl": {"43000"}, + "content": {record.Content}, + } + + recorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(recorder, req) + if recorder.Code != http.StatusTooManyRequests { + t.Errorf("expected too many requests code return, got %d", recorder.Code) + } +} + +func TestInternalRecordAppendsTopLevelDot(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.internal.", + } + database.SaveDomainOwner(db, domainOwner) + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + validOwner := httptest.NewRequest("POST", testServer.URL, nil) + validOwner.Form = map[string][]string{ + "internal": {"on"}, + "name": {"test.internal"}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"asdf.internal"}, + } + + validOwnerRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) + if validOwnerRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) + } + + recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME") + recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME") + + if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 { + t.Errorf("expected dot appended") + } } diff --git a/api/serve.go b/api/serve.go index 6d8c59c..2b0eba4 100644 --- a/api/serve.go +++ b/api/serve.go @@ -116,14 +116,16 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) + const MAX_USER_RECORDS = 100 + var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { diff --git a/database/dns.go b/database/dns.go index fc01347..7851ab4 100644 --- a/database/dns.go +++ b/database/dns.go @@ -9,6 +9,12 @@ import ( "time" ) +type DomainOwner struct { + UserID string `json:"user_id"` + Domain string `json:"domain"` + CreatedAt time.Time `json:"created_at"` +} + type DNSRecord struct { ID string `json:"id"` UserID string `json:"user_id"` @@ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) { func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) { log.Println("saving dns record", record.ID) - record.CreatedAt = time.Now() + if (record.CreatedAt == time.Time{}) { + record.CreatedAt = time.Now() + } + _, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt) if err != nil { @@ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err return records, nil } + +func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) { + log.Println("saving domain owner", domainOwner.Domain) + + domainOwner.CreatedAt = time.Now() + _, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt) + + if err != nil { + return nil, err + } + return domainOwner, nil +}