package dns_test

import (
	"database/sql"
	"fmt"
	"net/http"
	"net/http/httptest"
	"os"
	"strconv"
	"testing"
	"time"

	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)

const MAX_USER_RECORDS = 10

var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}

func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
		return success(context, req, resp)
	}
}

func setup() (*sql.DB, *types.RequestContext, func()) {
	randomDb := utils.RandomId()

	testDb := database.MakeConn(&randomDb)
	database.Migrate(testDb)

	user := &database.User{
		ID:          "test",
		Username:    "test",
		Mail:        "test@test.com",
		DisplayName: "test",
	}
	database.FindOrSaveBaseUser(testDb, user)

	context := &types.RequestContext{
		DBConn:       testDb,
		Args:         &args.Arguments{},
		TemplateData: &(map[string]interface{}{}),
		User:         user,
	}

	return testDb, context, func() {
		testDb.Close()
		os.Remove(randomDb)
	}
}

type SignallingExternalDnsAdapter struct {
	AddChannel chan *database.DNSRecord
	RmChannel  chan string
	UpdateChan chan *database.DNSRecord
}

func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
	id := utils.RandomId()
	go func() { adapter.AddChannel <- record }()

	return id, nil
}

func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
	go func() { adapter.RmChannel <- id }()

	return nil
}

func (adapter *SignallingExternalDnsAdapter) UpdateDNSRecord(record *database.DNSRecord) error {
	go func() { adapter.UpdateChan <- record }()

	return nil
}

func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	domainOwner := &database.DomainOwner{
		UserID: context.User.ID,
		Domain: "test.domain.",
	}
	domainOwner, _ = database.SaveDomainOwner(db, domainOwner)

	records, err := database.GetUserDNSRecords(db, context.User.ID)
	if err != nil {
		t.Fatal(err)
	}
	if len(records) > 0 {
		t.Errorf("expected no records, got records")
	}

	addChannel := make(chan *database.DNSRecord)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		AddChannel: addChannel,
	}

	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	validOwner := httptest.NewRequest("POST", testServer.URL, nil)
	validOwner.Form = map[string][]string{
		"internal": {"on"},
		"name":     {"new.test.domain."},
		"type":     {"CNAME"},
		"ttl":      {"43000"},
		"content":  {"test.domain."},
	}

	validOwnerRecorder := httptest.NewRecorder()
	testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
	if validOwnerRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
	}

	validOwnerNonInternalRecorder := httptest.NewRecorder()
	validOwner.Form["internal"] = []string{"off"}
	testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
	if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
		t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
	}

	invalidOwnerRecorder := httptest.NewRecorder()
	invalidOwner := validOwner
	invalidOwner.Form["internal"] = []string{"on"}
	invalidOwner.Form["name"] = []string{"new.invalid.domain."}
	testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
	if invalidOwnerRecorder.Code != http.StatusUnauthorized {
		t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
	}
}

func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	addChannel := make(chan *database.DNSRecord)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		AddChannel: addChannel,
	}

	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	responseRecorder := httptest.NewRecorder()
	req := httptest.NewRequest("POST", testServer.URL, nil)
	fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
	for _, format := range fmts {
		name := fmt.Sprintf(format, context.User.Username)

		req.Form = map[string][]string{
			"internal": {"off"},
			"name":     {name},
			"type":     {"CNAME"},
			"ttl":      {"43000"},
			"content":  {"test.domain."},
		}

		testServer.Config.Handler.ServeHTTP(responseRecorder, req)
		if responseRecorder.Code != http.StatusOK {
			t.Errorf("expected valid return, got %d", responseRecorder.Code)
		}

		namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
		if len(namedRecords) == 0 {
			t.Errorf("saved record not found")
		}
	}
}

func TestThatUserCanUpdateExistingRecord(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	updateChannel := make(chan *database.DNSRecord)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		UpdateChan: updateChannel,
	}

	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	responseRecorder := httptest.NewRecorder()
	nonexistantRecord := httptest.NewRequest("POST", testServer.URL, nil)

	id := "1"
	name := "test." + context.User.Username
	nonexistantRecord.Form = map[string][]string{
		"id":       {id},
		"internal": {"off"},
		"name":     {name},
		"type":     {"CNAME"},
		"ttl":      {"43000"},
		"content":  {"new.domain."},
	}

	testServer.Config.Handler.ServeHTTP(responseRecorder, nonexistantRecord)
	if responseRecorder.Code != http.StatusInternalServerError {
		t.Errorf("expected internal server error return, got %d", responseRecorder.Code)
	}

	record := &database.DNSRecord{
		ID:       id,
		Internal: false,
		Name:     name,
		Type:     "CNAME",
		Content:  "test.domain.",
		TTL:      43000,
		UserID:   context.User.ID,
	}
	_, err := database.SaveDNSRecord(db, record)
	if err != nil {
		t.Error(err)
	}

	existantRecord := nonexistantRecord
	existantRecordRecorder := httptest.NewRecorder()
	testServer.Config.Handler.ServeHTTP(existantRecordRecorder, existantRecord)
	if existantRecordRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", existantRecordRecorder.Code)
	}
	select {
	case req := <-updateChannel:
		newRecord, err := database.GetDNSRecord(db, req.ID)
		if err != nil {
			t.Error(err)
		}
		if newRecord.Content != "new.domain." {
			t.Errorf("expected updated record, got %s", newRecord.Content)
		}
	case <-time.After(100 * time.Millisecond):
		t.Errorf("expected updated record channel")
	}
}

func TestThatExternalDnsSaves(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	addChannel := make(chan *database.DNSRecord)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		AddChannel: addChannel,
	}

	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	responseRecorder := httptest.NewRecorder()
	externalRequest := httptest.NewRequest("POST", testServer.URL, nil)

	name := "test." + context.User.Username
	externalRequest.Form = map[string][]string{
		"internal": {"off"},
		"name":     {name},
		"type":     {"CNAME"},
		"ttl":      {"43000"},
		"content":  {"test.domain."},
	}

	testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
	if responseRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", responseRecorder.Code)
	}
	select {
	case res := <-addChannel:
		if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
			t.Errorf("received the wrong external record")
		}
	case <-time.After(100 * time.Millisecond):
		t.Errorf("timed out in waiting for external addition")
	}

	domainOwner := &database.DomainOwner{
		UserID: context.User.ID,
		Domain: "test.domain.",
	}
	domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
	internalRequest := externalRequest
	internalRequest.Form["internal"] = []string{"on"}
	internalRequest.Form["name"] = []string{"test.domain."}

	testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
	if responseRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", responseRecorder.Code)
	}
	select {
	case _ = <-addChannel:
		t.Errorf("expected nothing in the add channel")
	case <-time.After(100 * time.Millisecond):
	}
}

func TestThatUserMustOwnRecordToRemove(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	rmChannel := make(chan string)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		RmChannel: rmChannel,
	}

	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
	_, err := database.FindOrSaveBaseUser(db, nonOwnerUser)
	if err != nil {
		t.Error(err)
	}

	record := &database.DNSRecord{
		ID:       "1",
		Internal: false,
		Name:     "test",
		Type:     "CNAME",
		Content:  "asdf",
		TTL:      1000,
		UserID:   nonOwnerUser.ID,
	}
	_, err = database.SaveDNSRecord(db, record)
	if err != nil {
		t.Error(err)
	}

	nonOwnerRecorder := httptest.NewRecorder()
	nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
	nonOwner.Form = map[string][]string{
		"id": {record.ID},
	}

	testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
	if nonOwnerRecorder.Code != http.StatusUnauthorized {
		t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
	}

	record.UserID = context.User.ID
	record.ID = "2"
	database.SaveDNSRecord(db, record)

	owner := nonOwner
	owner.Form["id"] = []string{"2"}
	ownerRecorder := httptest.NewRecorder()
	testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
	if ownerRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", ownerRecorder.Code)
	}
}

func TestThatExternalDnsRemoves(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	record := &database.DNSRecord{
		ID:       "1",
		Internal: false,
		Name:     "test",
		Type:     "CNAME",
		Content:  "asdf",
		TTL:      1000,
		UserID:   context.User.ID,
	}
	database.SaveDNSRecord(db, record)

	rmChannel := make(chan string)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		RmChannel: rmChannel,
	}

	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	externalResponseRecorder := httptest.NewRecorder()
	deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)

	deleteRequest.Form = map[string][]string{
		"id": {record.ID},
	}

	testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
	if externalResponseRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
	}
	select {
	case res := <-rmChannel:
		if res != record.ID {
			t.Errorf("received the wrong external record")
		}
	case <-time.After(100 * time.Millisecond):
		t.Errorf("timed out in waiting for external addition")
	}

	record.Internal = true
	record.Name = "test.domain."
	database.SaveDNSRecord(db, record)
	domainOwner := &database.DomainOwner{
		UserID: context.User.ID,
		Domain: "test.domain.",
	}
	database.SaveDomainOwner(db, domainOwner)

	internalResponseRecorder := httptest.NewRecorder()
	testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
	if internalResponseRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
	}
	select {
	case _ = <-rmChannel:
		t.Errorf("expected nothing in the rmchannel")
	case <-time.After(100 * time.Millisecond):
	}
}

func TestRecordCountCannotExceed(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	record := &database.DNSRecord{
		Internal: false,
		Name:     context.User.Username,
		Type:     "CNAME",
		Content:  "asdf",
		TTL:      1000,
		UserID:   context.User.ID,
	}

	for i := 1; i <= MAX_USER_RECORDS; i++ {
		record.ID = strconv.Itoa(i)
		record.Name = record.ID + "." + record.Name
		database.SaveDNSRecord(db, record)
	}

	addChannel := make(chan *database.DNSRecord)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		AddChannel: addChannel,
	}
	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	req := httptest.NewRequest("POST", testServer.URL, nil)
	req.Form = map[string][]string{
		"internal": {"off"},
		"name":     {record.Name},
		"type":     {record.Type},
		"ttl":      {"43000"},
		"content":  {record.Content},
	}

	recorder := httptest.NewRecorder()
	testServer.Config.Handler.ServeHTTP(recorder, req)
	if recorder.Code != http.StatusTooManyRequests {
		t.Errorf("expected too many requests code return, got %d", recorder.Code)
	}
}

func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
	db, context, cleanup := setup()
	defer cleanup()

	domainOwner := &database.DomainOwner{
		UserID: context.User.ID,
		Domain: "test.internal.",
	}
	database.SaveDomainOwner(db, domainOwner)

	addChannel := make(chan *database.DNSRecord)
	signallingDnsAdapter := &SignallingExternalDnsAdapter{
		AddChannel: addChannel,
	}
	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
	}))
	defer testServer.Close()

	validOwner := httptest.NewRequest("POST", testServer.URL, nil)
	validOwner.Form = map[string][]string{
		"internal": {"on"},
		"name":     {"test.internal"},
		"type":     {"CNAME"},
		"ttl":      {"43000"},
		"content":  {"asdf.internal"},
	}

	validOwnerRecorder := httptest.NewRecorder()
	testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
	if validOwnerRecorder.Code != http.StatusOK {
		t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
	}

	recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
	recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")

	if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
		t.Errorf("expected dot appended")
	}
}