package hcdns import ( "database/sql" "fmt" "math/rand" "os" "sync" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "github.com/miekg/dns" ) func randomPort() int { return rand.Intn(3000) + 1024 } func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { randomDb := utils.RandomId() dnsPort := randomPort() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) testUser := &database.User{ ID: "test", } database.FindOrSaveUser(testDb, testUser) waitLock := &sync.Mutex{} server := hcdns.MakeServer(&args.Arguments{ DnsPort: dnsPort, }, testDb) server.NotifyStartedFunc = func() { waitLock.Unlock() } waitLock.Lock() go func() { server.ListenAndServe() }() waitLock.Lock() address := fmt.Sprintf("127.0.0.1:%d", dnsPort) return testDb, server, &address, waitLock, func() { server.Shutdown() testDb.Close() os.Remove(randomDb) } } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") testDb, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "res.example.com.", TTL: 300, Internal: true, } a := &database.DNSRecord{ ID: "2", UserID: "test", Name: "res.example.com.", Type: "A", Content: "127.0.0.1", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) database.SaveDNSRecord(testDb, a) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 2 { t.Fatalf("expected 2 answers, got %d", len(in.Answer)) } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[1].Header().Name != a.Name { t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) } if in.Answer[0].(*dns.CNAME).Target != a.Name { t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Answer[1].(*dns.A).A.String() != a.Content { t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[1].Header().Rrtype != dns.TypeA { t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) } if int(in.Answer[0].Header().Ttl) != cname.TTL { t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) } if !in.Authoritative { t.Fatalf("expected authoritative response") } } func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") _, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeNameError { t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") testDb, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "nonexistant.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 1 { t.Fatalf("expected 1 answer, got %d", len(in.Answer)) } if !in.Authoritative { t.Fatalf("expected authoritative response") } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[0].(*dns.CNAME).Target != cname.Content { t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Rcode == dns.RcodeNameError { t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") testDb, _, addr, lock, cleanup := setup() defer cleanup() defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "cname.internal.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) } if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } }