package dns_test import ( "database/sql" "fmt" "net/http" "net/http/httptest" "os" "strconv" "testing" "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) const MAX_USER_RECORDS = 10 var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { return success(context, req, resp) } } func setup() (*sql.DB, *types.RequestContext, func()) { randomDb := utils.RandomId() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) user := &database.User{ ID: "test", Username: "test", Mail: "test@test.com", DisplayName: "test", } database.FindOrSaveBaseUser(testDb, user) context := &types.RequestContext{ DBConn: testDb, Args: &args.Arguments{}, TemplateData: &(map[string]interface{}{}), User: user, } return testDb, context, func() { testDb.Close() os.Remove(randomDb) } } type SignallingExternalDnsAdapter struct { AddChannel chan *database.DNSRecord RmChannel chan string } func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { id := utils.RandomId() go func() { adapter.AddChannel <- record }() return id, nil } func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { go func() { adapter.RmChannel <- id }() return nil } func TestThatOwnerCanPutRecordInDomain(t *testing.T) { db, context, cleanup := setup() defer cleanup() domainOwner := &database.DomainOwner{ UserID: context.User.ID, Domain: "test.domain.", } domainOwner, _ = database.SaveDomainOwner(db, domainOwner) records, err := database.GetUserDNSRecords(db, context.User.ID) if err != nil { t.Fatal(err) } if len(records) > 0 { t.Errorf("expected no records, got records") } addChannel := make(chan *database.DNSRecord) signallingDnsAdapter := &SignallingExternalDnsAdapter{ AddChannel: addChannel, } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) })) defer testServer.Close() validOwner := httptest.NewRequest("POST", testServer.URL, nil) validOwner.Form = map[string][]string{ "internal": {"on"}, "name": {"new.test.domain."}, "type": {"CNAME"}, "ttl": {"43000"}, "content": {"test.domain."}, } validOwnerRecorder := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) if validOwnerRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) } validOwnerNonInternalRecorder := httptest.NewRecorder() validOwner.Form["internal"] = []string{"off"} testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner) if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized { t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code) } invalidOwnerRecorder := httptest.NewRecorder() invalidOwner := validOwner invalidOwner.Form["internal"] = []string{"on"} invalidOwner.Form["name"] = []string{"new.invalid.domain."} testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner) if invalidOwnerRecorder.Code != http.StatusUnauthorized { t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code) } } func TestThatUserCanAddToPublicEndpoints(t *testing.T) { db, context, cleanup := setup() defer cleanup() addChannel := make(chan *database.DNSRecord) signallingDnsAdapter := &SignallingExternalDnsAdapter{ AddChannel: addChannel, } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) })) defer testServer.Close() responseRecorder := httptest.NewRecorder() req := httptest.NewRequest("POST", testServer.URL, nil) fmts := USER_OWNED_INTERNAL_FMT_DOMAINS for _, format := range fmts { name := fmt.Sprintf(format, context.User.Username) req.Form = map[string][]string{ "internal": {"off"}, "name": {name}, "type": {"CNAME"}, "ttl": {"43000"}, "content": {"test.domain."}, } testServer.Config.Handler.ServeHTTP(responseRecorder, req) if responseRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", responseRecorder.Code) } namedRecords, _ := database.FindDNSRecords(db, name, "CNAME") if len(namedRecords) == 0 { t.Errorf("saved record not found") } } } func TestThatExternalDnsSaves(t *testing.T) { db, context, cleanup := setup() defer cleanup() addChannel := make(chan *database.DNSRecord) signallingDnsAdapter := &SignallingExternalDnsAdapter{ AddChannel: addChannel, } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) })) defer testServer.Close() responseRecorder := httptest.NewRecorder() externalRequest := httptest.NewRequest("POST", testServer.URL, nil) name := "test." + context.User.Username externalRequest.Form = map[string][]string{ "internal": {"off"}, "name": {name}, "type": {"CNAME"}, "ttl": {"43000"}, "content": {"test.domain."}, } testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) if responseRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", responseRecorder.Code) } select { case res := <-addChannel: if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." { t.Errorf("received the wrong external record") } case <-time.After(100 * time.Millisecond): t.Errorf("timed out in waiting for external addition") } domainOwner := &database.DomainOwner{ UserID: context.User.ID, Domain: "test.domain.", } domainOwner, _ = database.SaveDomainOwner(db, domainOwner) internalRequest := externalRequest internalRequest.Form["internal"] = []string{"on"} internalRequest.Form["name"] = []string{"test.domain."} testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) if responseRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", responseRecorder.Code) } select { case _ = <-addChannel: t.Errorf("expected nothing in the add channel") case <-time.After(100 * time.Millisecond): } } func TestThatUserMustOwnRecordToRemove(t *testing.T) { db, context, cleanup := setup() defer cleanup() rmChannel := make(chan string) signallingDnsAdapter := &SignallingExternalDnsAdapter{ RmChannel: rmChannel, } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) })) defer testServer.Close() nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"} _, err := database.FindOrSaveBaseUser(db, nonOwnerUser) if err != nil { t.Error(err) } record := &database.DNSRecord{ ID: "1", Internal: false, Name: "test", Type: "CNAME", Content: "asdf", TTL: 1000, UserID: nonOwnerUser.ID, } _, err = database.SaveDNSRecord(db, record) if err != nil { t.Error(err) } nonOwnerRecorder := httptest.NewRecorder() nonOwner := httptest.NewRequest("POST", testServer.URL, nil) nonOwner.Form = map[string][]string{ "id": {record.ID}, } testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner) if nonOwnerRecorder.Code != http.StatusUnauthorized { t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code) } record.UserID = context.User.ID record.ID = "2" database.SaveDNSRecord(db, record) owner := nonOwner owner.Form["id"] = []string{"2"} ownerRecorder := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(ownerRecorder, owner) if ownerRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", ownerRecorder.Code) } } func TestThatExternalDnsRemoves(t *testing.T) { db, context, cleanup := setup() defer cleanup() record := &database.DNSRecord{ ID: "1", Internal: false, Name: "test", Type: "CNAME", Content: "asdf", TTL: 1000, UserID: context.User.ID, } database.SaveDNSRecord(db, record) rmChannel := make(chan string) signallingDnsAdapter := &SignallingExternalDnsAdapter{ RmChannel: rmChannel, } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) })) defer testServer.Close() externalResponseRecorder := httptest.NewRecorder() deleteRequest := httptest.NewRequest("POST", testServer.URL, nil) deleteRequest.Form = map[string][]string{ "id": {record.ID}, } testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest) if externalResponseRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", externalResponseRecorder.Code) } select { case res := <-rmChannel: if res != record.ID { t.Errorf("received the wrong external record") } case <-time.After(100 * time.Millisecond): t.Errorf("timed out in waiting for external addition") } record.Internal = true record.Name = "test.domain." database.SaveDNSRecord(db, record) domainOwner := &database.DomainOwner{ UserID: context.User.ID, Domain: "test.domain.", } database.SaveDomainOwner(db, domainOwner) internalResponseRecorder := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest) if internalResponseRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", internalResponseRecorder.Code) } select { case _ = <-rmChannel: t.Errorf("expected nothing in the rmchannel") case <-time.After(100 * time.Millisecond): } } func TestRecordCountCannotExceed(t *testing.T) { db, context, cleanup := setup() defer cleanup() record := &database.DNSRecord{ Internal: false, Name: context.User.Username, Type: "CNAME", Content: "asdf", TTL: 1000, UserID: context.User.ID, } for i := 1; i <= MAX_USER_RECORDS; i++ { record.ID = strconv.Itoa(i) record.Name = record.ID + "." + record.Name database.SaveDNSRecord(db, record) } addChannel := make(chan *database.DNSRecord) signallingDnsAdapter := &SignallingExternalDnsAdapter{ AddChannel: addChannel, } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) })) defer testServer.Close() req := httptest.NewRequest("POST", testServer.URL, nil) req.Form = map[string][]string{ "internal": {"off"}, "name": {record.Name}, "type": {record.Type}, "ttl": {"43000"}, "content": {record.Content}, } recorder := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(recorder, req) if recorder.Code != http.StatusTooManyRequests { t.Errorf("expected too many requests code return, got %d", recorder.Code) } } func TestInternalRecordAppendsTopLevelDot(t *testing.T) { db, context, cleanup := setup() defer cleanup() domainOwner := &database.DomainOwner{ UserID: context.User.ID, Domain: "test.internal.", } database.SaveDomainOwner(db, domainOwner) addChannel := make(chan *database.DNSRecord) signallingDnsAdapter := &SignallingExternalDnsAdapter{ AddChannel: addChannel, } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) })) defer testServer.Close() validOwner := httptest.NewRequest("POST", testServer.URL, nil) validOwner.Form = map[string][]string{ "internal": {"on"}, "name": {"test.internal"}, "type": {"CNAME"}, "ttl": {"43000"}, "content": {"asdf.internal"}, } validOwnerRecorder := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) if validOwnerRecorder.Code != http.StatusOK { t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) } recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME") recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME") if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 { t.Errorf("expected dot appended") } }