reimplement-recursive-resolver #6
			
				
			
		
		
		
	| 
						 | 
				
			
			@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
 | 
			
		|||
 | 
			
		||||
EXPOSE 8080
 | 
			
		||||
 | 
			
		||||
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]
 | 
			
		||||
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-resolvers", "1.1.1.1,1.0.0.1"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										12
									
								
								args/args.go
								
								
								
								
							
							
						
						
									
										12
									
								
								args/args.go
								
								
								
								
							| 
						 | 
				
			
			@ -22,8 +22,9 @@ type Arguments struct {
 | 
			
		|||
	OauthConfig      *oauth2.Config
 | 
			
		||||
	OauthUserInfoURI string
 | 
			
		||||
 | 
			
		||||
	Dns     bool
 | 
			
		||||
	DnsPort int
 | 
			
		||||
	DnsResolvers []string
 | 
			
		||||
	Dns          bool
 | 
			
		||||
	DnsPort      int
 | 
			
		||||
 | 
			
		||||
	CloudflareToken string
 | 
			
		||||
	CloudflareZone  string
 | 
			
		||||
| 
						 | 
				
			
			@ -36,6 +37,7 @@ func GetArgs() (*Arguments, error) {
 | 
			
		|||
	databasePath := flag.String("database-path", "./hatecomputers.db", "Path to the SQLite database")
 | 
			
		||||
	templatePath := flag.String("template-path", "./templates", "Path to the template directory")
 | 
			
		||||
	staticPath := flag.String("static-path", "./static", "Path to the static directory")
 | 
			
		||||
	dnsResolvers := flag.String("dns-resolvers", "1.1.1.1,1.0.0.1", "Comma-separated list of DNS resolvers")
 | 
			
		||||
 | 
			
		||||
	scheduler := flag.Bool("scheduler", false, "Run scheduled jobs via cron")
 | 
			
		||||
	migrate := flag.Bool("migrate", false, "Run the migrations")
 | 
			
		||||
| 
						 | 
				
			
			@ -101,8 +103,10 @@ func GetArgs() (*Arguments, error) {
 | 
			
		|||
		Server:          *server,
 | 
			
		||||
		Migrate:         *migrate,
 | 
			
		||||
		Scheduler:       *scheduler,
 | 
			
		||||
		Dns:             *dns,
 | 
			
		||||
		DnsPort:         *dnsPort,
 | 
			
		||||
 | 
			
		||||
		Dns:          *dns,
 | 
			
		||||
		DnsPort:      *dnsPort,
 | 
			
		||||
		DnsResolvers: strings.Split(*dnsResolvers, ","),
 | 
			
		||||
 | 
			
		||||
		OauthConfig:      oauthConfig,
 | 
			
		||||
		OauthUserInfoURI: oauthUserInfoURI,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										189
									
								
								hcdns/server.go
								
								
								
								
							
							
						
						
									
										189
									
								
								hcdns/server.go
								
								
								
								
							| 
						 | 
				
			
			@ -11,74 +11,142 @@ import (
 | 
			
		|||
 | 
			
		||||
const MAX_RECURSION = 15
 | 
			
		||||
 | 
			
		||||
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
 | 
			
		||||
	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var answers []dns.RR
 | 
			
		||||
	for _, record := range internalCnames {
 | 
			
		||||
		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		answers = append(answers, cname)
 | 
			
		||||
 | 
			
		||||
		cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		answers = append(answers, cnameRecursive...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	qtypeName := dns.TypeToString[qtype]
 | 
			
		||||
	if qtypeName == "" {
 | 
			
		||||
		return nil, fmt.Errorf("invalid query type %d", qtype)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	for _, record := range typeDnsRecords {
 | 
			
		||||
		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		answers = append(answers, answer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return answers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
 | 
			
		||||
	if maxDepth == 0 {
 | 
			
		||||
		return nil, fmt.Errorf("too much recursion")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return answers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DnsHandler struct {
 | 
			
		||||
	DnsResolvers []string
 | 
			
		||||
	DbConn       *sql.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) {
 | 
			
		||||
	client := &dns.Client{}
 | 
			
		||||
	message := &dns.Msg{}
 | 
			
		||||
	message.SetQuestion(dns.Fqdn(domain), qtype)
 | 
			
		||||
	message.RecursionDesired = true
 | 
			
		||||
 | 
			
		||||
	if len(h.DnsResolvers) == 0 {
 | 
			
		||||
		return []dns.RR{}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	i := 0
 | 
			
		||||
	in, _, err := client.Exchange(message, h.DnsResolvers[i])
 | 
			
		||||
	for err != nil && i < len(h.DnsResolvers) {
 | 
			
		||||
		i++
 | 
			
		||||
		in, _, err = client.Exchange(message, h.DnsResolvers[i])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(in.Answer) == 0 {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return in.Answer, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool {
 | 
			
		||||
	for _, answer := range answers {
 | 
			
		||||
		if answer.Header().Name == domain && answer.Header().Rrtype == qtype {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
 | 
			
		||||
	internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, true, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	authoritative := true
 | 
			
		||||
	var answers []dns.RR
 | 
			
		||||
	for _, record := range internalCnames {
 | 
			
		||||
		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			return nil, authoritative, err
 | 
			
		||||
		}
 | 
			
		||||
		answers = append(answers, cname)
 | 
			
		||||
 | 
			
		||||
		cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1)
 | 
			
		||||
		authoritative = authoritative && cnameAuth
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			return nil, authoritative, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		answers = append(answers, cnameRecursive...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	qtypeName := dns.TypeToString[qtype]
 | 
			
		||||
	records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, authoritative, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, record := range records {
 | 
			
		||||
		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, authoritative, err
 | 
			
		||||
		}
 | 
			
		||||
		answers = append(answers, answer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return answers, authoritative, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
 | 
			
		||||
	log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth)
 | 
			
		||||
	if maxDepth == 0 {
 | 
			
		||||
		return nil, false, fmt.Errorf("too much recursion")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, false, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(answers) > 0 { // base case - we got the answer
 | 
			
		||||
		return answers, authoritative, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	externalAnswers, err := h.resolveExternal(domain, qtype)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, false, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	answers = append(answers, externalAnswers...)
 | 
			
		||||
	if resultSetFound(externalAnswers, domain, qtype) {
 | 
			
		||||
		return answers, false, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, answer := range externalAnswers {
 | 
			
		||||
		cname, ok := answer.(*dns.CNAME)
 | 
			
		||||
		if !ok {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
 | 
			
		||||
		authoritative = authoritative && cnameAuth
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, false, err
 | 
			
		||||
		}
 | 
			
		||||
		answers = append(answers, cnameAnswers...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	authoritative = authoritative && len(externalAnswers) == 0
 | 
			
		||||
	return answers, authoritative, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 | 
			
		||||
	msg := new(dns.Msg)
 | 
			
		||||
	msg := &dns.Msg{}
 | 
			
		||||
	msg.SetReply(r)
 | 
			
		||||
	msg.Authoritative = true
 | 
			
		||||
	msg.Authoritative = false
 | 
			
		||||
 | 
			
		||||
	for _, question := range r.Question {
 | 
			
		||||
		answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
 | 
			
		||||
		answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION)
 | 
			
		||||
		msg.Authoritative = authoritative
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println(err)
 | 
			
		||||
			msg.SetRcode(r, dns.RcodeServerFailure)
 | 
			
		||||
| 
						 | 
				
			
			@ -98,7 +166,8 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 | 
			
		|||
 | 
			
		||||
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
 | 
			
		||||
	handler := &DnsHandler{
 | 
			
		||||
		DbConn: dbConn,
 | 
			
		||||
		DbConn:       dbConn,
 | 
			
		||||
		DnsResolvers: argv.DnsResolvers,
 | 
			
		||||
	}
 | 
			
		||||
	addr := fmt.Sprintf(":%d", argv.DnsPort)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,9 +19,8 @@ func randomPort() int {
 | 
			
		|||
	return rand.Intn(3000) + 5192
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
 | 
			
		||||
func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
 | 
			
		||||
	randomDb := utils.RandomId()
 | 
			
		||||
	dnsPort := randomPort()
 | 
			
		||||
 | 
			
		||||
	testDb := database.MakeConn(&randomDb)
 | 
			
		||||
	database.Migrate(testDb)
 | 
			
		||||
| 
						 | 
				
			
			@ -30,10 +29,15 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
 | 
			
		|||
	}
 | 
			
		||||
	database.FindOrSaveUser(testDb, testUser)
 | 
			
		||||
 | 
			
		||||
	dnsArguments := arguments
 | 
			
		||||
	if dnsArguments == nil {
 | 
			
		||||
		dnsArguments = &args.Arguments{
 | 
			
		||||
			DnsPort: randomPort(),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	waitLock := &sync.Mutex{}
 | 
			
		||||
	server := hcdns.MakeServer(&args.Arguments{
 | 
			
		||||
		DnsPort: dnsPort,
 | 
			
		||||
	}, testDb)
 | 
			
		||||
	server := hcdns.MakeServer(dnsArguments, testDb)
 | 
			
		||||
	server.NotifyStartedFunc = func() {
 | 
			
		||||
		waitLock.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -44,8 +48,9 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
 | 
			
		|||
	}()
 | 
			
		||||
	waitLock.Lock()
 | 
			
		||||
 | 
			
		||||
	address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
 | 
			
		||||
	return testDb, server, &address, waitLock, func() {
 | 
			
		||||
	address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort)
 | 
			
		||||
	return testDb, server, address, func() {
 | 
			
		||||
		waitLock.Unlock()
 | 
			
		||||
		server.Shutdown()
 | 
			
		||||
 | 
			
		||||
		testDb.Close()
 | 
			
		||||
| 
						 | 
				
			
			@ -54,11 +59,8 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func TestWhenCNAMEIsResolved(t *testing.T) {
 | 
			
		||||
	t.Log("TestWhenCNAMEIsResolved")
 | 
			
		||||
 | 
			
		||||
	testDb, _, addr, lock, cleanup := setup()
 | 
			
		||||
	testDb, _, addr, cleanup := setup(nil)
 | 
			
		||||
	defer cleanup()
 | 
			
		||||
	defer lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	records := []*database.DNSRecord{
 | 
			
		||||
		{
 | 
			
		||||
| 
						 | 
				
			
			@ -99,7 +101,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
 | 
			
		|||
	message := &dns.Msg{}
 | 
			
		||||
	message.SetQuestion(domain, qtype)
 | 
			
		||||
 | 
			
		||||
	in, _, err := client.Exchange(message, *addr)
 | 
			
		||||
	in, _, err := client.Exchange(message, addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -132,11 +134,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func TestWhenNoRecordNxDomain(t *testing.T) {
 | 
			
		||||
	t.Log("TestWhenNoRecordNxDomain")
 | 
			
		||||
 | 
			
		||||
	_, _, addr, lock, cleanup := setup()
 | 
			
		||||
	_, _, addr, cleanup := setup(nil)
 | 
			
		||||
	defer cleanup()
 | 
			
		||||
	defer lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	qtype := dns.TypeA
 | 
			
		||||
	domain := dns.Fqdn("nonexistant.example.com.")
 | 
			
		||||
| 
						 | 
				
			
			@ -144,7 +143,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
 | 
			
		|||
	message := &dns.Msg{}
 | 
			
		||||
	message.SetQuestion(domain, qtype)
 | 
			
		||||
 | 
			
		||||
	in, _, err := client.Exchange(message, *addr)
 | 
			
		||||
	in, _, err := client.Exchange(message, addr)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
| 
						 | 
				
			
			@ -160,11 +159,8 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func TestWhenUnresolvingCNAME(t *testing.T) {
 | 
			
		||||
	t.Log("TestWhenUnresolvingCNAME")
 | 
			
		||||
 | 
			
		||||
	testDb, _, addr, lock, cleanup := setup()
 | 
			
		||||
	testDb, _, addr, cleanup := setup(nil)
 | 
			
		||||
	defer cleanup()
 | 
			
		||||
	defer lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	cname := &database.DNSRecord{
 | 
			
		||||
		ID:       "1",
 | 
			
		||||
| 
						 | 
				
			
			@ -183,7 +179,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
 | 
			
		|||
	message := &dns.Msg{}
 | 
			
		||||
	message.SetQuestion(domain, qtype)
 | 
			
		||||
 | 
			
		||||
	in, _, err := client.Exchange(message, *addr)
 | 
			
		||||
	in, _, err := client.Exchange(message, addr)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
| 
						 | 
				
			
			@ -215,11 +211,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
 | 
			
		||||
	t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
 | 
			
		||||
 | 
			
		||||
	testDb, _, addr, lock, cleanup := setup()
 | 
			
		||||
	testDb, _, addr, cleanup := setup(nil)
 | 
			
		||||
	defer cleanup()
 | 
			
		||||
	defer lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	cname := &database.DNSRecord{
 | 
			
		||||
		ID:       "1",
 | 
			
		||||
| 
						 | 
				
			
			@ -238,7 +231,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
 | 
			
		|||
	message := &dns.Msg{}
 | 
			
		||||
	message.SetQuestion(domain, qtype)
 | 
			
		||||
 | 
			
		||||
	in, _, err := client.Exchange(message, *addr)
 | 
			
		||||
	in, _, err := client.Exchange(message, addr)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
| 
						 | 
				
			
			@ -252,3 +245,86 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
 | 
			
		|||
		t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWhenExternalDomain(t *testing.T) {
 | 
			
		||||
	externalDb, _, externalAddr, externalCleanup := setup(nil)
 | 
			
		||||
	internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{
 | 
			
		||||
		DnsPort:      randomPort(),
 | 
			
		||||
		DnsResolvers: []string{externalAddr},
 | 
			
		||||
	})
 | 
			
		||||
	defer internalCleanup()
 | 
			
		||||
	defer externalCleanup()
 | 
			
		||||
 | 
			
		||||
	authoritativeRecords := []database.DNSRecord{
 | 
			
		||||
		{
 | 
			
		||||
			ID:      "1",
 | 
			
		||||
			UserID:  "test",
 | 
			
		||||
			Name:    "external.example.com.",
 | 
			
		||||
			Type:    "CNAME",
 | 
			
		||||
			Content: "external.internal.example.com.",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			ID:      "2",
 | 
			
		||||
			UserID:  "test",
 | 
			
		||||
			Name:    "final.example.com.",
 | 
			
		||||
			Type:    "A",
 | 
			
		||||
			Content: "127.0.0.1",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	internalRecords := []database.DNSRecord{
 | 
			
		||||
		{
 | 
			
		||||
			ID:      "1",
 | 
			
		||||
			UserID:  "test",
 | 
			
		||||
			Name:    "external.internal.example.com.",
 | 
			
		||||
			Type:    "CNAME",
 | 
			
		||||
			Content: "final.example.com",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			ID:      "2",
 | 
			
		||||
			UserID:  "test",
 | 
			
		||||
			Name:    "test.internal.example.com.",
 | 
			
		||||
			Type:    "CNAME",
 | 
			
		||||
			Content: "external.example.com.",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, record := range authoritativeRecords {
 | 
			
		||||
		database.SaveDNSRecord(externalDb, &record)
 | 
			
		||||
	}
 | 
			
		||||
	for _, record := range internalRecords {
 | 
			
		||||
		database.SaveDNSRecord(internalDb, &record)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ensure that if the record doesn't exist in the internal database, it will
 | 
			
		||||
	// go and query the external dns resolvers, then loop back to the internal
 | 
			
		||||
 | 
			
		||||
	qtype := dns.TypeA
 | 
			
		||||
	domain := dns.Fqdn("test.internal.example.com.")
 | 
			
		||||
	client := &dns.Client{}
 | 
			
		||||
	message := &dns.Msg{}
 | 
			
		||||
	message.SetQuestion(domain, qtype)
 | 
			
		||||
 | 
			
		||||
	in, _, err := client.Exchange(message, internalAddr)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(in.Answer) != 4 {
 | 
			
		||||
		t.Fatalf("expected 4 answers, got %d", len(in.Answer))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	aRecord := in.Answer[3]
 | 
			
		||||
	if aRecord.Header().Name != authoritativeRecords[1].Name {
 | 
			
		||||
		t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name)
 | 
			
		||||
	}
 | 
			
		||||
	if aRecord.Header().Rrtype != dns.TypeA {
 | 
			
		||||
		t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type)
 | 
			
		||||
	}
 | 
			
		||||
	if aRecord.(*dns.A).A.String() != authoritativeRecords[1].Content {
 | 
			
		||||
		t.Fatalf("expected %s, got %s", authoritativeRecords[1].Content, aRecord.(*dns.A).A.String())
 | 
			
		||||
	}
 | 
			
		||||
	if in.Authoritative {
 | 
			
		||||
		t.Fatalf("expected non-authoritative response")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue