testing | dont be recursive for external domains | finalize oauth #5
|
@ -14,10 +14,6 @@ import (
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MAX_USER_RECORDS = 65
|
|
||||||
|
|
||||||
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
|
||||||
|
|
||||||
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
|
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
|
||||||
ownedByUser := (user.ID == record.UserID)
|
ownedByUser := (user.ID == record.UserID)
|
||||||
if !ownedByUser {
|
if !ownedByUser {
|
||||||
|
@ -60,14 +56,14 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
formErrors := types.FormError{
|
formErrors := types.FormError{
|
||||||
Errors: []string{},
|
Errors: []string{},
|
||||||
}
|
}
|
||||||
|
|
||||||
internal := req.FormValue("internal") == "on"
|
internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
|
||||||
name := req.FormValue("name")
|
name := req.FormValue("name")
|
||||||
if internal && !strings.HasSuffix(name, ".") {
|
if internal && !strings.HasSuffix(name, ".") {
|
||||||
name += "."
|
name += "."
|
||||||
|
@ -80,6 +76,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
ttl := req.FormValue("ttl")
|
ttl := req.FormValue("ttl")
|
||||||
ttlNum, err := strconv.Atoi(ttl)
|
ttlNum, err := strconv.Atoi(ttl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +86,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
if dnsRecordCount >= MAX_USER_RECORDS {
|
if dnsRecordCount >= maxUserRecords {
|
||||||
|
resp.WriteHeader(http.StatusTooManyRequests)
|
||||||
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +100,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
Internal: internal,
|
Internal: internal,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) {
|
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
|
||||||
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +112,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
|
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
formErrors.Errors = append(formErrors.Errors, err.Error())
|
formErrors.Errors = append(formErrors.Errors, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -127,14 +127,11 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
if len(formErrors.Errors) == 0 {
|
||||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
(*context.TemplateData)["FormError"] = &formErrors
|
(*context.TemplateData)["FormError"] = &formErrors
|
||||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||||
|
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -151,7 +148,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) {
|
if !(record.UserID == context.User.ID) {
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
@ -171,7 +168,6 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,18 +2,25 @@ package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
// "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const MAX_USER_RECORDS = 10
|
||||||
|
|
||||||
|
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
||||||
|
|
||||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
|
@ -26,10 +33,19 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
testDb := database.MakeConn(&randomDb)
|
testDb := database.MakeConn(&randomDb)
|
||||||
database.Migrate(testDb)
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
user := &database.User{
|
||||||
|
ID: "test",
|
||||||
|
Username: "test",
|
||||||
|
Mail: "test@test.com",
|
||||||
|
DisplayName: "test",
|
||||||
|
}
|
||||||
|
database.FindOrSaveUser(testDb, user)
|
||||||
|
|
||||||
context := &types.RequestContext{
|
context := &types.RequestContext{
|
||||||
DBConn: testDb,
|
DBConn: testDb,
|
||||||
Args: &args.Arguments{},
|
Args: &args.Arguments{},
|
||||||
TemplateData: &(map[string]interface{}{}),
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
User: user,
|
||||||
}
|
}
|
||||||
|
|
||||||
return testDb, context, func() {
|
return testDb, context, func() {
|
||||||
|
@ -38,14 +54,33 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SignallingExternalDnsAdapter struct {
|
||||||
|
AddChannel chan *database.DNSRecord
|
||||||
|
RmChannel chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||||
|
id := utils.RandomId()
|
||||||
|
go func() { adapter.AddChannel <- record }()
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
|
||||||
|
go func() { adapter.RmChannel <- id }()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
||||||
db, context, cleanup := setup()
|
db, context, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
_ = &database.User{
|
domainOwner := &database.DomainOwner{
|
||||||
ID: "test",
|
UserID: context.User.ID,
|
||||||
Username: "test",
|
Domain: "test.domain.",
|
||||||
}
|
}
|
||||||
|
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
||||||
|
|
||||||
records, err := database.GetUserDNSRecords(db, context.User.ID)
|
records, err := database.GetUserDNSRecords(db, context.User.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -55,9 +90,353 @@ func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
||||||
t.Errorf("expected no records, got records")
|
t.Errorf("expected no records, got records")
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
addChannel := make(chan *database.DNSRecord)
|
||||||
// dns.CreateDNSRecordContinuation(context, r, w)(IdContinuation, IdContinuation)
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
}))
|
AddChannel: addChannel,
|
||||||
defer ts.Close()
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
validOwner.Form = map[string][]string{
|
||||||
|
"internal": {"on"},
|
||||||
|
"name": {"new.test.domain."},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"test.domain."},
|
||||||
|
}
|
||||||
|
|
||||||
|
validOwnerRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
||||||
|
if validOwnerRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
validOwnerNonInternalRecorder := httptest.NewRecorder()
|
||||||
|
validOwner.Form["internal"] = []string{"off"}
|
||||||
|
testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
|
||||||
|
if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidOwnerRecorder := httptest.NewRecorder()
|
||||||
|
invalidOwner := validOwner
|
||||||
|
invalidOwner.Form["internal"] = []string{"on"}
|
||||||
|
invalidOwner.Form["name"] = []string{"new.invalid.domain."}
|
||||||
|
testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
|
||||||
|
if invalidOwnerRecorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
|
||||||
|
for _, format := range fmts {
|
||||||
|
name := fmt.Sprintf(format, context.User.Username)
|
||||||
|
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"internal": {"off"},
|
||||||
|
"name": {name},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"test.domain."},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
||||||
|
if responseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
|
||||||
|
if len(namedRecords) == 0 {
|
||||||
|
t.Errorf("saved record not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatExternalDnsSaves(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
|
||||||
|
name := "test." + context.User.Username
|
||||||
|
externalRequest.Form = map[string][]string{
|
||||||
|
"internal": {"off"},
|
||||||
|
"name": {name},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"test.domain."},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
||||||
|
if responseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case res := <-addChannel:
|
||||||
|
if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
|
||||||
|
t.Errorf("received the wrong external record")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("timed out in waiting for external addition")
|
||||||
|
}
|
||||||
|
|
||||||
|
domainOwner := &database.DomainOwner{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Domain: "test.domain.",
|
||||||
|
}
|
||||||
|
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
||||||
|
internalRequest := externalRequest
|
||||||
|
internalRequest.Form["internal"] = []string{"on"}
|
||||||
|
internalRequest.Form["name"] = []string{"test.domain."}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
||||||
|
if responseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case _ = <-addChannel:
|
||||||
|
t.Errorf("expected nothing in the add channel")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatUserMustOwnRecordToRemove(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
rmChannel := make(chan string)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
RmChannel: rmChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
|
||||||
|
_, err := database.FindOrSaveUser(db, nonOwnerUser)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &database.DNSRecord{
|
||||||
|
ID: "1",
|
||||||
|
Internal: false,
|
||||||
|
Name: "test",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "asdf",
|
||||||
|
TTL: 1000,
|
||||||
|
UserID: nonOwnerUser.ID,
|
||||||
|
}
|
||||||
|
_, err = database.SaveDNSRecord(db, record)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonOwnerRecorder := httptest.NewRecorder()
|
||||||
|
nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
nonOwner.Form = map[string][]string{
|
||||||
|
"id": {record.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
|
||||||
|
if nonOwnerRecorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
record.UserID = context.User.ID
|
||||||
|
record.ID = "2"
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
|
||||||
|
owner := nonOwner
|
||||||
|
owner.Form["id"] = []string{"2"}
|
||||||
|
ownerRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
|
||||||
|
if ownerRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", ownerRecorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatExternalDnsRemoves(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
record := &database.DNSRecord{
|
||||||
|
ID: "1",
|
||||||
|
Internal: false,
|
||||||
|
Name: "test",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "asdf",
|
||||||
|
TTL: 1000,
|
||||||
|
UserID: context.User.ID,
|
||||||
|
}
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
|
||||||
|
rmChannel := make(chan string)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
RmChannel: rmChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
externalResponseRecorder := httptest.NewRecorder()
|
||||||
|
deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
|
||||||
|
deleteRequest.Form = map[string][]string{
|
||||||
|
"id": {record.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
|
||||||
|
if externalResponseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case res := <-rmChannel:
|
||||||
|
if res != record.ID {
|
||||||
|
t.Errorf("received the wrong external record")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("timed out in waiting for external addition")
|
||||||
|
}
|
||||||
|
|
||||||
|
record.Internal = true
|
||||||
|
record.Name = "test.domain."
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
domainOwner := &database.DomainOwner{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Domain: "test.domain.",
|
||||||
|
}
|
||||||
|
database.SaveDomainOwner(db, domainOwner)
|
||||||
|
|
||||||
|
internalResponseRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
|
||||||
|
if internalResponseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case _ = <-rmChannel:
|
||||||
|
t.Errorf("expected nothing in the rmchannel")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordCountCannotExceed(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
record := &database.DNSRecord{
|
||||||
|
Internal: false,
|
||||||
|
Name: context.User.Username,
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "asdf",
|
||||||
|
TTL: 1000,
|
||||||
|
UserID: context.User.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= MAX_USER_RECORDS; i++ {
|
||||||
|
record.ID = strconv.Itoa(i)
|
||||||
|
record.Name = record.ID + "." + record.Name
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"internal": {"off"},
|
||||||
|
"name": {record.Name},
|
||||||
|
"type": {record.Type},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {record.Content},
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(recorder, req)
|
||||||
|
if recorder.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected too many requests code return, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
domainOwner := &database.DomainOwner{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Domain: "test.internal.",
|
||||||
|
}
|
||||||
|
database.SaveDomainOwner(db, domainOwner)
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
validOwner.Form = map[string][]string{
|
||||||
|
"internal": {"on"},
|
||||||
|
"name": {"test.internal"},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"asdf.internal"},
|
||||||
|
}
|
||||||
|
|
||||||
|
validOwnerRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
||||||
|
if validOwnerRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
|
||||||
|
recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
|
||||||
|
|
||||||
|
if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
|
||||||
|
t.Errorf("expected dot appended")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,14 +116,16 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const MAX_USER_RECORDS = 100
|
||||||
|
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
||||||
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -9,6 +9,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type DomainOwner struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
type DNSRecord struct {
|
type DNSRecord struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
@ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) {
|
||||||
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
|
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
|
||||||
log.Println("saving dns record", record.ID)
|
log.Println("saving dns record", record.ID)
|
||||||
|
|
||||||
|
if (record.CreatedAt == time.Time{}) {
|
||||||
record.CreatedAt = time.Now()
|
record.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
|
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err
|
||||||
|
|
||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) {
|
||||||
|
log.Println("saving domain owner", domainOwner.Domain)
|
||||||
|
|
||||||
|
domainOwner.CreatedAt = time.Now()
|
||||||
|
_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return domainOwner, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue