diff --git a/Dockerfile b/Dockerfile index 591423f..82f411a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers EXPOSE 8080 -CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"] +CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-resolvers", "1.1.1.1,1.0.0.1"] diff --git a/args/args.go b/args/args.go index f71e8e3..8465fc8 100644 --- a/args/args.go +++ b/args/args.go @@ -22,8 +22,9 @@ type Arguments struct { OauthConfig *oauth2.Config OauthUserInfoURI string - Dns bool - DnsPort int + DnsResolvers []string + Dns bool + DnsPort int CloudflareToken string CloudflareZone string @@ -36,6 +37,7 @@ func GetArgs() (*Arguments, error) { databasePath := flag.String("database-path", "./hatecomputers.db", "Path to the SQLite database") templatePath := flag.String("template-path", "./templates", "Path to the template directory") staticPath := flag.String("static-path", "./static", "Path to the static directory") + dnsResolvers := flag.String("dns-resolvers", "1.1.1.1,1.0.0.1", "Comma-separated list of DNS resolvers") scheduler := flag.Bool("scheduler", false, "Run scheduled jobs via cron") migrate := flag.Bool("migrate", false, "Run the migrations") @@ -101,8 +103,10 @@ func GetArgs() (*Arguments, error) { Server: *server, Migrate: *migrate, Scheduler: *scheduler, - Dns: *dns, - DnsPort: *dnsPort, + + Dns: *dns, + DnsPort: *dnsPort, + DnsResolvers: strings.Split(*dnsResolvers, ","), OauthConfig: oauthConfig, OauthUserInfoURI: oauthUserInfoURI, diff --git a/hcdns/server.go b/hcdns/server.go index ce7894b..2e110e8 100644 --- a/hcdns/server.go +++ b/hcdns/server.go @@ -11,74 +11,142 @@ import ( const MAX_RECURSION = 15 -func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") - if err != nil { - return nil, err - } - - var answers []dns.RR - for _, record := range internalCnames { - cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) - if err != nil { - log.Println(err) - return nil, err - } - answers = append(answers, cname) - - cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) - if err != nil { - log.Println(err) - return nil, err - } - answers = append(answers, cnameRecursive...) - } - - qtypeName := dns.TypeToString[qtype] - if qtypeName == "" { - return nil, fmt.Errorf("invalid query type %d", qtype) - } - - typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) - if err != nil { - return nil, err - } - for _, record := range typeDnsRecords { - answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) - if err != nil { - return nil, err - } - answers = append(answers, answer) - } - - return answers, nil -} - -func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - if maxDepth == 0 { - return nil, fmt.Errorf("too much recursion") - } - - answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) - if err != nil { - return nil, err - } - - return answers, nil -} - type DnsHandler struct { DnsResolvers []string DbConn *sql.DB } +func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) { + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(dns.Fqdn(domain), qtype) + message.RecursionDesired = true + + if len(h.DnsResolvers) == 0 { + return []dns.RR{}, nil + } + + i := 0 + in, _, err := client.Exchange(message, h.DnsResolvers[i]) + for err != nil && i < len(h.DnsResolvers) { + i++ + in, _, err = client.Exchange(message, h.DnsResolvers[i]) + } + + if err != nil { + return nil, err + } + + if len(in.Answer) == 0 { + return nil, nil + } + + return in.Answer, nil +} + +func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool { + for _, answer := range answers { + if answer.Header().Name == domain && answer.Header().Rrtype == qtype { + return true + } + } + return false +} + +func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { + internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME") + if err != nil { + return nil, true, err + } + + authoritative := true + var answers []dns.RR + for _, record := range internalCnames { + cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) + if err != nil { + log.Println(err) + return nil, authoritative, err + } + answers = append(answers, cname) + + cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1) + authoritative = authoritative && cnameAuth + if err != nil { + log.Println(err) + return nil, authoritative, err + } + + answers = append(answers, cnameRecursive...) + } + + qtypeName := dns.TypeToString[qtype] + records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName) + if err != nil { + return nil, authoritative, err + } + + for _, record := range records { + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) + if err != nil { + return nil, authoritative, err + } + answers = append(answers, answer) + } + + return answers, authoritative, nil +} + +func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { + log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth) + if maxDepth == 0 { + return nil, false, fmt.Errorf("too much recursion") + } + + answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth) + if err != nil { + return nil, false, err + } + + if len(answers) > 0 { // base case - we got the answer + return answers, authoritative, nil + } + + externalAnswers, err := h.resolveExternal(domain, qtype) + if err != nil { + return nil, false, err + } + + answers = append(answers, externalAnswers...) + if resultSetFound(externalAnswers, domain, qtype) { + return answers, false, nil + } + + for _, answer := range externalAnswers { + cname, ok := answer.(*dns.CNAME) + if !ok { + continue + } + + cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1) + authoritative = authoritative && cnameAuth + if err != nil { + return nil, false, err + } + answers = append(answers, cnameAnswers...) + } + + authoritative = authoritative && len(externalAnswers) == 0 + return answers, authoritative, nil +} + func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - msg := new(dns.Msg) + msg := &dns.Msg{} msg.SetReply(r) - msg.Authoritative = true + msg.Authoritative = false for _, question := range r.Question { - answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) + answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION) + msg.Authoritative = authoritative if err != nil { fmt.Println(err) msg.SetRcode(r, dns.RcodeServerFailure) @@ -98,7 +166,8 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { handler := &DnsHandler{ - DbConn: dbConn, + DbConn: dbConn, + DnsResolvers: argv.DnsResolvers, } addr := fmt.Sprintf(":%d", argv.DnsPort) diff --git a/hcdns/server_test.go b/hcdns/server_test.go index 177def4..f1b283f 100644 --- a/hcdns/server_test.go +++ b/hcdns/server_test.go @@ -19,9 +19,8 @@ func randomPort() int { return rand.Intn(3000) + 5192 } -func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { +func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) { randomDb := utils.RandomId() - dnsPort := randomPort() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) @@ -30,10 +29,15 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { } database.FindOrSaveUser(testDb, testUser) + dnsArguments := arguments + if dnsArguments == nil { + dnsArguments = &args.Arguments{ + DnsPort: randomPort(), + } + } + waitLock := &sync.Mutex{} - server := hcdns.MakeServer(&args.Arguments{ - DnsPort: dnsPort, - }, testDb) + server := hcdns.MakeServer(dnsArguments, testDb) server.NotifyStartedFunc = func() { waitLock.Unlock() } @@ -44,8 +48,9 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { }() waitLock.Lock() - address := fmt.Sprintf("127.0.0.1:%d", dnsPort) - return testDb, server, &address, waitLock, func() { + address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort) + return testDb, server, address, func() { + waitLock.Unlock() server.Shutdown() testDb.Close() @@ -54,11 +59,8 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { } func TestWhenCNAMEIsResolved(t *testing.T) { - t.Log("TestWhenCNAMEIsResolved") - - testDb, _, addr, lock, cleanup := setup() + testDb, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() records := []*database.DNSRecord{ { @@ -99,7 +101,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) } @@ -132,11 +134,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) { } func TestWhenNoRecordNxDomain(t *testing.T) { - t.Log("TestWhenNoRecordNxDomain") - - _, _, addr, lock, cleanup := setup() + _, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") @@ -144,7 +143,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) @@ -160,11 +159,8 @@ func TestWhenNoRecordNxDomain(t *testing.T) { } func TestWhenUnresolvingCNAME(t *testing.T) { - t.Log("TestWhenUnresolvingCNAME") - - testDb, _, addr, lock, cleanup := setup() + testDb, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -183,7 +179,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) @@ -215,11 +211,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) { } func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { - t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - - testDb, _, addr, lock, cleanup := setup() + testDb, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -238,7 +231,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) @@ -252,3 +245,86 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } } + +func TestWhenExternalDomain(t *testing.T) { + externalDb, _, externalAddr, externalCleanup := setup(nil) + internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{ + DnsPort: randomPort(), + DnsResolvers: []string{externalAddr}, + }) + defer internalCleanup() + defer externalCleanup() + + authoritativeRecords := []database.DNSRecord{ + { + ID: "1", + UserID: "test", + Name: "external.example.com.", + Type: "CNAME", + Content: "external.internal.example.com.", + }, + { + ID: "2", + UserID: "test", + Name: "final.example.com.", + Type: "A", + Content: "127.0.0.1", + }, + } + internalRecords := []database.DNSRecord{ + { + ID: "1", + UserID: "test", + Name: "external.internal.example.com.", + Type: "CNAME", + Content: "final.example.com", + }, + { + ID: "2", + UserID: "test", + Name: "test.internal.example.com.", + Type: "CNAME", + Content: "external.example.com.", + }, + } + + for _, record := range authoritativeRecords { + database.SaveDNSRecord(externalDb, &record) + } + for _, record := range internalRecords { + database.SaveDNSRecord(internalDb, &record) + } + + // ensure that if the record doesn't exist in the internal database, it will + // go and query the external dns resolvers, then loop back to the internal + + qtype := dns.TypeA + domain := dns.Fqdn("test.internal.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, internalAddr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 4 { + t.Fatalf("expected 4 answers, got %d", len(in.Answer)) + } + + aRecord := in.Answer[3] + if aRecord.Header().Name != authoritativeRecords[1].Name { + t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name) + } + if aRecord.Header().Rrtype != dns.TypeA { + t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type) + } + if aRecord.(*dns.A).A.String() != authoritativeRecords[1].Content { + t.Fatalf("expected %s, got %s", authoritativeRecords[1].Content, aRecord.(*dns.A).A.String()) + } + if in.Authoritative { + t.Fatalf("expected non-authoritative response") + } +}