Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
Reviewed-on: #5
This commit is contained in:
commit
83cc6267fd
|
@ -2,3 +2,4 @@
|
|||
hatecomputers.club
|
||||
Dockerfile
|
||||
*.db
|
||||
.drone.yml
|
||||
|
|
31
.drone.yml
31
.drone.yml
|
@ -1,9 +1,30 @@
|
|||
---
|
||||
kind: pipeline
|
||||
type: docker
|
||||
name: build, publish docker image, deploy
|
||||
name: build
|
||||
|
||||
steps:
|
||||
- name: run tests
|
||||
image: golang
|
||||
commands:
|
||||
- go build
|
||||
- go test -p 1 -v ./...
|
||||
|
||||
trigger:
|
||||
event:
|
||||
- pull_request
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
type: docker
|
||||
name: deploy
|
||||
|
||||
steps:
|
||||
- name: run tests
|
||||
image: golang
|
||||
commands:
|
||||
- go build
|
||||
- go test -p 1 -v ./...
|
||||
- name: docker
|
||||
image: plugins/docker
|
||||
settings:
|
||||
|
@ -13,9 +34,6 @@ steps:
|
|||
from_secret: gitea_packpub_password
|
||||
registry: git.hatecomputers.club
|
||||
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
|
||||
tags:
|
||||
- latest
|
||||
- main
|
||||
- name: ssh
|
||||
image: appleboy/drone-ssh
|
||||
settings:
|
||||
|
@ -27,6 +45,9 @@ steps:
|
|||
command_timeout: 2m
|
||||
script:
|
||||
- systemctl restart docker-compose@hatecomputers-club
|
||||
|
||||
trigger:
|
||||
branch:
|
||||
- main
|
||||
- main
|
||||
event:
|
||||
- push
|
||||
|
|
|
@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
|
|||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"]
|
||||
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]
|
||||
|
|
|
@ -14,15 +14,20 @@ type CloudflareDNSResponse struct {
|
|||
Result database.DNSRecord `json:"result"`
|
||||
}
|
||||
|
||||
func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) {
|
||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId)
|
||||
type CloudflareExternalDNSAdapter struct {
|
||||
ZoneId string
|
||||
APIToken string
|
||||
}
|
||||
|
||||
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
|
||||
|
||||
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
||||
payload := strings.NewReader(reqBody)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, payload)
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+apiToken)
|
||||
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
|
@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
|
|||
return result.ID, nil
|
||||
}
|
||||
|
||||
func DeleteDNSRecord(zoneId string, apiToken string, id string) error {
|
||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id)
|
||||
func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
|
||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
|
||||
|
||||
req, _ := http.NewRequest("DELETE", url, nil)
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+apiToken)
|
||||
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
package external_dns
|
||||
|
||||
import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
|
||||
type ExternalDNSAdapter interface {
|
||||
CreateDNSRecord(record *database.DNSRecord) (string, error)
|
||||
DeleteDNSRecord(id string) error
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package api
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
|
@ -12,13 +12,14 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
verifier := utils.RandomId() + utils.RandomId()
|
||||
|
||||
sha2 := sha256.New()
|
||||
|
@ -34,7 +35,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
|
|||
Path: "/",
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 60,
|
||||
MaxAge: 200,
|
||||
})
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "state",
|
||||
|
@ -42,7 +43,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
|
|||
Path: "/",
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 60,
|
||||
MaxAge: 200,
|
||||
})
|
||||
|
||||
http.Redirect(resp, req, url, http.StatusFound)
|
||||
|
@ -50,8 +51,8 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
|
|||
}
|
||||
}
|
||||
|
||||
func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
state := req.URL.Query().Get("state")
|
||||
code := req.URL.Query().Get("code")
|
||||
|
||||
|
@ -73,7 +74,6 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
|||
reqContext := req.Context()
|
||||
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
@ -101,6 +101,16 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
|||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: true,
|
||||
})
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "verifier",
|
||||
Value: "",
|
||||
MaxAge: 0,
|
||||
})
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "state",
|
||||
Value: "",
|
||||
MaxAge: 0,
|
||||
})
|
||||
|
||||
redirect := "/"
|
||||
redirectCookie, err := req.Cookie("redirect")
|
||||
|
@ -109,6 +119,7 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
|||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "redirect",
|
||||
MaxAge: 0,
|
||||
Value: "",
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -117,6 +128,127 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
|||
}
|
||||
}
|
||||
|
||||
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
||||
|
||||
sessionCookie, err := req.Cookie("session")
|
||||
if err == nil && sessionCookie.Value != "" {
|
||||
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
||||
}
|
||||
|
||||
if userErr != nil || user == nil {
|
||||
log.Println(userErr, user)
|
||||
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "session",
|
||||
Value: "",
|
||||
MaxAge: 0,
|
||||
})
|
||||
|
||||
context.User = nil
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
context.User = user
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "redirect",
|
||||
Value: req.URL.Path,
|
||||
Path: "/",
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(resp, req, "/login", http.StatusFound)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
sessionCookie, err := req.Cookie("session")
|
||||
if err != nil {
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
||||
if err != nil {
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
sessionCookie, err := req.Cookie("session")
|
||||
if err == nil && sessionCookie.Value != "" {
|
||||
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
||||
}
|
||||
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "session",
|
||||
MaxAge: 0,
|
||||
Value: "",
|
||||
})
|
||||
http.Redirect(resp, req, "/", http.StatusFound)
|
||||
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
||||
userResponse, err := client.Get(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userStruct, err := createUserFromOauthResponse(userResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := database.FindOrSaveUser(dbConn, userStruct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
|
||||
user := &database.User{}
|
||||
err := json.NewDecoder(response.Body).Decode(user)
|
||||
defer response.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user.Username = strings.ToLower(user.Username)
|
||||
user.Username = strings.Split(user.Username, "@")[0]
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
||||
cookie, err := req.Cookie(stateCookieName)
|
||||
if err != nil || cookie.Value != expectedState {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
||||
if bearerToken == "" {
|
||||
return nil, nil
|
||||
|
@ -127,15 +259,15 @@ func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User,
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
apiKey, err := database.GetAPIKey(dbConn, parts[1])
|
||||
key, err := database.GetAPIKey(dbConn, parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if apiKey == nil {
|
||||
if key == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
user, err := database.GetUser(dbConn, apiKey.UserID)
|
||||
user, err := database.GetUser(dbConn, key.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -162,124 +294,3 @@ func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error
|
|||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
||||
|
||||
sessionCookie, err := req.Cookie("session")
|
||||
if err == nil && sessionCookie.Value != "" {
|
||||
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
||||
}
|
||||
|
||||
if userErr != nil || user == nil {
|
||||
log.Println(userErr, user)
|
||||
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "session",
|
||||
MaxAge: 0, // reset session cookie in case
|
||||
})
|
||||
|
||||
context.User = nil
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
context.User = user
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "redirect",
|
||||
Value: req.URL.Path,
|
||||
Path: "/",
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(resp, req, "/login", http.StatusFound)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
sessionCookie, err := req.Cookie("session")
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
sessionCookie, err := req.Cookie("session")
|
||||
if err == nil && sessionCookie.Value != "" {
|
||||
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
||||
}
|
||||
|
||||
http.Redirect(resp, req, "/", http.StatusFound)
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "session",
|
||||
MaxAge: 0,
|
||||
})
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
||||
userResponse, err := client.Get(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userStruct, err := createUserFromResponse(userResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := database.FindOrSaveUser(dbConn, userStruct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func createUserFromResponse(response *http.Response) (*database.User, error) {
|
||||
defer response.Body.Close()
|
||||
user := &database.User{
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
err := json.NewDecoder(response.Body).Decode(user)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user.Username = strings.ToLower(user.Username)
|
||||
user.Username = strings.Split(user.Username, "@")[0]
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
||||
cookie, err := req.Cookie(stateCookieName)
|
||||
if err != nil || cookie.Value != expectedState {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,307 @@
|
|||
package auth_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func FakedOauthServer() *httptest.Server {
|
||||
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/auth" {
|
||||
code := utils.RandomId()
|
||||
|
||||
state := r.URL.Query().Get("state")
|
||||
redirectPath := r.URL.Query().Get("redirect_uri")
|
||||
redirectPath += "?code=" + code + "&state=" + state
|
||||
|
||||
http.Redirect(w, r, redirectPath, http.StatusFound)
|
||||
}
|
||||
if r.URL.Path == "/token" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
|
||||
}
|
||||
if r.URL.Path == "/user" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
|
||||
}
|
||||
}))
|
||||
|
||||
return oauthServer
|
||||
}
|
||||
|
||||
func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
resp.Write([]byte(context.User.Username))
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/protected-path" {
|
||||
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
||||
}
|
||||
|
||||
if r.URL.Path == "/login" {
|
||||
log.Println("login")
|
||||
auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||
}
|
||||
|
||||
if r.URL.Path == "/callback" {
|
||||
log.Println("callback")
|
||||
auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||
}
|
||||
|
||||
if r.URL.Path == "/me" {
|
||||
auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
|
||||
}
|
||||
|
||||
if r.URL.Path == "/logout" {
|
||||
auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||
}
|
||||
}))
|
||||
return testServer
|
||||
}
|
||||
|
||||
func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
|
||||
randomDb := utils.RandomId()
|
||||
|
||||
testDb := database.MakeConn(&randomDb)
|
||||
database.Migrate(testDb)
|
||||
|
||||
context := &types.RequestContext{
|
||||
DBConn: testDb,
|
||||
Args: &args.Arguments{},
|
||||
TemplateData: &(map[string]interface{}{}),
|
||||
}
|
||||
|
||||
oauthServer := FakedOauthServer()
|
||||
testServer := MockUserEndpointServer(context)
|
||||
|
||||
return testDb, context, oauthServer, testServer, func() {
|
||||
oauthServer.Close()
|
||||
testServer.Close()
|
||||
|
||||
testDb.Close()
|
||||
os.Remove(randomDb)
|
||||
}
|
||||
}
|
||||
|
||||
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
|
||||
return &oauth2.Config{
|
||||
ClientID: "test",
|
||||
ClientSecret: "test",
|
||||
Scopes: []string{"test"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: oauthServerURL + "/auth",
|
||||
TokenURL: oauthServerURL + "/token",
|
||||
},
|
||||
RedirectURL: testServerURL + "/callback",
|
||||
}, oauthServerURL + "/user"
|
||||
}
|
||||
|
||||
func FollowAuthentication(
|
||||
oauthServer *httptest.Server,
|
||||
testServer *httptest.Server,
|
||||
cookies map[string]*http.Cookie,
|
||||
location string,
|
||||
) (map[string]*http.Cookie, string) {
|
||||
resp := httptest.NewRecorder()
|
||||
resp.Code = 0
|
||||
|
||||
for resp.Code == 0 || resp.Code == http.StatusFound {
|
||||
req := httptest.NewRequest("GET", location, nil)
|
||||
resp = httptest.NewRecorder()
|
||||
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
if strings.HasPrefix(location, oauthServer.URL) {
|
||||
oauthServer.Config.Handler.ServeHTTP(resp, req)
|
||||
} else {
|
||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||
}
|
||||
for _, cookie := range resp.Result().Cookies() {
|
||||
cookies[cookie.Name] = cookie
|
||||
}
|
||||
|
||||
if resp.Code == http.StatusFound {
|
||||
location = resp.Header().Get("Location")
|
||||
}
|
||||
}
|
||||
|
||||
return cookies, location
|
||||
}
|
||||
|
||||
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
|
||||
db, context, oauthServer, testServer, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||
|
||||
user, _ := database.GetUser(db, "test")
|
||||
if user != nil {
|
||||
t.Errorf("expected no user, got user")
|
||||
}
|
||||
|
||||
cookies := make(map[string]*http.Cookie)
|
||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||
|
||||
user, _ = database.GetUser(db, "test")
|
||||
if user == nil {
|
||||
t.Errorf("expected a user to be created, could not find user")
|
||||
}
|
||||
if user.Username != "test" {
|
||||
t.Errorf("expected username to be test, got %s", user.Username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
|
||||
_, context, oauthServer, testServer, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected-path", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||
location := resp.Header().Get("Location")
|
||||
if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
|
||||
t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||
}
|
||||
|
||||
cookies := make(map[string]*http.Cookie)
|
||||
cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
|
||||
|
||||
if !(strings.HasSuffix(location, "/protected-page")) {
|
||||
t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthSetsUniqueSession(t *testing.T) {
|
||||
db, context, oauthServer, testServer, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||
|
||||
cookies := make(map[string]*http.Cookie)
|
||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||
|
||||
cookiesAgain := make(map[string]*http.Cookie)
|
||||
cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
|
||||
|
||||
sessionOne := cookies["session"].Value
|
||||
sessionTwo := cookiesAgain["session"].Value
|
||||
if sessionOne == sessionTwo {
|
||||
t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
|
||||
}
|
||||
|
||||
session, _ := database.GetSession(db, sessionOne)
|
||||
if session.UserID != "test" {
|
||||
t.Errorf("expected session to be associated with user test, got %s", session.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutClearsSession(t *testing.T) {
|
||||
db, context, oauthServer, testServer, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||
|
||||
cookies := make(map[string]*http.Cookie)
|
||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||
|
||||
req := httptest.NewRequest("GET", "/logout", nil)
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||
for _, cookie := range resp.Result().Cookies() {
|
||||
cookies[cookie.Name] = cookie
|
||||
}
|
||||
|
||||
req = httptest.NewRequest("GET", "/me", nil)
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
resp = httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
||||
t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||
}
|
||||
|
||||
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||
if session != nil {
|
||||
t.Errorf("expected session to be deleted, got session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshUpdatesExpiration(t *testing.T) {
|
||||
db, context, oauthServer, testServer, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||
|
||||
cookies := make(map[string]*http.Cookie)
|
||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
||||
|
||||
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||
|
||||
req := httptest.NewRequest("GET", "/me", nil)
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||
|
||||
updatedSession, _ := database.GetSession(db, cookies["session"].Value)
|
||||
|
||||
if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
|
||||
t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
|
||||
db, context, oauthServer, testServer, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||
|
||||
cookies := make(map[string]*http.Cookie)
|
||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
||||
|
||||
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||
session.ExpireAt = time.Now().Add(-time.Hour)
|
||||
database.SaveSession(db, session)
|
||||
|
||||
req := httptest.NewRequest("GET", "/me", nil)
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
||||
t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||
}
|
||||
}
|
179
api/dns.go
179
api/dns.go
|
@ -1,179 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
const MAX_USER_RECORDS = 65
|
||||
|
||||
type FormError struct {
|
||||
Errors []string
|
||||
}
|
||||
|
||||
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool {
|
||||
ownedByUser := (user.ID == record.UserID)
|
||||
if !ownedByUser {
|
||||
return false
|
||||
}
|
||||
|
||||
if !record.Internal {
|
||||
userOwnedDomains := []string{
|
||||
fmt.Sprintf("%s", user.Username),
|
||||
fmt.Sprintf("%s.endpoints", user.Username),
|
||||
}
|
||||
|
||||
for _, domain := range userOwnedDomains {
|
||||
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
||||
if domain == record.Name || isInSubDomain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return false
|
||||
}
|
||||
|
||||
userIsOwnerOfDomain := owner == user.ID
|
||||
return ownedByUser && userIsOwnerOfDomain
|
||||
}
|
||||
|
||||
func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["DNSRecords"] = dnsRecords
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
formErrors := FormError{
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
internal := req.FormValue("internal") == "on"
|
||||
name := req.FormValue("name")
|
||||
if internal && !strings.HasSuffix(name, ".") {
|
||||
name += "."
|
||||
}
|
||||
|
||||
recordType := req.FormValue("type")
|
||||
recordType = strings.ToUpper(recordType)
|
||||
|
||||
recordContent := req.FormValue("content")
|
||||
ttl := req.FormValue("ttl")
|
||||
ttlNum, err := strconv.Atoi(ttl)
|
||||
if err != nil {
|
||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||
}
|
||||
|
||||
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
if dnsRecordCount >= MAX_USER_RECORDS {
|
||||
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
||||
}
|
||||
|
||||
dnsRecord := &database.DNSRecord{
|
||||
UserID: context.User.ID,
|
||||
Name: name,
|
||||
Type: recordType,
|
||||
Content: recordContent,
|
||||
TTL: ttlNum,
|
||||
Internal: internal,
|
||||
}
|
||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
|
||||
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) == 0 {
|
||||
if dnsRecord.Internal {
|
||||
dnsRecord.ID = utils.RandomId()
|
||||
} else {
|
||||
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
formErrors.Errors = append(formErrors.Errors, err.Error())
|
||||
}
|
||||
|
||||
dnsRecord.ID = cloudflareRecordId
|
||||
}
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) == 0 {
|
||||
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
||||
}
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) == 0 {
|
||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
||||
return success(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["FormError"] = &formErrors
|
||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
recordId := req.FormValue("id")
|
||||
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
if !record.Internal {
|
||||
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
|
||||
ownedByUser := (user.ID == record.UserID)
|
||||
if !ownedByUser {
|
||||
return false
|
||||
}
|
||||
|
||||
if !record.Internal {
|
||||
for _, format := range ownedInternalDomainFormats {
|
||||
domain := fmt.Sprintf(format, user.Username)
|
||||
|
||||
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
||||
if domain == record.Name || isInSubDomain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return false
|
||||
}
|
||||
|
||||
userIsOwnerOfDomain := owner == user.ID
|
||||
return ownedByUser && userIsOwnerOfDomain
|
||||
}
|
||||
|
||||
func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["DNSRecords"] = dnsRecords
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
formErrors := types.FormError{
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
|
||||
name := req.FormValue("name")
|
||||
if internal && !strings.HasSuffix(name, ".") {
|
||||
name += "."
|
||||
}
|
||||
|
||||
recordType := req.FormValue("type")
|
||||
recordType = strings.ToUpper(recordType)
|
||||
|
||||
recordContent := req.FormValue("content")
|
||||
ttl := req.FormValue("ttl")
|
||||
ttlNum, err := strconv.Atoi(ttl)
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||
}
|
||||
|
||||
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
if dnsRecordCount >= maxUserRecords {
|
||||
resp.WriteHeader(http.StatusTooManyRequests)
|
||||
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
||||
}
|
||||
|
||||
dnsRecord := &database.DNSRecord{
|
||||
UserID: context.User.ID,
|
||||
Name: name,
|
||||
Type: recordType,
|
||||
Content: recordContent,
|
||||
TTL: ttlNum,
|
||||
Internal: internal,
|
||||
}
|
||||
|
||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) == 0 {
|
||||
if dnsRecord.Internal {
|
||||
dnsRecord.ID = utils.RandomId()
|
||||
} else {
|
||||
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
formErrors.Errors = append(formErrors.Errors, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) == 0 {
|
||||
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
||||
}
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) == 0 {
|
||||
return success(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["FormError"] = &formErrors
|
||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
recordId := req.FormValue("id")
|
||||
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
if !(record.UserID == context.User.ID) {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
if !record.Internal {
|
||||
err = dnsAdapter.DeleteDNSRecord(recordId)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,442 @@
|
|||
package dns_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
const MAX_USER_RECORDS = 10
|
||||
|
||||
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
||||
|
||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||
randomDb := utils.RandomId()
|
||||
|
||||
testDb := database.MakeConn(&randomDb)
|
||||
database.Migrate(testDb)
|
||||
|
||||
user := &database.User{
|
||||
ID: "test",
|
||||
Username: "test",
|
||||
Mail: "test@test.com",
|
||||
DisplayName: "test",
|
||||
}
|
||||
database.FindOrSaveUser(testDb, user)
|
||||
|
||||
context := &types.RequestContext{
|
||||
DBConn: testDb,
|
||||
Args: &args.Arguments{},
|
||||
TemplateData: &(map[string]interface{}{}),
|
||||
User: user,
|
||||
}
|
||||
|
||||
return testDb, context, func() {
|
||||
testDb.Close()
|
||||
os.Remove(randomDb)
|
||||
}
|
||||
}
|
||||
|
||||
type SignallingExternalDnsAdapter struct {
|
||||
AddChannel chan *database.DNSRecord
|
||||
RmChannel chan string
|
||||
}
|
||||
|
||||
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||
id := utils.RandomId()
|
||||
go func() { adapter.AddChannel <- record }()
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
|
||||
go func() { adapter.RmChannel <- id }()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
domainOwner := &database.DomainOwner{
|
||||
UserID: context.User.ID,
|
||||
Domain: "test.domain.",
|
||||
}
|
||||
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
||||
|
||||
records, err := database.GetUserDNSRecords(db, context.User.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(records) > 0 {
|
||||
t.Errorf("expected no records, got records")
|
||||
}
|
||||
|
||||
addChannel := make(chan *database.DNSRecord)
|
||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||
AddChannel: addChannel,
|
||||
}
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
validOwner.Form = map[string][]string{
|
||||
"internal": {"on"},
|
||||
"name": {"new.test.domain."},
|
||||
"type": {"CNAME"},
|
||||
"ttl": {"43000"},
|
||||
"content": {"test.domain."},
|
||||
}
|
||||
|
||||
validOwnerRecorder := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
||||
if validOwnerRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
||||
}
|
||||
|
||||
validOwnerNonInternalRecorder := httptest.NewRecorder()
|
||||
validOwner.Form["internal"] = []string{"off"}
|
||||
testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
|
||||
if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
|
||||
}
|
||||
|
||||
invalidOwnerRecorder := httptest.NewRecorder()
|
||||
invalidOwner := validOwner
|
||||
invalidOwner.Form["internal"] = []string{"on"}
|
||||
invalidOwner.Form["name"] = []string{"new.invalid.domain."}
|
||||
testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
|
||||
if invalidOwnerRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
addChannel := make(chan *database.DNSRecord)
|
||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||
AddChannel: addChannel,
|
||||
}
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
|
||||
for _, format := range fmts {
|
||||
name := fmt.Sprintf(format, context.User.Username)
|
||||
|
||||
req.Form = map[string][]string{
|
||||
"internal": {"off"},
|
||||
"name": {name},
|
||||
"type": {"CNAME"},
|
||||
"ttl": {"43000"},
|
||||
"content": {"test.domain."},
|
||||
}
|
||||
|
||||
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||
}
|
||||
|
||||
namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
|
||||
if len(namedRecords) == 0 {
|
||||
t.Errorf("saved record not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThatExternalDnsSaves(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
addChannel := make(chan *database.DNSRecord)
|
||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||
AddChannel: addChannel,
|
||||
}
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
|
||||
name := "test." + context.User.Username
|
||||
externalRequest.Form = map[string][]string{
|
||||
"internal": {"off"},
|
||||
"name": {name},
|
||||
"type": {"CNAME"},
|
||||
"ttl": {"43000"},
|
||||
"content": {"test.domain."},
|
||||
}
|
||||
|
||||
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||
}
|
||||
select {
|
||||
case res := <-addChannel:
|
||||
if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
|
||||
t.Errorf("received the wrong external record")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Errorf("timed out in waiting for external addition")
|
||||
}
|
||||
|
||||
domainOwner := &database.DomainOwner{
|
||||
UserID: context.User.ID,
|
||||
Domain: "test.domain.",
|
||||
}
|
||||
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
||||
internalRequest := externalRequest
|
||||
internalRequest.Form["internal"] = []string{"on"}
|
||||
internalRequest.Form["name"] = []string{"test.domain."}
|
||||
|
||||
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||
}
|
||||
select {
|
||||
case _ = <-addChannel:
|
||||
t.Errorf("expected nothing in the add channel")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestThatUserMustOwnRecordToRemove(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
rmChannel := make(chan string)
|
||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||
RmChannel: rmChannel,
|
||||
}
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
|
||||
_, err := database.FindOrSaveUser(db, nonOwnerUser)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
record := &database.DNSRecord{
|
||||
ID: "1",
|
||||
Internal: false,
|
||||
Name: "test",
|
||||
Type: "CNAME",
|
||||
Content: "asdf",
|
||||
TTL: 1000,
|
||||
UserID: nonOwnerUser.ID,
|
||||
}
|
||||
_, err = database.SaveDNSRecord(db, record)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
nonOwnerRecorder := httptest.NewRecorder()
|
||||
nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
nonOwner.Form = map[string][]string{
|
||||
"id": {record.ID},
|
||||
}
|
||||
|
||||
testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
|
||||
if nonOwnerRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
|
||||
}
|
||||
|
||||
record.UserID = context.User.ID
|
||||
record.ID = "2"
|
||||
database.SaveDNSRecord(db, record)
|
||||
|
||||
owner := nonOwner
|
||||
owner.Form["id"] = []string{"2"}
|
||||
ownerRecorder := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
|
||||
if ownerRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", ownerRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThatExternalDnsRemoves(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
record := &database.DNSRecord{
|
||||
ID: "1",
|
||||
Internal: false,
|
||||
Name: "test",
|
||||
Type: "CNAME",
|
||||
Content: "asdf",
|
||||
TTL: 1000,
|
||||
UserID: context.User.ID,
|
||||
}
|
||||
database.SaveDNSRecord(db, record)
|
||||
|
||||
rmChannel := make(chan string)
|
||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||
RmChannel: rmChannel,
|
||||
}
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
externalResponseRecorder := httptest.NewRecorder()
|
||||
deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
|
||||
deleteRequest.Form = map[string][]string{
|
||||
"id": {record.ID},
|
||||
}
|
||||
|
||||
testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
|
||||
if externalResponseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
|
||||
}
|
||||
select {
|
||||
case res := <-rmChannel:
|
||||
if res != record.ID {
|
||||
t.Errorf("received the wrong external record")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Errorf("timed out in waiting for external addition")
|
||||
}
|
||||
|
||||
record.Internal = true
|
||||
record.Name = "test.domain."
|
||||
database.SaveDNSRecord(db, record)
|
||||
domainOwner := &database.DomainOwner{
|
||||
UserID: context.User.ID,
|
||||
Domain: "test.domain.",
|
||||
}
|
||||
database.SaveDomainOwner(db, domainOwner)
|
||||
|
||||
internalResponseRecorder := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
|
||||
if internalResponseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
|
||||
}
|
||||
select {
|
||||
case _ = <-rmChannel:
|
||||
t.Errorf("expected nothing in the rmchannel")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCountCannotExceed(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
record := &database.DNSRecord{
|
||||
Internal: false,
|
||||
Name: context.User.Username,
|
||||
Type: "CNAME",
|
||||
Content: "asdf",
|
||||
TTL: 1000,
|
||||
UserID: context.User.ID,
|
||||
}
|
||||
|
||||
for i := 1; i <= MAX_USER_RECORDS; i++ {
|
||||
record.ID = strconv.Itoa(i)
|
||||
record.Name = record.ID + "." + record.Name
|
||||
database.SaveDNSRecord(db, record)
|
||||
}
|
||||
|
||||
addChannel := make(chan *database.DNSRecord)
|
||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||
AddChannel: addChannel,
|
||||
}
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
req.Form = map[string][]string{
|
||||
"internal": {"off"},
|
||||
"name": {record.Name},
|
||||
"type": {record.Type},
|
||||
"ttl": {"43000"},
|
||||
"content": {record.Content},
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(recorder, req)
|
||||
if recorder.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected too many requests code return, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
domainOwner := &database.DomainOwner{
|
||||
UserID: context.User.ID,
|
||||
Domain: "test.internal.",
|
||||
}
|
||||
database.SaveDomainOwner(db, domainOwner)
|
||||
|
||||
addChannel := make(chan *database.DNSRecord)
|
||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||
AddChannel: addChannel,
|
||||
}
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
validOwner.Form = map[string][]string{
|
||||
"internal": {"on"},
|
||||
"name": {"test.internal"},
|
||||
"type": {"CNAME"},
|
||||
"ttl": {"43000"},
|
||||
"content": {"asdf.internal"},
|
||||
}
|
||||
|
||||
validOwnerRecorder := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
||||
if validOwnerRecorder.Code != http.StatusOK {
|
||||
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
||||
}
|
||||
|
||||
recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
|
||||
recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
|
||||
|
||||
if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
|
||||
t.Errorf("expected dot appended")
|
||||
}
|
||||
}
|
141
api/guestbook.go
141
api/guestbook.go
|
@ -1,141 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
type HcaptchaArgs struct {
|
||||
SiteKey string
|
||||
}
|
||||
|
||||
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
|
||||
errors := []string{}
|
||||
|
||||
if entry.Name == "" {
|
||||
errors = append(errors, "name is required")
|
||||
}
|
||||
|
||||
if entry.Message == "" {
|
||||
errors = append(errors, "message is required")
|
||||
}
|
||||
|
||||
messageLength := len(entry.Message)
|
||||
if messageLength > 500 {
|
||||
errors = append(errors, "message cannot be longer than 500 characters")
|
||||
}
|
||||
|
||||
newLines := strings.Count(entry.Message, "\n")
|
||||
if newLines > 10 {
|
||||
errors = append(errors, "message cannot contain more than 10 new lines")
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
name := req.FormValue("name")
|
||||
message := req.FormValue("message")
|
||||
hCaptchaResponse := req.FormValue("h-captcha-response")
|
||||
|
||||
formErrors := FormError{
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
if hCaptchaResponse == "" {
|
||||
formErrors.Errors = append(formErrors.Errors, "hCaptcha is required")
|
||||
}
|
||||
|
||||
entry := &database.GuestbookEntry{
|
||||
ID: utils.RandomId(),
|
||||
Name: name,
|
||||
Message: message,
|
||||
}
|
||||
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
||||
|
||||
err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed")
|
||||
}
|
||||
if len(formErrors.Errors) > 0 {
|
||||
(*context.TemplateData)["FormError"] = formErrors
|
||||
(*context.TemplateData)["EntryForm"] = entry
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
_, err = database.SaveGuestbookEntry(context.DBConn, entry)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
entries, err := database.GetGuestbookEntries(context.DBConn)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["GuestbookEntries"] = entries
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
||||
SiteKey: context.Args.HcaptchaSiteKey,
|
||||
}
|
||||
log.Println(context.Args.HcaptchaSiteKey)
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyHCaptcha(secret, response string) error {
|
||||
verifyURL := "https://hcaptcha.com/siteverify"
|
||||
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
||||
|
||||
req, err := http.NewRequest("POST", verifyURL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonResponse := struct {
|
||||
Success bool `json:"success"`
|
||||
}{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !jsonResponse.Success {
|
||||
return fmt.Errorf("hcaptcha verification failed")
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
package guestbook
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
|
||||
errors := []string{}
|
||||
|
||||
if entry.Name == "" {
|
||||
errors = append(errors, "name is required")
|
||||
}
|
||||
|
||||
if entry.Message == "" {
|
||||
errors = append(errors, "message is required")
|
||||
}
|
||||
|
||||
messageLength := len(entry.Message)
|
||||
if messageLength > 500 {
|
||||
errors = append(errors, "message cannot be longer than 500 characters")
|
||||
}
|
||||
|
||||
newLines := strings.Count(entry.Message, "\n")
|
||||
if newLines > 10 {
|
||||
errors = append(errors, "message cannot contain more than 10 new lines")
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
name := req.FormValue("name")
|
||||
message := req.FormValue("message")
|
||||
|
||||
formErrors := types.FormError{
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
entry := &database.GuestbookEntry{
|
||||
ID: utils.RandomId(),
|
||||
Name: name,
|
||||
Message: message,
|
||||
}
|
||||
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
||||
|
||||
if len(formErrors.Errors) == 0 {
|
||||
_, err := database.SaveGuestbookEntry(context.DBConn, entry)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
formErrors.Errors = append(formErrors.Errors, "failed to save entry")
|
||||
}
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) > 0 {
|
||||
(*context.TemplateData)["FormError"] = formErrors
|
||||
(*context.TemplateData)["EntryForm"] = entry
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
entries, err := database.GetGuestbookEntries(context.DBConn)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["GuestbookEntries"] = entries
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
package guestbook_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||
randomDb := utils.RandomId()
|
||||
|
||||
testDb := database.MakeConn(&randomDb)
|
||||
database.Migrate(testDb)
|
||||
|
||||
context := &types.RequestContext{
|
||||
DBConn: testDb,
|
||||
Args: &args.Arguments{},
|
||||
TemplateData: &(map[string]interface{}{}),
|
||||
}
|
||||
|
||||
return testDb, context, func() {
|
||||
testDb.Close()
|
||||
os.Remove(randomDb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidGuestbookPutsInDatabase(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
entries, err := database.GetGuestbookEntries(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(entries) > 0 {
|
||||
t.Errorf("expected no entries, got entries")
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("POST", ts.URL, nil)
|
||||
req.Form = map[string][]string{
|
||||
"name": {"test"},
|
||||
"message": {"test"},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ts.Config.Handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
entries, err = database.GetGuestbookEntries(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(entries) != 1 {
|
||||
t.Errorf("expected 1 entry, got %d", len(entries))
|
||||
}
|
||||
|
||||
if entries[0].Name != req.FormValue("name") {
|
||||
t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) {
|
||||
db, context, cleanup := setup()
|
||||
defer cleanup()
|
||||
|
||||
entries, err := database.GetGuestbookEntries(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(entries) > 0 {
|
||||
t.Errorf("expected no entries, got entries")
|
||||
}
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n"
|
||||
invalidRequests := []struct {
|
||||
name string
|
||||
message string
|
||||
}{
|
||||
{"", "test"},
|
||||
{"test", ""},
|
||||
{"", ""},
|
||||
{"test", reallyLongStringThatWouldTakeTooMuchSpace},
|
||||
}
|
||||
|
||||
for _, form := range invalidRequests {
|
||||
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||
req.Form = map[string][]string{
|
||||
"name": {form.name},
|
||||
"message": {form.message},
|
||||
}
|
||||
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
||||
|
||||
if responseRecorder.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status code 400, got %d", responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = database.GetGuestbookEntries(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(entries) != 0 {
|
||||
t.Errorf("expected 0 entries, got %d", len(entries))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package hcaptcha
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
)
|
||||
|
||||
type HcaptchaArgs struct {
|
||||
SiteKey string
|
||||
}
|
||||
|
||||
func verifyCaptcha(secret, response string) error {
|
||||
verifyURL := "https://hcaptcha.com/siteverify"
|
||||
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
||||
|
||||
req, err := http.NewRequest("POST", verifyURL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonResponse := struct {
|
||||
Success bool `json:"success"`
|
||||
}{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !jsonResponse.Success {
|
||||
return fmt.Errorf("hcaptcha verification failed")
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
||||
SiteKey: context.Args.HcaptchaSiteKey,
|
||||
}
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
hCaptchaResponse := req.FormValue("h-captcha-response")
|
||||
secretKey := context.Args.HcaptchaSecret
|
||||
|
||||
err := verifyCaptcha(secretKey, hCaptchaResponse)
|
||||
if err != nil {
|
||||
(*context.TemplateData)["FormError"] = types.FormError{
|
||||
Errors: []string{"hCaptcha verification failed"},
|
||||
}
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
|
@ -1,32 +1,33 @@
|
|||
package api
|
||||
package keys
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
const MAX_USER_API_KEYS = 5
|
||||
|
||||
func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
||||
func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["APIKeys"] = apiKeys
|
||||
(*context.TemplateData)["APIKeys"] = typesKeys
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
formErrors := FormError{
|
||||
func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
formErrors := types.FormError{
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
|
@ -38,7 +39,7 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
|
|||
}
|
||||
|
||||
if numKeys >= MAX_USER_API_KEYS {
|
||||
formErrors.Errors = append(formErrors.Errors, "max api keys reached")
|
||||
formErrors.Errors = append(formErrors.Errors, "max types keys reached")
|
||||
}
|
||||
|
||||
if len(formErrors.Errors) > 0 {
|
||||
|
@ -59,29 +60,28 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
|
|||
}
|
||||
}
|
||||
|
||||
func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
key := req.FormValue("key")
|
||||
func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
apiKey := req.FormValue("key")
|
||||
|
||||
apiKey, err := database.GetAPIKey(context.DBConn, key)
|
||||
key, err := database.GetAPIKey(context.DBConn, apiKey)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
if (apiKey == nil) || (apiKey.UserID != context.User.ID) {
|
||||
if (key == nil) || (key.UserID != context.User.ID) {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
err = database.DeleteAPIKey(context.DBConn, key)
|
||||
err = database.DeleteAPIKey(context.DBConn, apiKey)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
http.Redirect(resp, req, "/keys", http.StatusFound)
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
90
api/serve.go
90
api/serve.go
|
@ -7,27 +7,20 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
type RequestContext struct {
|
||||
DBConn *sql.DB
|
||||
Args *args.Arguments
|
||||
|
||||
Id string
|
||||
Start time.Time
|
||||
|
||||
TemplateData *map[string]interface{}
|
||||
User *database.User
|
||||
}
|
||||
|
||||
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
||||
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
||||
|
||||
func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||
func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||
context.Start = time.Now()
|
||||
context.Id = utils.RandomId()
|
||||
|
||||
|
@ -36,8 +29,8 @@ func LogRequestContinuation(context *RequestContext, req *http.Request, resp htt
|
|||
}
|
||||
}
|
||||
|
||||
func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||
func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||
end := time.Now()
|
||||
|
||||
log.Println(context.Id, "took", end.Sub(context.Start))
|
||||
|
@ -46,22 +39,22 @@ func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, re
|
|||
}
|
||||
}
|
||||
|
||||
func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||
func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||
resp.WriteHeader(200)
|
||||
resp.Write([]byte("healthy"))
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(_success Continuation, failure Continuation) ContinuationChain {
|
||||
func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
@ -80,89 +73,90 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
|||
fileServer := http.FileServer(http.Dir(argv.StaticPath))
|
||||
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
|
||||
|
||||
makeRequestContext := func() *RequestContext {
|
||||
return &RequestContext{
|
||||
DBConn: dbConn,
|
||||
Args: argv,
|
||||
cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
|
||||
APIToken: argv.CloudflareToken,
|
||||
ZoneId: argv.CloudflareZone,
|
||||
}
|
||||
|
||||
makeRequestContext := func() *types.RequestContext {
|
||||
return &types.RequestContext{
|
||||
DBConn: dbConn,
|
||||
Args: argv,
|
||||
TemplateData: &map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
const MAX_USER_RECORDS = 100
|
||||
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
||||
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
name := r.PathValue("name")
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
return &http.Server{
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package api
|
||||
package template
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -7,9 +7,11 @@ import (
|
|||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||
)
|
||||
|
||||
func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
|
||||
func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
|
||||
templatePath := context.Args.TemplatePath
|
||||
basePath := templatePath + "/base_empty.html"
|
||||
if showBaseHtml {
|
||||
|
@ -41,9 +43,9 @@ func renderTemplate(context *RequestContext, templateName string, showBaseHtml b
|
|||
return buffer, nil
|
||||
}
|
||||
|
||||
func TemplateContinuation(path string, showBase bool) Continuation {
|
||||
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
func TemplateContinuation(path string, showBase bool) types.Continuation {
|
||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||
html, err := renderTemplate(context, path, true)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
resp.WriteHeader(404)
|
||||
|
@ -66,7 +68,6 @@ func TemplateContinuation(path string, showBase bool) Continuation {
|
|||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
resp.WriteHeader(200)
|
||||
resp.Header().Set("Content-Type", "text/html")
|
||||
resp.Write(html.Bytes())
|
||||
return success(context, req, resp)
|
|
@ -0,0 +1,28 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
)
|
||||
|
||||
type RequestContext struct {
|
||||
DBConn *sql.DB
|
||||
Args *args.Arguments
|
||||
|
||||
Id string
|
||||
Start time.Time
|
||||
|
||||
TemplateData *map[string]interface{}
|
||||
User *database.User
|
||||
}
|
||||
|
||||
type FormError struct {
|
||||
Errors []string
|
||||
}
|
||||
|
||||
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
||||
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
|
@ -22,9 +22,8 @@ type Arguments struct {
|
|||
OauthConfig *oauth2.Config
|
||||
OauthUserInfoURI string
|
||||
|
||||
Dns bool
|
||||
DnsRecursion []string
|
||||
DnsPort int
|
||||
Dns bool
|
||||
DnsPort int
|
||||
|
||||
CloudflareToken string
|
||||
CloudflareZone string
|
||||
|
@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
|
|||
server := flag.Bool("server", false, "Run the server")
|
||||
|
||||
dns := flag.Bool("dns", false, "Run DNS resolver")
|
||||
dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
|
||||
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
|
||||
|
||||
flag.Parse()
|
||||
|
@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
|
|||
Migrate: *migrate,
|
||||
Scheduler: *scheduler,
|
||||
Dns: *dns,
|
||||
DnsRecursion: strings.Split(*dnsRecursion, ","),
|
||||
DnsPort: *dnsPort,
|
||||
|
||||
OauthConfig: oauthConfig,
|
||||
|
|
|
@ -9,6 +9,12 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
type DomainOwner struct {
|
||||
UserID string `json:"user_id"`
|
||||
Domain string `json:"domain"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type DNSRecord struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
|
@ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) {
|
|||
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
|
||||
log.Println("saving dns record", record.ID)
|
||||
|
||||
record.CreatedAt = time.Now()
|
||||
if (record.CreatedAt == time.Time{}) {
|
||||
record.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
|
@ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err
|
|||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) {
|
||||
log.Println("saving domain owner", domainOwner.Domain)
|
||||
|
||||
domainOwner.CreatedAt = time.Now()
|
||||
_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return domainOwner, nil
|
||||
}
|
||||
|
|
|
@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) {
|
||||
log.Println("saving session", session.ID)
|
||||
|
||||
_, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
|
||||
newExpireAt := time.Now().Add(ExpiryDuration)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package dns
|
||||
package hcdns
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
@ -9,27 +9,28 @@ import (
|
|||
"log"
|
||||
)
|
||||
|
||||
const MAX_RECURSION = 10
|
||||
|
||||
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
||||
if maxDepth == 0 {
|
||||
return nil, fmt.Errorf("too much recursion")
|
||||
}
|
||||
const MAX_RECURSION = 15
|
||||
|
||||
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
||||
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
answers := []dns.RR{}
|
||||
var answers []dns.RR
|
||||
for _, record := range internalCnames {
|
||||
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
answers = append(answers, cname)
|
||||
|
||||
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
|
||||
cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
answers = append(answers, cnameRecursive...)
|
||||
}
|
||||
|
||||
|
@ -43,36 +44,26 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
|
|||
return nil, err
|
||||
}
|
||||
for _, record := range typeDnsRecords {
|
||||
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
|
||||
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
answers = append(answers, answer)
|
||||
}
|
||||
|
||||
if len(answers) > 0 {
|
||||
// base case; we found the answer
|
||||
return answers, nil
|
||||
return answers, nil
|
||||
}
|
||||
|
||||
func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
||||
if maxDepth == 0 {
|
||||
return nil, fmt.Errorf("too much recursion")
|
||||
}
|
||||
|
||||
message := new(dns.Msg)
|
||||
message.SetQuestion(dns.Fqdn(domain), qtype)
|
||||
message.RecursionDesired = true
|
||||
|
||||
client := new(dns.Client)
|
||||
|
||||
i := 0
|
||||
in, _, err := client.Exchange(message, dnsResolvers[i])
|
||||
for err != nil {
|
||||
i += 1
|
||||
if i == len(dnsResolvers) {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
in, _, err = client.Exchange(message, dnsResolvers[i])
|
||||
answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
answers = append(answers, in.Answer...)
|
||||
return answers, nil
|
||||
}
|
||||
|
||||
|
@ -87,22 +78,27 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||
msg.Authoritative = true
|
||||
|
||||
for _, question := range r.Question {
|
||||
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
|
||||
answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
continue
|
||||
msg.SetRcode(r, dns.RcodeServerFailure)
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
msg.Answer = append(msg.Answer, answers...)
|
||||
}
|
||||
|
||||
if len(msg.Answer) == 0 {
|
||||
msg.SetRcode(r, dns.RcodeNameError)
|
||||
}
|
||||
|
||||
log.Println(msg.Answer)
|
||||
w.WriteMsg(msg)
|
||||
}
|
||||
|
||||
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
|
||||
handler := &DnsHandler{
|
||||
DnsResolvers: argv.DnsRecursion,
|
||||
DbConn: dbConn,
|
||||
DbConn: dbConn,
|
||||
}
|
||||
addr := fmt.Sprintf(":%d", argv.DnsPort)
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
package hcdns_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func randomPort() int {
|
||||
return rand.Intn(3000) + 5192
|
||||
}
|
||||
|
||||
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
|
||||
randomDb := utils.RandomId()
|
||||
dnsPort := randomPort()
|
||||
|
||||
testDb := database.MakeConn(&randomDb)
|
||||
database.Migrate(testDb)
|
||||
testUser := &database.User{
|
||||
ID: "test",
|
||||
}
|
||||
database.FindOrSaveUser(testDb, testUser)
|
||||
|
||||
waitLock := &sync.Mutex{}
|
||||
server := hcdns.MakeServer(&args.Arguments{
|
||||
DnsPort: dnsPort,
|
||||
}, testDb)
|
||||
server.NotifyStartedFunc = func() {
|
||||
waitLock.Unlock()
|
||||
}
|
||||
waitLock.Lock()
|
||||
|
||||
go func() {
|
||||
server.ListenAndServe()
|
||||
}()
|
||||
waitLock.Lock()
|
||||
|
||||
address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
|
||||
return testDb, server, &address, waitLock, func() {
|
||||
server.Shutdown()
|
||||
|
||||
testDb.Close()
|
||||
os.Remove(randomDb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenCNAMEIsResolved(t *testing.T) {
|
||||
t.Log("TestWhenCNAMEIsResolved")
|
||||
|
||||
testDb, _, addr, lock, cleanup := setup()
|
||||
defer cleanup()
|
||||
defer lock.Unlock()
|
||||
|
||||
records := []*database.DNSRecord{
|
||||
{
|
||||
ID: "0",
|
||||
UserID: "test",
|
||||
Name: "cname.internal.example.com.",
|
||||
Type: "CNAME",
|
||||
Content: "next.internal.example.com.",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
}, {
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
Name: "next.internal.example.com.",
|
||||
Type: "CNAME",
|
||||
Content: "res.example.com.",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
UserID: "test",
|
||||
Name: "res.example.com.",
|
||||
Type: "A",
|
||||
Content: "1.2.3.2",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
database.SaveDNSRecord(testDb, record)
|
||||
}
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn("cname.internal.example.com.")
|
||||
client := &dns.Client{}
|
||||
message := &dns.Msg{}
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, *addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 3 {
|
||||
t.Fatalf("expected 3 answers, got %d", len(in.Answer))
|
||||
}
|
||||
|
||||
for i, record := range records {
|
||||
if in.Answer[i].Header().Name != record.Name {
|
||||
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
|
||||
}
|
||||
|
||||
if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
|
||||
t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
|
||||
}
|
||||
|
||||
if int(in.Answer[i].Header().Ttl) != record.TTL {
|
||||
t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
|
||||
}
|
||||
|
||||
if !in.Authoritative {
|
||||
t.Fatalf("expected authoritative response")
|
||||
}
|
||||
}
|
||||
|
||||
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
|
||||
t.Fatalf("expected final record to be the A record with correct IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenNoRecordNxDomain(t *testing.T) {
|
||||
t.Log("TestWhenNoRecordNxDomain")
|
||||
|
||||
_, _, addr, lock, cleanup := setup()
|
||||
defer cleanup()
|
||||
defer lock.Unlock()
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn("nonexistant.example.com.")
|
||||
client := &dns.Client{}
|
||||
message := &dns.Msg{}
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, *addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 0 {
|
||||
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
||||
}
|
||||
|
||||
if in.Rcode != dns.RcodeNameError {
|
||||
t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenUnresolvingCNAME(t *testing.T) {
|
||||
t.Log("TestWhenUnresolvingCNAME")
|
||||
|
||||
testDb, _, addr, lock, cleanup := setup()
|
||||
defer cleanup()
|
||||
defer lock.Unlock()
|
||||
|
||||
cname := &database.DNSRecord{
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
Name: "cname.internal.example.com.",
|
||||
Type: "CNAME",
|
||||
Content: "nonexistant.example.com.",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
}
|
||||
database.SaveDNSRecord(testDb, cname)
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn(cname.Name)
|
||||
client := &dns.Client{}
|
||||
message := &dns.Msg{}
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, *addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 1 {
|
||||
t.Fatalf("expected 1 answer, got %d", len(in.Answer))
|
||||
}
|
||||
|
||||
if !in.Authoritative {
|
||||
t.Fatalf("expected authoritative response")
|
||||
}
|
||||
|
||||
if in.Answer[0].Header().Name != cname.Name {
|
||||
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
|
||||
}
|
||||
|
||||
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
|
||||
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
|
||||
}
|
||||
|
||||
if in.Answer[0].(*dns.CNAME).Target != cname.Content {
|
||||
t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
|
||||
}
|
||||
|
||||
if in.Rcode == dns.RcodeNameError {
|
||||
t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
|
||||
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
|
||||
|
||||
testDb, _, addr, lock, cleanup := setup()
|
||||
defer cleanup()
|
||||
defer lock.Unlock()
|
||||
|
||||
cname := &database.DNSRecord{
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
Name: "cname.internal.example.com.",
|
||||
Type: "CNAME",
|
||||
Content: "cname.internal.example.com.",
|
||||
TTL: 300,
|
||||
Internal: true,
|
||||
}
|
||||
database.SaveDNSRecord(testDb, cname)
|
||||
|
||||
qtype := dns.TypeA
|
||||
domain := dns.Fqdn(cname.Name)
|
||||
client := &dns.Client{}
|
||||
message := &dns.Msg{}
|
||||
message.SetQuestion(domain, qtype)
|
||||
|
||||
in, _, err := client.Exchange(message, *addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(in.Answer) > 0 {
|
||||
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
||||
}
|
||||
|
||||
if in.Rcode != dns.RcodeServerFailure {
|
||||
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
|
||||
}
|
||||
}
|
4
main.go
4
main.go
|
@ -6,7 +6,7 @@ import (
|
|||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
@ -52,7 +52,7 @@ func main() {
|
|||
}
|
||||
|
||||
if argv.Dns {
|
||||
server := dns.MakeServer(argv, dbConn)
|
||||
server := hcdns.MakeServer(argv, dbConn)
|
||||
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
|
||||
go func() {
|
||||
err = server.ListenAndServe()
|
||||
|
|
|
@ -15,6 +15,22 @@
|
|||
padding: 0;
|
||||
color: var(--text-color);
|
||||
font-family: "ComicSans", sans-serif;
|
||||
|
||||
cursor: url("/static/img/cursor-1.png"), auto;
|
||||
-webkit-animation: cursor 400ms infinite;
|
||||
animation: cursor 400ms infinite;
|
||||
}
|
||||
|
||||
@-webkit-keyframes cursor {
|
||||
0% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||
50% {cursor: url("/static/img/cursor-1.png"), auto;}
|
||||
100% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||
}
|
||||
|
||||
@keyframes cursor {
|
||||
0% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||
50% {cursor: url("/static/img/cursor-1.png"), auto;}
|
||||
100% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||
}
|
||||
|
||||
body {
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 570 B |
Binary file not shown.
After Width: | Height: | Size: 563 B |
Loading…
Reference in New Issue