refactor dns server test a bit

This commit is contained in:
Elizabeth 2024-04-03 15:33:02 -06:00
parent cc33a90bfd
commit da6b6011fc
Signed by: simponic
GPG Key ID: 2909B9A7FF6213EE
1 changed files with 53 additions and 52 deletions

View File

@ -1,4 +1,4 @@
package hcdns package hcdns_test
import ( import (
"database/sql" "database/sql"
@ -16,7 +16,7 @@ import (
) )
func randomPort() int { func randomPort() int {
return rand.Intn(3000) + 1024 return rand.Intn(3000) + 5192
} }
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
@ -60,74 +60,75 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
defer cleanup() defer cleanup()
defer lock.Unlock() defer lock.Unlock()
cname := &database.DNSRecord{ records := []*database.DNSRecord{
ID: "1", {
ID: "0",
UserID: "test", UserID: "test",
Name: "cname.internal.example.com.", Name: "cname.internal.example.com.",
Type: "CNAME", Type: "CNAME",
Content: "next.internal.example.com.",
TTL: 300,
Internal: true,
}, {
ID: "1",
UserID: "test",
Name: "next.internal.example.com.",
Type: "CNAME",
Content: "res.example.com.", Content: "res.example.com.",
TTL: 300, TTL: 300,
Internal: true, Internal: true,
} },
a := &database.DNSRecord{ {
ID: "2", ID: "2",
UserID: "test", UserID: "test",
Name: "res.example.com.", Name: "res.example.com.",
Type: "A", Type: "A",
Content: "127.0.0.1", Content: "1.2.3.2",
TTL: 300, TTL: 300,
Internal: true, Internal: true,
},
}
for _, record := range records {
database.SaveDNSRecord(testDb, record)
} }
database.SaveDNSRecord(testDb, cname)
database.SaveDNSRecord(testDb, a)
qtype := dns.TypeA qtype := dns.TypeA
domain := dns.Fqdn(cname.Name) domain := dns.Fqdn("cname.internal.example.com.")
client := &dns.Client{} client := &dns.Client{}
message := &dns.Msg{} message := &dns.Msg{}
message.SetQuestion(domain, qtype) message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr) in, _, err := client.Exchange(message, *addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(in.Answer) != 2 { if len(in.Answer) != 3 {
t.Fatalf("expected 2 answers, got %d", len(in.Answer)) t.Fatalf("expected 3 answers, got %d", len(in.Answer))
} }
if in.Answer[0].Header().Name != cname.Name { for i, record := range records {
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) if in.Answer[i].Header().Name != record.Name {
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
} }
if in.Answer[1].Header().Name != a.Name { if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
} }
if in.Answer[0].(*dns.CNAME).Target != a.Name { if int(in.Answer[i].Header().Ttl) != record.TTL {
t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
}
if in.Answer[1].(*dns.A).A.String() != a.Content {
t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String())
}
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
}
if in.Answer[1].Header().Rrtype != dns.TypeA {
t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype)
}
if int(in.Answer[0].Header().Ttl) != cname.TTL {
t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl)
} }
if !in.Authoritative { if !in.Authoritative {
t.Fatalf("expected authoritative response") t.Fatalf("expected authoritative response")
} }
}
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
t.Fatalf("expected final record to be the A record with correct IP")
}
} }
func TestWhenNoRecordNxDomain(t *testing.T) { func TestWhenNoRecordNxDomain(t *testing.T) {