add integration tests for dns server
This commit is contained in:
		
							parent
							
								
									657be66948
								
							
						
					
					
						commit
						bcdcc508ef
					
				|  | @ -1,4 +1,4 @@ | |||
| package dns | ||||
| package hcdns | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
|  | @ -9,7 +9,7 @@ import ( | |||
| 	"log" | ||||
| ) | ||||
| 
 | ||||
| const MAX_RECURSION = 10 | ||||
| const MAX_RECURSION = 15 | ||||
| 
 | ||||
| func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { | ||||
| 	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") | ||||
|  | @ -21,12 +21,14 @@ func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth | |||
| 	for _, record := range internalCnames { | ||||
| 		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		answers = append(answers, cname) | ||||
| 
 | ||||
| 		cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		answers = append(answers, cnameRecursive...) | ||||
|  | @ -62,13 +64,9 @@ func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dn | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if len(answers) > 0 { | ||||
| 	return answers, nil | ||||
| } | ||||
| 
 | ||||
| 	return nil, fmt.Errorf("no records found for %s", domain) | ||||
| } | ||||
| 
 | ||||
| type DnsHandler struct { | ||||
| 	DnsResolvers []string | ||||
| 	DbConn       *sql.DB | ||||
|  | @ -83,7 +81,9 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||
| 		answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) | ||||
| 		if err != nil { | ||||
| 			fmt.Println(err) | ||||
| 			continue | ||||
| 			msg.SetRcode(r, dns.RcodeServerFailure) | ||||
| 			w.WriteMsg(msg) | ||||
| 			return | ||||
| 		} | ||||
| 		msg.Answer = append(msg.Answer, answers...) | ||||
| 	} | ||||
|  | @ -98,7 +98,6 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||
| 
 | ||||
| func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { | ||||
| 	handler := &DnsHandler{ | ||||
| 		DnsResolvers: argv.DnsRecursion, | ||||
| 		DbConn: dbConn, | ||||
| 	} | ||||
| 	addr := fmt.Sprintf(":%d", argv.DnsPort) | ||||
							
								
								
									
										4
									
								
								main.go
								
								
								
								
							
							
						
						
									
										4
									
								
								main.go
								
								
								
								
							|  | @ -6,7 +6,7 @@ import ( | |||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" | ||||
| 	"github.com/joho/godotenv" | ||||
| ) | ||||
|  | @ -52,7 +52,7 @@ func main() { | |||
| 	} | ||||
| 
 | ||||
| 	if argv.Dns { | ||||
| 		server := dns.MakeServer(argv, dbConn) | ||||
| 		server := hcdns.MakeServer(argv, dbConn) | ||||
| 		log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) | ||||
| 		go func() { | ||||
| 			err = server.ListenAndServe() | ||||
|  |  | |||
|  | @ -0,0 +1,244 @@ | |||
| package hcdns | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"os" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	testDBPath = "test.db" | ||||
| 	address    = "127.0.0.1:8353" | ||||
| 	dnsPort    = 8353 | ||||
| ) | ||||
| 
 | ||||
| func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { | ||||
| 	testDb := database.MakeConn(&dbPath) | ||||
| 	database.Migrate(testDb) | ||||
| 	testUser := &database.User{ | ||||
| 		ID: "test", | ||||
| 	} | ||||
| 	database.FindOrSaveUser(testDb, testUser) | ||||
| 
 | ||||
| 	server := hcdns.MakeServer(&args.Arguments{ | ||||
| 		DnsPort: dnsPort, | ||||
| 	}, testDb) | ||||
| 
 | ||||
| 	waitGroup := sync.WaitGroup{} | ||||
| 	waitGroup.Add(1) | ||||
| 	go func() { | ||||
| 		server.ListenAndServe() | ||||
| 		waitGroup.Done() | ||||
| 	}() | ||||
| 
 | ||||
| 	return testDb, server, &waitGroup | ||||
| } | ||||
| 
 | ||||
| func destroy(conn *sql.DB, path string) { | ||||
| 	conn.Close() | ||||
| 	os.Remove(path) | ||||
| } | ||||
| 
 | ||||
| func TestWhenCNAMEIsResolved(t *testing.T) { | ||||
| 	t.Log("TestWhenCNAMEIsResolved") | ||||
| 
 | ||||
| 	testDb, server, _ := setup(testDBPath) | ||||
| 	defer destroy(testDb, testDBPath) | ||||
| 	defer server.Shutdown() | ||||
| 
 | ||||
| 	cname := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		UserID:   "test", | ||||
| 		Name:     "cname.internal.example.com.", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "res.example.com.", | ||||
| 		TTL:      300, | ||||
| 		Internal: true, | ||||
| 	} | ||||
| 	a := &database.DNSRecord{ | ||||
| 		ID:       "2", | ||||
| 		UserID:   "test", | ||||
| 		Name:     "res.example.com.", | ||||
| 		Type:     "A", | ||||
| 		Content:  "127.0.0.1", | ||||
| 		TTL:      300, | ||||
| 		Internal: true, | ||||
| 	} | ||||
| 	database.SaveDNSRecord(testDb, cname) | ||||
| 	database.SaveDNSRecord(testDb, a) | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn(cname.Name) | ||||
| 	client := new(dns.Client) | ||||
| 	message := new(dns.Msg) | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, address) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) != 2 { | ||||
| 		t.Fatalf("expected 2 answers, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].Header().Name != cname.Name { | ||||
| 		t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[1].Header().Name != a.Name { | ||||
| 		t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].(*dns.CNAME).Target != a.Name { | ||||
| 		t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[1].(*dns.A).A.String() != a.Content { | ||||
| 		t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].Header().Rrtype != dns.TypeCNAME { | ||||
| 		t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[1].Header().Rrtype != dns.TypeA { | ||||
| 		t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) | ||||
| 	} | ||||
| 
 | ||||
| 	if int(in.Answer[0].Header().Ttl) != cname.TTL { | ||||
| 		t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) | ||||
| 	} | ||||
| 
 | ||||
| 	if !in.Authoritative { | ||||
| 		t.Fatalf("expected authoritative response") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWhenNoRecordNxDomain(t *testing.T) { | ||||
| 	t.Log("TestWhenNoRecordNxDomain") | ||||
| 
 | ||||
| 	testDb, server, _ := setup(testDBPath) | ||||
| 	defer destroy(testDb, testDBPath) | ||||
| 	defer server.Shutdown() | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn("nonexistant.example.com.") | ||||
| 	client := new(dns.Client) | ||||
| 	message := new(dns.Msg) | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, address) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) != 0 { | ||||
| 		t.Fatalf("expected 0 answers, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Rcode != dns.RcodeNameError { | ||||
| 		t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWhenUnresolvingCNAME(t *testing.T) { | ||||
| 	t.Log("TestWhenUnresolvingCNAME") | ||||
| 
 | ||||
| 	testDb, server, _ := setup(testDBPath) | ||||
| 	defer destroy(testDb, testDBPath) | ||||
| 	defer server.Shutdown() | ||||
| 
 | ||||
| 	cname := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		UserID:   "test", | ||||
| 		Name:     "cname.internal.example.com.", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "nonexistant.example.com.", | ||||
| 		TTL:      300, | ||||
| 		Internal: true, | ||||
| 	} | ||||
| 	database.SaveDNSRecord(testDb, cname) | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn(cname.Name) | ||||
| 	client := new(dns.Client) | ||||
| 	message := new(dns.Msg) | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, address) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) != 1 { | ||||
| 		t.Fatalf("expected 1 answer, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 
 | ||||
| 	if !in.Authoritative { | ||||
| 		t.Fatalf("expected authoritative response") | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].Header().Name != cname.Name { | ||||
| 		t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].Header().Rrtype != dns.TypeCNAME { | ||||
| 		t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].(*dns.CNAME).Target != cname.Content { | ||||
| 		t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Rcode == dns.RcodeNameError { | ||||
| 		t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { | ||||
| 	t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") | ||||
| 
 | ||||
| 	testDb, server, _ := setup(testDBPath) | ||||
| 	defer destroy(testDb, testDBPath) | ||||
| 	defer server.Shutdown() | ||||
| 
 | ||||
| 	cname := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		UserID:   "test", | ||||
| 		Name:     "cname.internal.example.com.", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "cname.internal.example.com.", | ||||
| 		TTL:      300, | ||||
| 		Internal: true, | ||||
| 	} | ||||
| 	database.SaveDNSRecord(testDb, cname) | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn(cname.Name) | ||||
| 	client := new(dns.Client) | ||||
| 	message := new(dns.Msg) | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, address) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) > 0 { | ||||
| 		t.Fatalf("expected 0 answers, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 	if in.Rcode != dns.RcodeServerFailure { | ||||
| 		t.Fatalf("expected SERVFAIL, got %d", in.Rcode) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
		Reference in New Issue