testing | dont be recursive for external domains | finalize oauth #5

Merged
simponic merged 24 commits from dont-be-authoritative into main 2024-04-06 15:43:19 -04:00
3 changed files with 29 additions and 35 deletions
Showing only changes of commit d7843d18d0 - Show all commits

View File

@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"]
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]

View File

@ -22,9 +22,8 @@ type Arguments struct {
OauthConfig *oauth2.Config
OauthUserInfoURI string
Dns bool
DnsRecursion []string
DnsPort int
Dns bool
DnsPort int
CloudflareToken string
CloudflareZone string
@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
server := flag.Bool("server", false, "Run the server")
dns := flag.Bool("dns", false, "Run DNS resolver")
dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
flag.Parse()
@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
Migrate: *migrate,
Scheduler: *scheduler,
Dns: *dns,
DnsRecursion: strings.Split(*dnsRecursion, ","),
DnsPort: *dnsPort,
OauthConfig: oauthConfig,

View File

@ -11,17 +11,13 @@ import (
const MAX_RECURSION = 10
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
if maxDepth == 0 {
return nil, fmt.Errorf("too much recursion")
}
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil {
return nil, err
}
answers := []dns.RR{}
var answers []dns.RR
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
@ -29,7 +25,10 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
}
answers = append(answers, cname)
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
if err != nil {
return nil, err
}
answers = append(answers, cnameRecursive...)
}
@ -43,37 +42,31 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
return nil, err
}
for _, record := range typeDnsRecords {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil {
return nil, err
}
answers = append(answers, answer)
}
return answers, nil
}
func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
if maxDepth == 0 {
return nil, fmt.Errorf("too much recursion")
}
answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
if err != nil {
return nil, err
}
if len(answers) > 0 {
// base case; we found the answer
return answers, nil
}
message := new(dns.Msg)
message.SetQuestion(dns.Fqdn(domain), qtype)
message.RecursionDesired = true
client := new(dns.Client)
i := 0
in, _, err := client.Exchange(message, dnsResolvers[i])
for err != nil {
i += 1
if i == len(dnsResolvers) {
log.Println(err)
return nil, err
}
in, _, err = client.Exchange(message, dnsResolvers[i])
}
answers = append(answers, in.Answer...)
return answers, nil
return nil, fmt.Errorf("no records found for %s", domain)
}
type DnsHandler struct {
@ -87,7 +80,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true
for _, question := range r.Question {
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil {
fmt.Println(err)
continue
@ -95,6 +88,10 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Answer = append(msg.Answer, answers...)
}
if len(msg.Answer) == 0 {
msg.SetRcode(r, dns.RcodeNameError)
}
log.Println(msg.Answer)
w.WriteMsg(msg)
}