Compare commits

..

No commits in common. "86c4ad160a0442713680ff1eaa85ead635b10f8f" and "83cc6267fd5ce2f61200314424c5f400f65ff2ba" have entirely different histories.

4 changed files with 59 additions and 208 deletions

View File

@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080 EXPOSE 8080
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-resolvers", "1.1.1.1,1.0.0.1"] CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]

View File

@ -22,7 +22,6 @@ type Arguments struct {
OauthConfig *oauth2.Config OauthConfig *oauth2.Config
OauthUserInfoURI string OauthUserInfoURI string
DnsResolvers []string
Dns bool Dns bool
DnsPort int DnsPort int
@ -37,7 +36,6 @@ func GetArgs() (*Arguments, error) {
databasePath := flag.String("database-path", "./hatecomputers.db", "Path to the SQLite database") databasePath := flag.String("database-path", "./hatecomputers.db", "Path to the SQLite database")
templatePath := flag.String("template-path", "./templates", "Path to the template directory") templatePath := flag.String("template-path", "./templates", "Path to the template directory")
staticPath := flag.String("static-path", "./static", "Path to the static directory") staticPath := flag.String("static-path", "./static", "Path to the static directory")
dnsResolvers := flag.String("dns-resolvers", "1.1.1.1,1.0.0.1", "Comma-separated list of DNS resolvers")
scheduler := flag.Bool("scheduler", false, "Run scheduled jobs via cron") scheduler := flag.Bool("scheduler", false, "Run scheduled jobs via cron")
migrate := flag.Bool("migrate", false, "Run the migrations") migrate := flag.Bool("migrate", false, "Run the migrations")
@ -103,10 +101,8 @@ func GetArgs() (*Arguments, error) {
Server: *server, Server: *server,
Migrate: *migrate, Migrate: *migrate,
Scheduler: *scheduler, Scheduler: *scheduler,
Dns: *dns, Dns: *dns,
DnsPort: *dnsPort, DnsPort: *dnsPort,
DnsResolvers: strings.Split(*dnsResolvers, ","),
OauthConfig: oauthConfig, OauthConfig: oauthConfig,
OauthUserInfoURI: oauthUserInfoURI, OauthUserInfoURI: oauthUserInfoURI,

View File

@ -11,142 +11,74 @@ import (
const MAX_RECURSION = 15 const MAX_RECURSION = 15
type DnsHandler struct { func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
DnsResolvers []string internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
DbConn *sql.DB
}
func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) {
client := &dns.Client{}
message := &dns.Msg{}
message.SetQuestion(dns.Fqdn(domain), qtype)
message.RecursionDesired = true
if len(h.DnsResolvers) == 0 {
return []dns.RR{}, nil
}
i := 0
in, _, err := client.Exchange(message, h.DnsResolvers[i])
for err != nil && i < len(h.DnsResolvers) {
i++
in, _, err = client.Exchange(message, h.DnsResolvers[i])
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(in.Answer) == 0 {
return nil, nil
}
return in.Answer, nil
}
func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool {
for _, answer := range answers {
if answer.Header().Name == domain && answer.Header().Rrtype == qtype {
return true
}
}
return false
}
func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME")
if err != nil {
return nil, true, err
}
authoritative := true
var answers []dns.RR var answers []dns.RR
for _, record := range internalCnames { for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return nil, authoritative, err return nil, err
} }
answers = append(answers, cname) answers = append(answers, cname)
cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1) cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
authoritative = authoritative && cnameAuth
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return nil, authoritative, err return nil, err
} }
answers = append(answers, cnameRecursive...) answers = append(answers, cnameRecursive...)
} }
qtypeName := dns.TypeToString[qtype] qtypeName := dns.TypeToString[qtype]
records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName) if qtypeName == "" {
if err != nil { return nil, fmt.Errorf("invalid query type %d", qtype)
return nil, authoritative, err
} }
for _, record := range records { typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
if err != nil {
return nil, err
}
for _, record := range typeDnsRecords {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil { if err != nil {
return nil, authoritative, err return nil, err
} }
answers = append(answers, answer) answers = append(answers, answer)
} }
return answers, authoritative, nil return answers, nil
} }
func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth)
if maxDepth == 0 { if maxDepth == 0 {
return nil, false, fmt.Errorf("too much recursion") return nil, fmt.Errorf("too much recursion")
} }
answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth) answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
if len(answers) > 0 { // base case - we got the answer return answers, nil
return answers, authoritative, nil }
}
externalAnswers, err := h.resolveExternal(domain, qtype) type DnsHandler struct {
if err != nil { DnsResolvers []string
return nil, false, err DbConn *sql.DB
}
answers = append(answers, externalAnswers...)
if resultSetFound(externalAnswers, domain, qtype) {
return answers, false, nil
}
for _, answer := range externalAnswers {
cname, ok := answer.(*dns.CNAME)
if !ok {
continue
}
cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
authoritative = authoritative && cnameAuth
if err != nil {
return nil, false, err
}
answers = append(answers, cnameAnswers...)
}
authoritative = authoritative && len(externalAnswers) == 0
return answers, authoritative, nil
} }
func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg := &dns.Msg{} msg := new(dns.Msg)
msg.SetReply(r) msg.SetReply(r)
msg.Authoritative = false msg.Authoritative = true
for _, question := range r.Question { for _, question := range r.Question {
answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION) answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
msg.Authoritative = authoritative
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
msg.SetRcode(r, dns.RcodeServerFailure) msg.SetRcode(r, dns.RcodeServerFailure)
@ -167,7 +99,6 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{ handler := &DnsHandler{
DbConn: dbConn, DbConn: dbConn,
DnsResolvers: argv.DnsResolvers,
} }
addr := fmt.Sprintf(":%d", argv.DnsPort) addr := fmt.Sprintf(":%d", argv.DnsPort)

View File

@ -19,8 +19,9 @@ func randomPort() int {
return rand.Intn(3000) + 5192 return rand.Intn(3000) + 5192
} }
func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) { func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
randomDb := utils.RandomId() randomDb := utils.RandomId()
dnsPort := randomPort()
testDb := database.MakeConn(&randomDb) testDb := database.MakeConn(&randomDb)
database.Migrate(testDb) database.Migrate(testDb)
@ -29,15 +30,10 @@ func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
} }
database.FindOrSaveUser(testDb, testUser) database.FindOrSaveUser(testDb, testUser)
dnsArguments := arguments
if dnsArguments == nil {
dnsArguments = &args.Arguments{
DnsPort: randomPort(),
}
}
waitLock := &sync.Mutex{} waitLock := &sync.Mutex{}
server := hcdns.MakeServer(dnsArguments, testDb) server := hcdns.MakeServer(&args.Arguments{
DnsPort: dnsPort,
}, testDb)
server.NotifyStartedFunc = func() { server.NotifyStartedFunc = func() {
waitLock.Unlock() waitLock.Unlock()
} }
@ -48,9 +44,8 @@ func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
}() }()
waitLock.Lock() waitLock.Lock()
address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort) address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
return testDb, server, address, func() { return testDb, server, &address, waitLock, func() {
waitLock.Unlock()
server.Shutdown() server.Shutdown()
testDb.Close() testDb.Close()
@ -59,8 +54,11 @@ func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
} }
func TestWhenCNAMEIsResolved(t *testing.T) { func TestWhenCNAMEIsResolved(t *testing.T) {
testDb, _, addr, cleanup := setup(nil) t.Log("TestWhenCNAMEIsResolved")
testDb, _, addr, lock, cleanup := setup()
defer cleanup() defer cleanup()
defer lock.Unlock()
records := []*database.DNSRecord{ records := []*database.DNSRecord{
{ {
@ -101,7 +99,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
message := &dns.Msg{} message := &dns.Msg{}
message.SetQuestion(domain, qtype) message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, addr) in, _, err := client.Exchange(message, *addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -134,8 +132,11 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
} }
func TestWhenNoRecordNxDomain(t *testing.T) { func TestWhenNoRecordNxDomain(t *testing.T) {
_, _, addr, cleanup := setup(nil) t.Log("TestWhenNoRecordNxDomain")
_, _, addr, lock, cleanup := setup()
defer cleanup() defer cleanup()
defer lock.Unlock()
qtype := dns.TypeA qtype := dns.TypeA
domain := dns.Fqdn("nonexistant.example.com.") domain := dns.Fqdn("nonexistant.example.com.")
@ -143,7 +144,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
message := &dns.Msg{} message := &dns.Msg{}
message.SetQuestion(domain, qtype) message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, addr) in, _, err := client.Exchange(message, *addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -159,8 +160,11 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
} }
func TestWhenUnresolvingCNAME(t *testing.T) { func TestWhenUnresolvingCNAME(t *testing.T) {
testDb, _, addr, cleanup := setup(nil) t.Log("TestWhenUnresolvingCNAME")
testDb, _, addr, lock, cleanup := setup()
defer cleanup() defer cleanup()
defer lock.Unlock()
cname := &database.DNSRecord{ cname := &database.DNSRecord{
ID: "1", ID: "1",
@ -179,7 +183,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
message := &dns.Msg{} message := &dns.Msg{}
message.SetQuestion(domain, qtype) message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, addr) in, _, err := client.Exchange(message, *addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -211,8 +215,11 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
} }
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
testDb, _, addr, cleanup := setup(nil) t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
testDb, _, addr, lock, cleanup := setup()
defer cleanup() defer cleanup()
defer lock.Unlock()
cname := &database.DNSRecord{ cname := &database.DNSRecord{
ID: "1", ID: "1",
@ -231,7 +238,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
message := &dns.Msg{} message := &dns.Msg{}
message.SetQuestion(domain, qtype) message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, addr) in, _, err := client.Exchange(message, *addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -245,86 +252,3 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
t.Fatalf("expected SERVFAIL, got %d", in.Rcode) t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
} }
} }
func TestWhenExternalDomain(t *testing.T) {
externalDb, _, externalAddr, externalCleanup := setup(nil)
internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{
DnsPort: randomPort(),
DnsResolvers: []string{externalAddr},
})
defer internalCleanup()
defer externalCleanup()
authoritativeRecords := []database.DNSRecord{
{
ID: "1",
UserID: "test",
Name: "external.example.com.",
Type: "CNAME",
Content: "external.internal.example.com.",
},
{
ID: "2",
UserID: "test",
Name: "final.example.com.",
Type: "A",
Content: "127.0.0.1",
},
}
internalRecords := []database.DNSRecord{
{
ID: "1",
UserID: "test",
Name: "external.internal.example.com.",
Type: "CNAME",
Content: "final.example.com",
},
{
ID: "2",
UserID: "test",
Name: "test.internal.example.com.",
Type: "CNAME",
Content: "external.example.com.",
},
}
for _, record := range authoritativeRecords {
database.SaveDNSRecord(externalDb, &record)
}
for _, record := range internalRecords {
database.SaveDNSRecord(internalDb, &record)
}
// ensure that if the record doesn't exist in the internal database, it will
// go and query the external dns resolvers, then loop back to the internal
qtype := dns.TypeA
domain := dns.Fqdn("test.internal.example.com.")
client := &dns.Client{}
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, internalAddr)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) != 4 {
t.Fatalf("expected 4 answers, got %d", len(in.Answer))
}
aRecord := in.Answer[3]
if aRecord.Header().Name != authoritativeRecords[1].Name {
t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name)
}
if aRecord.Header().Rrtype != dns.TypeA {
t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type)
}
if aRecord.(*dns.A).A.String() != authoritativeRecords[1].Content {
t.Fatalf("expected %s, got %s", authoritativeRecords[1].Content, aRecord.(*dns.A).A.String())
}
if in.Authoritative {
t.Fatalf("expected non-authoritative response")
}
}