325 lines
7.4 KiB
Go
325 lines
7.4 KiB
Go
package hcdns_test
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"math/rand"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
func randomPort() int {
|
|
return rand.Intn(3000) + 5192
|
|
}
|
|
|
|
func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
|
|
randomDb := utils.RandomId()
|
|
|
|
testDb := database.MakeConn(&randomDb)
|
|
database.Migrate(testDb)
|
|
testUser := &database.User{
|
|
ID: "test",
|
|
}
|
|
database.FindOrSaveUser(testDb, testUser)
|
|
|
|
dnsArguments := arguments
|
|
if dnsArguments == nil {
|
|
dnsArguments = &args.Arguments{
|
|
DnsPort: randomPort(),
|
|
}
|
|
}
|
|
|
|
waitLock := &sync.Mutex{}
|
|
server := hcdns.MakeServer(dnsArguments, testDb)
|
|
server.NotifyStartedFunc = func() {
|
|
waitLock.Unlock()
|
|
}
|
|
waitLock.Lock()
|
|
|
|
go func() {
|
|
server.ListenAndServe()
|
|
}()
|
|
waitLock.Lock()
|
|
|
|
address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort)
|
|
return testDb, server, address, func() {
|
|
waitLock.Unlock()
|
|
server.Shutdown()
|
|
|
|
testDb.Close()
|
|
os.Remove(randomDb)
|
|
}
|
|
}
|
|
|
|
func TestWhenExternalDomain(t *testing.T) {
|
|
externalDb, _, externalAddr, externalCleanup := setup(nil)
|
|
internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{
|
|
DnsPort: randomPort(),
|
|
DnsResolvers: []string{externalAddr},
|
|
})
|
|
defer internalCleanup()
|
|
defer externalCleanup()
|
|
|
|
authoritativeRecords := []database.DNSRecord{
|
|
{
|
|
ID: "1",
|
|
UserID: "test",
|
|
Name: "external.example.com.",
|
|
Type: "CNAME",
|
|
Content: "external.internal.example.com.",
|
|
},
|
|
}
|
|
internalRecords := []database.DNSRecord{
|
|
{
|
|
ID: "1",
|
|
UserID: "test",
|
|
Name: "external.internal.example.com.",
|
|
Type: "A",
|
|
Content: "127.0.0.1",
|
|
},
|
|
{
|
|
ID: "2",
|
|
UserID: "test",
|
|
Name: "test.internal.example.com.",
|
|
Type: "CNAME",
|
|
Content: "external.example.com.",
|
|
},
|
|
}
|
|
|
|
for _, record := range authoritativeRecords {
|
|
database.SaveDNSRecord(externalDb, &record)
|
|
}
|
|
for _, record := range internalRecords {
|
|
database.SaveDNSRecord(internalDb, &record)
|
|
}
|
|
|
|
// ensure that if the record doesn't exist in the internal database, it will
|
|
// go and query the external dns resolvers, then loop back to the internal
|
|
|
|
qtype := dns.TypeA
|
|
domain := dns.Fqdn("test.internal.example.com.")
|
|
client := &dns.Client{}
|
|
message := &dns.Msg{}
|
|
message.SetQuestion(domain, qtype)
|
|
|
|
in, _, err := client.Exchange(message, internalAddr)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(in.Answer) != 3 {
|
|
t.Fatalf("expected 3 answers, got %d", len(in.Answer))
|
|
}
|
|
|
|
aRecord := in.Answer[2]
|
|
if aRecord.Header().Name != internalRecords[0].Name {
|
|
t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name)
|
|
}
|
|
if aRecord.Header().Rrtype != dns.TypeA {
|
|
t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type)
|
|
}
|
|
if aRecord.(*dns.A).A.String() != internalRecords[0].Content {
|
|
t.Fatalf("expected %s, got %s", internalRecords[0].Content, aRecord.(*dns.A).A.String())
|
|
}
|
|
|
|
if in.Authoritative {
|
|
t.Fatalf("expected non-authoritative response")
|
|
}
|
|
}
|
|
|
|
func TestWhenCNAMEIsResolved(t *testing.T) {
|
|
testDb, _, addr, cleanup := setup(nil)
|
|
defer cleanup()
|
|
|
|
records := []*database.DNSRecord{
|
|
{
|
|
ID: "0",
|
|
UserID: "test",
|
|
Name: "cname.internal.example.com.",
|
|
Type: "CNAME",
|
|
Content: "next.internal.example.com.",
|
|
TTL: 300,
|
|
Internal: true,
|
|
}, {
|
|
ID: "1",
|
|
UserID: "test",
|
|
Name: "next.internal.example.com.",
|
|
Type: "CNAME",
|
|
Content: "res.example.com.",
|
|
TTL: 300,
|
|
Internal: true,
|
|
},
|
|
{
|
|
ID: "2",
|
|
UserID: "test",
|
|
Name: "res.example.com.",
|
|
Type: "A",
|
|
Content: "1.2.3.2",
|
|
TTL: 300,
|
|
Internal: true,
|
|
},
|
|
}
|
|
|
|
for _, record := range records {
|
|
database.SaveDNSRecord(testDb, record)
|
|
}
|
|
|
|
qtype := dns.TypeA
|
|
domain := dns.Fqdn("cname.internal.example.com.")
|
|
client := &dns.Client{}
|
|
message := &dns.Msg{}
|
|
message.SetQuestion(domain, qtype)
|
|
|
|
in, _, err := client.Exchange(message, addr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(in.Answer) != 3 {
|
|
t.Fatalf("expected 3 answers, got %d", len(in.Answer))
|
|
}
|
|
|
|
for i, record := range records {
|
|
if in.Answer[i].Header().Name != record.Name {
|
|
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
|
|
}
|
|
|
|
if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
|
|
t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
|
|
}
|
|
|
|
if int(in.Answer[i].Header().Ttl) != record.TTL {
|
|
t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
|
|
}
|
|
|
|
if !in.Authoritative {
|
|
t.Fatalf("expected authoritative response")
|
|
}
|
|
}
|
|
|
|
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
|
|
t.Fatalf("expected final record to be the A record with correct IP")
|
|
}
|
|
}
|
|
|
|
func TestWhenNoRecordNxDomain(t *testing.T) {
|
|
_, _, addr, cleanup := setup(nil)
|
|
defer cleanup()
|
|
|
|
qtype := dns.TypeA
|
|
domain := dns.Fqdn("nonexistant.example.com.")
|
|
client := &dns.Client{}
|
|
message := &dns.Msg{}
|
|
message.SetQuestion(domain, qtype)
|
|
|
|
in, _, err := client.Exchange(message, addr)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(in.Answer) != 0 {
|
|
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
|
}
|
|
|
|
if in.Rcode != dns.RcodeNameError {
|
|
t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
|
|
}
|
|
}
|
|
|
|
func TestWhenUnresolvingCNAME(t *testing.T) {
|
|
testDb, _, addr, cleanup := setup(nil)
|
|
defer cleanup()
|
|
|
|
cname := &database.DNSRecord{
|
|
ID: "1",
|
|
UserID: "test",
|
|
Name: "cname.internal.example.com.",
|
|
Type: "CNAME",
|
|
Content: "nonexistant.example.com.",
|
|
TTL: 300,
|
|
Internal: true,
|
|
}
|
|
database.SaveDNSRecord(testDb, cname)
|
|
|
|
qtype := dns.TypeA
|
|
domain := dns.Fqdn(cname.Name)
|
|
client := &dns.Client{}
|
|
message := &dns.Msg{}
|
|
message.SetQuestion(domain, qtype)
|
|
|
|
in, _, err := client.Exchange(message, addr)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(in.Answer) != 1 {
|
|
t.Fatalf("expected 1 answer, got %d", len(in.Answer))
|
|
}
|
|
|
|
if !in.Authoritative {
|
|
t.Fatalf("expected authoritative response")
|
|
}
|
|
|
|
if in.Answer[0].Header().Name != cname.Name {
|
|
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
|
|
}
|
|
|
|
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
|
|
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
|
|
}
|
|
|
|
if in.Answer[0].(*dns.CNAME).Target != cname.Content {
|
|
t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
|
|
}
|
|
|
|
if in.Rcode == dns.RcodeNameError {
|
|
t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
|
|
}
|
|
}
|
|
|
|
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
|
|
testDb, _, addr, cleanup := setup(nil)
|
|
defer cleanup()
|
|
|
|
cname := &database.DNSRecord{
|
|
ID: "1",
|
|
UserID: "test",
|
|
Name: "cname.internal.example.com.",
|
|
Type: "CNAME",
|
|
Content: "cname.internal.example.com.",
|
|
TTL: 300,
|
|
Internal: true,
|
|
}
|
|
database.SaveDNSRecord(testDb, cname)
|
|
|
|
qtype := dns.TypeA
|
|
domain := dns.Fqdn(cname.Name)
|
|
client := &dns.Client{}
|
|
message := &dns.Msg{}
|
|
message.SetQuestion(domain, qtype)
|
|
|
|
in, _, err := client.Exchange(message, addr)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(in.Answer) > 0 {
|
|
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
|
}
|
|
|
|
if in.Rcode != dns.RcodeServerFailure {
|
|
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
|
|
}
|
|
}
|