dns api tests
	
		
			
	
		
	
	
		
			
				
	
				continuous-integration/drone/pr Build is passing
				
					Details
				
			
		
	
				
					
				
			
				
	
				continuous-integration/drone/pr Build is passing
				
					Details
				
			
		
	This commit is contained in:
		
							parent
							
								
									f38e8719c2
								
							
						
					
					
						commit
						d9d39a01f2
					
				|  | @ -14,10 +14,6 @@ import ( | |||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| const MAX_USER_RECORDS = 65 | ||||
| 
 | ||||
| var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} | ||||
| 
 | ||||
| func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { | ||||
| 	ownedByUser := (user.ID == record.UserID) | ||||
| 	if !ownedByUser { | ||||
|  | @ -60,14 +56,14 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 		return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 			formErrors := types.FormError{ | ||||
| 				Errors: []string{}, | ||||
| 			} | ||||
| 
 | ||||
| 			internal := req.FormValue("internal") == "on" | ||||
| 			internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" | ||||
| 			name := req.FormValue("name") | ||||
| 			if internal && !strings.HasSuffix(name, ".") { | ||||
| 				name += "." | ||||
|  | @ -80,6 +76,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 			ttl := req.FormValue("ttl") | ||||
| 			ttlNum, err := strconv.Atoi(ttl) | ||||
| 			if err != nil { | ||||
| 				resp.WriteHeader(http.StatusBadRequest) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "invalid ttl") | ||||
| 			} | ||||
| 
 | ||||
|  | @ -89,7 +86,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 				resp.WriteHeader(http.StatusInternalServerError) | ||||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 			if dnsRecordCount >= MAX_USER_RECORDS { | ||||
| 			if dnsRecordCount >= maxUserRecords { | ||||
| 				resp.WriteHeader(http.StatusTooManyRequests) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "max records reached") | ||||
| 			} | ||||
| 
 | ||||
|  | @ -102,7 +100,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 				Internal: internal, | ||||
| 			} | ||||
| 
 | ||||
| 			if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { | ||||
| 			if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { | ||||
| 				resp.WriteHeader(http.StatusUnauthorized) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") | ||||
| 			} | ||||
| 
 | ||||
|  | @ -113,6 +112,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 					dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) | ||||
| 					if err != nil { | ||||
| 						log.Println(err) | ||||
| 						resp.WriteHeader(http.StatusInternalServerError) | ||||
| 						formErrors.Errors = append(formErrors.Errors, err.Error()) | ||||
| 					} | ||||
| 				} | ||||
|  | @ -127,14 +127,11 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 			} | ||||
| 
 | ||||
| 			if len(formErrors.Errors) == 0 { | ||||
| 				http.Redirect(resp, req, "/dns", http.StatusFound) | ||||
| 				return success(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			(*context.TemplateData)["FormError"] = &formErrors | ||||
| 			(*context.TemplateData)["RecordForm"] = dnsRecord | ||||
| 
 | ||||
| 			resp.WriteHeader(http.StatusBadRequest) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 	} | ||||
|  | @ -151,7 +148,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { | ||||
| 			if !(record.UserID == context.User.ID) { | ||||
| 				resp.WriteHeader(http.StatusUnauthorized) | ||||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
|  | @ -171,7 +168,6 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			http.Redirect(resp, req, "/dns", http.StatusFound) | ||||
| 			return success(context, req, resp) | ||||
| 		} | ||||
| 	} | ||||
|  |  | |||
|  | @ -2,18 +2,25 @@ package dns_test | |||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	//	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| const MAX_USER_RECORDS = 10 | ||||
| 
 | ||||
| var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} | ||||
| 
 | ||||
| func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		return success(context, req, resp) | ||||
|  | @ -26,10 +33,19 @@ func setup() (*sql.DB, *types.RequestContext, func()) { | |||
| 	testDb := database.MakeConn(&randomDb) | ||||
| 	database.Migrate(testDb) | ||||
| 
 | ||||
| 	user := &database.User{ | ||||
| 		ID:          "test", | ||||
| 		Username:    "test", | ||||
| 		Mail:        "test@test.com", | ||||
| 		DisplayName: "test", | ||||
| 	} | ||||
| 	database.FindOrSaveUser(testDb, user) | ||||
| 
 | ||||
| 	context := &types.RequestContext{ | ||||
| 		DBConn:       testDb, | ||||
| 		Args:         &args.Arguments{}, | ||||
| 		TemplateData: &(map[string]interface{}{}), | ||||
| 		User:         user, | ||||
| 	} | ||||
| 
 | ||||
| 	return testDb, context, func() { | ||||
|  | @ -38,14 +54,33 @@ func setup() (*sql.DB, *types.RequestContext, func()) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| type SignallingExternalDnsAdapter struct { | ||||
| 	AddChannel chan *database.DNSRecord | ||||
| 	RmChannel  chan string | ||||
| } | ||||
| 
 | ||||
| func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { | ||||
| 	id := utils.RandomId() | ||||
| 	go func() { adapter.AddChannel <- record }() | ||||
| 
 | ||||
| 	return id, nil | ||||
| } | ||||
| 
 | ||||
| func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { | ||||
| 	go func() { adapter.RmChannel <- id }() | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func TestThatOwnerCanPutRecordInDomain(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	_ = &database.User{ | ||||
| 		ID:       "test", | ||||
| 		Username: "test", | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.domain.", | ||||
| 	} | ||||
| 	domainOwner, _ = database.SaveDomainOwner(db, domainOwner) | ||||
| 
 | ||||
| 	records, err := database.GetUserDNSRecords(db, context.User.ID) | ||||
| 	if err != nil { | ||||
|  | @ -55,9 +90,353 @@ func TestThatOwnerCanPutRecordInDomain(t *testing.T) { | |||
| 		t.Errorf("expected no records, got records") | ||||
| 	} | ||||
| 
 | ||||
| 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		//		dns.CreateDNSRecordContinuation(context, r, w)(IdContinuation, IdContinuation)
 | ||||
| 	})) | ||||
| 	defer ts.Close() | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	validOwner := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	validOwner.Form = map[string][]string{ | ||||
| 		"internal": {"on"}, | ||||
| 		"name":     {"new.test.domain."}, | ||||
| 		"type":     {"CNAME"}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {"test.domain."}, | ||||
| 	} | ||||
| 
 | ||||
| 	validOwnerRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) | ||||
| 	if validOwnerRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	validOwnerNonInternalRecorder := httptest.NewRecorder() | ||||
| 	validOwner.Form["internal"] = []string{"off"} | ||||
| 	testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner) | ||||
| 	if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized { | ||||
| 		t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	invalidOwnerRecorder := httptest.NewRecorder() | ||||
| 	invalidOwner := validOwner | ||||
| 	invalidOwner.Form["internal"] = []string{"on"} | ||||
| 	invalidOwner.Form["name"] = []string{"new.invalid.domain."} | ||||
| 	testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner) | ||||
| 	if invalidOwnerRecorder.Code != http.StatusUnauthorized { | ||||
| 		t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatUserCanAddToPublicEndpoints(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	responseRecorder := httptest.NewRecorder() | ||||
| 	req := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	fmts := USER_OWNED_INTERNAL_FMT_DOMAINS | ||||
| 	for _, format := range fmts { | ||||
| 		name := fmt.Sprintf(format, context.User.Username) | ||||
| 
 | ||||
| 		req.Form = map[string][]string{ | ||||
| 			"internal": {"off"}, | ||||
| 			"name":     {name}, | ||||
| 			"type":     {"CNAME"}, | ||||
| 			"ttl":      {"43000"}, | ||||
| 			"content":  {"test.domain."}, | ||||
| 		} | ||||
| 
 | ||||
| 		testServer.Config.Handler.ServeHTTP(responseRecorder, req) | ||||
| 		if responseRecorder.Code != http.StatusOK { | ||||
| 			t.Errorf("expected valid return, got %d", responseRecorder.Code) | ||||
| 		} | ||||
| 
 | ||||
| 		namedRecords, _ := database.FindDNSRecords(db, name, "CNAME") | ||||
| 		if len(namedRecords) == 0 { | ||||
| 			t.Errorf("saved record not found") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatExternalDnsSaves(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	responseRecorder := httptest.NewRecorder() | ||||
| 	externalRequest := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 
 | ||||
| 	name := "test." + context.User.Username | ||||
| 	externalRequest.Form = map[string][]string{ | ||||
| 		"internal": {"off"}, | ||||
| 		"name":     {name}, | ||||
| 		"type":     {"CNAME"}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {"test.domain."}, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) | ||||
| 	if responseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", responseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case res := <-addChannel: | ||||
| 		if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." { | ||||
| 			t.Errorf("received the wrong external record") | ||||
| 		} | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 		t.Errorf("timed out in waiting for external addition") | ||||
| 	} | ||||
| 
 | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.domain.", | ||||
| 	} | ||||
| 	domainOwner, _ = database.SaveDomainOwner(db, domainOwner) | ||||
| 	internalRequest := externalRequest | ||||
| 	internalRequest.Form["internal"] = []string{"on"} | ||||
| 	internalRequest.Form["name"] = []string{"test.domain."} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) | ||||
| 	if responseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", responseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case _ = <-addChannel: | ||||
| 		t.Errorf("expected nothing in the add channel") | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatUserMustOwnRecordToRemove(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	rmChannel := make(chan string) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		RmChannel: rmChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"} | ||||
| 	_, err := database.FindOrSaveUser(db, nonOwnerUser) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	record := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		Internal: false, | ||||
| 		Name:     "test", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "asdf", | ||||
| 		TTL:      1000, | ||||
| 		UserID:   nonOwnerUser.ID, | ||||
| 	} | ||||
| 	_, err = database.SaveDNSRecord(db, record) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	nonOwnerRecorder := httptest.NewRecorder() | ||||
| 	nonOwner := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	nonOwner.Form = map[string][]string{ | ||||
| 		"id": {record.ID}, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner) | ||||
| 	if nonOwnerRecorder.Code != http.StatusUnauthorized { | ||||
| 		t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	record.UserID = context.User.ID | ||||
| 	record.ID = "2" | ||||
| 	database.SaveDNSRecord(db, record) | ||||
| 
 | ||||
| 	owner := nonOwner | ||||
| 	owner.Form["id"] = []string{"2"} | ||||
| 	ownerRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(ownerRecorder, owner) | ||||
| 	if ownerRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", ownerRecorder.Code) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatExternalDnsRemoves(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	record := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		Internal: false, | ||||
| 		Name:     "test", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "asdf", | ||||
| 		TTL:      1000, | ||||
| 		UserID:   context.User.ID, | ||||
| 	} | ||||
| 	database.SaveDNSRecord(db, record) | ||||
| 
 | ||||
| 	rmChannel := make(chan string) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		RmChannel: rmChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	externalResponseRecorder := httptest.NewRecorder() | ||||
| 	deleteRequest := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 
 | ||||
| 	deleteRequest.Form = map[string][]string{ | ||||
| 		"id": {record.ID}, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest) | ||||
| 	if externalResponseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", externalResponseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case res := <-rmChannel: | ||||
| 		if res != record.ID { | ||||
| 			t.Errorf("received the wrong external record") | ||||
| 		} | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 		t.Errorf("timed out in waiting for external addition") | ||||
| 	} | ||||
| 
 | ||||
