package hcdns_test import ( "database/sql" "fmt" "math/rand" "os" "sync" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "github.com/miekg/dns" ) func randomPort() int { return rand.Intn(3000) + 5192 } func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) { randomDb := utils.RandomId() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) testUser := &database.User{ ID: "test", } database.FindOrSaveUser(testDb, testUser) dnsArguments := arguments if dnsArguments == nil { dnsArguments = &args.Arguments{ DnsPort: randomPort(), } } waitLock := &sync.Mutex{} server := hcdns.MakeServer(dnsArguments, testDb) server.NotifyStartedFunc = func() { waitLock.Unlock() } waitLock.Lock() go func() { server.ListenAndServe() }() waitLock.Lock() address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort) return testDb, server, address, func() { waitLock.Unlock() server.Shutdown() testDb.Close() os.Remove(randomDb) } } func TestWhenExternalDomain(t *testing.T) { externalDb, _, externalAddr, externalCleanup := setup(nil) internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{ DnsPort: randomPort(), DnsResolvers: []string{externalAddr}, }) defer internalCleanup() defer externalCleanup() authoritativeRecords := []database.DNSRecord{ { ID: "1", UserID: "test", Name: "external.example.com.", Type: "CNAME", Content: "external.internal.example.com.", }, } internalRecords := []database.DNSRecord{ { ID: "1", UserID: "test", Name: "external.internal.example.com.", Type: "A", Content: "127.0.0.1", }, { ID: "2", UserID: "test", Name: "test.internal.example.com.", Type: "CNAME", Content: "external.example.com.", }, } for _, record := range authoritativeRecords { database.SaveDNSRecord(externalDb, &record) } for _, record := range internalRecords { database.SaveDNSRecord(internalDb, &record) } // ensure that if the record doesn't exist in the internal database, it will // go and query the external dns resolvers, then loop back to the internal qtype := dns.TypeA domain := dns.Fqdn("test.internal.example.com.") client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, internalAddr) if err != nil { t.Fatal(err) } if len(in.Answer) != 3 { t.Fatalf("expected 3 answers, got %d", len(in.Answer)) } aRecord := in.Answer[2] if aRecord.Header().Name != internalRecords[0].Name { t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name) } if aRecord.Header().Rrtype != dns.TypeA { t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type) } if aRecord.(*dns.A).A.String() != internalRecords[0].Content { t.Fatalf("expected %s, got %s", internalRecords[0].Content, aRecord.(*dns.A).A.String()) } if in.Authoritative { t.Fatalf("expected non-authoritative response") } } func TestWhenCNAMEIsResolved(t *testing.T) { testDb, _, addr, cleanup := setup(nil) defer cleanup() records := []*database.DNSRecord{ { ID: "0", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "next.internal.example.com.", TTL: 300, Internal: true, }, { ID: "1", UserID: "test", Name: "next.internal.example.com.", Type: "CNAME", Content: "res.example.com.", TTL: 300, Internal: true, }, { ID: "2", UserID: "test", Name: "res.example.com.", Type: "A", Content: "1.2.3.2", TTL: 300, Internal: true, }, } for _, record := range records { database.SaveDNSRecord(testDb, record) } qtype := dns.TypeA domain := dns.Fqdn("cname.internal.example.com.") client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 3 { t.Fatalf("expected 3 answers, got %d", len(in.Answer)) } for i, record := range records { if in.Answer[i].Header().Name != record.Name { t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) } if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) } if int(in.Answer[i].Header().Ttl) != record.TTL { t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) } if !in.Authoritative { t.Fatalf("expected authoritative response") } } if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { t.Fatalf("expected final record to be the A record with correct IP") } } func TestWhenNoRecordNxDomain(t *testing.T) { _, _, addr, cleanup := setup(nil) defer cleanup() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeNameError { t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAME(t *testing.T) { testDb, _, addr, cleanup := setup(nil) defer cleanup() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "nonexistant.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) } if len(in.Answer) != 1 { t.Fatalf("expected 1 answer, got %d", len(in.Answer)) } if !in.Authoritative { t.Fatalf("expected authoritative response") } if in.Answer[0].Header().Name != cname.Name { t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) } if in.Answer[0].Header().Rrtype != dns.TypeCNAME { t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) } if in.Answer[0].(*dns.CNAME).Target != cname.Content { t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) } if in.Rcode == dns.RcodeNameError { t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) } } func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { testDb, _, addr, cleanup := setup(nil) defer cleanup() cname := &database.DNSRecord{ ID: "1", UserID: "test", Name: "cname.internal.example.com.", Type: "CNAME", Content: "cname.internal.example.com.", TTL: 300, Internal: true, } database.SaveDNSRecord(testDb, cname) qtype := dns.TypeA domain := dns.Fqdn(cname.Name) client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) } if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } }