hatecomputers.club/hcdns/server.go

182 lines
4.3 KiB
Go
Raw Normal View History

2024-04-02 18:26:39 -04:00
package hcdns
import (
"database/sql"
"fmt"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"github.com/miekg/dns"
"log"
)
2024-04-02 18:26:39 -04:00
const MAX_RECURSION = 15
2024-04-07 19:04:43 -04:00
type DnsHandler struct {
DnsResolvers []string
DbConn *sql.DB
}
func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) {
client := &dns.Client{}
message := &dns.Msg{}
message.SetQuestion(dns.Fqdn(domain), qtype)
message.RecursionDesired = true
if len(h.DnsResolvers) == 0 {
return []dns.RR{}, nil
}
i := 0
in, _, err := client.Exchange(message, h.DnsResolvers[i])
for err != nil && i < len(h.DnsResolvers) {
i++
in, _, err = client.Exchange(message, h.DnsResolvers[i])
}
if err != nil {
return nil, err
}
2024-04-07 19:04:43 -04:00
if len(in.Answer) == 0 {
return nil, nil
}
return in.Answer, nil
}
func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool {
for _, answer := range answers {
if answer.Header().Name == domain && answer.Header().Rrtype == qtype {
return true
}
}
return false
}
func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME")
if err != nil {
return nil, true, err
}
authoritative := true
var answers []dns.RR
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
2024-04-02 18:26:39 -04:00
log.Println(err)
2024-04-07 19:04:43 -04:00
return nil, authoritative, err
}
answers = append(answers, cname)
2024-04-07 19:04:43 -04:00
cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1)
authoritative = authoritative && cnameAuth
if err != nil {
2024-04-02 18:26:39 -04:00
log.Println(err)
2024-04-07 19:04:43 -04:00
return nil, authoritative, err
}
answers = append(answers, cnameRecursive...)
}
qtypeName := dns.TypeToString[qtype]
2024-04-07 19:04:43 -04:00
records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName)
if err != nil {
2024-04-07 19:04:43 -04:00
return nil, authoritative, err
}
2024-04-07 19:04:43 -04:00
for _, record := range records {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil {
2024-04-07 19:04:43 -04:00
return nil, authoritative, err
}
answers = append(answers, answer)
}
2024-04-07 19:04:43 -04:00
return answers, authoritative, nil
}
2024-04-07 19:04:43 -04:00
func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth)
if maxDepth == 0 {
2024-04-07 19:04:43 -04:00
return nil, false, fmt.Errorf("too much recursion")
}
2024-04-07 19:04:43 -04:00
answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth)
if err != nil {
2024-04-07 19:04:43 -04:00
return nil, false, err
}
2024-04-07 19:04:43 -04:00
if len(answers) > 0 { // base case - we got the answer
return answers, authoritative, nil
}
2024-04-07 19:04:43 -04:00
externalAnswers, err := h.resolveExternal(domain, qtype)
if err != nil {
return nil, false, err
}
answers = append(answers, externalAnswers...)
if resultSetFound(externalAnswers, domain, qtype) {
return answers, false, nil
}
for _, answer := range externalAnswers {
cname, ok := answer.(*dns.CNAME)
if !ok {
continue
}
cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
authoritative = authoritative && cnameAuth
2024-04-07 19:04:43 -04:00
if err != nil {
return nil, false, err
}
answers = append(answers, cnameAnswers...)
}
authoritative = authoritative && len(externalAnswers) == 0
return answers, authoritative, nil
}
func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
2024-04-07 19:04:43 -04:00
msg := &dns.Msg{}
msg.SetReply(r)
2024-04-07 19:04:43 -04:00
msg.Authoritative = false
for _, question := range r.Question {
2024-04-07 19:04:43 -04:00
answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION)
msg.Authoritative = authoritative
if err != nil {
fmt.Println(err)
2024-04-02 18:26:39 -04:00
msg.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(msg)
return
}
msg.Answer = append(msg.Answer, answers...)
}
if len(msg.Answer) == 0 {
msg.SetRcode(r, dns.RcodeNameError)
}
log.Println(msg.Answer)
w.WriteMsg(msg)
}
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{
2024-04-07 19:04:43 -04:00
DbConn: dbConn,
DnsResolvers: argv.DnsResolvers,
}
addr := fmt.Sprintf(":%d", argv.DnsPort)
return &dns.Server{
Addr: addr,
Net: "udp",
Handler: handler,
UDPSize: 65535,
ReusePort: true,
}
}