hatecomputers.club/dns/server.go

117 lines
2.7 KiB
Go

package dns
import (
"database/sql"
"fmt"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"github.com/miekg/dns"
"log"
)
const MAX_RECURSION = 10
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
if maxDepth == 0 {
return nil, fmt.Errorf("too much recursion")
}
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil {
return nil, err
}
answers := []dns.RR{}
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
return nil, err
}
answers = append(answers, cname)
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
answers = append(answers, cnameRecursive...)
}
qtypeName := dns.TypeToString[qtype]
if qtypeName == "" {
return nil, fmt.Errorf("invalid query type %d", qtype)
}
typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
if err != nil {
return nil, err
}
for _, record := range typeDnsRecords {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
if err != nil {
return nil, err
}
answers = append(answers, answer)
}
if len(answers) > 0 {
// base case; we found the answer
return answers, nil
}
message := new(dns.Msg)
message.SetQuestion(dns.Fqdn(domain), qtype)
message.RecursionDesired = true
client := new(dns.Client)
i := 0
in, _, err := client.Exchange(message, dnsResolvers[i])
for err != nil {
i += 1
if i == len(dnsResolvers) {
log.Println(err)
return nil, err
}
in, _, err = client.Exchange(message, dnsResolvers[i])
}
answers = append(answers, in.Answer...)
return answers, nil
}
type DnsHandler struct {
DnsResolvers []string
DbConn *sql.DB
}
func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
for _, question := range r.Question {
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
if err != nil {
fmt.Println(err)
continue
}
msg.Answer = append(msg.Answer, answers...)
}
log.Println(msg.Answer)
w.WriteMsg(msg)
}
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{
DnsResolvers: argv.DnsRecursion,
DbConn: dbConn,
}
addr := fmt.Sprintf(":%d", argv.DnsPort)
return &dns.Server{
Addr: addr,
Net: "udp",
Handler: handler,
UDPSize: 65535,
ReusePort: true,
}
}