hatecomputers.club/api/dns/dns_test.go

443 lines
13 KiB
Go

package dns_test
import (
"database/sql"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
const MAX_USER_RECORDS = 10
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
return success(context, req, resp)
}
}
func setup() (*sql.DB, *types.RequestContext, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
user := &database.User{
ID: "test",
Username: "test",
Mail: "test@test.com",
DisplayName: "test",
}
database.FindOrSaveUser(testDb, user)
context := &types.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
User: user,
}
return testDb, context, func() {
testDb.Close()
os.Remove(randomDb)
}
}
type SignallingExternalDnsAdapter struct {
AddChannel chan *database.DNSRecord
RmChannel chan string
}
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
id := utils.RandomId()
go func() { adapter.AddChannel <- record }()
return id, nil
}
func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
go func() { adapter.RmChannel <- id }()
return nil
}
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.domain.",
}
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
records, err := database.GetUserDNSRecords(db, context.User.ID)
if err != nil {
t.Fatal(err)
}
if len(records) > 0 {
t.Errorf("expected no records, got records")
}
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
validOwner.Form = map[string][]string{
"internal": {"on"},
"name": {"new.test.domain."},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"test.domain."},
}
validOwnerRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
if validOwnerRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
}
validOwnerNonInternalRecorder := httptest.NewRecorder()
validOwner.Form["internal"] = []string{"off"}
testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
}
invalidOwnerRecorder := httptest.NewRecorder()
invalidOwner := validOwner
invalidOwner.Form["internal"] = []string{"on"}
invalidOwner.Form["name"] = []string{"new.invalid.domain."}
testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
if invalidOwnerRecorder.Code != http.StatusUnauthorized {
t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
}
}
func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
responseRecorder := httptest.NewRecorder()
req := httptest.NewRequest("POST", testServer.URL, nil)
fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
for _, format := range fmts {
name := fmt.Sprintf(format, context.User.Username)
req.Form = map[string][]string{
"internal": {"off"},
"name": {name},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"test.domain."},
}
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
if responseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", responseRecorder.Code)
}
namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
if len(namedRecords) == 0 {
t.Errorf("saved record not found")
}
}
}
func TestThatExternalDnsSaves(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
responseRecorder := httptest.NewRecorder()
externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
name := "test." + context.User.Username
externalRequest.Form = map[string][]string{
"internal": {"off"},
"name": {name},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"test.domain."},
}
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
if responseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", responseRecorder.Code)
}
select {
case res := <-addChannel:
if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
t.Errorf("received the wrong external record")
}
case <-time.After(100 * time.Millisecond):
t.Errorf("timed out in waiting for external addition")
}
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.domain.",
}
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
internalRequest := externalRequest
internalRequest.Form["internal"] = []string{"on"}
internalRequest.Form["name"] = []string{"test.domain."}
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
if responseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", responseRecorder.Code)
}
select {
case _ = <-addChannel:
t.Errorf("expected nothing in the add channel")
case <-time.After(100 * time.Millisecond):
}
}
func TestThatUserMustOwnRecordToRemove(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
rmChannel := make(chan string)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
RmChannel: rmChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
_, err := database.FindOrSaveUser(db, nonOwnerUser)
if err != nil {
t.Error(err)
}
record := &database.DNSRecord{
ID: "1",
Internal: false,
Name: "test",
Type: "CNAME",
Content: "asdf",
TTL: 1000,
UserID: nonOwnerUser.ID,
}
_, err = database.SaveDNSRecord(db, record)
if err != nil {
t.Error(err)
}
nonOwnerRecorder := httptest.NewRecorder()
nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
nonOwner.Form = map[string][]string{
"id": {record.ID},
}
testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
if nonOwnerRecorder.Code != http.StatusUnauthorized {
t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
}
record.UserID = context.User.ID
record.ID = "2"
database.SaveDNSRecord(db, record)
owner := nonOwner
owner.Form["id"] = []string{"2"}
ownerRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
if ownerRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", ownerRecorder.Code)
}
}
func TestThatExternalDnsRemoves(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
record := &database.DNSRecord{
ID: "1",
Internal: false,
Name: "test",
Type: "CNAME",
Content: "asdf",
TTL: 1000,
UserID: context.User.ID,
}
database.SaveDNSRecord(db, record)
rmChannel := make(chan string)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
RmChannel: rmChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
externalResponseRecorder := httptest.NewRecorder()
deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
deleteRequest.Form = map[string][]string{
"id": {record.ID},
}
testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
if externalResponseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
}
select {
case res := <-rmChannel:
if res != record.ID {
t.Errorf("received the wrong external record")
}
case <-time.After(100 * time.Millisecond):
t.Errorf("timed out in waiting for external addition")
}
record.Internal = true
record.Name = "test.domain."
database.SaveDNSRecord(db, record)
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.domain.",
}
database.SaveDomainOwner(db, domainOwner)
internalResponseRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
if internalResponseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
}
select {
case _ = <-rmChannel:
t.Errorf("expected nothing in the rmchannel")
case <-time.After(100 * time.Millisecond):
}
}
func TestRecordCountCannotExceed(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
record := &database.DNSRecord{
Internal: false,
Name: context.User.Username,
Type: "CNAME",
Content: "asdf",
TTL: 1000,
UserID: context.User.ID,
}
for i := 1; i <= MAX_USER_RECORDS; i++ {
record.ID = strconv.Itoa(i)
record.Name = record.ID + "." + record.Name
database.SaveDNSRecord(db, record)
}
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
req := httptest.NewRequest("POST", testServer.URL, nil)
req.Form = map[string][]string{
"internal": {"off"},
"name": {record.Name},
"type": {record.Type},
"ttl": {"43000"},
"content": {record.Content},
}
recorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusTooManyRequests {
t.Errorf("expected too many requests code return, got %d", recorder.Code)
}
}
func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.internal.",
}
database.SaveDomainOwner(db, domainOwner)
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
validOwner.Form = map[string][]string{
"internal": {"on"},
"name": {"test.internal"},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"asdf.internal"},
}
validOwnerRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
if validOwnerRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
}
recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
t.Errorf("expected dot appended")
}
}