package hcdns import ( "database/sql" "fmt" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "github.com/miekg/dns" "log" ) const MAX_RECURSION = 15 type DnsHandler struct { DnsResolvers []string DbConn *sql.DB } func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) { client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(dns.Fqdn(domain), qtype) message.RecursionDesired = true if len(h.DnsResolvers) == 0 { return []dns.RR{}, nil } i := 0 in, _, err := client.Exchange(message, h.DnsResolvers[i]) for err != nil && i < len(h.DnsResolvers) { i++ in, _, err = client.Exchange(message, h.DnsResolvers[i]) } if err != nil { return nil, err } if len(in.Answer) == 0 { return nil, nil } return in.Answer, nil } func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool { for _, answer := range answers { if answer.Header().Name == domain && answer.Header().Rrtype == qtype { return true } } return false } func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME") if err != nil { return nil, true, err } authoritative := true var answers []dns.RR for _, record := range internalCnames { cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) if err != nil { log.Println(err) return nil, authoritative, err } answers = append(answers, cname) cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1) if err != nil { log.Println(err) return nil, authoritative, err } authoritative = authoritative && cnameAuth answers = append(answers, cnameRecursive...) } qtypeName := dns.TypeToString[qtype] records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName) if err != nil { return nil, authoritative, err } for _, record := range records { answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) if err != nil { return nil, authoritative, err } answers = append(answers, answer) } return answers, authoritative, nil } func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth) if maxDepth == 0 { return nil, false, fmt.Errorf("too much recursion") } answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth) if err != nil { return nil, false, err } if len(answers) > 0 { // base case - we got the answer return answers, authoritative, nil } externalAnswers, err := h.resolveExternal(domain, qtype) if err != nil { return nil, false, err } answers = append(answers, externalAnswers...) if resultSetFound(externalAnswers, domain, qtype) { return answers, false, nil } for _, answer := range externalAnswers { cname, ok := answer.(*dns.CNAME) if !ok { continue } cnameAnswers, _, err := h.resolveDNS(cname.Target, qtype, maxDepth-1) if err != nil { return nil, false, err } answers = append(answers, cnameAnswers...) } return answers, false, nil } func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg := &dns.Msg{} msg.SetReply(r) msg.Authoritative = false for _, question := range r.Question { answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION) msg.Authoritative = authoritative if err != nil { fmt.Println(err) msg.SetRcode(r, dns.RcodeServerFailure) w.WriteMsg(msg) return } msg.Answer = append(msg.Answer, answers...) } if len(msg.Answer) == 0 { msg.SetRcode(r, dns.RcodeNameError) } log.Println(msg.Answer) w.WriteMsg(msg) } func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { handler := &DnsHandler{ DbConn: dbConn, DnsResolvers: argv.DnsResolvers, } addr := fmt.Sprintf(":%d", argv.DnsPort) return &dns.Server{ Addr: addr, Net: "udp", Handler: handler, UDPSize: 65535, ReusePort: true, } }