diff --git a/Dockerfile b/Dockerfile index a46f6c4..591423f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers EXPOSE 8080 -CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"] +CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"] diff --git a/args/args.go b/args/args.go index 40dd1af..f71e8e3 100644 --- a/args/args.go +++ b/args/args.go @@ -22,9 +22,8 @@ type Arguments struct { OauthConfig *oauth2.Config OauthUserInfoURI string - Dns bool - DnsRecursion []string - DnsPort int + Dns bool + DnsPort int CloudflareToken string CloudflareZone string @@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) { server := flag.Bool("server", false, "Run the server") dns := flag.Bool("dns", false, "Run DNS resolver") - dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers") dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver") flag.Parse() @@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) { Migrate: *migrate, Scheduler: *scheduler, Dns: *dns, - DnsRecursion: strings.Split(*dnsRecursion, ","), DnsPort: *dnsPort, OauthConfig: oauthConfig, diff --git a/dns/server.go b/dns/server.go index f5365e8..9b3e5e9 100644 --- a/dns/server.go +++ b/dns/server.go @@ -11,17 +11,13 @@ import ( const MAX_RECURSION = 10 -func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - if maxDepth == 0 { - return nil, fmt.Errorf("too much recursion") - } - +func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") if err != nil { return nil, err } - answers := []dns.RR{} + var answers []dns.RR for _, record := range internalCnames { cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) if err != nil { @@ -29,7 +25,10 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp } answers = append(answers, cname) - cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) + cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + if err != nil { + return nil, err + } answers = append(answers, cnameRecursive...) } @@ -43,37 +42,31 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp return nil, err } for _, record := range typeDnsRecords { - answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) if err != nil { return nil, err } answers = append(answers, answer) } + return answers, nil +} + +func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } + + answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + if err != nil { + return nil, err + } + if len(answers) > 0 { - // base case; we found the answer return answers, nil } - message := new(dns.Msg) - message.SetQuestion(dns.Fqdn(domain), qtype) - message.RecursionDesired = true - - client := new(dns.Client) - - i := 0 - in, _, err := client.Exchange(message, dnsResolvers[i]) - for err != nil { - i += 1 - if i == len(dnsResolvers) { - log.Println(err) - return nil, err - } - in, _, err = client.Exchange(message, dnsResolvers[i]) - } - - answers = append(answers, in.Answer...) - return answers, nil + return nil, fmt.Errorf("no records found for %s", domain) } type DnsHandler struct { @@ -87,7 +80,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg.Authoritative = true for _, question := range r.Question { - answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) + answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) if err != nil { fmt.Println(err) continue @@ -95,6 +88,10 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg.Answer = append(msg.Answer, answers...) } + if len(msg.Answer) == 0 { + msg.SetRcode(r, dns.RcodeNameError) + } + log.Println(msg.Answer) w.WriteMsg(msg) }