From bcdcc508ef4a0ae646937c91d0994a90bde719e1 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:26:39 -0600 Subject: [PATCH] add integration tests for dns server --- {dns => hcdns}/server.go | 19 ++- main.go | 4 +- test/dns_test.go | 244 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 255 insertions(+), 12 deletions(-) rename {dns => hcdns}/server.go (91%) create mode 100644 test/dns_test.go diff --git a/dns/server.go b/hcdns/server.go similarity index 91% rename from dns/server.go rename to hcdns/server.go index 9b3e5e9..ce7894b 100644 --- a/dns/server.go +++ b/hcdns/server.go @@ -1,4 +1,4 @@ -package dns +package hcdns import ( "database/sql" @@ -9,7 +9,7 @@ import ( "log" ) -const MAX_RECURSION = 10 +const MAX_RECURSION = 15 func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") @@ -21,12 +21,14 @@ func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth for _, record := range internalCnames { cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) if err != nil { + log.Println(err) return nil, err } answers = append(answers, cname) cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) if err != nil { + log.Println(err) return nil, err } answers = append(answers, cnameRecursive...) @@ -62,11 +64,7 @@ func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dn return nil, err } - if len(answers) > 0 { - return answers, nil - } - - return nil, fmt.Errorf("no records found for %s", domain) + return answers, nil } type DnsHandler struct { @@ -83,7 +81,9 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) if err != nil { fmt.Println(err) - continue + msg.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(msg) + return } msg.Answer = append(msg.Answer, answers...) } @@ -98,8 +98,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { handler := &DnsHandler{ - DnsResolvers: argv.DnsRecursion, - DbConn: dbConn, + DbConn: dbConn, } addr := fmt.Sprintf(":%d", argv.DnsPort) diff --git a/main.go b/main.go index 2991821..e0f3e55 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,7 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" "github.com/joho/godotenv" ) @@ -52,7 +52,7 @@ func main() { } if argv.Dns { - server := dns.MakeServer(argv, dbConn) + server := hcdns.MakeServer(argv, dbConn) log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) go func() { err = server.ListenAndServe() diff --git a/test/dns_test.go b/test/dns_test.go new file mode 100644 index 0000000..ce6deb5 --- /dev/null +++ b/test/dns_test.go @@ -0,0 +1,244 @@ +package hcdns + +import ( + "database/sql" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "github.com/miekg/dns" +) + +const ( + testDBPath = "test.db" + address = "127.0.0.1:8353" + dnsPort = 8353 +) + +func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { + testDb := database.MakeConn(&dbPath) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + go func() { + server.ListenAndServe() + waitGroup.Done() + }() + + return testDb, server, &waitGroup +} + +func destroy(conn *sql.DB, path string) { + conn.Close() + os.Remove(path) +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + } + a := &database.DNSRecord{ + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "127.0.0.1", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + database.SaveDNSRecord(testDb, a) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 2 { + t.Fatalf("expected 2 answers, got %d", len(in.Answer)) + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[1].Header().Name != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) + } + + if in.Answer[0].(*dns.CNAME).Target != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Answer[1].(*dns.A).A.String() != a.Content { + t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[1].Header().Rrtype != dns.TypeA { + t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) + } + + if int(in.Answer[0].Header().Ttl) != cname.TTL { + t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +}