add integration tests for dns server
This commit is contained in:
parent
657be66948
commit
bcdcc508ef
|
@ -1,4 +1,4 @@
|
|||
package dns
|
||||
package hcdns
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
"log"
|
||||
)
|
||||
|
||||
const MAX_RECURSION = 10
|
||||
const MAX_RECURSION = 15
|
||||
|
||||
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
||||
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
|
||||
|
@ -21,12 +21,14 @@ func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth
|
|||
for _, record := range internalCnames {
|
||||
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
answers = append(answers, cname)
|
||||
|
||||
cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
answers = append(answers, cnameRecursive...)
|
||||
|
@ -62,13 +64,9 @@ func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dn
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if len(answers) > 0 {
|
||||
return answers, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no records found for %s", domain)
|
||||
}
|
||||
|
||||
type DnsHandler struct {
|
||||
DnsResolvers []string
|
||||
DbConn *sql.DB
|
||||
|
@ -83,7 +81,9 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||
answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
continue
|
||||
msg.SetRcode(r, dns.RcodeServerFailure)
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
msg.Answer = append(msg.Answer, answers...)
|
||||
}
|
||||
|
@ -98,7 +98,6 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||
|
||||
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
|
||||
handler := &DnsHandler{
|
||||
DnsResolvers: argv.DnsRecursion,
|
||||
DbConn: dbConn,
|
||||
}
|
||||
addr := fmt.Sprintf(":%d", argv.DnsPort)
|
4
main.go
4
main.go
|
@ -6,7 +6,7 @@ import (
|
|||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
@ -52,7 +52,7 @@ func main() {
|
|||
}
|
||||
|
||||
if argv.Dns {
|
||||
server := dns.MakeServer(argv, dbConn)
|
||||
server := hcdns.MakeServer(argv, dbConn)
|
||||
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
|
||||
go func() {
|
||||
err = server.ListenAndServe()
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
package hcdns
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
testDBPath = "test.db"
|
||||
address = "127.0.0.1:8353"
|
||||
dnsPort = 8353
|
||||
)
|
||||
|
||||
func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) {
|
||||
testDb := database.MakeConn(&dbPath)
|
||||
database.Migrate(testDb)
|
||||
testUser := &database.User{
|
||||
ID: "test",
|
||||
}
|
||||
database.FindOrSaveUser(testDb, testUser)
|
||||
|
||||
server := hcdns.MakeServer(&args.Arguments{
|
||||
DnsPort: dnsPort,
|
||||
}, testDb)
|
||||
|
||||
waitGroup := sync.WaitGroup{}
|
||||
waitGroup.Add(1)
|
||||
go func() {
|
||||
server.ListenAndServe()
|
||||
waitGroup.Done()
|
||||
}()
|
||||
|
||||
return testDb, server, &waitGroup
|
||||
}
|
||||
|
||||
func destroy(conn *sql.DB, path string) {
|
||||
conn.Close()
|
||||
os.Remove(path)
|
||||
}
|
||||
|
||||
func TestWhenCNAMEIsResolved(t *testing.T) {
|
||||
t.Log("TestWhenCNAMEIsResolved")
|
||||
|
||||
testDb, server, _ := setup(testDBPath)
|
||||
defer destroy(testDb, testDBPath)
|
||||
defer server.Shutdown()
|
||||
|
||||
cname := &database.DNSRecord{
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
Name: "cname.internal.example.com.",
|
||||
Type: "CNAME",
|
||||
Content: "res.example.com.",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
}
|
||||
a := &database.DNSRecord{
|
||||
ID: "2",
|
||||
UserID: "test",
|
||||
Name: "res.example.com.",
|
||||
Type: "A",
|
||||
Content: "127.0.0.1",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
}
|
||||
database.SaveDNSRecord(testDb, cname)
|
||||
database.SaveDNSRecord(testDb, a)
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn(cname.Name)
|
||||
client := new(dns.Client)
|
||||
message := new(dns.Msg)
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, address)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 2 {
|
||||
t.Fatalf("expected 2 answers, got %d", len(in.Answer))
|
||||
}
|
||||
|
||||
if in.Answer[0].Header().Name != cname.Name {
|
||||
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
|
||||
}
|
||||
|
||||
if in.Answer[1].Header().Name != a.Name {
|
||||
t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name)
|
||||
}
|
||||
|
||||
if in.Answer[0].(*dns.CNAME).Target != a.Name {
|
||||
t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
|
||||
}
|
||||
|
||||
if in.Answer[1].(*dns.A).A.String() != a.Content {
|
||||
t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String())
|
||||
}
|
||||
|
||||
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
|
||||
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
|
||||
}
|
||||
|
||||
if in.Answer[1].Header().Rrtype != dns.TypeA {
|
||||
t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype)
|
||||
}
|
||||
|
||||
if int(in.Answer[0].Header().Ttl) != cname.TTL {
|
||||
t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl)
|
||||
}
|
||||
|
||||
if !in.Authoritative {
|
||||
t.Fatalf("expected authoritative response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenNoRecordNxDomain(t *testing.T) {
|
||||
t.Log("TestWhenNoRecordNxDomain")
|
||||
|
||||
testDb, server, _ := setup(testDBPath)
|
||||
defer destroy(testDb, testDBPath)
|
||||
defer server.Shutdown()
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn("nonexistant.example.com.")
|
||||
client := new(dns.Client)
|
||||
message := new(dns.Msg)
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, address)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 0 {
|
||||
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
||||
}
|
||||
|
||||
if in.Rcode != dns.RcodeNameError {
|
||||
t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenUnresolvingCNAME(t *testing.T) {
|
||||
t.Log("TestWhenUnresolvingCNAME")
|
||||
|
||||
testDb, server, _ := setup(testDBPath)
|
||||
defer destroy(testDb, testDBPath)
|
||||
defer server.Shutdown()
|
||||
|
||||
cname := &database.DNSRecord{
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
Name: "cname.internal.example.com.",
|
||||
Type: "CNAME",
|
||||
Content: "nonexistant.example.com.",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
}
|
||||
database.SaveDNSRecord(testDb, cname)
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn(cname.Name)
|
||||
client := new(dns.Client)
|
||||
message := new(dns.Msg)
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, address)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 1 {
|
||||
t.Fatalf("expected 1 answer, got %d", len(in.Answer))
|
||||
}
|
||||
|
||||
if !in.Authoritative {
|
||||
t.Fatalf("expected authoritative response")
|
||||
}
|
||||
|
||||
if in.Answer[0].Header().Name != cname.Name {
|
||||
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
|
||||
}
|
||||
|
||||
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
|
||||
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
|
||||
}
|
||||
|
||||
if in.Answer[0].(*dns.CNAME).Target != cname.Content {
|
||||
t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
|
||||
}
|
||||
|
||||
if in.Rcode == dns.RcodeNameError {
|
||||
t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
|
||||
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
|
||||
|
||||
testDb, server, _ := setup(testDBPath)
|
||||
defer destroy(testDb, testDBPath)
|
||||
defer server.Shutdown()
|
||||
|
||||
cname := &database.DNSRecord{
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
Name: "cname.internal.example.com.",
|
||||
Type: "CNAME",
|
||||
Content: "cname.internal.example.com.",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
}
|
||||
database.SaveDNSRecord(testDb, cname)
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn(cname.Name)
|
||||
client := new(dns.Client)
|
||||
message := new(dns.Msg)
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, address)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) > 0 {
|
||||
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
||||
}
|
||||
if in.Rcode != dns.RcodeServerFailure {
|
||||
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue