diff --git a/hcdns/server.go b/hcdns/server.go index e5a8d29..2e110e8 100644 --- a/hcdns/server.go +++ b/hcdns/server.go @@ -70,11 +70,12 @@ func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) answers = append(answers, cname) cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1) + authoritative = authoritative && cnameAuth if err != nil { log.Println(err) return nil, authoritative, err } - authoritative = authoritative && cnameAuth + answers = append(answers, cnameRecursive...) } @@ -126,14 +127,16 @@ func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dn continue } - cnameAnswers, _, err := h.resolveDNS(cname.Target, qtype, maxDepth-1) + cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1) + authoritative = authoritative && cnameAuth if err != nil { return nil, false, err } answers = append(answers, cnameAnswers...) } - return answers, false, nil + authoritative = authoritative && len(externalAnswers) == 0 + return answers, authoritative, nil } func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { diff --git a/hcdns/server_test.go b/hcdns/server_test.go index 9993bbf..f1b283f 100644 --- a/hcdns/server_test.go +++ b/hcdns/server_test.go @@ -58,83 +58,6 @@ func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) { } } -func TestWhenExternalDomain(t *testing.T) { - externalDb, _, externalAddr, externalCleanup := setup(nil) - internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{ - DnsPort: randomPort(), - DnsResolvers: []string{externalAddr}, - }) - defer internalCleanup() - defer externalCleanup() - - authoritativeRecords := []database.DNSRecord{ - { - ID: "1", - UserID: "test", - Name: "external.example.com.", - Type: "CNAME", - Content: "external.internal.example.com.", - }, - } - internalRecords := []database.DNSRecord{ - { - ID: "1", - UserID: "test", - Name: "external.internal.example.com.", - Type: "A", - Content: "127.0.0.1", - }, - { - ID: "2", - UserID: "test", - Name: "test.internal.example.com.", - Type: "CNAME", - Content: "external.example.com.", - }, - } - - for _, record := range authoritativeRecords { - database.SaveDNSRecord(externalDb, &record) - } - for _, record := range internalRecords { - database.SaveDNSRecord(internalDb, &record) - } - - // ensure that if the record doesn't exist in the internal database, it will - // go and query the external dns resolvers, then loop back to the internal - - qtype := dns.TypeA - domain := dns.Fqdn("test.internal.example.com.") - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, internalAddr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) != 3 { - t.Fatalf("expected 3 answers, got %d", len(in.Answer)) - } - - aRecord := in.Answer[2] - if aRecord.Header().Name != internalRecords[0].Name { - t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name) - } - if aRecord.Header().Rrtype != dns.TypeA { - t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type) - } - if aRecord.(*dns.A).A.String() != internalRecords[0].Content { - t.Fatalf("expected %s, got %s", internalRecords[0].Content, aRecord.(*dns.A).A.String()) - } - - if in.Authoritative { - t.Fatalf("expected non-authoritative response") - } -} - func TestWhenCNAMEIsResolved(t *testing.T) { testDb, _, addr, cleanup := setup(nil) defer cleanup() @@ -322,3 +245,86 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } } + +func TestWhenExternalDomain(t *testing.T) { + externalDb, _, externalAddr, externalCleanup := setup(nil) + internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{ + DnsPort: randomPort(), + DnsResolvers: []string{externalAddr}, + }) + defer internalCleanup() + defer externalCleanup() + + authoritativeRecords := []database.DNSRecord{ + { + ID: "1", + UserID: "test", + Name: "external.example.com.", + Type: "CNAME", + Content: "external.internal.example.com.", + }, + { + ID: "2", + UserID: "test", + Name: "final.example.com.", + Type: "A", + Content: "127.0.0.1", + }, + } + internalRecords := []database.DNSRecord{ + { + ID: "1", + UserID: "test", + Name: "external.internal.example.com.", + Type: "CNAME", + Content: "final.example.com", + }, + { + ID: "2", + UserID: "test", + Name: "test.internal.example.com.", + Type: "CNAME", + Content: "external.example.com.", + }, + } + + for _, record := range authoritativeRecords { + database.SaveDNSRecord(externalDb, &record) + } + for _, record := range internalRecords { + database.SaveDNSRecord(internalDb, &record) + } + + // ensure that if the record doesn't exist in the internal database, it will + // go and query the external dns resolvers, then loop back to the internal + + qtype := dns.TypeA + domain := dns.Fqdn("test.internal.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, internalAddr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 4 { + t.Fatalf("expected 4 answers, got %d", len(in.Answer)) + } + + aRecord := in.Answer[3] + if aRecord.Header().Name != authoritativeRecords[1].Name { + t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name) + } + if aRecord.Header().Rrtype != dns.TypeA { + t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type) + } + if aRecord.(*dns.A).A.String() != authoritativeRecords[1].Content { + t.Fatalf("expected %s, got %s", authoritativeRecords[1].Content, aRecord.(*dns.A).A.String()) + } + if in.Authoritative { + t.Fatalf("expected non-authoritative response") + } +}