diff --git a/test/dns_test.go b/hcdns/server_test.go similarity index 72% rename from test/dns_test.go rename to hcdns/server_test.go index d875f3f..177def4 100644 --- a/test/dns_test.go +++ b/hcdns/server_test.go @@ -1,4 +1,4 @@ -package hcdns +package hcdns_test import ( "database/sql" @@ -16,7 +16,7 @@ import ( ) func randomPort() int { - return rand.Intn(3000) + 1024 + return rand.Intn(3000) + 5192 } func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { @@ -60,73 +60,74 @@ func TestWhenCNAMEIsResolved(t *testing.T) { defer cleanup() defer lock.Unlock() - cname := &database.DNSRecord{ - ID: "1", - UserID: "test", - Name: "cname.internal.example.com.", - Type: "CNAME", - Content: "res.example.com.", - TTL: 300, - Internal: true, + records := []*database.DNSRecord{ + { + ID: "0", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "next.internal.example.com.", + TTL: 300, + Internal: true, + }, { + ID: "1", + UserID: "test", + Name: "next.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + }, + { + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "1.2.3.2", + TTL: 300, + Internal: true, + }, } - a := &database.DNSRecord{ - ID: "2", - UserID: "test", - Name: "res.example.com.", - Type: "A", - Content: "127.0.0.1", - TTL: 300, - Internal: true, + + for _, record := range records { + database.SaveDNSRecord(testDb, record) } - database.SaveDNSRecord(testDb, cname) - database.SaveDNSRecord(testDb, a) qtype := dns.TypeA - domain := dns.Fqdn(cname.Name) + domain := dns.Fqdn("cname.internal.example.com.") client := &dns.Client{} message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) - if err != nil { t.Fatal(err) } - if len(in.Answer) != 2 { - t.Fatalf("expected 2 answers, got %d", len(in.Answer)) + if len(in.Answer) != 3 { + t.Fatalf("expected 3 answers, got %d", len(in.Answer)) } - if in.Answer[0].Header().Name != cname.Name { - t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + for i, record := range records { + if in.Answer[i].Header().Name != record.Name { + t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) + } + + if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { + t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) + } + + if int(in.Answer[i].Header().Ttl) != record.TTL { + t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } } - if in.Answer[1].Header().Name != a.Name { - t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) - } - - if in.Answer[0].(*dns.CNAME).Target != a.Name { - t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) - } - - if in.Answer[1].(*dns.A).A.String() != a.Content { - t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) - } - - if in.Answer[0].Header().Rrtype != dns.TypeCNAME { - t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) - } - - if in.Answer[1].Header().Rrtype != dns.TypeA { - t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) - } - - if int(in.Answer[0].Header().Ttl) != cname.TTL { - t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) - } - - if !in.Authoritative { - t.Fatalf("expected authoritative response") + if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { + t.Fatalf("expected final record to be the A record with correct IP") } }