package dns import ( "database/sql" "fmt" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "github.com/miekg/dns" "log" ) const MAX_RECURSION = 10 func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { if maxDepth == 0 { return nil, fmt.Errorf("too much recursion") } internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") if err != nil { return nil, err } answers := []dns.RR{} for _, record := range internalCnames { cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) if err != nil { return nil, err } answers = append(answers, cname) cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) answers = append(answers, cnameRecursive...) } qtypeName := dns.TypeToString[qtype] if qtypeName == "" { return nil, fmt.Errorf("invalid query type %d", qtype) } typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) if err != nil { return nil, err } for _, record := range typeDnsRecords { answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) if err != nil { return nil, err } answers = append(answers, answer) } if len(answers) > 0 { // base case; we found the answer return answers, nil } message := new(dns.Msg) message.SetQuestion(dns.Fqdn(domain), qtype) message.RecursionDesired = true client := new(dns.Client) i := 0 in, _, err := client.Exchange(message, dnsResolvers[i]) for err != nil { i += 1 if i == len(dnsResolvers) { log.Println(err) return nil, err } in, _, err = client.Exchange(message, dnsResolvers[i]) } answers = append(answers, in.Answer...) return answers, nil } type DnsHandler struct { DnsResolvers []string DbConn *sql.DB } func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg := new(dns.Msg) msg.SetReply(r) msg.Authoritative = true for _, question := range r.Question { answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) if err != nil { fmt.Println(err) continue } msg.Answer = append(msg.Answer, answers...) } log.Println(msg.Answer) w.WriteMsg(msg) } func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { handler := &DnsHandler{ DnsResolvers: argv.DnsRecursion, DbConn: dbConn, } addr := fmt.Sprintf(":%d", argv.DnsPort) return &dns.Server{ Addr: addr, Net: "udp", Handler: handler, UDPSize: 65535, ReusePort: true, } }