package hcdns import ( "database/sql" "fmt" "math/rand" "os" "sync" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "github.com/miekg/dns" ) func destroy(conn *sql.DB, path string) { conn.Close() os.Remove(path) } func randomPort() int { return rand.Intn(3000) + 10000 } func setup() (*sql.DB, *dns.Server, int, *string, func()) { randomDb := utils.RandomId() dnsPort := randomPort() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) testUser := &database.User{ ID: "test", } database.FindOrSaveUser(testDb, testUser) server := hcdns.MakeServer(&args.Arguments{ DnsPort: dnsPort, }, testDb) waitGroup := sync.WaitGroup{} waitGroup.Add(1) go func() { server.ListenAndServe() waitGroup.Done() }() address := fmt.Sprintf("127.0.0.1:%d", dnsPort) return testDb, server, dnsPort, &address, func() { testDb.Close() os.Remove(randomDb) server.Shutdown() waitGroup.Wait() } } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") testDb, _, _, addr, cleanup := setup() defer cleanup() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "res.example.com.", TTL: 300, Internal: true, } a := &database.DNSRecord{ ID: "2", UserID: "test", Name: "res.example.com.", Type: "A", Content: "127.0.0.1", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) database.SaveDNSRecord(testDb, a) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 2 { t.Fatalf("expected 2 answers, got %d", len(in.Answer)) } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[1].Header().Name != a.Name { t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) } if in.Answer[0].(*dns.CNAME).Target != a.Name { t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Answer[1].(*dns.A).A.String() != a.Content { t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[1].Header().Rrtype != dns.TypeA { t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) } if int(in.Answer[0].Header().Ttl) != cname.TTL { t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) } if !in.Authoritative { t.Fatalf("expected authoritative response") } } func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") _, _, _, addr, cleanup := setup() defer cleanup() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeNameError { t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") testDb, _, _, addr, cleanup := setup() defer cleanup() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "nonexistant.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 1 { t.Fatalf("expected 1 answer, got %d", len(in.Answer)) } if !in.Authoritative { t.Fatalf("expected authoritative response") } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[0].(*dns.CNAME).Target != cname.Content { t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Rcode == dns.RcodeNameError { t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") testDb, _, _, addr, cleanup := setup() defer cleanup() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "cname.internal.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := new(dns.Client) message := new(dns.Msg) message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } }