POST record with id to update to fix cloudflare 500
	
		
			
	
		
	
	
		
			
				
	
				continuous-integration/drone/push Build is passing
				
					Details
				
			
		
	
				
					
				
			
				
	
				continuous-integration/drone/push Build is passing
				
					Details
				
			
		
	This commit is contained in:
		
							parent
							
								
									fca8f5d8ad
								
							
						
					
					
						commit
						dbd548d428
					
				|  | @ -21,15 +21,11 @@ type CloudflareExternalDNSAdapter struct { | |||
| 
 | ||||
| func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { | ||||
| 	url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId) | ||||
| 
 | ||||
| 	reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) | ||||
| 	payload := strings.NewReader(reqBody) | ||||
| 	payload := strings.NewReader(encodeDNSRecord(record)) | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("POST", url, payload) | ||||
| 
 | ||||
| 	req.Header.Add("Authorization", "Bearer "+adapter.APIToken) | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 
 | ||||
| 	res, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
|  | @ -53,6 +49,27 @@ func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DN | |||
| 	return result.ID, nil | ||||
| } | ||||
| 
 | ||||
| func (adapter *CloudflareExternalDNSAdapter) UpdateDNSRecord(record *database.DNSRecord) error { | ||||
| 	url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, record.ID) | ||||
| 	payload := strings.NewReader(encodeDNSRecord(record)) | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("PUT", url, payload) | ||||
| 	req.Header.Add("Authorization", "Bearer "+adapter.APIToken) | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 	res, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	defer res.Body.Close() | ||||
| 	body, _ := io.ReadAll(res.Body) | ||||
| 	if res.StatusCode != 200 { | ||||
| 		return fmt.Errorf("error updating dns record: %s", body) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error { | ||||
| 	url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id) | ||||
| 
 | ||||
|  | @ -74,3 +91,7 @@ func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error { | |||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func encodeDNSRecord(record *database.DNSRecord) string { | ||||
| 	return fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) | ||||
| } | ||||
|  |  | |||
|  | @ -4,5 +4,6 @@ import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | |||
| 
 | ||||
| type ExternalDNSAdapter interface { | ||||
| 	CreateDNSRecord(record *database.DNSRecord) (string, error) | ||||
| 	UpdateDNSRecord(record *database.DNSRecord) error | ||||
| 	DeleteDNSRecord(id string) error | ||||
| } | ||||
|  |  | |||
|  | @ -36,29 +36,48 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 		return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 			formErrors := types.BannerMessages{ | ||||
| 				Messages: []string{}, | ||||
| 			} | ||||
| 
 | ||||
| 			internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" | ||||
| 			name := req.FormValue("name") | ||||
| 			if internal && !strings.HasSuffix(name, ".") { | ||||
| 				name += "." | ||||
| 			dnsRecord := &database.DNSRecord{} | ||||
| 			id := req.FormValue("id") | ||||
| 			isNewRecord := id == "" | ||||
| 			if !isNewRecord { | ||||
| 				retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id) | ||||
| 				if err != nil { | ||||
| 					log.Println(err) | ||||
| 					resp.WriteHeader(http.StatusInternalServerError) | ||||
| 					formErrors.Messages = append(formErrors.Messages, "error getting record from id") | ||||
| 				} else { | ||||
| 					dnsRecord = retrievedDnsRecord | ||||
| 				} | ||||
| 			} else { | ||||
| 				dnsRecord.UserID = context.User.ID | ||||
| 			} | ||||
| 
 | ||||
| 			dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true" | ||||
| 
 | ||||
| 			dnsRecord.Name = req.FormValue("name") | ||||
| 			if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") { | ||||
| 				dnsRecord.Name += "." | ||||
| 			} | ||||
| 
 | ||||
| 			recordType := req.FormValue("type") | ||||
| 			recordType = strings.ToUpper(recordType) | ||||
| 			dnsRecord.Type = strings.ToUpper(recordType) | ||||
| 
 | ||||
| 			dnsRecord.Content = req.FormValue("content") | ||||
| 
 | ||||
| 			recordContent := req.FormValue("content") | ||||
| 			ttl := req.FormValue("ttl") | ||||
| 			ttlNum, err := strconv.Atoi(ttl) | ||||
| 			if err != nil { | ||||
| 				resp.WriteHeader(http.StatusBadRequest) | ||||
| 				formErrors.Messages = append(formErrors.Messages, "invalid ttl") | ||||
| 			} | ||||
| 			dnsRecord.TTL = ttlNum | ||||
| 
 | ||||
| 			dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) | ||||
| 			if err != nil { | ||||
|  | @ -71,27 +90,29 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max | |||
| 				formErrors.Messages = append(formErrors.Messages, "max records reached") | ||||
| 			} | ||||
| 
 | ||||
| 			dnsRecord := &database.DNSRecord{ | ||||
| 				UserID:   context.User.ID, | ||||
| 				Name:     name, | ||||
| 				Type:     recordType, | ||||
| 				Content:  recordContent, | ||||
| 				TTL:      ttlNum, | ||||
| 				Internal: internal, | ||||
| 			} | ||||
| 
 | ||||
| 			if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { | ||||
| 			if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { | ||||
| 				resp.WriteHeader(http.StatusUnauthorized) | ||||
| 				formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") | ||||
| 				formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") | ||||
| 			} | ||||
| 
 | ||||
| 			if len(formErrors.Messages) == 0 { | ||||
| 			if isNewRecord && len(formErrors.Messages) == 0 { | ||||
| 				if dnsRecord.Internal { | ||||
| 					dnsRecord.ID = utils.RandomId() | ||||
| 				} else { | ||||
| 					dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) | ||||
| 					dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord) | ||||
| 					if err != nil { | ||||
| 						log.Println(err) | ||||
| 						log.Println("error creating external dns record", err) | ||||
| 						resp.WriteHeader(http.StatusInternalServerError) | ||||
| 						formErrors.Messages = append(formErrors.Messages, err.Error()) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if !isNewRecord && len(formErrors.Messages) == 0 { | ||||
| 				if !dnsRecord.Internal { | ||||
| 					err = externalDnsAdapter.UpdateDNSRecord(dnsRecord) | ||||
| 					if err != nil { | ||||
| 						log.Println("error updating external dns record", err) | ||||
| 						resp.WriteHeader(http.StatusInternalServerError) | ||||
| 						formErrors.Messages = append(formErrors.Messages, err.Error()) | ||||
| 					} | ||||
|  | @ -108,20 +129,21 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max | |||
| 
 | ||||
| 			if len(formErrors.Messages) == 0 { | ||||
| 				formSuccess := types.BannerMessages{ | ||||
| 					Messages: []string{"record added."}, | ||||
| 					Messages: []string{"record saved."}, | ||||
| 				} | ||||
| 				(*context.TemplateData)["Success"] = formSuccess | ||||
| 				return success(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			(*context.TemplateData)[""] = &formErrors | ||||
| 			log.Println(formErrors.Messages) | ||||
| 			(*context.TemplateData)["Error"] = &formErrors | ||||
| 			(*context.TemplateData)["RecordForm"] = dnsRecord | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 		return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 			recordId := req.FormValue("id") | ||||
|  | @ -138,7 +160,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun | |||
| 			} | ||||
| 
 | ||||
| 			if !record.Internal { | ||||
| 				err = dnsAdapter.DeleteDNSRecord(recordId) | ||||
| 				err = externalDnsAdapter.DeleteDNSRecord(recordId) | ||||
| 				if err != nil { | ||||
| 					log.Println(err) | ||||
| 					resp.WriteHeader(http.StatusInternalServerError) | ||||
|  |  | |||
|  | @ -57,6 +57,7 @@ func setup() (*sql.DB, *types.RequestContext, func()) { | |||
| type SignallingExternalDnsAdapter struct { | ||||
| 	AddChannel chan *database.DNSRecord | ||||
| 	RmChannel  chan string | ||||
| 	UpdateChan chan *database.DNSRecord | ||||
| } | ||||
| 
 | ||||
| func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { | ||||
|  | @ -72,6 +73,12 @@ func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (adapter *SignallingExternalDnsAdapter) UpdateDNSRecord(record *database.DNSRecord) error { | ||||
| 	go func() { adapter.UpdateChan <- record }() | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func TestThatOwnerCanPutRecordInDomain(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
|  | @ -172,6 +179,73 @@ func TestThatUserCanAddToPublicEndpoints(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatUserCanUpdateExistingRecord(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	updateChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		UpdateChan: updateChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	responseRecorder := httptest.NewRecorder() | ||||
| 	nonexistantRecord := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 
 | ||||
| 	id := "1" | ||||
| 	name := "test." + context.User.Username | ||||
| 	nonexistantRecord.Form = map[string][]string{ | ||||
| 		"id":       {id}, | ||||
| 		"internal": {"off"}, | ||||
| 		"name":     {name}, | ||||
| 		"type":     {"CNAME"}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {"new.domain."}, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(responseRecorder, nonexistantRecord) | ||||
| 	if responseRecorder.Code != http.StatusInternalServerError { | ||||
| 		t.Errorf("expected internal server error return, got %d", responseRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	record := &database.DNSRecord{ | ||||
| 		ID:       id, | ||||
| 		Internal: false, | ||||
| 		Name:     name, | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "test.domain.", | ||||
| 		TTL:      43000, | ||||
| 		UserID:   context.User.ID, | ||||
| 	} | ||||
| 	_, err := database.SaveDNSRecord(db, record) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	existantRecord := nonexistantRecord | ||||
| 	existantRecordRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(existantRecordRecorder, existantRecord) | ||||
| 	if existantRecordRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", existantRecordRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case req := <-updateChannel: | ||||
| 		newRecord, err := database.GetDNSRecord(db, req.ID) | ||||
| 		if err != nil { | ||||
| 			t.Error(err) | ||||
| 		} | ||||
| 		if newRecord.Content != "new.domain." { | ||||
| 			t.Errorf("expected updated record, got %s", newRecord.Content) | ||||
| 		} | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 		t.Errorf("expected updated record channel") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatExternalDnsSaves(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
|  |  | |||
|  | @ -46,7 +46,7 @@ func renderTemplate(context *types.RequestContext, templateName string, showBase | |||
| func TemplateContinuation(path string, showBase bool) types.Continuation { | ||||
| 	return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 		return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 			html, err := renderTemplate(context, path, true) | ||||
| 			html, err := renderTemplate(context, path, showBase) | ||||
| 			if errors.Is(err, os.ErrNotExist) { | ||||
| 				resp.WriteHeader(404) | ||||
| 				html, err = renderTemplate(context, "404.html", true) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue