package hcdns import ( "database/sql" "os" "sync" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "github.com/miekg/dns" ) const ( testDBPath = "test.db" address = "127.0.0.1:8353" dnsPort = 8353 ) func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { testDb := database.MakeConn(&dbPath) database.Migrate(testDb) testUser := &database.User{ ID: "test", } database.FindOrSaveUser(testDb, testUser) server := hcdns.MakeServer(&args.Arguments{ DnsPort: dnsPort, }, testDb) waitGroup := sync.WaitGroup{} waitGroup.Add(1) go func() { server.ListenAndServe() waitGroup.Done() }() return testDb, server, &waitGroup } func destroy(conn *sql.DB, path string) { conn.Close() os.Remove(path) } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") testDb, server, _ := setup(testDBPath) defer destroy(testDb, testDBPath) defer server.Shutdown() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "res.example.com.", TTL: 300, Internal: true, } a := &database.DNSRecord{ ID: "2", UserID: "test", Name: "res.example.com.", Type: "A", Content: "127.0.0.1", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) database.SaveDNSRecord(testDb, a) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, address) if err != nil { t.Fatal(err) } if len(in.Answer) != 2 { t.Fatalf("expected 2 answers, got %d", len(in.Answer)) } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[1].Header().Name != a.Name { t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) } if in.Answer[0].(*dns.CNAME).Target != a.Name { t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Answer[1].(*dns.A).A.String() != a.Content { t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[1].Header().Rrtype != dns.TypeA { t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) } if int(in.Answer[0].Header().Ttl) != cname.TTL { t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) } if !in.Authoritative { t.Fatalf("expected authoritative response") } } func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") testDb, server, _ := setup(testDBPath) defer destroy(testDb, testDBPath) defer server.Shutdown() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, address) if err != nil { t.Fatal(err) } if len(in.Answer) != 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeNameError { t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") testDb, server, _ := setup(testDBPath) defer destroy(testDb, testDBPath) defer server.Shutdown() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "nonexistant.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, address) if err != nil { t.Fatal(err) } if len(in.Answer) != 1 { t.Fatalf("expected 1 answer, got %d", len(in.Answer)) } if !in.Authoritative { t.Fatalf("expected authoritative response") } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[0].(*dns.CNAME).Target != cname.Content { t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Rcode == dns.RcodeNameError { t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") testDb, server, _ := setup(testDBPath) defer destroy(testDb, testDBPath) defer server.Shutdown() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "cname.internal.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, address) if err != nil { t.Fatal(err) } if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } }