POST record with id to update
This commit is contained in:
parent
fca8f5d8ad
commit
483bf39444
|
@ -21,15 +21,11 @@ type CloudflareExternalDNSAdapter struct {
|
||||||
|
|
||||||
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
|
||||||
|
payload := strings.NewReader(encodeDNSRecord(record))
|
||||||
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
|
||||||
payload := strings.NewReader(reqBody)
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("POST", url, payload)
|
req, _ := http.NewRequest("POST", url, payload)
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -53,6 +49,27 @@ func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DN
|
||||||
return result.ID, nil
|
return result.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (adapter *CloudflareExternalDNSAdapter) UpdateDNSRecord(record *database.DNSRecord) error {
|
||||||
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, record.ID)
|
||||||
|
payload := strings.NewReader(encodeDNSRecord(record))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("PUT", url, payload)
|
||||||
|
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
body, _ := io.ReadAll(res.Body)
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
return fmt.Errorf("error updating dns record: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
|
func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
|
||||||
|
|
||||||
|
@ -74,3 +91,7 @@ func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func encodeDNSRecord(record *database.DNSRecord) string {
|
||||||
|
return fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
||||||
|
}
|
||||||
|
|
|
@ -4,5 +4,6 @@ import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
|
||||||
type ExternalDNSAdapter interface {
|
type ExternalDNSAdapter interface {
|
||||||
CreateDNSRecord(record *database.DNSRecord) (string, error)
|
CreateDNSRecord(record *database.DNSRecord) (string, error)
|
||||||
|
UpdateDNSRecord(record *database.DNSRecord) error
|
||||||
DeleteDNSRecord(id string) error
|
DeleteDNSRecord(id string) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,29 +36,48 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
formErrors := types.BannerMessages{
|
formErrors := types.BannerMessages{
|
||||||
Messages: []string{},
|
Messages: []string{},
|
||||||
}
|
}
|
||||||
|
|
||||||
internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
|
dnsRecord := &database.DNSRecord{}
|
||||||
name := req.FormValue("name")
|
id := req.FormValue("id")
|
||||||
if internal && !strings.HasSuffix(name, ".") {
|
isNewRecord := id == ""
|
||||||
name += "."
|
if !isNewRecord {
|
||||||
|
retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
formErrors.Messages = append(formErrors.Messages, "error getting record from id")
|
||||||
|
} else {
|
||||||
|
dnsRecord = retrievedDnsRecord
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dnsRecord.UserID = context.User.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
|
||||||
|
|
||||||
|
dnsRecord.Name = req.FormValue("name")
|
||||||
|
if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") {
|
||||||
|
dnsRecord.Name += "."
|
||||||
}
|
}
|
||||||
|
|
||||||
recordType := req.FormValue("type")
|
recordType := req.FormValue("type")
|
||||||
recordType = strings.ToUpper(recordType)
|
dnsRecord.Type = strings.ToUpper(recordType)
|
||||||
|
|
||||||
|
dnsRecord.Content = req.FormValue("content")
|
||||||
|
|
||||||
recordContent := req.FormValue("content")
|
|
||||||
ttl := req.FormValue("ttl")
|
ttl := req.FormValue("ttl")
|
||||||
ttlNum, err := strconv.Atoi(ttl)
|
ttlNum, err := strconv.Atoi(ttl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
formErrors.Messages = append(formErrors.Messages, "invalid ttl")
|
formErrors.Messages = append(formErrors.Messages, "invalid ttl")
|
||||||
}
|
}
|
||||||
|
dnsRecord.TTL = ttlNum
|
||||||
|
|
||||||
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -71,27 +90,29 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
|
||||||
formErrors.Messages = append(formErrors.Messages, "max records reached")
|
formErrors.Messages = append(formErrors.Messages, "max records reached")
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRecord := &database.DNSRecord{
|
if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
|
||||||
UserID: context.User.ID,
|
|
||||||
Name: name,
|
|
||||||
Type: recordType,
|
|
||||||
Content: recordContent,
|
|
||||||
TTL: ttlNum,
|
|
||||||
Internal: internal,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
|
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(formErrors.Messages) == 0 {
|
if isNewRecord && len(formErrors.Messages) == 0 {
|
||||||
if dnsRecord.Internal {
|
if dnsRecord.Internal {
|
||||||
dnsRecord.ID = utils.RandomId()
|
dnsRecord.ID = utils.RandomId()
|
||||||
} else {
|
} else {
|
||||||
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
|
dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println("error creating external dns record", err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
formErrors.Messages = append(formErrors.Messages, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isNewRecord && len(formErrors.Messages) == 0 {
|
||||||
|
if !dnsRecord.Internal {
|
||||||
|
err = externalDnsAdapter.UpdateDNSRecord(dnsRecord)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("error updating external dns record", err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
formErrors.Messages = append(formErrors.Messages, err.Error())
|
formErrors.Messages = append(formErrors.Messages, err.Error())
|
||||||
}
|
}
|
||||||
|
@ -108,20 +129,21 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
|
||||||
|
|
||||||
if len(formErrors.Messages) == 0 {
|
if len(formErrors.Messages) == 0 {
|
||||||
formSuccess := types.BannerMessages{
|
formSuccess := types.BannerMessages{
|
||||||
Messages: []string{"record added."},
|
Messages: []string{"record saved."},
|
||||||
}
|
}
|
||||||
(*context.TemplateData)["Success"] = formSuccess
|
(*context.TemplateData)["Success"] = formSuccess
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
(*context.TemplateData)[""] = &formErrors
|
log.Println(formErrors.Messages)
|
||||||
|
(*context.TemplateData)["Error"] = &formErrors
|
||||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
recordId := req.FormValue("id")
|
recordId := req.FormValue("id")
|
||||||
|
@ -138,7 +160,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
}
|
}
|
||||||
|
|
||||||
if !record.Internal {
|
if !record.Internal {
|
||||||
err = dnsAdapter.DeleteDNSRecord(recordId)
|
err = externalDnsAdapter.DeleteDNSRecord(recordId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
|
@ -57,6 +57,7 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
type SignallingExternalDnsAdapter struct {
|
type SignallingExternalDnsAdapter struct {
|
||||||
AddChannel chan *database.DNSRecord
|
AddChannel chan *database.DNSRecord
|
||||||
RmChannel chan string
|
RmChannel chan string
|
||||||
|
UpdateChan chan *database.DNSRecord
|
||||||
}
|
}
|
||||||
|
|
||||||
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||||
|
@ -72,6 +73,12 @@ func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (adapter *SignallingExternalDnsAdapter) UpdateDNSRecord(record *database.DNSRecord) error {
|
||||||
|
go func() { adapter.UpdateChan <- record }()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
||||||
db, context, cleanup := setup()
|
db, context, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
@ -172,6 +179,73 @@ func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestThatUserCanUpdateExistingRecord(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
updateChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
UpdateChan: updateChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
nonexistantRecord := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
|
||||||
|
id := "1"
|
||||||
|
name := "test." + context.User.Username
|
||||||
|
nonexistantRecord.Form = map[string][]string{
|
||||||
|
"id": {id},
|
||||||
|
"internal": {"off"},
|
||||||
|
"name": {name},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"new.domain."},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, nonexistantRecord)
|
||||||
|
if responseRecorder.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected internal server error return, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &database.DNSRecord{
|
||||||
|
ID: id,
|
||||||
|
Internal: false,
|
||||||
|
Name: name,
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "test.domain.",
|
||||||
|
TTL: 43000,
|
||||||
|
UserID: context.User.ID,
|
||||||
|
}
|
||||||
|
_, err := database.SaveDNSRecord(db, record)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
existantRecord := nonexistantRecord
|
||||||
|
existantRecordRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(existantRecordRecorder, existantRecord)
|
||||||
|
if existantRecordRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", existantRecordRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case req := <-updateChannel:
|
||||||
|
newRecord, err := database.GetDNSRecord(db, req.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if newRecord.Content != "new.domain." {
|
||||||
|
t.Errorf("expected updated record, got %s", newRecord.Content)
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("expected updated record channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestThatExternalDnsSaves(t *testing.T) {
|
func TestThatExternalDnsSaves(t *testing.T) {
|
||||||
db, context, cleanup := setup()
|
db, context, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
|
@ -46,7 +46,7 @@ func renderTemplate(context *types.RequestContext, templateName string, showBase
|
||||||
func TemplateContinuation(path string, showBase bool) types.Continuation {
|
func TemplateContinuation(path string, showBase bool) types.Continuation {
|
||||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
html, err := renderTemplate(context, path, true)
|
html, err := renderTemplate(context, path, showBase)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
resp.WriteHeader(404)
|
resp.WriteHeader(404)
|
||||||
html, err = renderTemplate(context, "404.html", true)
|
html, err = renderTemplate(context, "404.html", true)
|
||||||
|
|
Loading…
Reference in New Issue