Compare commits

..

4 Commits

Author SHA1 Message Date
Elizabeth 07c272b809
add test step to ci
continuous-integration/drone/pr Build is failing Details
2024-04-02 16:32:03 -06:00
Elizabeth bcdcc508ef
add integration tests for dns server 2024-04-02 16:26:39 -06:00
Elizabeth 657be66948
defer body close after encoding json 2024-04-02 14:53:50 -06:00
Elizabeth d7843d18d0
stop being authoritative for stuff not in internal dns 2024-04-02 14:49:18 -06:00
7 changed files with 297 additions and 47 deletions

View File

@ -1,9 +1,15 @@
--- ---
kind: pipeline kind: pipeline
type: docker type: docker
name: build, publish docker image, deploy name: deployment
steps: steps:
- name: run tests
image: golang
commands:
- go build
- go test -v ./...
- name: docker - name: docker
image: plugins/docker image: plugins/docker
settings: settings:
@ -13,9 +19,10 @@ steps:
from_secret: gitea_packpub_password from_secret: gitea_packpub_password
registry: git.hatecomputers.club registry: git.hatecomputers.club
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
tags: when:
- latest branch:
- main - main
- name: ssh - name: ssh
image: appleboy/drone-ssh image: appleboy/drone-ssh
settings: settings:
@ -27,6 +34,10 @@ steps:
command_timeout: 2m command_timeout: 2m
script: script:
- systemctl restart docker-compose@hatecomputers-club - systemctl restart docker-compose@hatecomputers-club
when:
branch:
- main
trigger: trigger:
branch: branch:
- main - main

View File

@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080 EXPOSE 8080
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"] CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]

View File

@ -259,11 +259,13 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us
} }
func createUserFromResponse(response *http.Response) (*database.User, error) { func createUserFromResponse(response *http.Response) (*database.User, error) {
defer response.Body.Close()
user := &database.User{ user := &database.User{
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
err := json.NewDecoder(response.Body).Decode(user) err := json.NewDecoder(response.Body).Decode(user)
defer response.Body.Close()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return nil, err return nil, err

View File

@ -23,7 +23,6 @@ type Arguments struct {
OauthUserInfoURI string OauthUserInfoURI string
Dns bool Dns bool
DnsRecursion []string
DnsPort int DnsPort int
CloudflareToken string CloudflareToken string
@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
server := flag.Bool("server", false, "Run the server") server := flag.Bool("server", false, "Run the server")
dns := flag.Bool("dns", false, "Run DNS resolver") dns := flag.Bool("dns", false, "Run DNS resolver")
dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver") dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
flag.Parse() flag.Parse()
@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
Migrate: *migrate, Migrate: *migrate,
Scheduler: *scheduler, Scheduler: *scheduler,
Dns: *dns, Dns: *dns,
DnsRecursion: strings.Split(*dnsRecursion, ","),
DnsPort: *dnsPort, DnsPort: *dnsPort,
OauthConfig: oauthConfig, OauthConfig: oauthConfig,

View File

@ -1,4 +1,4 @@
package dns package hcdns
import ( import (
"database/sql" "database/sql"
@ -9,27 +9,28 @@ import (
"log" "log"
) )
const MAX_RECURSION = 10 const MAX_RECURSION = 15
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
if maxDepth == 0 {
return nil, fmt.Errorf("too much recursion")
}
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil { if err != nil {
return nil, err return nil, err
} }
answers := []dns.RR{} var answers []dns.RR
for _, record := range internalCnames { for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil { if err != nil {
log.Println(err)
return nil, err return nil, err
} }
answers = append(answers, cname) answers = append(answers, cname)
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
if err != nil {
log.Println(err)
return nil, err
}
answers = append(answers, cnameRecursive...) answers = append(answers, cnameRecursive...)
} }
@ -43,36 +44,26 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
return nil, err return nil, err
} }
for _, record := range typeDnsRecords { for _, record := range typeDnsRecords {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil { if err != nil {
return nil, err return nil, err
} }
answers = append(answers, answer) answers = append(answers, answer)
} }
if len(answers) > 0 {
// base case; we found the answer
return answers, nil return answers, nil
} }
message := new(dns.Msg) func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
message.SetQuestion(dns.Fqdn(domain), qtype) if maxDepth == 0 {
message.RecursionDesired = true return nil, fmt.Errorf("too much recursion")
}
client := new(dns.Client) answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
if err != nil {
i := 0
in, _, err := client.Exchange(message, dnsResolvers[i])
for err != nil {
i += 1
if i == len(dnsResolvers) {
log.Println(err)
return nil, err return nil, err
} }
in, _, err = client.Exchange(message, dnsResolvers[i])
}
answers = append(answers, in.Answer...)
return answers, nil return answers, nil
} }
@ -87,21 +78,26 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true msg.Authoritative = true
for _, question := range r.Question { for _, question := range r.Question {
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue msg.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(msg)
return
} }
msg.Answer = append(msg.Answer, answers...) msg.Answer = append(msg.Answer, answers...)
} }
if len(msg.Answer) == 0 {
msg.SetRcode(r, dns.RcodeNameError)
}
log.Println(msg.Answer) log.Println(msg.Answer)
w.WriteMsg(msg) w.WriteMsg(msg)
} }
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{ handler := &DnsHandler{
DnsResolvers: argv.DnsRecursion,
DbConn: dbConn, DbConn: dbConn,
} }
addr := fmt.Sprintf(":%d", argv.DnsPort) addr := fmt.Sprintf(":%d", argv.DnsPort)

View File

@ -6,7 +6,7 @@ import (
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" "git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
"github.com/joho/godotenv" "github.com/joho/godotenv"
) )
@ -52,7 +52,7 @@ func main() {
} }
if argv.Dns { if argv.Dns {
server := dns.MakeServer(argv, dbConn) server := hcdns.MakeServer(argv, dbConn)
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
go func() { go func() {
err = server.ListenAndServe() err = server.ListenAndServe()

244
test/dns_test.go Normal file
View File

@ -0,0 +1,244 @@
package hcdns
import (
"database/sql"
"os"
"sync"
"testing"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
"github.com/miekg/dns"
)
const (
testDBPath = "test.db"
address = "127.0.0.1:8353"
dnsPort = 8353
)
func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) {
testDb := database.MakeConn(&dbPath)
database.Migrate(testDb)
testUser := &database.User{
ID: "test",
}
database.FindOrSaveUser(testDb, testUser)
server := hcdns.MakeServer(&args.Arguments{
DnsPort: dnsPort,
}, testDb)
waitGroup := sync.WaitGroup{}
waitGroup.Add(1)
go func() {
server.ListenAndServe()
waitGroup.Done()
}()
return testDb, server, &waitGroup
}
func destroy(conn *sql.DB, path string) {
conn.Close()
os.Remove(path)
}
func TestWhenCNAMEIsResolved(t *testing.T) {
t.Log("TestWhenCNAMEIsResolved")
testDb, server, _ := setup(testDBPath)
defer destroy(testDb, testDBPath)
defer server.Shutdown()
cname := &database.DNSRecord{
ID: "1",
UserID: "test",
Name: "cname.internal.example.com.",
Type: "CNAME",
Content: "res.example.com.",
TTL: 300,
Internal: true,
}
a := &database.DNSRecord{
ID: "2",
UserID: "test",
Name: "res.example.com.",
Type: "A",
Content: "127.0.0.1",
TTL: 300,
Internal: true,
}
database.SaveDNSRecord(testDb, cname)
database.SaveDNSRecord(testDb, a)
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
client := new(dns.Client)
message := new(dns.Msg)
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, address)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) != 2 {
t.Fatalf("expected 2 answers, got %d", len(in.Answer))
}
if in.Answer[0].Header().Name != cname.Name {
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
}
if in.Answer[1].Header().Name != a.Name {
t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name)
}
if in.Answer[0].(*dns.CNAME).Target != a.Name {
t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
}
if in.Answer[1].(*dns.A).A.String() != a.Content {
t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String())
}
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
}
if in.Answer[1].Header().Rrtype != dns.TypeA {
t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype)
}
if int(in.Answer[0].Header().Ttl) != cname.TTL {
t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl)
}
if !in.Authoritative {
t.Fatalf("expected authoritative response")
}
}
func TestWhenNoRecordNxDomain(t *testing.T) {
t.Log("TestWhenNoRecordNxDomain")
testDb, server, _ := setup(testDBPath)
defer destroy(testDb, testDBPath)
defer server.Shutdown()
qtype := dns.TypeA
domain := dns.Fqdn("nonexistant.example.com.")
client := new(dns.Client)
message := new(dns.Msg)
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, address)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) != 0 {
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
}
if in.Rcode != dns.RcodeNameError {
t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
}
}
func TestWhenUnresolvingCNAME(t *testing.T) {
t.Log("TestWhenUnresolvingCNAME")
testDb, server, _ := setup(testDBPath)
defer destroy(testDb, testDBPath)
defer server.Shutdown()
cname := &database.DNSRecord{
ID: "1",
UserID: "test",
Name: "cname.internal.example.com.",
Type: "CNAME",
Content: "nonexistant.example.com.",
TTL: 300,
Internal: true,
}
database.SaveDNSRecord(testDb, cname)
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
client := new(dns.Client)
message := new(dns.Msg)
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, address)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) != 1 {
t.Fatalf("expected 1 answer, got %d", len(in.Answer))
}
if !in.Authoritative {
t.Fatalf("expected authoritative response")
}
if in.Answer[0].Header().Name != cname.Name {
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
}
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
}
if in.Answer[0].(*dns.CNAME).Target != cname.Content {
t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
}
if in.Rcode == dns.RcodeNameError {
t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
}
}
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
testDb, server, _ := setup(testDBPath)
defer destroy(testDb, testDBPath)
defer server.Shutdown()
cname := &database.DNSRecord{
ID: "1",
UserID: "test",
Name: "cname.internal.example.com.",
Type: "CNAME",
Content: "cname.internal.example.com.",
TTL: 300,
Internal: true,
}
database.SaveDNSRecord(testDb, cname)
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
client := new(dns.Client)
message := new(dns.Msg)
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, address)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) > 0 {
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
}
if in.Rcode != dns.RcodeServerFailure {
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
}
}