Merge pull request 'reimplement-recursive-resolver' (#6) from reimplement-recursive-resolver into main
	
		
			
	
		
	
	
		
			
				
	
				continuous-integration/drone/push Build is passing
				
					Details
				
			
		
	
				
					
				
			
				
	
				continuous-integration/drone/push Build is passing
				
					Details
				
			
		
	Reviewed-on: #6
This commit is contained in:
		
						commit
						86c4ad160a
					
				|  | @ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers | ||||||
| 
 | 
 | ||||||
| EXPOSE 8080 | EXPOSE 8080 | ||||||
| 
 | 
 | ||||||
| CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"] | CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-resolvers", "1.1.1.1,1.0.0.1"] | ||||||
|  |  | ||||||
							
								
								
									
										12
									
								
								args/args.go
								
								
								
								
							
							
						
						
									
										12
									
								
								args/args.go
								
								
								
								
							|  | @ -22,8 +22,9 @@ type Arguments struct { | ||||||
| 	OauthConfig      *oauth2.Config | 	OauthConfig      *oauth2.Config | ||||||
| 	OauthUserInfoURI string | 	OauthUserInfoURI string | ||||||
| 
 | 
 | ||||||
| 	Dns     bool | 	DnsResolvers []string | ||||||
| 	DnsPort int | 	Dns          bool | ||||||
|  | 	DnsPort      int | ||||||
| 
 | 
 | ||||||
| 	CloudflareToken string | 	CloudflareToken string | ||||||
| 	CloudflareZone  string | 	CloudflareZone  string | ||||||
|  | @ -36,6 +37,7 @@ func GetArgs() (*Arguments, error) { | ||||||
| 	databasePath := flag.String("database-path", "./hatecomputers.db", "Path to the SQLite database") | 	databasePath := flag.String("database-path", "./hatecomputers.db", "Path to the SQLite database") | ||||||
| 	templatePath := flag.String("template-path", "./templates", "Path to the template directory") | 	templatePath := flag.String("template-path", "./templates", "Path to the template directory") | ||||||
| 	staticPath := flag.String("static-path", "./static", "Path to the static directory") | 	staticPath := flag.String("static-path", "./static", "Path to the static directory") | ||||||
|  | 	dnsResolvers := flag.String("dns-resolvers", "1.1.1.1,1.0.0.1", "Comma-separated list of DNS resolvers") | ||||||
| 
 | 
 | ||||||
| 	scheduler := flag.Bool("scheduler", false, "Run scheduled jobs via cron") | 	scheduler := flag.Bool("scheduler", false, "Run scheduled jobs via cron") | ||||||
| 	migrate := flag.Bool("migrate", false, "Run the migrations") | 	migrate := flag.Bool("migrate", false, "Run the migrations") | ||||||
|  | @ -101,8 +103,10 @@ func GetArgs() (*Arguments, error) { | ||||||
| 		Server:          *server, | 		Server:          *server, | ||||||
| 		Migrate:         *migrate, | 		Migrate:         *migrate, | ||||||
| 		Scheduler:       *scheduler, | 		Scheduler:       *scheduler, | ||||||
| 		Dns:             *dns, | 
 | ||||||
| 		DnsPort:         *dnsPort, | 		Dns:          *dns, | ||||||
|  | 		DnsPort:      *dnsPort, | ||||||
|  | 		DnsResolvers: strings.Split(*dnsResolvers, ","), | ||||||
| 
 | 
 | ||||||
| 		OauthConfig:      oauthConfig, | 		OauthConfig:      oauthConfig, | ||||||
| 		OauthUserInfoURI: oauthUserInfoURI, | 		OauthUserInfoURI: oauthUserInfoURI, | ||||||
|  |  | ||||||
							
								
								
									
										189
									
								
								hcdns/server.go
								
								
								
								
							
							
						
						
									
										189
									
								
								hcdns/server.go
								
								
								
								
							|  | @ -11,74 +11,142 @@ import ( | ||||||
| 
 | 
 | ||||||
| const MAX_RECURSION = 15 | const MAX_RECURSION = 15 | ||||||
| 
 | 
 | ||||||
| func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { |  | ||||||
| 	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var answers []dns.RR |  | ||||||
| 	for _, record := range internalCnames { |  | ||||||
| 		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Println(err) |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		answers = append(answers, cname) |  | ||||||
| 
 |  | ||||||
| 		cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Println(err) |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		answers = append(answers, cnameRecursive...) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	qtypeName := dns.TypeToString[qtype] |  | ||||||
| 	if qtypeName == "" { |  | ||||||
| 		return nil, fmt.Errorf("invalid query type %d", qtype) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	for _, record := range typeDnsRecords { |  | ||||||
| 		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		answers = append(answers, answer) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return answers, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { |  | ||||||
| 	if maxDepth == 0 { |  | ||||||
| 		return nil, fmt.Errorf("too much recursion") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return answers, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type DnsHandler struct { | type DnsHandler struct { | ||||||
| 	DnsResolvers []string | 	DnsResolvers []string | ||||||
| 	DbConn       *sql.DB | 	DbConn       *sql.DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) { | ||||||
|  | 	client := &dns.Client{} | ||||||
|  | 	message := &dns.Msg{} | ||||||
|  | 	message.SetQuestion(dns.Fqdn(domain), qtype) | ||||||
|  | 	message.RecursionDesired = true | ||||||
|  | 
 | ||||||
|  | 	if len(h.DnsResolvers) == 0 { | ||||||
|  | 		return []dns.RR{}, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	i := 0 | ||||||
|  | 	in, _, err := client.Exchange(message, h.DnsResolvers[i]) | ||||||
|  | 	for err != nil && i < len(h.DnsResolvers) { | ||||||
|  | 		i++ | ||||||
|  | 		in, _, err = client.Exchange(message, h.DnsResolvers[i]) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(in.Answer) == 0 { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return in.Answer, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool { | ||||||
|  | 	for _, answer := range answers { | ||||||
|  | 		if answer.Header().Name == domain && answer.Header().Rrtype == qtype { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { | ||||||
|  | 	internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, true, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	authoritative := true | ||||||
|  | 	var answers []dns.RR | ||||||
|  | 	for _, record := range internalCnames { | ||||||
|  | 		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Println(err) | ||||||
|  | 			return nil, authoritative, err | ||||||
|  | 		} | ||||||
|  | 		answers = append(answers, cname) | ||||||
|  | 
 | ||||||
|  | 		cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1) | ||||||
|  | 		authoritative = authoritative && cnameAuth | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Println(err) | ||||||
|  | 			return nil, authoritative, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		answers = append(answers, cnameRecursive...) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	qtypeName := dns.TypeToString[qtype] | ||||||
|  | 	records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, authoritative, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, record := range records { | ||||||
|  | 		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, authoritative, err | ||||||
|  | 		} | ||||||
|  | 		answers = append(answers, answer) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return answers, authoritative, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { | ||||||
|  | 	log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth) | ||||||
|  | 	if maxDepth == 0 { | ||||||
|  | 		return nil, false, fmt.Errorf("too much recursion") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(answers) > 0 { // base case - we got the answer
 | ||||||
|  | 		return answers, authoritative, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	externalAnswers, err := h.resolveExternal(domain, qtype) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	answers = append(answers, externalAnswers...) | ||||||
|  | 	if resultSetFound(externalAnswers, domain, qtype) { | ||||||
|  | 		return answers, false, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, answer := range externalAnswers { | ||||||
|  | 		cname, ok := answer.(*dns.CNAME) | ||||||
|  | 		if !ok { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1) | ||||||
|  | 		authoritative = authoritative && cnameAuth | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, false, err | ||||||
|  | 		} | ||||||
|  | 		answers = append(answers, cnameAnswers...) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	authoritative = authoritative && len(externalAnswers) == 0 | ||||||
|  | 	return answers, authoritative, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | ||||||
| 	msg := new(dns.Msg) | 	msg := &dns.Msg{} | ||||||
| 	msg.SetReply(r) | 	msg.SetReply(r) | ||||||
| 	msg.Authoritative = true | 	msg.Authoritative = false | ||||||
| 
 | 
 | ||||||
| 	for _, question := range r.Question { | 	for _, question := range r.Question { | ||||||
| 		answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) | 		answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION) | ||||||
|  | 		msg.Authoritative = authoritative | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			fmt.Println(err) | 			fmt.Println(err) | ||||||
| 			msg.SetRcode(r, dns.RcodeServerFailure) | 			msg.SetRcode(r, dns.RcodeServerFailure) | ||||||
|  | @ -98,7 +166,8 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | ||||||
| 
 | 
 | ||||||
| func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { | func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { | ||||||
| 	handler := &DnsHandler{ | 	handler := &DnsHandler{ | ||||||
| 		DbConn: dbConn, | 		DbConn:       dbConn, | ||||||
|  | 		DnsResolvers: argv.DnsResolvers, | ||||||
| 	} | 	} | ||||||
| 	addr := fmt.Sprintf(":%d", argv.DnsPort) | 	addr := fmt.Sprintf(":%d", argv.DnsPort) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -19,9 +19,8 @@ func randomPort() int { | ||||||
| 	return rand.Intn(3000) + 5192 | 	return rand.Intn(3000) + 5192 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { | func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) { | ||||||
| 	randomDb := utils.RandomId() | 	randomDb := utils.RandomId() | ||||||
| 	dnsPort := randomPort() |  | ||||||
| 
 | 
 | ||||||
| 	testDb := database.MakeConn(&randomDb) | 	testDb := database.MakeConn(&randomDb) | ||||||
| 	database.Migrate(testDb) | 	database.Migrate(testDb) | ||||||
|  | @ -30,10 +29,15 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { | ||||||
| 	} | 	} | ||||||
| 	database.FindOrSaveUser(testDb, testUser) | 	database.FindOrSaveUser(testDb, testUser) | ||||||
| 
 | 
 | ||||||
|  | 	dnsArguments := arguments | ||||||
|  | 	if dnsArguments == nil { | ||||||
|  | 		dnsArguments = &args.Arguments{ | ||||||
|  | 			DnsPort: randomPort(), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	waitLock := &sync.Mutex{} | 	waitLock := &sync.Mutex{} | ||||||
| 	server := hcdns.MakeServer(&args.Arguments{ | 	server := hcdns.MakeServer(dnsArguments, testDb) | ||||||
| 		DnsPort: dnsPort, |  | ||||||
| 	}, testDb) |  | ||||||
| 	server.NotifyStartedFunc = func() { | 	server.NotifyStartedFunc = func() { | ||||||
| 		waitLock.Unlock() | 		waitLock.Unlock() | ||||||
| 	} | 	} | ||||||
|  | @ -44,8 +48,9 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { | ||||||
| 	}() | 	}() | ||||||
| 	waitLock.Lock() | 	waitLock.Lock() | ||||||
| 
 | 
 | ||||||
| 	address := fmt.Sprintf("127.0.0.1:%d", dnsPort) | 	address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort) | ||||||
| 	return testDb, server, &address, waitLock, func() { | 	return testDb, server, address, func() { | ||||||
|  | 		waitLock.Unlock() | ||||||
| 		server.Shutdown() | 		server.Shutdown() | ||||||
| 
 | 
 | ||||||
| 		testDb.Close() | 		testDb.Close() | ||||||
|  | @ -54,11 +59,8 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestWhenCNAMEIsResolved(t *testing.T) { | func TestWhenCNAMEIsResolved(t *testing.T) { | ||||||
| 	t.Log("TestWhenCNAMEIsResolved") | 	testDb, _, addr, cleanup := setup(nil) | ||||||
| 
 |  | ||||||
| 	testDb, _, addr, lock, cleanup := setup() |  | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	defer lock.Unlock() |  | ||||||
| 
 | 
 | ||||||
| 	records := []*database.DNSRecord{ | 	records := []*database.DNSRecord{ | ||||||
| 		{ | 		{ | ||||||
|  | @ -99,7 +101,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) { | ||||||
| 	message := &dns.Msg{} | 	message := &dns.Msg{} | ||||||
| 	message.SetQuestion(domain, qtype) | 	message.SetQuestion(domain, qtype) | ||||||
| 
 | 
 | ||||||
| 	in, _, err := client.Exchange(message, *addr) | 	in, _, err := client.Exchange(message, addr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  | @ -132,11 +134,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestWhenNoRecordNxDomain(t *testing.T) { | func TestWhenNoRecordNxDomain(t *testing.T) { | ||||||
| 	t.Log("TestWhenNoRecordNxDomain") | 	_, _, addr, cleanup := setup(nil) | ||||||
| 
 |  | ||||||
| 	_, _, addr, lock, cleanup := setup() |  | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	defer lock.Unlock() |  | ||||||
| 
 | 
 | ||||||
| 	qtype := dns.TypeA | 	qtype := dns.TypeA | ||||||
| 	domain := dns.Fqdn("nonexistant.example.com.") | 	domain := dns.Fqdn("nonexistant.example.com.") | ||||||
|  | @ -144,7 +143,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) { | ||||||
| 	message := &dns.Msg{} | 	message := &dns.Msg{} | ||||||
| 	message.SetQuestion(domain, qtype) | 	message.SetQuestion(domain, qtype) | ||||||
| 
 | 
 | ||||||
| 	in, _, err := client.Exchange(message, *addr) | 	in, _, err := client.Exchange(message, addr) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|  | @ -160,11 +159,8 @@ func TestWhenNoRecordNxDomain(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestWhenUnresolvingCNAME(t *testing.T) { | func TestWhenUnresolvingCNAME(t *testing.T) { | ||||||
| 	t.Log("TestWhenUnresolvingCNAME") | 	testDb, _, addr, cleanup := setup(nil) | ||||||
| 
 |  | ||||||
| 	testDb, _, addr, lock, cleanup := setup() |  | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	defer lock.Unlock() |  | ||||||
| 
 | 
 | ||||||
| 	cname := &database.DNSRecord{ | 	cname := &database.DNSRecord{ | ||||||
| 		ID:       "1", | 		ID:       "1", | ||||||
|  | @ -183,7 +179,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) { | ||||||
| 	message := &dns.Msg{} | 	message := &dns.Msg{} | ||||||
| 	message.SetQuestion(domain, qtype) | 	message.SetQuestion(domain, qtype) | ||||||
| 
 | 
 | ||||||
| 	in, _, err := client.Exchange(message, *addr) | 	in, _, err := client.Exchange(message, addr) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|  | @ -215,11 +211,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { | func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { | ||||||
| 	t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") | 	testDb, _, addr, cleanup := setup(nil) | ||||||
| 
 |  | ||||||
| 	testDb, _, addr, lock, cleanup := setup() |  | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	defer lock.Unlock() |  | ||||||
| 
 | 
 | ||||||
| 	cname := &database.DNSRecord{ | 	cname := &database.DNSRecord{ | ||||||
| 		ID:       "1", | 		ID:       "1", | ||||||
|  | @ -238,7 +231,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { | ||||||
| 	message := &dns.Msg{} | 	message := &dns.Msg{} | ||||||
| 	message.SetQuestion(domain, qtype) | 	message.SetQuestion(domain, qtype) | ||||||
| 
 | 
 | ||||||
| 	in, _, err := client.Exchange(message, *addr) | 	in, _, err := client.Exchange(message, addr) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|  | @ -252,3 +245,86 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { | ||||||
| 		t.Fatalf("expected SERVFAIL, got %d", in.Rcode) | 		t.Fatalf("expected SERVFAIL, got %d", in.Rcode) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestWhenExternalDomain(t *testing.T) { | ||||||
|  | 	externalDb, _, externalAddr, externalCleanup := setup(nil) | ||||||
|  | 	internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{ | ||||||
|  | 		DnsPort:      randomPort(), | ||||||
|  | 		DnsResolvers: []string{externalAddr}, | ||||||
|  | 	}) | ||||||
|  | 	defer internalCleanup() | ||||||
|  | 	defer externalCleanup() | ||||||
|  | 
 | ||||||
|  | 	authoritativeRecords := []database.DNSRecord{ | ||||||
|  | 		{ | ||||||
|  | 			ID:      "1", | ||||||
|  | 			UserID:  "test", | ||||||
|  | 			Name:    "external.example.com.", | ||||||
|  | 			Type:    "CNAME", | ||||||
|  | 			Content: "external.internal.example.com.", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ID:      "2", | ||||||
|  | 			UserID:  "test", | ||||||
|  | 			Name:    "final.example.com.", | ||||||
|  | 			Type:    "A", | ||||||
|  | 			Content: "127.0.0.1", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	internalRecords := []database.DNSRecord{ | ||||||
|  | 		{ | ||||||
|  | 			ID:      "1", | ||||||
|  | 			UserID:  "test", | ||||||
|  | 			Name:    "external.internal.example.com.", | ||||||
|  | 			Type:    "CNAME", | ||||||
|  | 			Content: "final.example.com", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ID:      "2", | ||||||
|  | 			UserID:  "test", | ||||||
|  | 			Name:    "test.internal.example.com.", | ||||||
|  | 			Type:    "CNAME", | ||||||
|  | 			Content: "external.example.com.", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, record := range authoritativeRecords { | ||||||
|  | 		database.SaveDNSRecord(externalDb, &record) | ||||||
|  | 	} | ||||||
|  | 	for _, record := range internalRecords { | ||||||
|  | 		database.SaveDNSRecord(internalDb, &record) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// ensure that if the record doesn't exist in the internal database, it will
 | ||||||
|  | 	// go and query the external dns resolvers, then loop back to the internal
 | ||||||
|  | 
 | ||||||
|  | 	qtype := dns.TypeA | ||||||
|  | 	domain := dns.Fqdn("test.internal.example.com.") | ||||||
|  | 	client := &dns.Client{} | ||||||
|  | 	message := &dns.Msg{} | ||||||
|  | 	message.SetQuestion(domain, qtype) | ||||||
|  | 
 | ||||||
|  | 	in, _, err := client.Exchange(message, internalAddr) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(in.Answer) != 4 { | ||||||
|  | 		t.Fatalf("expected 4 answers, got %d", len(in.Answer)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	aRecord := in.Answer[3] | ||||||
|  | 	if aRecord.Header().Name != authoritativeRecords[1].Name { | ||||||
|  | 		t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name) | ||||||
|  | 	} | ||||||
|  | 	if aRecord.Header().Rrtype != dns.TypeA { | ||||||
|  | 		t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type) | ||||||
|  | 	} | ||||||
|  | 	if aRecord.(*dns.A).A.String() != authoritativeRecords[1].Content { | ||||||
|  | 		t.Fatalf("expected %s, got %s", authoritativeRecords[1].Content, aRecord.(*dns.A).A.String()) | ||||||
|  | 	} | ||||||
|  | 	if in.Authoritative { | ||||||
|  | 		t.Fatalf("expected non-authoritative response") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue