testing | dont be recursive for external domains | finalize oauth #5

Merged
simponic merged 24 commits from dont-be-authoritative into main 2024-04-06 15:43:19 -04:00
27 changed files with 1839 additions and 575 deletions

View File

@ -2,3 +2,4 @@
hatecomputers.club hatecomputers.club
Dockerfile Dockerfile
*.db *.db
.drone.yml

View File

@ -1,9 +1,30 @@
--- ---
kind: pipeline kind: pipeline
type: docker type: docker
name: build, publish docker image, deploy name: build
steps: steps:
- name: run tests
image: golang
commands:
- go build
- go test -p 1 -v ./...
trigger:
event:
- pull_request
---
kind: pipeline
type: docker
name: deploy
steps:
- name: run tests
image: golang
commands:
- go build
- go test -p 1 -v ./...
- name: docker - name: docker
image: plugins/docker image: plugins/docker
settings: settings:
@ -13,9 +34,6 @@ steps:
from_secret: gitea_packpub_password from_secret: gitea_packpub_password
registry: git.hatecomputers.club registry: git.hatecomputers.club
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
tags:
- latest
- main
- name: ssh - name: ssh
image: appleboy/drone-ssh image: appleboy/drone-ssh
settings: settings:
@ -27,6 +45,9 @@ steps:
command_timeout: 2m command_timeout: 2m
script: script:
- systemctl restart docker-compose@hatecomputers-club - systemctl restart docker-compose@hatecomputers-club
trigger: trigger:
branch: branch:
- main - main
event:
- push

View File

@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080 EXPOSE 8080
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"] CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]

View File

@ -14,15 +14,20 @@ type CloudflareDNSResponse struct {
Result database.DNSRecord `json:"result"` Result database.DNSRecord `json:"result"`
} }
func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) { type CloudflareExternalDNSAdapter struct {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) ZoneId string
APIToken string
}
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
payload := strings.NewReader(reqBody) payload := strings.NewReader(reqBody)
req, _ := http.NewRequest("POST", url, payload) req, _ := http.NewRequest("POST", url, payload)
req.Header.Add("Authorization", "Bearer "+apiToken) req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
return result.ID, nil return result.ID, nil
} }
func DeleteDNSRecord(zoneId string, apiToken string, id string) error { func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id) url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
req, _ := http.NewRequest("DELETE", url, nil) req, _ := http.NewRequest("DELETE", url, nil)
req.Header.Add("Authorization", "Bearer "+apiToken) req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {

8
adapters/external_dns.go Normal file
View File

@ -0,0 +1,8 @@
package external_dns
import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
type ExternalDNSAdapter interface {
CreateDNSRecord(record *database.DNSRecord) (string, error)
DeleteDNSRecord(id string) error
}

View File

@ -1,4 +1,4 @@
package api package auth
import ( import (
"crypto/sha256" "crypto/sha256"
@ -12,13 +12,14 @@ import (
"strings" "strings"
"time" "time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
verifier := utils.RandomId() + utils.RandomId() verifier := utils.RandomId() + utils.RandomId()
sha2 := sha256.New() sha2 := sha256.New()
@ -34,7 +35,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
Path: "/", Path: "/",
Secure: true, Secure: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
MaxAge: 60, MaxAge: 200,
}) })
http.SetCookie(resp, &http.Cookie{ http.SetCookie(resp, &http.Cookie{
Name: "state", Name: "state",
@ -42,7 +43,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
Path: "/", Path: "/",
Secure: true, Secure: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
MaxAge: 60, MaxAge: 200,
}) })
http.Redirect(resp, req, url, http.StatusFound) http.Redirect(resp, req, url, http.StatusFound)
@ -50,8 +51,8 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
} }
} }
func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
state := req.URL.Query().Get("state") state := req.URL.Query().Get("state")
code := req.URL.Query().Get("code") code := req.URL.Query().Get("code")
@ -73,7 +74,6 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
reqContext := req.Context() reqContext := req.Context()
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
if err != nil { if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp) return failure(context, req, resp)
} }
@ -101,6 +101,16 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
Secure: true, Secure: true,
}) })
http.SetCookie(resp, &http.Cookie{
Name: "verifier",
Value: "",
MaxAge: 0,
})
http.SetCookie(resp, &http.Cookie{
Name: "state",
Value: "",
MaxAge: 0,
})
redirect := "/" redirect := "/"
redirectCookie, err := req.Cookie("redirect") redirectCookie, err := req.Cookie("redirect")
@ -109,6 +119,7 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
http.SetCookie(resp, &http.Cookie{ http.SetCookie(resp, &http.Cookie{
Name: "redirect", Name: "redirect",
MaxAge: 0, MaxAge: 0,
Value: "",
}) })
} }
@ -117,6 +128,127 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
} }
} }
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
authHeader := req.Header.Get("Authorization")
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
sessionCookie, err := req.Cookie("session")
if err == nil && sessionCookie.Value != "" {
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
}
if userErr != nil || user == nil {
log.Println(userErr, user)
http.SetCookie(resp, &http.Cookie{
Name: "session",
Value: "",
MaxAge: 0,
})
context.User = nil
return failure(context, req, resp)
}
context.User = user
return success(context, req, resp)
}
}
func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
http.SetCookie(resp, &http.Cookie{
Name: "redirect",
Value: req.URL.Path,
Path: "/",
Secure: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(resp, req, "/login", http.StatusFound)
return failure(context, req, resp)
}
}
func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err != nil {
return failure(context, req, resp)
}
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
if err != nil {
return failure(context, req, resp)
}
return success(context, req, resp)
}
}
func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err == nil && sessionCookie.Value != "" {
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
}
http.SetCookie(resp, &http.Cookie{
Name: "session",
MaxAge: 0,
Value: "",
})
http.Redirect(resp, req, "/", http.StatusFound)
return success(context, req, resp)
}
}
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
userResponse, err := client.Get(uri)
if err != nil {
return nil, err
}
userStruct, err := createUserFromOauthResponse(userResponse)
if err != nil {
return nil, err
}
user, err := database.FindOrSaveUser(dbConn, userStruct)
if err != nil {
return nil, err
}
return user, nil
}
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
user := &database.User{}
err := json.NewDecoder(response.Body).Decode(user)
defer response.Body.Close()
if err != nil {
log.Println(err)
return nil, err
}
user.Username = strings.ToLower(user.Username)
user.Username = strings.Split(user.Username, "@")[0]
return user, nil
}
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
cookie, err := req.Cookie(stateCookieName)
if err != nil || cookie.Value != expectedState {
return false
}
return true
}
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
if bearerToken == "" { if bearerToken == "" {
return nil, nil return nil, nil
@ -127,15 +259,15 @@ func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User,
return nil, nil return nil, nil
} }
apiKey, err := database.GetAPIKey(dbConn, parts[1]) key, err := database.GetAPIKey(dbConn, parts[1])
if err != nil { if err != nil {
return nil, err return nil, err
} }
if apiKey == nil { if key == nil {
return nil, nil return nil, nil
} }
user, err := database.GetUser(dbConn, apiKey.UserID) user, err := database.GetUser(dbConn, key.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -162,124 +294,3 @@ func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error
return user, nil return user, nil
} }
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
authHeader := req.Header.Get("Authorization")
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
sessionCookie, err := req.Cookie("session")
if err == nil && sessionCookie.Value != "" {
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
}
if userErr != nil || user == nil {
log.Println(userErr, user)
http.SetCookie(resp, &http.Cookie{
Name: "session",
MaxAge: 0, // reset session cookie in case
})
context.User = nil
return failure(context, req, resp)
}
context.User = user
return success(context, req, resp)
}
}
func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
http.SetCookie(resp, &http.Cookie{
Name: "redirect",
Value: req.URL.Path,
Path: "/",
Secure: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(resp, req, "/login", http.StatusFound)
return failure(context, req, resp)
}
}
func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err != nil {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
if err != nil {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
return success(context, req, resp)
}
}
func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err == nil && sessionCookie.Value != "" {
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
}
http.Redirect(resp, req, "/", http.StatusFound)
http.SetCookie(resp, &http.Cookie{
Name: "session",
MaxAge: 0,
})
return success(context, req, resp)
}
}
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
userResponse, err := client.Get(uri)
if err != nil {
return nil, err
}
userStruct, err := createUserFromResponse(userResponse)
if err != nil {
return nil, err
}
user, err := database.FindOrSaveUser(dbConn, userStruct)
if err != nil {
return nil, err
}
return user, nil
}
func createUserFromResponse(response *http.Response) (*database.User, error) {
defer response.Body.Close()
user := &database.User{
CreatedAt: time.Now(),
}
err := json.NewDecoder(response.Body).Decode(user)
if err != nil {
log.Println(err)
return nil, err
}
user.Username = strings.ToLower(user.Username)
user.Username = strings.Split(user.Username, "@")[0]
return user, nil
}
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
cookie, err := req.Cookie(stateCookieName)
if err != nil || cookie.Value != expectedState {
return false
}
return true
}

307
api/auth/auth_test.go Normal file
View File

@ -0,0 +1,307 @@
package auth_test
import (
"database/sql"
"log"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
"golang.org/x/oauth2"
)
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
return success(context, req, resp)
}
}
func FakedOauthServer() *httptest.Server {
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/auth" {
code := utils.RandomId()
state := r.URL.Query().Get("state")
redirectPath := r.URL.Query().Get("redirect_uri")
redirectPath += "?code=" + code + "&state=" + state
http.Redirect(w, r, redirectPath, http.StatusFound)
}
if r.URL.Path == "/token" {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
}
if r.URL.Path == "/user" {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
}
}))
return oauthServer
}
func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
resp.Write([]byte(context.User.Username))
return success(context, req, resp)
}
}
func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/protected-path" {
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/login" {
log.Println("login")
auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/callback" {
log.Println("callback")
auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/me" {
auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/logout" {
auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
}
}))
return testServer
}
func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
context := &types.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
}
oauthServer := FakedOauthServer()
testServer := MockUserEndpointServer(context)
return testDb, context, oauthServer, testServer, func() {
oauthServer.Close()
testServer.Close()
testDb.Close()
os.Remove(randomDb)
}
}
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
return &oauth2.Config{
ClientID: "test",
ClientSecret: "test",
Scopes: []string{"test"},
Endpoint: oauth2.Endpoint{
AuthURL: oauthServerURL + "/auth",
TokenURL: oauthServerURL + "/token",
},
RedirectURL: testServerURL + "/callback",
}, oauthServerURL + "/user"
}
func FollowAuthentication(
oauthServer *httptest.Server,
testServer *httptest.Server,
cookies map[string]*http.Cookie,
location string,
) (map[string]*http.Cookie, string) {
resp := httptest.NewRecorder()
resp.Code = 0
for resp.Code == 0 || resp.Code == http.StatusFound {
req := httptest.NewRequest("GET", location, nil)
resp = httptest.NewRecorder()
for _, cookie := range cookies {
req.AddCookie(cookie)
}
if strings.HasPrefix(location, oauthServer.URL) {
oauthServer.Config.Handler.ServeHTTP(resp, req)
} else {
testServer.Config.Handler.ServeHTTP(resp, req)
}
for _, cookie := range resp.Result().Cookies() {
cookies[cookie.Name] = cookie
}
if resp.Code == http.StatusFound {
location = resp.Header().Get("Location")
}
}
return cookies, location
}
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
user, _ := database.GetUser(db, "test")
if user != nil {
t.Errorf("expected no user, got user")
}
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
user, _ = database.GetUser(db, "test")
if user == nil {
t.Errorf("expected a user to be created, could not find user")
}
if user.Username != "test" {
t.Errorf("expected username to be test, got %s", user.Username)
}
}
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
_, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
req := httptest.NewRequest("GET", "/protected-path", nil)
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
location := resp.Header().Get("Location")
if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
}
cookies := make(map[string]*http.Cookie)
cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
if !(strings.HasSuffix(location, "/protected-page")) {
t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
}
}
func TestOauthSetsUniqueSession(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
cookiesAgain := make(map[string]*http.Cookie)
cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
sessionOne := cookies["session"].Value
sessionTwo := cookiesAgain["session"].Value
if sessionOne == sessionTwo {
t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
}
session, _ := database.GetSession(db, sessionOne)
if session.UserID != "test" {
t.Errorf("expected session to be associated with user test, got %s", session.UserID)
}
}
func TestLogoutClearsSession(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
req := httptest.NewRequest("GET", "/logout", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
for _, cookie := range resp.Result().Cookies() {
cookies[cookie.Name] = cookie
}
req = httptest.NewRequest("GET", "/me", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
}
session, _ := database.GetSession(db, cookies["session"].Value)
if session != nil {
t.Errorf("expected session to be deleted, got session")
}
}
func TestRefreshUpdatesExpiration(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
session, _ := database.GetSession(db, cookies["session"].Value)
req := httptest.NewRequest("GET", "/me", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
updatedSession, _ := database.GetSession(db, cookies["session"].Value)
if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
}
}
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
session, _ := database.GetSession(db, cookies["session"].Value)
session.ExpireAt = time.Now().Add(-time.Hour)
database.SaveSession(db, session)
req := httptest.NewRequest("GET", "/me", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
}
}

View File

@ -1,179 +0,0 @@
package api
import (
"database/sql"
"fmt"
"log"
"net/http"
"strconv"
"strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
const MAX_USER_RECORDS = 65
type FormError struct {
Errors []string
}
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool {
ownedByUser := (user.ID == record.UserID)
if !ownedByUser {
return false
}
if !record.Internal {
userOwnedDomains := []string{
fmt.Sprintf("%s", user.Username),
fmt.Sprintf("%s.endpoints", user.Username),
}
for _, domain := range userOwnedDomains {
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
if domain == record.Name || isInSubDomain {
return true
}
}
return false
}
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
if err != nil {
log.Println(err)
return false
}
userIsOwnerOfDomain := owner == user.ID
return ownedByUser && userIsOwnerOfDomain
}
func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
(*context.TemplateData)["DNSRecords"] = dnsRecords
return success(context, req, resp)
}
}
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
formErrors := FormError{
Errors: []string{},
}
internal := req.FormValue("internal") == "on"
name := req.FormValue("name")
if internal && !strings.HasSuffix(name, ".") {
name += "."
}
recordType := req.FormValue("type")
recordType = strings.ToUpper(recordType)
recordContent := req.FormValue("content")
ttl := req.FormValue("ttl")
ttlNum, err := strconv.Atoi(ttl)
if err != nil {
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
}
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if dnsRecordCount >= MAX_USER_RECORDS {
formErrors.Errors = append(formErrors.Errors, "max records reached")
}
dnsRecord := &database.DNSRecord{
UserID: context.User.ID,
Name: name,
Type: recordType,
Content: recordContent,
TTL: ttlNum,
Internal: internal,
}
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
}
if len(formErrors.Errors) == 0 {
if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId()
} else {
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, err.Error())
}
dnsRecord.ID = cloudflareRecordId
}
}
if len(formErrors.Errors) == 0 {
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, "error saving record")
}
}
if len(formErrors.Errors) == 0 {
http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
(*context.TemplateData)["FormError"] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
}
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
recordId := req.FormValue("id")
record, err := database.GetDNSRecord(context.DBConn, recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
if !record.Internal {
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
}
err = database.DeleteDNSRecord(context.DBConn, recordId)
if err != nil {
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
}

174
api/dns/dns.go Normal file
View File

@ -0,0 +1,174 @@
package dns
import (
"database/sql"
"fmt"
"log"
"net/http"
"strconv"
"strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
ownedByUser := (user.ID == record.UserID)
if !ownedByUser {
return false
}
if !record.Internal {
for _, format := range ownedInternalDomainFormats {
domain := fmt.Sprintf(format, user.Username)
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
if domain == record.Name || isInSubDomain {
return true
}
}
return false
}
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
if err != nil {
log.Println(err)
return false
}
userIsOwnerOfDomain := owner == user.ID
return ownedByUser && userIsOwnerOfDomain
}
func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
(*context.TemplateData)["DNSRecords"] = dnsRecords
return success(context, req, resp)
}
}
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
formErrors := types.FormError{
Errors: []string{},
}
internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
name := req.FormValue("name")
if internal && !strings.HasSuffix(name, ".") {
name += "."
}
recordType := req.FormValue("type")
recordType = strings.ToUpper(recordType)
recordContent := req.FormValue("content")
ttl := req.FormValue("ttl")
ttlNum, err := strconv.Atoi(ttl)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
}
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if dnsRecordCount >= maxUserRecords {
resp.WriteHeader(http.StatusTooManyRequests)
formErrors.Errors = append(formErrors.Errors, "max records reached")
}
dnsRecord := &database.DNSRecord{
UserID: context.User.ID,
Name: name,
Type: recordType,
Content: recordContent,
TTL: ttlNum,
Internal: internal,
}
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
resp.WriteHeader(http.StatusUnauthorized)
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
}
if len(formErrors.Errors) == 0 {
if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId()
} else {
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
formErrors.Errors = append(formErrors.Errors, err.Error())
}
}
}
if len(formErrors.Errors) == 0 {
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, "error saving record")
}
}
if len(formErrors.Errors) == 0 {
return success(context, req, resp)
}
(*context.TemplateData)["FormError"] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
return failure(context, req, resp)
}
}
}
func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
recordId := req.FormValue("id")
record, err := database.GetDNSRecord(context.DBConn, recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
if !(record.UserID == context.User.ID) {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
if !record.Internal {
err = dnsAdapter.DeleteDNSRecord(recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
}
err = database.DeleteDNSRecord(context.DBConn, recordId)
if err != nil {
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
return success(context, req, resp)
}
}
}

442
api/dns/dns_test.go Normal file
View File

@ -0,0 +1,442 @@
package dns_test
import (
"database/sql"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
const MAX_USER_RECORDS = 10
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
return success(context, req, resp)
}
}
func setup() (*sql.DB, *types.RequestContext, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
user := &database.User{
ID: "test",
Username: "test",
Mail: "test@test.com",
DisplayName: "test",
}
database.FindOrSaveUser(testDb, user)
context := &types.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
User: user,
}
return testDb, context, func() {
testDb.Close()
os.Remove(randomDb)
}
}
type SignallingExternalDnsAdapter struct {
AddChannel chan *database.DNSRecord
RmChannel chan string
}
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
id := utils.RandomId()
go func() { adapter.AddChannel <- record }()
return id, nil
}
func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
go func() { adapter.RmChannel <- id }()
return nil
}
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.domain.",
}
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
records, err := database.GetUserDNSRecords(db, context.User.ID)
if err != nil {
t.Fatal(err)
}
if len(records) > 0 {
t.Errorf("expected no records, got records")
}
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
validOwner.Form = map[string][]string{
"internal": {"on"},
"name": {"new.test.domain."},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"test.domain."},
}
validOwnerRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
if validOwnerRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
}
validOwnerNonInternalRecorder := httptest.NewRecorder()
validOwner.Form["internal"] = []string{"off"}
testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
}
invalidOwnerRecorder := httptest.NewRecorder()
invalidOwner := validOwner
invalidOwner.Form["internal"] = []string{"on"}
invalidOwner.Form["name"] = []string{"new.invalid.domain."}
testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
if invalidOwnerRecorder.Code != http.StatusUnauthorized {
t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
}
}
func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
responseRecorder := httptest.NewRecorder()
req := httptest.NewRequest("POST", testServer.URL, nil)
fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
for _, format := range fmts {
name := fmt.Sprintf(format, context.User.Username)
req.Form = map[string][]string{
"internal": {"off"},
"name": {name},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"test.domain."},
}
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
if responseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", responseRecorder.Code)
}
namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
if len(namedRecords) == 0 {
t.Errorf("saved record not found")
}
}
}
func TestThatExternalDnsSaves(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
responseRecorder := httptest.NewRecorder()
externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
name := "test." + context.User.Username
externalRequest.Form = map[string][]string{
"internal": {"off"},
"name": {name},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"test.domain."},
}
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
if responseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", responseRecorder.Code)
}
select {
case res := <-addChannel:
if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
t.Errorf("received the wrong external record")
}
case <-time.After(100 * time.Millisecond):
t.Errorf("timed out in waiting for external addition")
}
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.domain.",
}
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
internalRequest := externalRequest
internalRequest.Form["internal"] = []string{"on"}
internalRequest.Form["name"] = []string{"test.domain."}
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
if responseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", responseRecorder.Code)
}
select {
case _ = <-addChannel:
t.Errorf("expected nothing in the add channel")
case <-time.After(100 * time.Millisecond):
}
}
func TestThatUserMustOwnRecordToRemove(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
rmChannel := make(chan string)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
RmChannel: rmChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
_, err := database.FindOrSaveUser(db, nonOwnerUser)
if err != nil {
t.Error(err)
}
record := &database.DNSRecord{
ID: "1",
Internal: false,
Name: "test",
Type: "CNAME",
Content: "asdf",
TTL: 1000,
UserID: nonOwnerUser.ID,
}
_, err = database.SaveDNSRecord(db, record)
if err != nil {
t.Error(err)
}
nonOwnerRecorder := httptest.NewRecorder()
nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
nonOwner.Form = map[string][]string{
"id": {record.ID},
}
testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
if nonOwnerRecorder.Code != http.StatusUnauthorized {
t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
}
record.UserID = context.User.ID
record.ID = "2"
database.SaveDNSRecord(db, record)
owner := nonOwner
owner.Form["id"] = []string{"2"}
ownerRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
if ownerRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", ownerRecorder.Code)
}
}
func TestThatExternalDnsRemoves(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
record := &database.DNSRecord{
ID: "1",
Internal: false,
Name: "test",
Type: "CNAME",
Content: "asdf",
TTL: 1000,
UserID: context.User.ID,
}
database.SaveDNSRecord(db, record)
rmChannel := make(chan string)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
RmChannel: rmChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
externalResponseRecorder := httptest.NewRecorder()
deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
deleteRequest.Form = map[string][]string{
"id": {record.ID},
}
testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
if externalResponseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
}
select {
case res := <-rmChannel:
if res != record.ID {
t.Errorf("received the wrong external record")
}
case <-time.After(100 * time.Millisecond):
t.Errorf("timed out in waiting for external addition")
}
record.Internal = true
record.Name = "test.domain."
database.SaveDNSRecord(db, record)
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.domain.",
}
database.SaveDomainOwner(db, domainOwner)
internalResponseRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
if internalResponseRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
}
select {
case _ = <-rmChannel:
t.Errorf("expected nothing in the rmchannel")
case <-time.After(100 * time.Millisecond):
}
}
func TestRecordCountCannotExceed(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
record := &database.DNSRecord{
Internal: false,
Name: context.User.Username,
Type: "CNAME",
Content: "asdf",
TTL: 1000,
UserID: context.User.ID,
}
for i := 1; i <= MAX_USER_RECORDS; i++ {
record.ID = strconv.Itoa(i)
record.Name = record.ID + "." + record.Name
database.SaveDNSRecord(db, record)
}
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
req := httptest.NewRequest("POST", testServer.URL, nil)
req.Form = map[string][]string{
"internal": {"off"},
"name": {record.Name},
"type": {record.Type},
"ttl": {"43000"},
"content": {record.Content},
}
recorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusTooManyRequests {
t.Errorf("expected too many requests code return, got %d", recorder.Code)
}
}
func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
domainOwner := &database.DomainOwner{
UserID: context.User.ID,
Domain: "test.internal.",
}
database.SaveDomainOwner(db, domainOwner)
addChannel := make(chan *database.DNSRecord)
signallingDnsAdapter := &SignallingExternalDnsAdapter{
AddChannel: addChannel,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
validOwner.Form = map[string][]string{
"internal": {"on"},
"name": {"test.internal"},
"type": {"CNAME"},
"ttl": {"43000"},
"content": {"asdf.internal"},
}
validOwnerRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
if validOwnerRecorder.Code != http.StatusOK {
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
}
recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
t.Errorf("expected dot appended")
}
}

View File

@ -1,141 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
type HcaptchaArgs struct {
SiteKey string
}
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
errors := []string{}
if entry.Name == "" {
errors = append(errors, "name is required")
}
if entry.Message == "" {
errors = append(errors, "message is required")
}
messageLength := len(entry.Message)
if messageLength > 500 {
errors = append(errors, "message cannot be longer than 500 characters")
}
newLines := strings.Count(entry.Message, "\n")
if newLines > 10 {
errors = append(errors, "message cannot contain more than 10 new lines")
}
return errors
}
func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
name := req.FormValue("name")
message := req.FormValue("message")
hCaptchaResponse := req.FormValue("h-captcha-response")
formErrors := FormError{
Errors: []string{},
}
if hCaptchaResponse == "" {
formErrors.Errors = append(formErrors.Errors, "hCaptcha is required")
}
entry := &database.GuestbookEntry{
ID: utils.RandomId(),
Name: name,
Message: message,
}
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed")
}
if len(formErrors.Errors) > 0 {
(*context.TemplateData)["FormError"] = formErrors
(*context.TemplateData)["EntryForm"] = entry
return failure(context, req, resp)
}
_, err = database.SaveGuestbookEntry(context.DBConn, entry)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
return success(context, req, resp)
}
}
func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
entries, err := database.GetGuestbookEntries(context.DBConn)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
(*context.TemplateData)["GuestbookEntries"] = entries
return success(context, req, resp)
}
}
func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
SiteKey: context.Args.HcaptchaSiteKey,
}
log.Println(context.Args.HcaptchaSiteKey)
return success(context, req, resp)
}
}
func verifyHCaptcha(secret, response string) error {
verifyURL := "https://hcaptcha.com/siteverify"
body := strings.NewReader("secret=" + secret + "&response=" + response)
req, err := http.NewRequest("POST", verifyURL, body)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
jsonResponse := struct {
Success bool `json:"success"`
}{}
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
if err != nil {
return err
}
if !jsonResponse.Success {
return fmt.Errorf("hcaptcha verification failed")
}
defer resp.Body.Close()
return nil
}

View File

@ -0,0 +1,85 @@
package guestbook
import (
"log"
"net/http"
"strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
errors := []string{}
if entry.Name == "" {
errors = append(errors, "name is required")
}
if entry.Message == "" {
errors = append(errors, "message is required")
}
messageLength := len(entry.Message)
if messageLength > 500 {
errors = append(errors, "message cannot be longer than 500 characters")
}
newLines := strings.Count(entry.Message, "\n")
if newLines > 10 {
errors = append(errors, "message cannot contain more than 10 new lines")
}
return errors
}
func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
name := req.FormValue("name")
message := req.FormValue("message")
formErrors := types.FormError{
Errors: []string{},
}
entry := &database.GuestbookEntry{
ID: utils.RandomId(),
Name: name,
Message: message,
}
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
if len(formErrors.Errors) == 0 {
_, err := database.SaveGuestbookEntry(context.DBConn, entry)
if err != nil {
log.Println(err)
formErrors.Errors = append(formErrors.Errors, "failed to save entry")
}
}
if len(formErrors.Errors) > 0 {
(*context.TemplateData)["FormError"] = formErrors
(*context.TemplateData)["EntryForm"] = entry
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
return success(context, req, resp)
}
}
func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
entries, err := database.GetGuestbookEntries(context.DBConn)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
(*context.TemplateData)["GuestbookEntries"] = entries
return success(context, req, resp)
}
}

View File

@ -0,0 +1,136 @@
package guestbook_test
import (
"database/sql"
"net/http"
"net/http/httptest"
"os"
"testing"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
return success(context, req, resp)
}
}
func setup() (*sql.DB, *types.RequestContext, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
context := &types.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
}
return testDb, context, func() {
testDb.Close()
os.Remove(randomDb)
}
}
func TestValidGuestbookPutsInDatabase(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
entries, err := database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) > 0 {
t.Errorf("expected no entries, got entries")
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
}))
defer ts.Close()
req := httptest.NewRequest("POST", ts.URL, nil)
req.Form = map[string][]string{
"name": {"test"},
"message": {"test"},
}
w := httptest.NewRecorder()
ts.Config.Handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status code 200, got %d", w.Code)
}
entries, err = database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) != 1 {
t.Errorf("expected 1 entry, got %d", len(entries))
}
if entries[0].Name != req.FormValue("name") {
t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name)
}
}
func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
entries, err := database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) > 0 {
t.Errorf("expected no entries, got entries")
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
}))
defer testServer.Close()
reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n"
invalidRequests := []struct {
name string
message string
}{
{"", "test"},
{"test", ""},
{"", ""},
{"test", reallyLongStringThatWouldTakeTooMuchSpace},
}
for _, form := range invalidRequests {
req := httptest.NewRequest("POST", testServer.URL, nil)
req.Form = map[string][]string{
"name": {form.name},
"message": {form.message},
}
responseRecorder := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
if responseRecorder.Code != http.StatusBadRequest {
t.Errorf("expected status code 400, got %d", responseRecorder.Code)
}
}
entries, err = database.GetGuestbookEntries(db)
if err != nil {
t.Fatal(err)
}
if len(entries) != 0 {
t.Errorf("expected 0 entries, got %d", len(entries))
}
}

75
api/hcaptcha/hcaptcha.go Normal file
View File

@ -0,0 +1,75 @@
package hcaptcha
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
)
type HcaptchaArgs struct {
SiteKey string
}
func verifyCaptcha(secret, response string) error {
verifyURL := "https://hcaptcha.com/siteverify"
body := strings.NewReader("secret=" + secret + "&response=" + response)
req, err := http.NewRequest("POST", verifyURL, body)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
jsonResponse := struct {
Success bool `json:"success"`
}{}
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
if err != nil {
return err
}
if !jsonResponse.Success {
return fmt.Errorf("hcaptcha verification failed")
}
defer resp.Body.Close()
return nil
}
func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
SiteKey: context.Args.HcaptchaSiteKey,
}
return success(context, req, resp)
}
}
func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
hCaptchaResponse := req.FormValue("h-captcha-response")
secretKey := context.Args.HcaptchaSecret
err := verifyCaptcha(secretKey, hCaptchaResponse)
if err != nil {
(*context.TemplateData)["FormError"] = types.FormError{
Errors: []string{"hCaptcha verification failed"},
}
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
return success(context, req, resp)
}
}

View File

@ -1,32 +1,33 @@
package api package keys
import ( import (
"log" "log"
"net/http" "net/http"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
) )
const MAX_USER_API_KEYS = 5 const MAX_USER_API_KEYS = 5
func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp) return failure(context, req, resp)
} }
(*context.TemplateData)["APIKeys"] = apiKeys (*context.TemplateData)["APIKeys"] = typesKeys
return success(context, req, resp) return success(context, req, resp)
} }
} }
func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
formErrors := FormError{ formErrors := types.FormError{
Errors: []string{}, Errors: []string{},
} }
@ -38,7 +39,7 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
} }
if numKeys >= MAX_USER_API_KEYS { if numKeys >= MAX_USER_API_KEYS {
formErrors.Errors = append(formErrors.Errors, "max api keys reached") formErrors.Errors = append(formErrors.Errors, "max types keys reached")
} }
if len(formErrors.Errors) > 0 { if len(formErrors.Errors) > 0 {
@ -59,29 +60,28 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
} }
} }
func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
key := req.FormValue("key") apiKey := req.FormValue("key")
apiKey, err := database.GetAPIKey(context.DBConn, key) key, err := database.GetAPIKey(context.DBConn, apiKey)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp) return failure(context, req, resp)
} }
if (apiKey == nil) || (apiKey.UserID != context.User.ID) { if (key == nil) || (key.UserID != context.User.ID) {
resp.WriteHeader(http.StatusUnauthorized) resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp) return failure(context, req, resp)
} }
err = database.DeleteAPIKey(context.DBConn, key) err = database.DeleteAPIKey(context.DBConn, apiKey)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp) return failure(context, req, resp)
} }
http.Redirect(resp, req, "/keys", http.StatusFound)
return success(context, req, resp) return success(context, req, resp)
} }
} }

View File

@ -7,27 +7,20 @@ import (
"net/http" "net/http"
"time" "time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
) )
type RequestContext struct { func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
DBConn *sql.DB return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
Args *args.Arguments
Id string
Start time.Time
TemplateData *map[string]interface{}
User *database.User
}
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
type ContinuationChain func(Continuation, Continuation) ContinuationChain
func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, _failure Continuation) ContinuationChain {
context.Start = time.Now() context.Start = time.Now()
context.Id = utils.RandomId() context.Id = utils.RandomId()
@ -36,8 +29,8 @@ func LogRequestContinuation(context *RequestContext, req *http.Request, resp htt
} }
} }
func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, _failure Continuation) ContinuationChain { return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
end := time.Now() end := time.Now()
log.Println(context.Id, "took", end.Sub(context.Start)) log.Println(context.Id, "took", end.Sub(context.Start))
@ -46,22 +39,22 @@ func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, re
} }
} }
func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, _failure Continuation) ContinuationChain { return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
resp.WriteHeader(200) resp.WriteHeader(200)
resp.Write([]byte("healthy")) resp.Write([]byte("healthy"))
return success(context, req, resp) return success(context, req, resp)
} }
} }
func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(_success Continuation, failure Continuation) ContinuationChain { return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain {
return failure(context, req, resp) return failure(context, req, resp)
} }
} }
func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, _failure Continuation) ContinuationChain { return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
return success(context, req, resp) return success(context, req, resp)
} }
} }
@ -80,89 +73,90 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
fileServer := http.FileServer(http.Dir(argv.StaticPath)) fileServer := http.FileServer(http.Dir(argv.StaticPath))
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600))) mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
makeRequestContext := func() *RequestContext { cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
return &RequestContext{ APIToken: argv.CloudflareToken,
ZoneId: argv.CloudflareZone,
}
makeRequestContext := func() *types.RequestContext {
return &types.RequestContext{
DBConn: dbConn, DBConn: dbConn,
Args: argv, Args: argv,
TemplateData: &map[string]interface{}{}, TemplateData: &map[string]interface{}{},
} }
} }
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
const MAX_USER_RECORDS = 100
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext() requestContext := makeRequestContext()
name := r.PathValue("name") name := r.PathValue("name")
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
}) })
return &http.Server{ return &http.Server{

View File

@ -1,4 +1,4 @@
package api package template
import ( import (
"bytes" "bytes"
@ -7,9 +7,11 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
) )
func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) { func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
templatePath := context.Args.TemplatePath templatePath := context.Args.TemplatePath
basePath := templatePath + "/base_empty.html" basePath := templatePath + "/base_empty.html"
if showBaseHtml { if showBaseHtml {
@ -41,9 +43,9 @@ func renderTemplate(context *RequestContext, templateName string, showBaseHtml b
return buffer, nil return buffer, nil
} }
func TemplateContinuation(path string, showBase bool) Continuation { func TemplateContinuation(path string, showBase bool) types.Continuation {
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
html, err := renderTemplate(context, path, true) html, err := renderTemplate(context, path, true)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
resp.WriteHeader(404) resp.WriteHeader(404)
@ -66,7 +68,6 @@ func TemplateContinuation(path string, showBase bool) Continuation {
return failure(context, req, resp) return failure(context, req, resp)
} }
resp.WriteHeader(200)
resp.Header().Set("Content-Type", "text/html") resp.Header().Set("Content-Type", "text/html")
resp.Write(html.Bytes()) resp.Write(html.Bytes())
return success(context, req, resp) return success(context, req, resp)

28
api/types/types.go Normal file
View File

@ -0,0 +1,28 @@
package types
import (
"database/sql"
"net/http"
"time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
)
type RequestContext struct {
DBConn *sql.DB
Args *args.Arguments
Id string
Start time.Time
TemplateData *map[string]interface{}
User *database.User
}
type FormError struct {
Errors []string
}
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
type ContinuationChain func(Continuation, Continuation) ContinuationChain

View File

@ -23,7 +23,6 @@ type Arguments struct {
OauthUserInfoURI string OauthUserInfoURI string
Dns bool Dns bool
DnsRecursion []string
DnsPort int DnsPort int
CloudflareToken string CloudflareToken string
@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
server := flag.Bool("server", false, "Run the server") server := flag.Bool("server", false, "Run the server")
dns := flag.Bool("dns", false, "Run DNS resolver") dns := flag.Bool("dns", false, "Run DNS resolver")
dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver") dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
flag.Parse() flag.Parse()
@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
Migrate: *migrate, Migrate: *migrate,
Scheduler: *scheduler, Scheduler: *scheduler,
Dns: *dns, Dns: *dns,
DnsRecursion: strings.Split(*dnsRecursion, ","),
DnsPort: *dnsPort, DnsPort: *dnsPort,
OauthConfig: oauthConfig, OauthConfig: oauthConfig,

View File

@ -9,6 +9,12 @@ import (
"time" "time"
) )
type DomainOwner struct {
UserID string `json:"user_id"`
Domain string `json:"domain"`
CreatedAt time.Time `json:"created_at"`
}
type DNSRecord struct { type DNSRecord struct {
ID string `json:"id"` ID string `json:"id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) {
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) { func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
log.Println("saving dns record", record.ID) log.Println("saving dns record", record.ID)
if (record.CreatedAt == time.Time{}) {
record.CreatedAt = time.Now() record.CreatedAt = time.Now()
}
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt) _, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
if err != nil { if err != nil {
@ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err
return records, nil return records, nil
} }
func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) {
log.Println("saving domain owner", domainOwner.Domain)
domainOwner.CreatedAt = time.Now()
_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt)
if err != nil {
return nil, err
}
return domainOwner, nil
}

View File

@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error {
return nil return nil
} }
func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) {
log.Println("saving session", session.ID)
_, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt)
if err != nil {
log.Println(err)
return nil, err
}
return session, nil
}
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) { func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
newExpireAt := time.Now().Add(ExpiryDuration) newExpireAt := time.Now().Add(ExpiryDuration)

View File

@ -1,4 +1,4 @@
package dns package hcdns
import ( import (
"database/sql" "database/sql"
@ -9,27 +9,28 @@ import (
"log" "log"
) )
const MAX_RECURSION = 10 const MAX_RECURSION = 15
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
if maxDepth == 0 {
return nil, fmt.Errorf("too much recursion")
}
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil { if err != nil {
return nil, err return nil, err
} }
answers := []dns.RR{} var answers []dns.RR
for _, record := range internalCnames { for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil { if err != nil {
log.Println(err)
return nil, err return nil, err
} }
answers = append(answers, cname) answers = append(answers, cname)
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
if err != nil {
log.Println(err)
return nil, err
}
answers = append(answers, cnameRecursive...) answers = append(answers, cnameRecursive...)
} }
@ -43,36 +44,26 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
return nil, err return nil, err
} }
for _, record := range typeDnsRecords { for _, record := range typeDnsRecords {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil { if err != nil {
return nil, err return nil, err
} }
answers = append(answers, answer) answers = append(answers, answer)
} }
if len(answers) > 0 {
// base case; we found the answer
return answers, nil return answers, nil
}
func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
if maxDepth == 0 {
return nil, fmt.Errorf("too much recursion")
} }
message := new(dns.Msg) answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
message.SetQuestion(dns.Fqdn(domain), qtype) if err != nil {
message.RecursionDesired = true
client := new(dns.Client)
i := 0
in, _, err := client.Exchange(message, dnsResolvers[i])
for err != nil {
i += 1
if i == len(dnsResolvers) {
log.Println(err)
return nil, err return nil, err
} }
in, _, err = client.Exchange(message, dnsResolvers[i])
}
answers = append(answers, in.Answer...)
return answers, nil return answers, nil
} }
@ -87,21 +78,26 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true msg.Authoritative = true
for _, question := range r.Question { for _, question := range r.Question {
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue msg.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(msg)
return
} }
msg.Answer = append(msg.Answer, answers...) msg.Answer = append(msg.Answer, answers...)
} }
if len(msg.Answer) == 0 {
msg.SetRcode(r, dns.RcodeNameError)
}
log.Println(msg.Answer) log.Println(msg.Answer)
w.WriteMsg(msg) w.WriteMsg(msg)
} }
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{ handler := &DnsHandler{
DnsResolvers: argv.DnsRecursion,
DbConn: dbConn, DbConn: dbConn,
} }
addr := fmt.Sprintf(":%d", argv.DnsPort) addr := fmt.Sprintf(":%d", argv.DnsPort)

254
hcdns/server_test.go Normal file
View File

@ -0,0 +1,254 @@
package hcdns_test
import (
"database/sql"
"fmt"
"math/rand"
"os"
"sync"
"testing"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
"github.com/miekg/dns"
)
func randomPort() int {
return rand.Intn(3000) + 5192
}
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
randomDb := utils.RandomId()
dnsPort := randomPort()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
testUser := &database.User{
ID: "test",
}
database.FindOrSaveUser(testDb, testUser)
waitLock := &sync.Mutex{}
server := hcdns.MakeServer(&args.Arguments{
DnsPort: dnsPort,
}, testDb)
server.NotifyStartedFunc = func() {
waitLock.Unlock()
}
waitLock.Lock()
go func() {
server.ListenAndServe()
}()
waitLock.Lock()
address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
return testDb, server, &address, waitLock, func() {
server.Shutdown()
testDb.Close()
os.Remove(randomDb)
}
}
func TestWhenCNAMEIsResolved(t *testing.T) {
t.Log("TestWhenCNAMEIsResolved")
testDb, _, addr, lock, cleanup := setup()
defer cleanup()
defer lock.Unlock()
records := []*database.DNSRecord{
{
ID: "0",
UserID: "test",
Name: "cname.internal.example.com.",
Type: "CNAME",
Content: "next.internal.example.com.",
TTL: 300,
Internal: true,
}, {
ID: "1",
UserID: "test",
Name: "next.internal.example.com.",
Type: "CNAME",
Content: "res.example.com.",
TTL: 300,
Internal: true,
},
{
ID: "2",
UserID: "test",
Name: "res.example.com.",
Type: "A",
Content: "1.2.3.2",
TTL: 300,
Internal: true,
},
}
for _, record := range records {
database.SaveDNSRecord(testDb, record)
}
qtype := dns.TypeA
domain := dns.Fqdn("cname.internal.example.com.")
client := &dns.Client{}
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) != 3 {
t.Fatalf("expected 3 answers, got %d", len(in.Answer))
}
for i, record := range records {
if in.Answer[i].Header().Name != record.Name {
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
}
if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
}
if int(in.Answer[i].Header().Ttl) != record.TTL {
t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
}
if !in.Authoritative {
t.Fatalf("expected authoritative response")
}
}
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
t.Fatalf("expected final record to be the A record with correct IP")
}
}
func TestWhenNoRecordNxDomain(t *testing.T) {
t.Log("TestWhenNoRecordNxDomain")
_, _, addr, lock, cleanup := setup()
defer cleanup()
defer lock.Unlock()
qtype := dns.TypeA
domain := dns.Fqdn("nonexistant.example.com.")
client := &dns.Client{}
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) != 0 {
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
}
if in.Rcode != dns.RcodeNameError {
t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
}
}
func TestWhenUnresolvingCNAME(t *testing.T) {
t.Log("TestWhenUnresolvingCNAME")
testDb, _, addr, lock, cleanup := setup()
defer cleanup()
defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
UserID: "test",
Name: "cname.internal.example.com.",
Type: "CNAME",
Content: "nonexistant.example.com.",
TTL: 300,
Internal: true,
}
database.SaveDNSRecord(testDb, cname)
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
client := &dns.Client{}
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) != 1 {
t.Fatalf("expected 1 answer, got %d", len(in.Answer))
}
if !in.Authoritative {
t.Fatalf("expected authoritative response")
}
if in.Answer[0].Header().Name != cname.Name {
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
}
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
}
if in.Answer[0].(*dns.CNAME).Target != cname.Content {
t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
}
if in.Rcode == dns.RcodeNameError {
t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
}
}
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
testDb, _, addr, lock, cleanup := setup()
defer cleanup()
defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
UserID: "test",
Name: "cname.internal.example.com.",
Type: "CNAME",
Content: "cname.internal.example.com.",
TTL: 300,
Internal: true,
}
database.SaveDNSRecord(testDb, cname)
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
client := &dns.Client{}
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
if err != nil {
t.Fatal(err)
}
if len(in.Answer) > 0 {
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
}
if in.Rcode != dns.RcodeServerFailure {
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
}
}

View File

@ -6,7 +6,7 @@ import (
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" "git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
"github.com/joho/godotenv" "github.com/joho/godotenv"
) )
@ -52,7 +52,7 @@ func main() {
} }
if argv.Dns { if argv.Dns {
server := dns.MakeServer(argv, dbConn) server := hcdns.MakeServer(argv, dbConn)
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
go func() { go func() {
err = server.ListenAndServe() err = server.ListenAndServe()

View File

@ -15,6 +15,22 @@
padding: 0; padding: 0;
color: var(--text-color); color: var(--text-color);
font-family: "ComicSans", sans-serif; font-family: "ComicSans", sans-serif;
cursor: url("/static/img/cursor-1.png"), auto;
-webkit-animation: cursor 400ms infinite;
animation: cursor 400ms infinite;
}
@-webkit-keyframes cursor {
0% {cursor: url("/static/img/cursor-2.png"), auto;}
50% {cursor: url("/static/img/cursor-1.png"), auto;}
100% {cursor: url("/static/img/cursor-2.png"), auto;}
}
@keyframes cursor {
0% {cursor: url("/static/img/cursor-2.png"), auto;}
50% {cursor: url("/static/img/cursor-1.png"), auto;}
100% {cursor: url("/static/img/cursor-2.png"), auto;}
} }
body { body {

BIN
static/img/cursor-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 570 B

BIN
static/img/cursor-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 563 B