package hcdns_test import ( "database/sql" "fmt" "math/rand" "os" "sync" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "github.com/miekg/dns" ) func randomPort() int { return rand.Intn(3000) + 5192 } func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { randomDb := utils.RandomId() dnsPort := randomPort() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) testUser := &database.User{ ID: "test", } database.FindOrSaveUser(testDb, testUser) waitLock := &sync.Mutex{} server := hcdns.MakeServer(&args.Arguments{ DnsPort: dnsPort, }, testDb) server.NotifyStartedFunc = func() { waitLock.Unlock() } waitLock.Lock() go func() { server.ListenAndServe() }() waitLock.Lock() address := fmt.Sprintf("127.0.0.1:%d", dnsPort) return testDb, server, &address, waitLock, func() { server.Shutdown() testDb.Close() os.Remove(randomDb) } } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") testDb, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() records := []*database.DNSRecord{ { ID: "0", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "next.internal.example.com.", TTL: 300, Internal: true, }, { ID: "1", UserID: "test", Name: "next.internal.example.com.", Type: "CNAME", Content: "res.example.com.", TTL: 300, Internal: true, }, { ID: "2", UserID: "test", Name: "res.example.com.", Type: "A", Content: "1.2.3.2", TTL: 300, Internal: true, }, } for _, record := range records { database.SaveDNSRecord(testDb, record) } qtype := dns.TypeA domain := dns.Fqdn("cname.internal.example.com.") client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 3 { t.Fatalf("expected 3 answers, got %d", len(in.Answer)) } for i, record := range records { if in.Answer[i].Header().Name != record.Name { t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) } if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) } if int(in.Answer[i].Header().Ttl) != record.TTL { t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) } if !in.Authoritative { t.Fatalf("expected authoritative response") } } if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { t.Fatalf("expected final record to be the A record with correct IP") } } func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") _, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeNameError { t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") testDb, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "nonexistant.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 1 { t.Fatalf("expected 1 answer, got %d", len(in.Answer)) } if !in.Authoritative { t.Fatalf("expected authoritative response") } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[0].(*dns.CNAME).Target != cname.Content { t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Rcode == dns.RcodeNameError { t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") testDb, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "cname.internal.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } }