testing | dont be recursive for external domains | finalize oauth #5

Merged
simponic merged 24 commits from dont-be-authoritative into main 2024-04-06 15:43:19 -04:00
3 changed files with 29 additions and 35 deletions
Showing only changes of commit d7843d18d0 - Show all commits

View File

@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080 EXPOSE 8080
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"] CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]

View File

@ -23,7 +23,6 @@ type Arguments struct {
OauthUserInfoURI string OauthUserInfoURI string
Dns bool Dns bool
DnsRecursion []string
DnsPort int DnsPort int
CloudflareToken string CloudflareToken string
@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
server := flag.Bool("server", false, "Run the server") server := flag.Bool("server", false, "Run the server")
dns := flag.Bool("dns", false, "Run DNS resolver") dns := flag.Bool("dns", false, "Run DNS resolver")
dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver") dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
flag.Parse() flag.Parse()
@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
Migrate: *migrate, Migrate: *migrate,
Scheduler: *scheduler, Scheduler: *scheduler,
Dns: *dns, Dns: *dns,
DnsRecursion: strings.Split(*dnsRecursion, ","),
DnsPort: *dnsPort, DnsPort: *dnsPort,
OauthConfig: oauthConfig, OauthConfig: oauthConfig,

View File

@ -11,17 +11,13 @@ import (
const MAX_RECURSION = 10 const MAX_RECURSION = 10
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
if maxDepth == 0 {
return nil, fmt.Errorf("too much recursion")
}
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil { if err != nil {
return nil, err return nil, err
} }
answers := []dns.RR{} var answers []dns.RR
for _, record := range internalCnames { for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil { if err != nil {
@ -29,7 +25,10 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
} }
answers = append(answers, cname) answers = append(answers, cname)
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
if err != nil {
return nil, err
}
answers = append(answers, cnameRecursive...) answers = append(answers, cnameRecursive...)
} }
@ -43,37 +42,31 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
return nil, err return nil, err
} }
for _, record := range typeDnsRecords { for _, record := range typeDnsRecords {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil { if err != nil {
return nil, err return nil, err
} }
answers = append(answers, answer) answers = append(answers, answer)
} }
if len(answers) > 0 {
// base case; we found the answer
return answers, nil return answers, nil
} }
message := new(dns.Msg) func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
message.SetQuestion(dns.Fqdn(domain), qtype) if maxDepth == 0 {
message.RecursionDesired = true return nil, fmt.Errorf("too much recursion")
}
client := new(dns.Client) answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
if err != nil {
i := 0
in, _, err := client.Exchange(message, dnsResolvers[i])
for err != nil {
i += 1
if i == len(dnsResolvers) {
log.Println(err)
return nil, err return nil, err
} }
in, _, err = client.Exchange(message, dnsResolvers[i])
if len(answers) > 0 {
return answers, nil
} }
answers = append(answers, in.Answer...) return nil, fmt.Errorf("no records found for %s", domain)
return answers, nil
} }
type DnsHandler struct { type DnsHandler struct {
@ -87,7 +80,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true msg.Authoritative = true
for _, question := range r.Question { for _, question := range r.Question {
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue continue
@ -95,6 +88,10 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Answer = append(msg.Answer, answers...) msg.Answer = append(msg.Answer, answers...)
} }
if len(msg.Answer) == 0 {
msg.SetRcode(r, dns.RcodeNameError)
}
log.Println(msg.Answer) log.Println(msg.Answer)
w.WriteMsg(msg) w.WriteMsg(msg)
} }