add integration tests for dns server
This commit is contained in:
		
							parent
							
								
									657be66948
								
							
						
					
					
						commit
						bcdcc508ef
					
				|  | @ -1,4 +1,4 @@ | ||||||
| package dns | package hcdns | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
|  | @ -9,7 +9,7 @@ import ( | ||||||
| 	"log" | 	"log" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const MAX_RECURSION = 10 | const MAX_RECURSION = 15 | ||||||
| 
 | 
 | ||||||
| func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { | func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { | ||||||
| 	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") | 	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") | ||||||
|  | @ -21,12 +21,14 @@ func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth | ||||||
| 	for _, record := range internalCnames { | 	for _, record := range internalCnames { | ||||||
| 		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) | 		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			log.Println(err) | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		answers = append(answers, cname) | 		answers = append(answers, cname) | ||||||
| 
 | 
 | ||||||
| 		cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) | 		cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			log.Println(err) | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		answers = append(answers, cnameRecursive...) | 		answers = append(answers, cnameRecursive...) | ||||||
|  | @ -62,13 +64,9 @@ func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dn | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(answers) > 0 { |  | ||||||
| 	return answers, nil | 	return answers, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 	return nil, fmt.Errorf("no records found for %s", domain) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type DnsHandler struct { | type DnsHandler struct { | ||||||
| 	DnsResolvers []string | 	DnsResolvers []string | ||||||
| 	DbConn       *sql.DB | 	DbConn       *sql.DB | ||||||
|  | @ -83,7 +81,9 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | ||||||
| 		answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) | 		answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			fmt.Println(err) | 			fmt.Println(err) | ||||||
| 			continue | 			msg.SetRcode(r, dns.RcodeServerFailure) | ||||||
|  | 			w.WriteMsg(msg) | ||||||
|  | 			return | ||||||
| 		} | 		} | ||||||
| 		msg.Answer = append(msg.Answer, answers...) | 		msg.Answer = append(msg.Answer, answers...) | ||||||
| 	} | 	} | ||||||
|  | @ -98,7 +98,6 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | ||||||
| 
 | 
 | ||||||
| func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { | func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { | ||||||
| 	handler := &DnsHandler{ | 	handler := &DnsHandler{ | ||||||
| 		DnsResolvers: argv.DnsRecursion, |  | ||||||
| 		DbConn: dbConn, | 		DbConn: dbConn, | ||||||
| 	} | 	} | ||||||
| 	addr := fmt.Sprintf(":%d", argv.DnsPort) | 	addr := fmt.Sprintf(":%d", argv.DnsPort) | ||||||
							
								
								
									
										4
									
								
								main.go
								
								
								
								
							
							
						
						
									
										4
									
								
								main.go
								
								
								
								
							|  | @ -6,7 +6,7 @@ import ( | ||||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api" | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api" | ||||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" | ||||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" | ||||||
| 	"github.com/joho/godotenv" | 	"github.com/joho/godotenv" | ||||||
| ) | ) | ||||||
|  | @ -52,7 +52,7 @@ func main() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if argv.Dns { | 	if argv.Dns { | ||||||
| 		server := dns.MakeServer(argv, dbConn) | 		server := hcdns.MakeServer(argv, dbConn) | ||||||
| 		log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) | 		log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) | ||||||
| 		go func() { | 		go func() { | ||||||
| 			err = server.ListenAndServe() | 			err = server.ListenAndServe() | ||||||
|  |  | ||||||
|  | @ -0,0 +1,244 @@ | ||||||
|  | package hcdns | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"os" | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||||
|  | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||||
|  | 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	testDBPath = "test.db" | ||||||
|  | 	address    = "127.0.0.1:8353" | ||||||
|  | 	dnsPort    = 8353 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { | ||||||
|  | 	testDb := database.MakeConn(&dbPath) | ||||||
|  | 	database.Migrate(testDb) | ||||||
|  | 	testUser := &database.User{ | ||||||
|  | 		ID: "test", | ||||||
|  | 	} | ||||||
|  | 	database.FindOrSaveUser(testDb, testUser) | ||||||
|  | 
 | ||||||
|  | 	server := hcdns.MakeServer(&args.Arguments{ | ||||||
|  | 		DnsPort: dnsPort, | ||||||
|  | 	}, testDb) | ||||||
|  | 
 | ||||||
|  | 	waitGroup := sync.WaitGroup{} | ||||||
|  | 	waitGroup.Add(1) | ||||||
|  | 	go func() { | ||||||
|  | 		server.ListenAndServe() | ||||||
|  | 		waitGroup.Done() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	return testDb, server, &waitGroup | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func destroy(conn *sql.DB, path string) { | ||||||
|  | 	conn.Close() | ||||||
|  | 	os.Remove(path) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestWhenCNAMEIsResolved(t *testing.T) { | ||||||
|  | 	t.Log("TestWhenCNAMEIsResolved") | ||||||
|  | 
 | ||||||
|  | 	testDb, server, _ := setup(testDBPath) | ||||||
|  | 	defer destroy(testDb, testDBPath) | ||||||
|  | 	defer server.Shutdown() | ||||||
|  | 
 | ||||||
|  | 	cname := &database.DNSRecord{ | ||||||
|  | 		ID:       "1", | ||||||
|  | 		UserID:   "test", | ||||||
|  | 		Name:     "cname.internal.example.com.", | ||||||
|  | 		Type:     "CNAME", | ||||||
|  | 		Content:  "res.example.com.", | ||||||
|  | 		TTL:      300, | ||||||
|  | 		Internal: true, | ||||||
|  | 	} | ||||||
|  | 	a := &database.DNSRecord{ | ||||||
|  | 		ID:       "2", | ||||||
|  | 		UserID:   "test", | ||||||
|  | 		Name:     "res.example.com.", | ||||||
|  | 		Type:     "A", | ||||||
|  | 		Content:  "127.0.0.1", | ||||||
|  | 		TTL:      300, | ||||||
|  | 		Internal: true, | ||||||
|  | 	} | ||||||
|  | 	database.SaveDNSRecord(testDb, cname) | ||||||
|  | 	database.SaveDNSRecord(testDb, a) | ||||||
|  | 
 | ||||||
|  | 	qtype := dns.TypeA | ||||||
|  | 	domain := dns.Fqdn(cname.Name) | ||||||
|  | 	client := new(dns.Client) | ||||||
|  | 	message := new(dns.Msg) | ||||||
|  | 	message.SetQuestion(domain, qtype) | ||||||
|  | 
 | ||||||
|  | 	in, _, err := client.Exchange(message, address) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(in.Answer) != 2 { | ||||||
|  | 		t.Fatalf("expected 2 answers, got %d", len(in.Answer)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[0].Header().Name != cname.Name { | ||||||
|  | 		t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[1].Header().Name != a.Name { | ||||||
|  | 		t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[0].(*dns.CNAME).Target != a.Name { | ||||||
|  | 		t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[1].(*dns.A).A.String() != a.Content { | ||||||
|  | 		t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[0].Header().Rrtype != dns.TypeCNAME { | ||||||
|  | 		t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[1].Header().Rrtype != dns.TypeA { | ||||||
|  | 		t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if int(in.Answer[0].Header().Ttl) != cname.TTL { | ||||||
|  | 		t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !in.Authoritative { | ||||||
|  | 		t.Fatalf("expected authoritative response") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestWhenNoRecordNxDomain(t *testing.T) { | ||||||
|  | 	t.Log("TestWhenNoRecordNxDomain") | ||||||
|  | 
 | ||||||
|  | 	testDb, server, _ := setup(testDBPath) | ||||||
|  | 	defer destroy(testDb, testDBPath) | ||||||
|  | 	defer server.Shutdown() | ||||||
|  | 
 | ||||||
|  | 	qtype := dns.TypeA | ||||||
|  | 	domain := dns.Fqdn("nonexistant.example.com.") | ||||||
|  | 	client := new(dns.Client) | ||||||
|  | 	message := new(dns.Msg) | ||||||
|  | 	message.SetQuestion(domain, qtype) | ||||||
|  | 
 | ||||||
|  | 	in, _, err := client.Exchange(message, address) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(in.Answer) != 0 { | ||||||
|  | 		t.Fatalf("expected 0 answers, got %d", len(in.Answer)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Rcode != dns.RcodeNameError { | ||||||
|  | 		t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestWhenUnresolvingCNAME(t *testing.T) { | ||||||
|  | 	t.Log("TestWhenUnresolvingCNAME") | ||||||
|  | 
 | ||||||
|  | 	testDb, server, _ := setup(testDBPath) | ||||||
|  | 	defer destroy(testDb, testDBPath) | ||||||
|  | 	defer server.Shutdown() | ||||||
|  | 
 | ||||||
|  | 	cname := &database.DNSRecord{ | ||||||
|  | 		ID:       "1", | ||||||
|  | 		UserID:   "test", | ||||||
|  | 		Name:     "cname.internal.example.com.", | ||||||
|  | 		Type:     "CNAME", | ||||||
|  | 		Content:  "nonexistant.example.com.", | ||||||
|  | 		TTL:      300, | ||||||
|  | 		Internal: true, | ||||||
|  | 	} | ||||||
|  | 	database.SaveDNSRecord(testDb, cname) | ||||||
|  | 
 | ||||||
|  | 	qtype := dns.TypeA | ||||||
|  | 	domain := dns.Fqdn(cname.Name) | ||||||
|  | 	client := new(dns.Client) | ||||||
|  | 	message := new(dns.Msg) | ||||||
|  | 	message.SetQuestion(domain, qtype) | ||||||
|  | 
 | ||||||
|  | 	in, _, err := client.Exchange(message, address) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(in.Answer) != 1 { | ||||||
|  | 		t.Fatalf("expected 1 answer, got %d", len(in.Answer)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !in.Authoritative { | ||||||
|  | 		t.Fatalf("expected authoritative response") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[0].Header().Name != cname.Name { | ||||||
|  | 		t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[0].Header().Rrtype != dns.TypeCNAME { | ||||||
|  | 		t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Answer[0].(*dns.CNAME).Target != cname.Content { | ||||||
|  | 		t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if in.Rcode == dns.RcodeNameError { | ||||||
|  | 		t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { | ||||||
|  | 	t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") | ||||||
|  | 
 | ||||||
|  | 	testDb, server, _ := setup(testDBPath) | ||||||
|  | 	defer destroy(testDb, testDBPath) | ||||||
|  | 	defer server.Shutdown() | ||||||
|  | 
 | ||||||
|  | 	cname := &database.DNSRecord{ | ||||||
|  | 		ID:       "1", | ||||||
|  | 		UserID:   "test", | ||||||
|  | 		Name:     "cname.internal.example.com.", | ||||||
|  | 		Type:     "CNAME", | ||||||
|  | 		Content:  "cname.internal.example.com.", | ||||||
|  | 		TTL:      300, | ||||||
|  | 		Internal: true, | ||||||
|  | 	} | ||||||
|  | 	database.SaveDNSRecord(testDb, cname) | ||||||
|  | 
 | ||||||
|  | 	qtype := dns.TypeA | ||||||
|  | 	domain := dns.Fqdn(cname.Name) | ||||||
|  | 	client := new(dns.Client) | ||||||
|  | 	message := new(dns.Msg) | ||||||
|  | 	message.SetQuestion(domain, qtype) | ||||||
|  | 
 | ||||||
|  | 	in, _, err := client.Exchange(message, address) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(in.Answer) > 0 { | ||||||
|  | 		t.Fatalf("expected 0 answers, got %d", len(in.Answer)) | ||||||
|  | 	} | ||||||
|  | 	if in.Rcode != dns.RcodeServerFailure { | ||||||
|  | 		t.Fatalf("expected SERVFAIL, got %d", in.Rcode) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue