Compare commits

..

6 Commits

Author SHA1 Message Date
Elizabeth e398cf0540
checkpoint to save work; had to get on the bus
continuous-integration/drone/pr Build is failing Details
2024-04-03 16:22:19 -06:00
Elizabeth b74a955dcb
add guestbook tests 2024-04-03 16:07:40 -06:00
Elizabeth 8c7d9b3762
dont always 200 on template render 2024-04-03 15:59:12 -06:00
Elizabeth 47cc8feefa
rename auth redirect login name 2024-04-03 15:58:44 -06:00
Elizabeth da6b6011fc
refactor dns server test a bit 2024-04-03 15:33:02 -06:00
Elizabeth cc33a90bfd
abstract dns adapter 2024-04-03 14:27:55 -06:00
12 changed files with 496 additions and 243 deletions

View File

@ -14,15 +14,20 @@ type CloudflareDNSResponse struct {
Result database.DNSRecord `json:"result"` Result database.DNSRecord `json:"result"`
} }
func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) { type CloudflareExternalDNSAdapter struct {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) ZoneId string
APIToken string
}
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
payload := strings.NewReader(reqBody) payload := strings.NewReader(reqBody)
req, _ := http.NewRequest("POST", url, payload) req, _ := http.NewRequest("POST", url, payload)
req.Header.Add("Authorization", "Bearer "+apiToken) req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
return result.ID, nil return result.ID, nil
} }
func DeleteDNSRecord(zoneId string, apiToken string, id string) error { func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id) url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
req, _ := http.NewRequest("DELETE", url, nil) req, _ := http.NewRequest("DELETE", url, nil)
req.Header.Add("Authorization", "Bearer "+apiToken) req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {

8
adapters/external_dns.go Normal file
View File

@ -0,0 +1,8 @@
package external_dns
import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
type ExternalDNSAdapter interface {
CreateDNSRecord(record *database.DNSRecord) (string, error)
DeleteDNSRecord(id string) error
}

View File

@ -50,7 +50,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
} }
} }
func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func InterceptOauthCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain {
state := req.URL.Query().Get("state") state := req.URL.Query().Get("state")
code := req.URL.Query().Get("code") code := req.URL.Query().Get("code")

37
api/auth_test.go Normal file
View File

@ -0,0 +1,37 @@
package api_test
import (
"database/sql"
"os"
"testing"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
func setup() (*sql.DB, *api.RequestContext, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
context := &api.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
}
return testDb, context, func() {
testDb.Close()
os.Remove(randomDb)
}
}
/*
todo: test api key creation
+ api key attached to user
+ user session is unique
+ goLogin goes to page in cookie
*/

View File

@ -8,30 +8,25 @@ import (
"strconv" "strconv"
"strings" "strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
) )
const MAX_USER_RECORDS = 65 const MAX_USER_RECORDS = 65
type FormError struct { var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
Errors []string
}
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool { func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
ownedByUser := (user.ID == record.UserID) ownedByUser := (user.ID == record.UserID)
if !ownedByUser { if !ownedByUser {
return false return false
} }
if !record.Internal { if !record.Internal {
userOwnedDomains := []string{ for _, format := range ownedInternalDomainFormats {
fmt.Sprintf("%s", user.Username), domain := fmt.Sprintf(format, user.Username)
fmt.Sprintf("%s.endpoints", user.Username),
}
for _, domain := range userOwnedDomains {
isInSubDomain := strings.HasSuffix(record.Name, "."+domain) isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
if domain == record.Name || isInSubDomain { if domain == record.Name || isInSubDomain {
return true return true
@ -64,7 +59,8 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp
} }
} }
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain {
formErrors := FormError{ formErrors := FormError{
Errors: []string{}, Errors: []string{},
@ -104,7 +100,8 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res
TTL: ttlNum, TTL: ttlNum,
Internal: internal, Internal: internal,
} }
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) {
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
} }
@ -112,13 +109,11 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res
if dnsRecord.Internal { if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId() dnsRecord.ID = utils.RandomId()
} else { } else {
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
formErrors.Errors = append(formErrors.Errors, err.Error()) formErrors.Errors = append(formErrors.Errors, err.Error())
} }
dnsRecord.ID = cloudflareRecordId
} }
} }
@ -142,8 +137,10 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res
return failure(context, req, resp) return failure(context, req, resp)
} }
} }
}
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain {
recordId := req.FormValue("id") recordId := req.FormValue("id")
record, err := database.GetDNSRecord(context.DBConn, recordId) record, err := database.GetDNSRecord(context.DBConn, recordId)
@ -153,13 +150,13 @@ func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, res
return failure(context, req, resp) return failure(context, req, resp)
} }
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) {
resp.WriteHeader(http.StatusUnauthorized) resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp) return failure(context, req, resp)
} }
if !record.Internal { if !record.Internal {
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) err = dnsAdapter.DeleteDNSRecord(recordId)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
@ -177,3 +174,4 @@ func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, res
return success(context, req, resp) return success(context, req, resp)
} }
} }
}

56
api/dns_test.go Normal file
View File

@ -0,0 +1,56 @@
package api_test
import (
"database/sql"
"net/http"
"net/http/httptest"
"os"
"testing"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
func setup() (*sql.DB, *api.RequestContext, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
context := &api.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
}
return testDb, context, func() {
testDb.Close()
os.Remove(randomDb)
}
}
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
testUser := &database.User{
ID: "test",
Username: "test",
}
records, err := database.GetUserDNSRecords(db, context.User.ID)
if err != nil {
t.Fatal(err)
}
if len(records) > 0 {
t.Errorf("expected no records, got records")
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
api.PutDNSRecordContinuation(context, r, w)(api.IdContinuation, api.IdContinuation)
}))
defer ts.Close()
}

View File

@ -1,8 +1,6 @@
package api package api
import ( import (
"encoding/json"
"fmt"
"log" "log"
"net/http" "net/http"
"strings" "strings"
@ -43,16 +41,11 @@ func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp
return func(success Continuation, failure Continuation) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain {
name := req.FormValue("name") name := req.FormValue("name")
message := req.FormValue("message") message := req.FormValue("message")
hCaptchaResponse := req.FormValue("h-captcha-response")
formErrors := FormError{ formErrors := FormError{
Errors: []string{}, Errors: []string{},
} }
if hCaptchaResponse == "" {
formErrors.Errors = append(formErrors.Errors, "hCaptcha is required")
}
entry := &database.GuestbookEntry{ entry := &database.GuestbookEntry{
ID: utils.RandomId(), ID: utils.RandomId(),
Name: name, Name: name,
@ -60,22 +53,19 @@ func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp
} }
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse) if len(formErrors.Errors) == 0 {
_, err := database.SaveGuestbookEntry(context.DBConn, entry)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
formErrors.Errors = append(formErrors.Errors, "failed to save entry")
formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed")
} }
}
if len(formErrors.Errors) > 0 { if len(formErrors.Errors) > 0 {
(*context.TemplateData)["FormError"] = formErrors (*context.TemplateData)["FormError"] = formErrors
(*context.TemplateData)["EntryForm"] = entry (*context.TemplateData)["EntryForm"] = entry
return failure(context, req, resp) resp.WriteHeader(http.StatusBadRequest)
}
_, err = database.SaveGuestbookEntry(context.DBConn, entry)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp) return failure(context, req, resp)
} }
@ -96,46 +86,3 @@ func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp
return success(context, req, resp) return success(context, req, resp)
} }
} }
func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
SiteKey: context.Args.HcaptchaSiteKey,
}
log.Println(context.Args.HcaptchaSiteKey)
return success(context, req, resp)
}
}
func verifyHCaptcha(secret, response string) error {
verifyURL := "https://hcaptcha.com/siteverify"
body := strings.NewReader("secret=" + secret + "&response=" + response)
req, err := http.NewRequest("POST", verifyURL, body)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
jsonResponse := struct {
Success bool `json:"success"`
}{}
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
if err != nil {
return err
}
if !jsonResponse.Success {
return fmt.Errorf("hcaptcha verification failed")
}
defer resp.Body.Close()
return nil
}

129
api/guestbook_test.go Normal file
View File

@ -0,0 +1,129 @@
package api_test
import (
"database/sql"
"net/http"
"net/http/httptest"
"os"
"testing"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
func setup() (*sql.DB, *api.RequestContext, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
context := &api.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
}
return testDb, context, func() {
testDb.Close()
os.Remove(randomDb)
}
}
func TestValidGuestbookPutsInDatabase(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
entries, err := database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) > 0 {
t.Errorf("expected no entries, got entries")
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation)
}))
defer ts.Close()
req := httptest.NewRequest("POST", ts.URL, nil)
req.Form = map[string][]string{
"name": {"test"},
"message": {"test"},
}
w := httptest.NewRecorder()
ts.Config.Handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status code 200, got %d", w.Code)
}
entries, err = database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) != 1 {
t.Errorf("expected 1 entry, got %d", len(entries))
}
if entries[0].Name != req.FormValue("name") {
t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name)
}
}
func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
entries, err := database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) > 0 {
t.Errorf("expected no entries, got entries")
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation)
}))
defer testServer.Close()
reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n"
invalidRequests := []struct {
name string
message string
}{
{"", "test"},
{"test", ""},
{"", ""},
{"test", reallyLongStringThatWouldTakeTooMuchSpace},
}
for _, form := range invalidRequests {
req := httptest.NewRequest("POST", testServer.URL, nil)
req.Form = map[string][]string{
"name": {form.name},
"message": {form.message},
}
responseRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
if responseRecorder.Code != http.StatusBadRequest {
t.Errorf("expected status code 400, got %d", responseRecorder.Code)
}
}
entries, err = database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) != 0 {
t.Errorf("expected 0 entries, got %d", len(entries))
}
}

69
api/hcaptcha.go Normal file
View File

@ -0,0 +1,69 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
func verifyCaptcha(secret, response string) error {
verifyURL := "https://hcaptcha.com/siteverify"
body := strings.NewReader("secret=" + secret + "&response=" + response)
req, err := http.NewRequest("POST", verifyURL, body)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
jsonResponse := struct {
Success bool `json:"success"`
}{}
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
if err != nil {
return err
}
if !jsonResponse.Success {
return fmt.Errorf("hcaptcha verification failed")
}
defer resp.Body.Close()
return nil
}
func CaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
SiteKey: context.Args.HcaptchaSiteKey,
}
return success(context, req, resp)
}
}
func CaptchaVerificationContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
hCaptchaResponse := req.FormValue("h-captcha-response")
secretKey := context.Args.HcaptchaSecret
err := verifyCaptcha(secretKey, hCaptchaResponse)
if err != nil {
(*context.TemplateData)["FormError"] = FormError{
Errors: []string{"hCaptcha verification failed"},
}
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
return success(context, req, resp)
}
}

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"time" "time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
@ -23,6 +24,10 @@ type RequestContext struct {
User *database.User User *database.User
} }
type FormError struct {
Errors []string
}
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
type ContinuationChain func(Continuation, Continuation) ContinuationChain type ContinuationChain func(Continuation, Continuation) ContinuationChain
@ -80,11 +85,15 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
fileServer := http.FileServer(http.Dir(argv.StaticPath)) fileServer := http.FileServer(http.Dir(argv.StaticPath))
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600))) mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
APIToken: argv.CloudflareToken,
ZoneId: argv.CloudflareZone,
}
makeRequestContext := func() *RequestContext { makeRequestContext := func() *RequestContext {
return &RequestContext{ return &RequestContext{
DBConn: dbConn, DBConn: dbConn,
Args: argv, Args: argv,
TemplateData: &map[string]interface{}{}, TemplateData: &map[string]interface{}{},
} }
} }
@ -94,7 +103,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
@ -106,12 +115,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
@ -126,12 +130,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation(cloudflareAdapter), GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
@ -151,12 +155,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaVerificationContinuation, CaptchaVerificationContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {

View File

@ -66,7 +66,6 @@ func TemplateContinuation(path string, showBase bool) Continuation {
return failure(context, req, resp) return failure(context, req, resp)
} }
resp.WriteHeader(200)
resp.Header().Set("Content-Type", "text/html") resp.Header().Set("Content-Type", "text/html")
resp.Write(html.Bytes()) resp.Write(html.Bytes())
return success(context, req, resp) return success(context, req, resp)

View File

@ -1,4 +1,4 @@
package hcdns package hcdns_test
import ( import (
"database/sql" "database/sql"
@ -16,7 +16,7 @@ import (
) )
func randomPort() int { func randomPort() int {
return rand.Intn(3000) + 1024 return rand.Intn(3000) + 5192
} }
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
@ -60,69 +60,65 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
defer cleanup() defer cleanup()
defer lock.Unlock() defer lock.Unlock()
cname := &database.DNSRecord{ records := []*database.DNSRecord{
ID: "1", {
ID: "0",
UserID: "test", UserID: "test",
Name: "cname.internal.example.com.", Name: "cname.internal.example.com.",
Type: "CNAME", Type: "CNAME",
Content: "next.internal.example.com.",
TTL: 300,
Internal: true,
}, {
ID: "1",
UserID: "test",
Name: "next.internal.example.com.",
Type: "CNAME",
Content: "res.example.com.", Content: "res.example.com.",
TTL: 300, TTL: 300,
Internal: true, Internal: true,
} },
a := &database.DNSRecord{ {
ID: "2", ID: "2",
UserID: "test", UserID: "test",
Name: "res.example.com.", Name: "res.example.com.",
Type: "A", Type: "A",
Content: "127.0.0.1", Content: "1.2.3.2",
TTL: 300, TTL: 300,
Internal: true, Internal: true,
},
}
for _, record := range records {
database.SaveDNSRecord(testDb, record)
} }
database.SaveDNSRecord(testDb, cname)
database.SaveDNSRecord(testDb, a)
qtype := dns.TypeA qtype := dns.TypeA
domain := dns.Fqdn(cname.Name) domain := dns.Fqdn("cname.internal.example.com.")
client := &dns.Client{} client := &dns.Client{}
message := &dns.Msg{} message := &dns.Msg{}
message.SetQuestion(domain, qtype) message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr) in, _, err := client.Exchange(message, *addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(in.Answer) != 2 { if len(in.Answer) != 3 {
t.Fatalf("expected 2 answers, got %d", len(in.Answer)) t.Fatalf("expected 3 answers, got %d", len(in.Answer))
} }
if in.Answer[0].Header().Name != cname.Name { for i, record := range records {
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) if in.Answer[i].Header().Name != record.Name {
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
} }
if in.Answer[1].Header().Name != a.Name { if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
} }
if in.Answer[0].(*dns.CNAME).Target != a.Name { if int(in.Answer[i].Header().Ttl) != record.TTL {
t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
}
if in.Answer[1].(*dns.A).A.String() != a.Content {
t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String())
}
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
}
if in.Answer[1].Header().Rrtype != dns.TypeA {
t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype)
}
if int(in.Answer[0].Header().Ttl) != cname.TTL {
t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl)
} }
if !in.Authoritative { if !in.Authoritative {
@ -130,6 +126,11 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
} }
} }
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
t.Fatalf("expected final record to be the A record with correct IP")
}
}
func TestWhenNoRecordNxDomain(t *testing.T) { func TestWhenNoRecordNxDomain(t *testing.T) {
t.Log("TestWhenNoRecordNxDomain") t.Log("TestWhenNoRecordNxDomain")