| 	record.Internal = true | ||||
| 	record.Name = "test.domain." | ||||
| 	database.SaveDNSRecord(db, record) | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.domain.", | ||||
| 	} | ||||
| 	database.SaveDomainOwner(db, domainOwner) | ||||
| 
 | ||||
| 	internalResponseRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest) | ||||
| 	if internalResponseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", internalResponseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case _ = <-rmChannel: | ||||
| 		t.Errorf("expected nothing in the rmchannel") | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRecordCountCannotExceed(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	record := &database.DNSRecord{ | ||||
| 		Internal: false, | ||||
| 		Name:     context.User.Username, | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "asdf", | ||||
| 		TTL:      1000, | ||||
| 		UserID:   context.User.ID, | ||||
| 	} | ||||
| 
 | ||||
| 	for i := 1; i <= MAX_USER_RECORDS; i++ { | ||||
| 		record.ID = strconv.Itoa(i) | ||||
| 		record.Name = record.ID + "." + record.Name | ||||
| 		database.SaveDNSRecord(db, record) | ||||
| 	} | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	req := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	req.Form = map[string][]string{ | ||||
| 		"internal": {"off"}, | ||||
| 		"name":     {record.Name}, | ||||
| 		"type":     {record.Type}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {record.Content}, | ||||
| 	} | ||||
| 
 | ||||
| 	recorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(recorder, req) | ||||
| 	if recorder.Code != http.StatusTooManyRequests { | ||||
| 		t.Errorf("expected too many requests code return, got %d", recorder.Code) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestInternalRecordAppendsTopLevelDot(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.internal.", | ||||
| 	} | ||||
| 	database.SaveDomainOwner(db, domainOwner) | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	validOwner := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	validOwner.Form = map[string][]string{ | ||||
| 		"internal": {"on"}, | ||||
| 		"name":     {"test.internal"}, | ||||
| 		"type":     {"CNAME"}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {"asdf.internal"}, | ||||
| 	} | ||||
| 
 | ||||
| 	validOwnerRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) | ||||
| 	if validOwnerRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME") | ||||
| 	recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME") | ||||
| 
 | ||||
| 	if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 { | ||||
| 		t.Errorf("expected dot appended") | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -116,14 +116,16 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { | |||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	const MAX_USER_RECORDS = 100 | ||||
| 	var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} | ||||
| 	mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { | ||||
|  |  | |||
|  | @ -9,6 +9,12 @@ import ( | |||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type DomainOwner struct { | ||||
| 	UserID    string    `json:"user_id"` | ||||
| 	Domain    string    `json:"domain"` | ||||
| 	CreatedAt time.Time `json:"created_at"` | ||||
| } | ||||
| 
 | ||||
| type DNSRecord struct { | ||||
| 	ID        string    `json:"id"` | ||||
| 	UserID    string    `json:"user_id"` | ||||
|  | @ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) { | |||
| func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) { | ||||
| 	log.Println("saving dns record", record.ID) | ||||
| 
 | ||||
| 	record.CreatedAt = time.Now() | ||||
| 	if (record.CreatedAt == time.Time{}) { | ||||
| 		record.CreatedAt = time.Now() | ||||
| 	} | ||||
| 
 | ||||
| 	_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt) | ||||
| 
 | ||||
| 	if err != nil { | ||||
|  | @ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err | |||
| 
 | ||||
| 	return records, nil | ||||
| } | ||||
| 
 | ||||
| func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) { | ||||
| 	log.Println("saving domain owner", domainOwner.Domain) | ||||
| 
 | ||||
| 	domainOwner.CreatedAt = time.Now() | ||||
| 	_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return domainOwner, nil | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue