From 385d4a84eb813ce6f777b6ab10642ad447f93321 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Tue, 2 Apr 2024 20:26:24 -0600 Subject: [PATCH] fix dns race condition --- .drone.yml | 7 ++++++- test/dns_test.go | 49 +++++++++++++++++++++++++++--------------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/.drone.yml b/.drone.yml index b96d25e..d056e69 100644 --- a/.drone.yml +++ b/.drone.yml @@ -12,7 +12,7 @@ steps: trigger: event: - - push + - pull_request --- kind: pipeline @@ -20,6 +20,11 @@ type: docker name: deploy steps: + - name: run tests + image: golang + commands: + - go build + - go test -p 1 -v ./... - name: docker image: plugins/docker settings: diff --git a/test/dns_test.go b/test/dns_test.go index 55bb060..2caabe4 100644 --- a/test/dns_test.go +++ b/test/dns_test.go @@ -21,10 +21,10 @@ func destroy(conn *sql.DB, path string) { } func randomPort() int { - return rand.Intn(3000) + 10000 + return rand.Intn(3000) + 1024 } -func setup() (*sql.DB, *dns.Server, int, *string, func()) { +func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { randomDb := utils.RandomId() dnsPort := randomPort() @@ -35,32 +35,35 @@ func setup() (*sql.DB, *dns.Server, int, *string, func()) { } database.FindOrSaveUser(testDb, testUser) + waitLock := &sync.Mutex{} server := hcdns.MakeServer(&args.Arguments{ DnsPort: dnsPort, }, testDb) + server.NotifyStartedFunc = func() { + waitLock.Unlock() + } + waitLock.Lock() - waitGroup := sync.WaitGroup{} - waitGroup.Add(1) go func() { server.ListenAndServe() - waitGroup.Done() }() + waitLock.Lock() address := fmt.Sprintf("127.0.0.1:%d", dnsPort) - return testDb, server, dnsPort, &address, func() { + return testDb, server, &address, waitLock, func() { + server.Shutdown() + testDb.Close() os.Remove(randomDb) - - server.Shutdown() - waitGroup.Wait() } } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -85,8 +88,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -135,13 +138,14 @@ func TestWhenCNAMEIsResolved(t *testing.T) { func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") - _, _, _, addr, cleanup := setup() + _, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -162,8 +166,9 @@ func TestWhenNoRecordNxDomain(t *testing.T) { func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -178,8 +183,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -216,8 +221,9 @@ func TestWhenUnresolvingCNAME(t *testing.T) { func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -232,8 +238,8 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -245,6 +251,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } + if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) }