package hcdns

import (
	"database/sql"
	"fmt"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
	"github.com/miekg/dns"
	"log"
)

const MAX_RECURSION = 10

type DnsHandler struct {
	DnsResolvers []string
	DbConn       *sql.DB
}

func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) {
	client := &dns.Client{}
	message := &dns.Msg{}
	message.SetQuestion(dns.Fqdn(domain), qtype)
	message.RecursionDesired = true

	if len(h.DnsResolvers) == 0 {
		return []dns.RR{}, nil
	}

	i := 0
	in, _, err := client.Exchange(message, h.DnsResolvers[i])
	for err != nil && i < len(h.DnsResolvers) {
		in, _, err = client.Exchange(message, h.DnsResolvers[i])
		i++
	}

	if err != nil {
		return nil, err
	}

	if len(in.Answer) == 0 {
		return nil, nil
	}

	return in.Answer, nil
}

func resultSetFound(answers []dns.RR, qtype uint16) bool {
	for _, answer := range answers {
		if answer.Header().Rrtype == qtype {
			return true
		}
	}
	return false
}

func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
	internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME")
	if err != nil {
		return nil, true, err
	}

	authoritative := true
	var answers []dns.RR
	for _, record := range internalCnames {
		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
		if err != nil {
			log.Println(err)
			return nil, authoritative, err
		}
		answers = append(answers, cname)

		cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1)
		authoritative = authoritative && cnameAuth
		if err != nil {
			log.Println(err)
			return nil, authoritative, err
		}

		answers = append(answers, cnameRecursive...)
	}

	qtypeName := dns.TypeToString[qtype]
	records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName)
	if err != nil {
		return nil, authoritative, err
	}

	for _, record := range records {
		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
		if err != nil {
			return nil, authoritative, err
		}
		answers = append(answers, answer)
	}

	return answers, authoritative, nil
}

func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
	log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth)
	if maxDepth == 0 {
		return nil, false, fmt.Errorf("too much recursion")
	}

	answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth)
	if err != nil {
		return nil, false, err
	}

	if len(answers) > 0 { // base case - we got the answer
		return answers, authoritative, nil
	}

	externalAnswers, err := h.resolveExternal(domain, qtype)
	if err != nil {
		return nil, false, err
	}

	answers = append(answers, externalAnswers...)
	if resultSetFound(answers, qtype) {
		return answers, false, nil
	}

	for _, answer := range externalAnswers {
		cname, cnameCastErr := answer.(*dns.CNAME)
		if !cnameCastErr {
			continue
		}

		cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
		authoritative = authoritative && cnameAuth
		if err != nil {
			return nil, false, err
		}
		answers = append(answers, cnameAnswers...)
	}

	authoritative = authoritative && len(externalAnswers) == 0
	return answers, authoritative, nil
}

func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
	msg := &dns.Msg{}
	msg.SetReply(r)
	msg.Authoritative = false

	for _, question := range r.Question {
		answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION)
		msg.Authoritative = authoritative
		if err != nil {
			fmt.Println(err)
			msg.SetRcode(r, dns.RcodeServerFailure)
			w.WriteMsg(msg)
			return
		}
		msg.Answer = append(msg.Answer, answers...)
	}

	if len(msg.Answer) == 0 {
		msg.SetRcode(r, dns.RcodeNameError)
	}

	log.Println(msg.Answer)
	w.WriteMsg(msg)
}

func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
	handler := &DnsHandler{
		DbConn:       dbConn,
		DnsResolvers: argv.DnsResolvers,
	}
	addr := fmt.Sprintf(":%d", argv.DnsPort)

	return &dns.Server{
		Addr:      addr,
		Net:       "udp",
		Handler:   handler,
		UDPSize:   65535,
		ReusePort: true,
	}
}