Compare commits
No commits in common. "83cc6267fd5ce2f61200314424c5f400f65ff2ba" and "569d2788ebfb90774faf361f62bfe7968e091465" have entirely different histories.
83cc6267fd
...
569d2788eb
|
@ -2,4 +2,3 @@
|
||||||
hatecomputers.club
|
hatecomputers.club
|
||||||
Dockerfile
|
Dockerfile
|
||||||
*.db
|
*.db
|
||||||
.drone.yml
|
|
||||||
|
|
29
.drone.yml
29
.drone.yml
|
@ -1,30 +1,9 @@
|
||||||
---
|
---
|
||||||
kind: pipeline
|
kind: pipeline
|
||||||
type: docker
|
type: docker
|
||||||
name: build
|
name: build, publish docker image, deploy
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: run tests
|
|
||||||
image: golang
|
|
||||||
commands:
|
|
||||||
- go build
|
|
||||||
- go test -p 1 -v ./...
|
|
||||||
|
|
||||||
trigger:
|
|
||||||
event:
|
|
||||||
- pull_request
|
|
||||||
|
|
||||||
---
|
|
||||||
kind: pipeline
|
|
||||||
type: docker
|
|
||||||
name: deploy
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: run tests
|
|
||||||
image: golang
|
|
||||||
commands:
|
|
||||||
- go build
|
|
||||||
- go test -p 1 -v ./...
|
|
||||||
- name: docker
|
- name: docker
|
||||||
image: plugins/docker
|
image: plugins/docker
|
||||||
settings:
|
settings:
|
||||||
|
@ -34,6 +13,9 @@ steps:
|
||||||
from_secret: gitea_packpub_password
|
from_secret: gitea_packpub_password
|
||||||
registry: git.hatecomputers.club
|
registry: git.hatecomputers.club
|
||||||
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
|
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
|
||||||
|
tags:
|
||||||
|
- latest
|
||||||
|
- main
|
||||||
- name: ssh
|
- name: ssh
|
||||||
image: appleboy/drone-ssh
|
image: appleboy/drone-ssh
|
||||||
settings:
|
settings:
|
||||||
|
@ -45,9 +27,6 @@ steps:
|
||||||
command_timeout: 2m
|
command_timeout: 2m
|
||||||
script:
|
script:
|
||||||
- systemctl restart docker-compose@hatecomputers-club
|
- systemctl restart docker-compose@hatecomputers-club
|
||||||
|
|
||||||
trigger:
|
trigger:
|
||||||
branch:
|
branch:
|
||||||
- main
|
- main
|
||||||
event:
|
|
||||||
- push
|
|
||||||
|
|
|
@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]
|
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"]
|
||||||
|
|
|
@ -14,20 +14,15 @@ type CloudflareDNSResponse struct {
|
||||||
Result database.DNSRecord `json:"result"`
|
Result database.DNSRecord `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloudflareExternalDNSAdapter struct {
|
func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) {
|
||||||
ZoneId string
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId)
|
||||||
APIToken string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
|
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
||||||
payload := strings.NewReader(reqBody)
|
payload := strings.NewReader(reqBody)
|
||||||
|
|
||||||
req, _ := http.NewRequest("POST", url, payload)
|
req, _ := http.NewRequest("POST", url, payload)
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
req.Header.Add("Authorization", "Bearer "+apiToken)
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
@ -53,12 +48,12 @@ func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DN
|
||||||
return result.ID, nil
|
return result.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
|
func DeleteDNSRecord(zoneId string, apiToken string, id string) error {
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id)
|
||||||
|
|
||||||
req, _ := http.NewRequest("DELETE", url, nil)
|
req, _ := http.NewRequest("DELETE", url, nil)
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
req.Header.Add("Authorization", "Bearer "+apiToken)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
package external_dns
|
|
||||||
|
|
||||||
import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
|
|
||||||
type ExternalDNSAdapter interface {
|
|
||||||
CreateDNSRecord(record *database.DNSRecord) (string, error)
|
|
||||||
DeleteDNSRecord(id string) error
|
|
||||||
}
|
|
|
@ -1,33 +1,32 @@
|
||||||
package keys
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MAX_USER_API_KEYS = 5
|
const MAX_USER_API_KEYS = 5
|
||||||
|
|
||||||
func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
(*context.TemplateData)["APIKeys"] = typesKeys
|
(*context.TemplateData)["APIKeys"] = apiKeys
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
formErrors := types.FormError{
|
formErrors := FormError{
|
||||||
Errors: []string{},
|
Errors: []string{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +38,7 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request,
|
||||||
}
|
}
|
||||||
|
|
||||||
if numKeys >= MAX_USER_API_KEYS {
|
if numKeys >= MAX_USER_API_KEYS {
|
||||||
formErrors.Errors = append(formErrors.Errors, "max types keys reached")
|
formErrors.Errors = append(formErrors.Errors, "max api keys reached")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(formErrors.Errors) > 0 {
|
if len(formErrors.Errors) > 0 {
|
||||||
|
@ -60,28 +59,29 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
apiKey := req.FormValue("key")
|
key := req.FormValue("key")
|
||||||
|
|
||||||
key, err := database.GetAPIKey(context.DBConn, apiKey)
|
apiKey, err := database.GetAPIKey(context.DBConn, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
if (key == nil) || (key.UserID != context.User.ID) {
|
if (apiKey == nil) || (apiKey.UserID != context.User.ID) {
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = database.DeleteAPIKey(context.DBConn, apiKey)
|
err = database.DeleteAPIKey(context.DBConn, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
http.Redirect(resp, req, "/keys", http.StatusFound)
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package auth
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
@ -12,14 +12,13 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
verifier := utils.RandomId() + utils.RandomId()
|
verifier := utils.RandomId() + utils.RandomId()
|
||||||
|
|
||||||
sha2 := sha256.New()
|
sha2 := sha256.New()
|
||||||
|
@ -35,7 +34,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 200,
|
MaxAge: 60,
|
||||||
})
|
})
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "state",
|
Name: "state",
|
||||||
|
@ -43,7 +42,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 200,
|
MaxAge: 60,
|
||||||
})
|
})
|
||||||
|
|
||||||
http.Redirect(resp, req, url, http.StatusFound)
|
http.Redirect(resp, req, url, http.StatusFound)
|
||||||
|
@ -51,8 +50,8 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
state := req.URL.Query().Get("state")
|
state := req.URL.Query().Get("state")
|
||||||
code := req.URL.Query().Get("code")
|
code := req.URL.Query().Get("code")
|
||||||
|
|
||||||
|
@ -74,6 +73,7 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
reqContext := req.Context()
|
reqContext := req.Context()
|
||||||
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
@ -101,16 +101,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
})
|
})
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "verifier",
|
|
||||||
Value: "",
|
|
||||||
MaxAge: 0,
|
|
||||||
})
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "state",
|
|
||||||
Value: "",
|
|
||||||
MaxAge: 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
redirect := "/"
|
redirect := "/"
|
||||||
redirectCookie, err := req.Cookie("redirect")
|
redirectCookie, err := req.Cookie("redirect")
|
||||||
|
@ -119,7 +109,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "redirect",
|
Name: "redirect",
|
||||||
MaxAge: 0,
|
MaxAge: 0,
|
||||||
Value: "",
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,127 +117,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
authHeader := req.Header.Get("Authorization")
|
|
||||||
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
|
||||||
|
|
||||||
sessionCookie, err := req.Cookie("session")
|
|
||||||
if err == nil && sessionCookie.Value != "" {
|
|
||||||
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
if userErr != nil || user == nil {
|
|
||||||
log.Println(userErr, user)
|
|
||||||
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "session",
|
|
||||||
Value: "",
|
|
||||||
MaxAge: 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
context.User = nil
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
context.User = user
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "redirect",
|
|
||||||
Value: req.URL.Path,
|
|
||||||
Path: "/",
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
})
|
|
||||||
|
|
||||||
http.Redirect(resp, req, "/login", http.StatusFound)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
sessionCookie, err := req.Cookie("session")
|
|
||||||
if err != nil {
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
|
||||||
if err != nil {
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
sessionCookie, err := req.Cookie("session")
|
|
||||||
if err == nil && sessionCookie.Value != "" {
|
|
||||||
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "session",
|
|
||||||
MaxAge: 0,
|
|
||||||
Value: "",
|
|
||||||
})
|
|
||||||
http.Redirect(resp, req, "/", http.StatusFound)
|
|
||||||
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
|
||||||
userResponse, err := client.Get(uri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userStruct, err := createUserFromOauthResponse(userResponse)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := database.FindOrSaveUser(dbConn, userStruct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
|
|
||||||
user := &database.User{}
|
|
||||||
err := json.NewDecoder(response.Body).Decode(user)
|
|
||||||
defer response.Body.Close()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.Username = strings.ToLower(user.Username)
|
|
||||||
user.Username = strings.Split(user.Username, "@")[0]
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
|
||||||
cookie, err := req.Cookie(stateCookieName)
|
|
||||||
if err != nil || cookie.Value != expectedState {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
||||||
if bearerToken == "" {
|
if bearerToken == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -259,15 +127,15 @@ func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User,
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := database.GetAPIKey(dbConn, parts[1])
|
apiKey, err := database.GetAPIKey(dbConn, parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if key == nil {
|
if apiKey == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := database.GetUser(dbConn, key.UserID)
|
user, err := database.GetUser(dbConn, apiKey.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -294,3 +162,124 @@ func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
authHeader := req.Header.Get("Authorization")
|
||||||
|
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
||||||
|
|
||||||
|
sessionCookie, err := req.Cookie("session")
|
||||||
|
if err == nil && sessionCookie.Value != "" {
|
||||||
|
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userErr != nil || user == nil {
|
||||||
|
log.Println(userErr, user)
|
||||||
|
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "session",
|
||||||
|
MaxAge: 0, // reset session cookie in case
|
||||||
|
})
|
||||||
|
|
||||||
|
context.User = nil
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
context.User = user
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "redirect",
|
||||||
|
Value: req.URL.Path,
|
||||||
|
Path: "/",
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(resp, req, "/login", http.StatusFound)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
sessionCookie, err := req.Cookie("session")
|
||||||
|
if err != nil {
|
||||||
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
sessionCookie, err := req.Cookie("session")
|
||||||
|
if err == nil && sessionCookie.Value != "" {
|
||||||
|
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(resp, req, "/", http.StatusFound)
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "session",
|
||||||
|
MaxAge: 0,
|
||||||
|
})
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
||||||
|
userResponse, err := client.Get(uri)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
userStruct, err := createUserFromResponse(userResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := database.FindOrSaveUser(dbConn, userStruct)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUserFromResponse(response *http.Response) (*database.User, error) {
|
||||||
|
defer response.Body.Close()
|
||||||
|
user := &database.User{
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
err := json.NewDecoder(response.Body).Decode(user)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Username = strings.ToLower(user.Username)
|
||||||
|
user.Username = strings.Split(user.Username, "@")[0]
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
||||||
|
cookie, err := req.Cookie(stateCookieName)
|
||||||
|
if err != nil || cookie.Value != expectedState {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
|
@ -1,307 +0,0 @@
|
||||||
package auth_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
)
|
|
||||||
|
|
||||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func FakedOauthServer() *httptest.Server {
|
|
||||||
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/auth" {
|
|
||||||
code := utils.RandomId()
|
|
||||||
|
|
||||||
state := r.URL.Query().Get("state")
|
|
||||||
redirectPath := r.URL.Query().Get("redirect_uri")
|
|
||||||
redirectPath += "?code=" + code + "&state=" + state
|
|
||||||
|
|
||||||
http.Redirect(w, r, redirectPath, http.StatusFound)
|
|
||||||
}
|
|
||||||
if r.URL.Path == "/token" {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
|
|
||||||
}
|
|
||||||
if r.URL.Path == "/user" {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
return oauthServer
|
|
||||||
}
|
|
||||||
|
|
||||||
func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
resp.Write([]byte(context.User.Username))
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/protected-path" {
|
|
||||||
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.URL.Path == "/login" {
|
|
||||||
log.Println("login")
|
|
||||||
auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.URL.Path == "/callback" {
|
|
||||||
log.Println("callback")
|
|
||||||
auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.URL.Path == "/me" {
|
|
||||||
auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.URL.Path == "/logout" {
|
|
||||||
auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
return testServer
|
|
||||||
}
|
|
||||||
|
|
||||||
func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
|
|
||||||
randomDb := utils.RandomId()
|
|
||||||
|
|
||||||
testDb := database.MakeConn(&randomDb)
|
|
||||||
database.Migrate(testDb)
|
|
||||||
|
|
||||||
context := &types.RequestContext{
|
|
||||||
DBConn: testDb,
|
|
||||||
Args: &args.Arguments{},
|
|
||||||
TemplateData: &(map[string]interface{}{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
oauthServer := FakedOauthServer()
|
|
||||||
testServer := MockUserEndpointServer(context)
|
|
||||||
|
|
||||||
return testDb, context, oauthServer, testServer, func() {
|
|
||||||
oauthServer.Close()
|
|
||||||
testServer.Close()
|
|
||||||
|
|
||||||
testDb.Close()
|
|
||||||
os.Remove(randomDb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
|
|
||||||
return &oauth2.Config{
|
|
||||||
ClientID: "test",
|
|
||||||
ClientSecret: "test",
|
|
||||||
Scopes: []string{"test"},
|
|
||||||
Endpoint: oauth2.Endpoint{
|
|
||||||
AuthURL: oauthServerURL + "/auth",
|
|
||||||
TokenURL: oauthServerURL + "/token",
|
|
||||||
},
|
|
||||||
RedirectURL: testServerURL + "/callback",
|
|
||||||
}, oauthServerURL + "/user"
|
|
||||||
}
|
|
||||||
|
|
||||||
func FollowAuthentication(
|
|
||||||
oauthServer *httptest.Server,
|
|
||||||
testServer *httptest.Server,
|
|
||||||
cookies map[string]*http.Cookie,
|
|
||||||
location string,
|
|
||||||
) (map[string]*http.Cookie, string) {
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
resp.Code = 0
|
|
||||||
|
|
||||||
for resp.Code == 0 || resp.Code == http.StatusFound {
|
|
||||||
req := httptest.NewRequest("GET", location, nil)
|
|
||||||
resp = httptest.NewRecorder()
|
|
||||||
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
req.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(location, oauthServer.URL) {
|
|
||||||
oauthServer.Config.Handler.ServeHTTP(resp, req)
|
|
||||||
} else {
|
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
|
||||||
}
|
|
||||||
for _, cookie := range resp.Result().Cookies() {
|
|
||||||
cookies[cookie.Name] = cookie
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Code == http.StatusFound {
|
|
||||||
location = resp.Header().Get("Location")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return cookies, location
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
|
|
||||||
db, context, oauthServer, testServer, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
|
||||||
|
|
||||||
user, _ := database.GetUser(db, "test")
|
|
||||||
if user != nil {
|
|
||||||
t.Errorf("expected no user, got user")
|
|
||||||
}
|
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
|
||||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
|
||||||
|
|
||||||
user, _ = database.GetUser(db, "test")
|
|
||||||
if user == nil {
|
|
||||||
t.Errorf("expected a user to be created, could not find user")
|
|
||||||
}
|
|
||||||
if user.Username != "test" {
|
|
||||||
t.Errorf("expected username to be test, got %s", user.Username)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
|
|
||||||
_, context, oauthServer, testServer, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/protected-path", nil)
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
|
||||||
location := resp.Header().Get("Location")
|
|
||||||
if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
|
|
||||||
t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
|
||||||
}
|
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
|
||||||
cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
|
|
||||||
|
|
||||||
if !(strings.HasSuffix(location, "/protected-page")) {
|
|
||||||
t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOauthSetsUniqueSession(t *testing.T) {
|
|
||||||
db, context, oauthServer, testServer, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
|
||||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
|
||||||
|
|
||||||
cookiesAgain := make(map[string]*http.Cookie)
|
|
||||||
cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
|
|
||||||
|
|
||||||
sessionOne := cookies["session"].Value
|
|
||||||
sessionTwo := cookiesAgain["session"].Value
|
|
||||||
if sessionOne == sessionTwo {
|
|
||||||
t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
|
|
||||||
}
|
|
||||||
|
|
||||||
session, _ := database.GetSession(db, sessionOne)
|
|
||||||
if session.UserID != "test" {
|
|
||||||
t.Errorf("expected session to be associated with user test, got %s", session.UserID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogoutClearsSession(t *testing.T) {
|
|
||||||
db, context, oauthServer, testServer, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
|
||||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/logout", nil)
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
req.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
|
||||||
for _, cookie := range resp.Result().Cookies() {
|
|
||||||
cookies[cookie.Name] = cookie
|
|
||||||
}
|
|
||||||
|
|
||||||
req = httptest.NewRequest("GET", "/me", nil)
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
req.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
resp = httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
|
||||||
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
|
||||||
t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
|
||||||
}
|
|
||||||
|
|
||||||
session, _ := database.GetSession(db, cookies["session"].Value)
|
|
||||||
if session != nil {
|
|
||||||
t.Errorf("expected session to be deleted, got session")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRefreshUpdatesExpiration(t *testing.T) {
|
|
||||||
db, context, oauthServer, testServer, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
|
||||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
|
||||||
|
|
||||||
session, _ := database.GetSession(db, cookies["session"].Value)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/me", nil)
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
req.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
|
||||||
|
|
||||||
updatedSession, _ := database.GetSession(db, cookies["session"].Value)
|
|
||||||
|
|
||||||
if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
|
|
||||||
t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
|
|
||||||
db, context, oauthServer, testServer, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
|
||||||
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
|
||||||
|
|
||||||
session, _ := database.GetSession(db, cookies["session"].Value)
|
|
||||||
session.ExpireAt = time.Now().Add(-time.Hour)
|
|
||||||
database.SaveSession(db, session)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/me", nil)
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
req.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
|
||||||
|
|
||||||
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
|
||||||
t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,179 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MAX_USER_RECORDS = 65
|
||||||
|
|
||||||
|
type FormError struct {
|
||||||
|
Errors []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool {
|
||||||
|
ownedByUser := (user.ID == record.UserID)
|
||||||
|
if !ownedByUser {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !record.Internal {
|
||||||
|
userOwnedDomains := []string{
|
||||||
|
fmt.Sprintf("%s", user.Username),
|
||||||
|
fmt.Sprintf("%s.endpoints", user.Username),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range userOwnedDomains {
|
||||||
|
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
||||||
|
if domain == record.Name || isInSubDomain {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
userIsOwnerOfDomain := owner == user.ID
|
||||||
|
return ownedByUser && userIsOwnerOfDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
(*context.TemplateData)["DNSRecords"] = dnsRecords
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
formErrors := FormError{
|
||||||
|
Errors: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
internal := req.FormValue("internal") == "on"
|
||||||
|
name := req.FormValue("name")
|
||||||
|
if internal && !strings.HasSuffix(name, ".") {
|
||||||
|
name += "."
|
||||||
|
}
|
||||||
|
|
||||||
|
recordType := req.FormValue("type")
|
||||||
|
recordType = strings.ToUpper(recordType)
|
||||||
|
|
||||||
|
recordContent := req.FormValue("content")
|
||||||
|
ttl := req.FormValue("ttl")
|
||||||
|
ttlNum, err := strconv.Atoi(ttl)
|
||||||
|
if err != nil {
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
if dnsRecordCount >= MAX_USER_RECORDS {
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRecord := &database.DNSRecord{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Name: name,
|
||||||
|
Type: recordType,
|
||||||
|
Content: recordContent,
|
||||||
|
TTL: ttlNum,
|
||||||
|
Internal: internal,
|
||||||
|
}
|
||||||
|
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
if dnsRecord.Internal {
|
||||||
|
dnsRecord.ID = utils.RandomId()
|
||||||
|
} else {
|
||||||
|
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRecord.ID = cloudflareRecordId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
http.Redirect(resp, req, "/dns", http.StatusFound)
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
(*context.TemplateData)["FormError"] = &formErrors
|
||||||
|
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||||
|
|
||||||
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
recordId := req.FormValue("id")
|
||||||
|
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
|
||||||
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !record.Internal {
|
||||||
|
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
||||||
|
if err != nil {
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(resp, req, "/dns", http.StatusFound)
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
174
api/dns/dns.go
174
api/dns/dns.go
|
@ -1,174 +0,0 @@
|
||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
|
|
||||||
ownedByUser := (user.ID == record.UserID)
|
|
||||||
if !ownedByUser {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !record.Internal {
|
|
||||||
for _, format := range ownedInternalDomainFormats {
|
|
||||||
domain := fmt.Sprintf(format, user.Username)
|
|
||||||
|
|
||||||
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
|
||||||
if domain == record.Name || isInSubDomain {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
userIsOwnerOfDomain := owner == user.ID
|
|
||||||
return ownedByUser && userIsOwnerOfDomain
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
(*context.TemplateData)["DNSRecords"] = dnsRecords
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
formErrors := types.FormError{
|
|
||||||
Errors: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
|
|
||||||
name := req.FormValue("name")
|
|
||||||
if internal && !strings.HasSuffix(name, ".") {
|
|
||||||
name += "."
|
|
||||||
}
|
|
||||||
|
|
||||||
recordType := req.FormValue("type")
|
|
||||||
recordType = strings.ToUpper(recordType)
|
|
||||||
|
|
||||||
recordContent := req.FormValue("content")
|
|
||||||
ttl := req.FormValue("ttl")
|
|
||||||
ttlNum, err := strconv.Atoi(ttl)
|
|
||||||
if err != nil {
|
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
if dnsRecordCount >= maxUserRecords {
|
|
||||||
resp.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecord := &database.DNSRecord{
|
|
||||||
UserID: context.User.ID,
|
|
||||||
Name: name,
|
|
||||||
Type: recordType,
|
|
||||||
Content: recordContent,
|
|
||||||
TTL: ttlNum,
|
|
||||||
Internal: internal,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
|
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
if dnsRecord.Internal {
|
|
||||||
dnsRecord.ID = utils.RandomId()
|
|
||||||
} else {
|
|
||||||
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
(*context.TemplateData)["FormError"] = &formErrors
|
|
||||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
recordId := req.FormValue("id")
|
|
||||||
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(record.UserID == context.User.ID) {
|
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !record.Internal {
|
|
||||||
err = dnsAdapter.DeleteDNSRecord(recordId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
|
||||||
if err != nil {
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,442 +0,0 @@
|
||||||
package dns_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
const MAX_USER_RECORDS = 10
|
|
||||||
|
|
||||||
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
|
||||||
|
|
||||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setup() (*sql.DB, *types.RequestContext, func()) {
|
|
||||||
randomDb := utils.RandomId()
|
|
||||||
|
|
||||||
testDb := database.MakeConn(&randomDb)
|
|
||||||
database.Migrate(testDb)
|
|
||||||
|
|
||||||
user := &database.User{
|
|
||||||
ID: "test",
|
|
||||||
Username: "test",
|
|
||||||
Mail: "test@test.com",
|
|
||||||
DisplayName: "test",
|
|
||||||
}
|
|
||||||
database.FindOrSaveUser(testDb, user)
|
|
||||||
|
|
||||||
context := &types.RequestContext{
|
|
||||||
DBConn: testDb,
|
|
||||||
Args: &args.Arguments{},
|
|
||||||
TemplateData: &(map[string]interface{}{}),
|
|
||||||
User: user,
|
|
||||||
}
|
|
||||||
|
|
||||||
return testDb, context, func() {
|
|
||||||
testDb.Close()
|
|
||||||
os.Remove(randomDb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type SignallingExternalDnsAdapter struct {
|
|
||||||
AddChannel chan *database.DNSRecord
|
|
||||||
RmChannel chan string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
|
||||||
id := utils.RandomId()
|
|
||||||
go func() { adapter.AddChannel <- record }()
|
|
||||||
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
|
|
||||||
go func() { adapter.RmChannel <- id }()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
domainOwner := &database.DomainOwner{
|
|
||||||
UserID: context.User.ID,
|
|
||||||
Domain: "test.domain.",
|
|
||||||
}
|
|
||||||
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
|
||||||
|
|
||||||
records, err := database.GetUserDNSRecords(db, context.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(records) > 0 {
|
|
||||||
t.Errorf("expected no records, got records")
|
|
||||||
}
|
|
||||||
|
|
||||||
addChannel := make(chan *database.DNSRecord)
|
|
||||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
|
||||||
AddChannel: addChannel,
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
validOwner.Form = map[string][]string{
|
|
||||||
"internal": {"on"},
|
|
||||||
"name": {"new.test.domain."},
|
|
||||||
"type": {"CNAME"},
|
|
||||||
"ttl": {"43000"},
|
|
||||||
"content": {"test.domain."},
|
|
||||||
}
|
|
||||||
|
|
||||||
validOwnerRecorder := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
|
||||||
if validOwnerRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
validOwnerNonInternalRecorder := httptest.NewRecorder()
|
|
||||||
validOwner.Form["internal"] = []string{"off"}
|
|
||||||
testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
|
|
||||||
if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
|
|
||||||
t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
invalidOwnerRecorder := httptest.NewRecorder()
|
|
||||||
invalidOwner := validOwner
|
|
||||||
invalidOwner.Form["internal"] = []string{"on"}
|
|
||||||
invalidOwner.Form["name"] = []string{"new.invalid.domain."}
|
|
||||||
testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
|
|
||||||
if invalidOwnerRecorder.Code != http.StatusUnauthorized {
|
|
||||||
t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
addChannel := make(chan *database.DNSRecord)
|
|
||||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
|
||||||
AddChannel: addChannel,
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
responseRecorder := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
|
|
||||||
for _, format := range fmts {
|
|
||||||
name := fmt.Sprintf(format, context.User.Username)
|
|
||||||
|
|
||||||
req.Form = map[string][]string{
|
|
||||||
"internal": {"off"},
|
|
||||||
"name": {name},
|
|
||||||
"type": {"CNAME"},
|
|
||||||
"ttl": {"43000"},
|
|
||||||
"content": {"test.domain."},
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
|
||||||
if responseRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
|
|
||||||
if len(namedRecords) == 0 {
|
|
||||||
t.Errorf("saved record not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThatExternalDnsSaves(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
addChannel := make(chan *database.DNSRecord)
|
|
||||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
|
||||||
AddChannel: addChannel,
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
responseRecorder := httptest.NewRecorder()
|
|
||||||
externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
|
|
||||||
name := "test." + context.User.Username
|
|
||||||
externalRequest.Form = map[string][]string{
|
|
||||||
"internal": {"off"},
|
|
||||||
"name": {name},
|
|
||||||
"type": {"CNAME"},
|
|
||||||
"ttl": {"43000"},
|
|
||||||
"content": {"test.domain."},
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
|
||||||
if responseRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case res := <-addChannel:
|
|
||||||
if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
|
|
||||||
t.Errorf("received the wrong external record")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Errorf("timed out in waiting for external addition")
|
|
||||||
}
|
|
||||||
|
|
||||||
domainOwner := &database.DomainOwner{
|
|
||||||
UserID: context.User.ID,
|
|
||||||
Domain: "test.domain.",
|
|
||||||
}
|
|
||||||
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
|
||||||
internalRequest := externalRequest
|
|
||||||
internalRequest.Form["internal"] = []string{"on"}
|
|
||||||
internalRequest.Form["name"] = []string{"test.domain."}
|
|
||||||
|
|
||||||
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
|
||||||
if responseRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case _ = <-addChannel:
|
|
||||||
t.Errorf("expected nothing in the add channel")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThatUserMustOwnRecordToRemove(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
rmChannel := make(chan string)
|
|
||||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
|
||||||
RmChannel: rmChannel,
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
|
|
||||||
_, err := database.FindOrSaveUser(db, nonOwnerUser)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
record := &database.DNSRecord{
|
|
||||||
ID: "1",
|
|
||||||
Internal: false,
|
|
||||||
Name: "test",
|
|
||||||
Type: "CNAME",
|
|
||||||
Content: "asdf",
|
|
||||||
TTL: 1000,
|
|
||||||
UserID: nonOwnerUser.ID,
|
|
||||||
}
|
|
||||||
_, err = database.SaveDNSRecord(db, record)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nonOwnerRecorder := httptest.NewRecorder()
|
|
||||||
nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
nonOwner.Form = map[string][]string{
|
|
||||||
"id": {record.ID},
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
|
|
||||||
if nonOwnerRecorder.Code != http.StatusUnauthorized {
|
|
||||||
t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
record.UserID = context.User.ID
|
|
||||||
record.ID = "2"
|
|
||||||
database.SaveDNSRecord(db, record)
|
|
||||||
|
|
||||||
owner := nonOwner
|
|
||||||
owner.Form["id"] = []string{"2"}
|
|
||||||
ownerRecorder := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
|
|
||||||
if ownerRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", ownerRecorder.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThatExternalDnsRemoves(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
record := &database.DNSRecord{
|
|
||||||
ID: "1",
|
|
||||||
Internal: false,
|
|
||||||
Name: "test",
|
|
||||||
Type: "CNAME",
|
|
||||||
Content: "asdf",
|
|
||||||
TTL: 1000,
|
|
||||||
UserID: context.User.ID,
|
|
||||||
}
|
|
||||||
database.SaveDNSRecord(db, record)
|
|
||||||
|
|
||||||
rmChannel := make(chan string)
|
|
||||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
|
||||||
RmChannel: rmChannel,
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
externalResponseRecorder := httptest.NewRecorder()
|
|
||||||
deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
|
|
||||||
deleteRequest.Form = map[string][]string{
|
|
||||||
"id": {record.ID},
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
|
|
||||||
if externalResponseRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case res := <-rmChannel:
|
|
||||||
if res != record.ID {
|
|
||||||
t.Errorf("received the wrong external record")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Errorf("timed out in waiting for external addition")
|
|
||||||
}
|
|
||||||
|
|
||||||
record.Internal = true
|
|
||||||
record.Name = "test.domain."
|
|
||||||
database.SaveDNSRecord(db, record)
|
|
||||||
domainOwner := &database.DomainOwner{
|
|
||||||
UserID: context.User.ID,
|
|
||||||
Domain: "test.domain.",
|
|
||||||
}
|
|
||||||
database.SaveDomainOwner(db, domainOwner)
|
|
||||||
|
|
||||||
internalResponseRecorder := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
|
|
||||||
if internalResponseRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case _ = <-rmChannel:
|
|
||||||
t.Errorf("expected nothing in the rmchannel")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecordCountCannotExceed(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
record := &database.DNSRecord{
|
|
||||||
Internal: false,
|
|
||||||
Name: context.User.Username,
|
|
||||||
Type: "CNAME",
|
|
||||||
Content: "asdf",
|
|
||||||
TTL: 1000,
|
|
||||||
UserID: context.User.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 1; i <= MAX_USER_RECORDS; i++ {
|
|
||||||
record.ID = strconv.Itoa(i)
|
|
||||||
record.Name = record.ID + "." + record.Name
|
|
||||||
database.SaveDNSRecord(db, record)
|
|
||||||
}
|
|
||||||
|
|
||||||
addChannel := make(chan *database.DNSRecord)
|
|
||||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
|
||||||
AddChannel: addChannel,
|
|
||||||
}
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
req.Form = map[string][]string{
|
|
||||||
"internal": {"off"},
|
|
||||||
"name": {record.Name},
|
|
||||||
"type": {record.Type},
|
|
||||||
"ttl": {"43000"},
|
|
||||||
"content": {record.Content},
|
|
||||||
}
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(recorder, req)
|
|
||||||
if recorder.Code != http.StatusTooManyRequests {
|
|
||||||
t.Errorf("expected too many requests code return, got %d", recorder.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
domainOwner := &database.DomainOwner{
|
|
||||||
UserID: context.User.ID,
|
|
||||||
Domain: "test.internal.",
|
|
||||||
}
|
|
||||||
database.SaveDomainOwner(db, domainOwner)
|
|
||||||
|
|
||||||
addChannel := make(chan *database.DNSRecord)
|
|
||||||
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
|
||||||
AddChannel: addChannel,
|
|
||||||
}
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
validOwner.Form = map[string][]string{
|
|
||||||
"internal": {"on"},
|
|
||||||
"name": {"test.internal"},
|
|
||||||
"type": {"CNAME"},
|
|
||||||
"ttl": {"43000"},
|
|
||||||
"content": {"asdf.internal"},
|
|
||||||
}
|
|
||||||
|
|
||||||
validOwnerRecorder := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
|
||||||
if validOwnerRecorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
|
|
||||||
recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
|
|
||||||
|
|
||||||
if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
|
|
||||||
t.Errorf("expected dot appended")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,141 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HcaptchaArgs struct {
|
||||||
|
SiteKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
|
||||||
|
errors := []string{}
|
||||||
|
|
||||||
|
if entry.Name == "" {
|
||||||
|
errors = append(errors, "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.Message == "" {
|
||||||
|
errors = append(errors, "message is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
messageLength := len(entry.Message)
|
||||||
|
if messageLength > 500 {
|
||||||
|
errors = append(errors, "message cannot be longer than 500 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
newLines := strings.Count(entry.Message, "\n")
|
||||||
|
if newLines > 10 {
|
||||||
|
errors = append(errors, "message cannot contain more than 10 new lines")
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
name := req.FormValue("name")
|
||||||
|
message := req.FormValue("message")
|
||||||
|
hCaptchaResponse := req.FormValue("h-captcha-response")
|
||||||
|
|
||||||
|
formErrors := FormError{
|
||||||
|
Errors: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if hCaptchaResponse == "" {
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "hCaptcha is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &database.GuestbookEntry{
|
||||||
|
ID: utils.RandomId(),
|
||||||
|
Name: name,
|
||||||
|
Message: message,
|
||||||
|
}
|
||||||
|
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
||||||
|
|
||||||
|
err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed")
|
||||||
|
}
|
||||||
|
if len(formErrors.Errors) > 0 {
|
||||||
|
(*context.TemplateData)["FormError"] = formErrors
|
||||||
|
(*context.TemplateData)["EntryForm"] = entry
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = database.SaveGuestbookEntry(context.DBConn, entry)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
entries, err := database.GetGuestbookEntries(context.DBConn)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
(*context.TemplateData)["GuestbookEntries"] = entries
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
|
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
||||||
|
SiteKey: context.Args.HcaptchaSiteKey,
|
||||||
|
}
|
||||||
|
log.Println(context.Args.HcaptchaSiteKey)
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyHCaptcha(secret, response string) error {
|
||||||
|
verifyURL := "https://hcaptcha.com/siteverify"
|
||||||
|
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", verifyURL, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse := struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}{}
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !jsonResponse.Success {
|
||||||
|
return fmt.Errorf("hcaptcha verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,85 +0,0 @@
|
||||||
package guestbook
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
|
|
||||||
errors := []string{}
|
|
||||||
|
|
||||||
if entry.Name == "" {
|
|
||||||
errors = append(errors, "name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
if entry.Message == "" {
|
|
||||||
errors = append(errors, "message is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
messageLength := len(entry.Message)
|
|
||||||
if messageLength > 500 {
|
|
||||||
errors = append(errors, "message cannot be longer than 500 characters")
|
|
||||||
}
|
|
||||||
|
|
||||||
newLines := strings.Count(entry.Message, "\n")
|
|
||||||
if newLines > 10 {
|
|
||||||
errors = append(errors, "message cannot contain more than 10 new lines")
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors
|
|
||||||
}
|
|
||||||
|
|
||||||
func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
name := req.FormValue("name")
|
|
||||||
message := req.FormValue("message")
|
|
||||||
|
|
||||||
formErrors := types.FormError{
|
|
||||||
Errors: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &database.GuestbookEntry{
|
|
||||||
ID: utils.RandomId(),
|
|
||||||
Name: name,
|
|
||||||
Message: message,
|
|
||||||
}
|
|
||||||
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
_, err := database.SaveGuestbookEntry(context.DBConn, entry)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "failed to save entry")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) > 0 {
|
|
||||||
(*context.TemplateData)["FormError"] = formErrors
|
|
||||||
(*context.TemplateData)["EntryForm"] = entry
|
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
|
||||||
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
entries, err := database.GetGuestbookEntries(context.DBConn)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
(*context.TemplateData)["GuestbookEntries"] = entries
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,136 +0,0 @@
|
||||||
package guestbook_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setup() (*sql.DB, *types.RequestContext, func()) {
|
|
||||||
randomDb := utils.RandomId()
|
|
||||||
|
|
||||||
testDb := database.MakeConn(&randomDb)
|
|
||||||
database.Migrate(testDb)
|
|
||||||
|
|
||||||
context := &types.RequestContext{
|
|
||||||
DBConn: testDb,
|
|
||||||
Args: &args.Arguments{},
|
|
||||||
TemplateData: &(map[string]interface{}{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
return testDb, context, func() {
|
|
||||||
testDb.Close()
|
|
||||||
os.Remove(randomDb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidGuestbookPutsInDatabase(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
entries, err := database.GetGuestbookEntries(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(entries) > 0 {
|
|
||||||
t.Errorf("expected no entries, got entries")
|
|
||||||
}
|
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", ts.URL, nil)
|
|
||||||
req.Form = map[string][]string{
|
|
||||||
"name": {"test"},
|
|
||||||
"message": {"test"},
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
ts.Config.Handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected status code 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
entries, err = database.GetGuestbookEntries(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(entries) != 1 {
|
|
||||||
t.Errorf("expected 1 entry, got %d", len(entries))
|
|
||||||
}
|
|
||||||
|
|
||||||
if entries[0].Name != req.FormValue("name") {
|
|
||||||
t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) {
|
|
||||||
db, context, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
entries, err := database.GetGuestbookEntries(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(entries) > 0 {
|
|
||||||
t.Errorf("expected no entries, got entries")
|
|
||||||
}
|
|
||||||
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n"
|
|
||||||
invalidRequests := []struct {
|
|
||||||
name string
|
|
||||||
message string
|
|
||||||
}{
|
|
||||||
{"", "test"},
|
|
||||||
{"test", ""},
|
|
||||||
{"", ""},
|
|
||||||
{"test", reallyLongStringThatWouldTakeTooMuchSpace},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, form := range invalidRequests {
|
|
||||||
req := httptest.NewRequest("POST", testServer.URL, nil)
|
|
||||||
req.Form = map[string][]string{
|
|
||||||
"name": {form.name},
|
|
||||||
"message": {form.message},
|
|
||||||
}
|
|
||||||
|
|
||||||
responseRecorder := httptest.NewRecorder()
|
|
||||||
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
|
||||||
|
|
||||||
if responseRecorder.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected status code 400, got %d", responseRecorder.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
entries, err = database.GetGuestbookEntries(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(entries) != 0 {
|
|
||||||
t.Errorf("expected 0 entries, got %d", len(entries))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,75 +0,0 @@
|
||||||
package hcaptcha
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
type HcaptchaArgs struct {
|
|
||||||
SiteKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyCaptcha(secret, response string) error {
|
|
||||||
verifyURL := "https://hcaptcha.com/siteverify"
|
|
||||||
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", verifyURL, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse := struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
}{}
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !jsonResponse.Success {
|
|
||||||
return fmt.Errorf("hcaptcha verification failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
|
||||||
SiteKey: context.Args.HcaptchaSiteKey,
|
|
||||||
}
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
||||||
hCaptchaResponse := req.FormValue("h-captcha-response")
|
|
||||||
secretKey := context.Args.HcaptchaSecret
|
|
||||||
|
|
||||||
err := verifyCaptcha(secretKey, hCaptchaResponse)
|
|
||||||
if err != nil {
|
|
||||||
(*context.TemplateData)["FormError"] = types.FormError{
|
|
||||||
Errors: []string{"hCaptcha verification failed"},
|
|
||||||
}
|
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
|
||||||
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
88
api/serve.go
88
api/serve.go
|
@ -7,20 +7,27 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
type RequestContext struct {
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
DBConn *sql.DB
|
||||||
|
Args *args.Arguments
|
||||||
|
|
||||||
|
Id string
|
||||||
|
Start time.Time
|
||||||
|
|
||||||
|
TemplateData *map[string]interface{}
|
||||||
|
User *database.User
|
||||||
|
}
|
||||||
|
|
||||||
|
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
||||||
|
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
||||||
|
|
||||||
|
func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
|
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||||
context.Start = time.Now()
|
context.Start = time.Now()
|
||||||
context.Id = utils.RandomId()
|
context.Id = utils.RandomId()
|
||||||
|
|
||||||
|
@ -29,8 +36,8 @@ func LogRequestContinuation(context *types.RequestContext, req *http.Request, re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||||
end := time.Now()
|
end := time.Now()
|
||||||
|
|
||||||
log.Println(context.Id, "took", end.Sub(context.Start))
|
log.Println(context.Id, "took", end.Sub(context.Start))
|
||||||
|
@ -39,22 +46,22 @@ func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Reque
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||||
resp.WriteHeader(200)
|
resp.WriteHeader(200)
|
||||||
resp.Write([]byte("healthy"))
|
resp.Write([]byte("healthy"))
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(_success Continuation, failure Continuation) ContinuationChain {
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, _failure Continuation) ContinuationChain {
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,90 +80,89 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
fileServer := http.FileServer(http.Dir(argv.StaticPath))
|
fileServer := http.FileServer(http.Dir(argv.StaticPath))
|
||||||
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
|
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
|
||||||
|
|
||||||
cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
|
makeRequestContext := func() *RequestContext {
|
||||||
APIToken: argv.CloudflareToken,
|
return &RequestContext{
|
||||||
ZoneId: argv.CloudflareZone,
|
|
||||||
}
|
|
||||||
|
|
||||||
makeRequestContext := func() *types.RequestContext {
|
|
||||||
return &types.RequestContext{
|
|
||||||
DBConn: dbConn,
|
DBConn: dbConn,
|
||||||
Args: argv,
|
Args: argv,
|
||||||
|
|
||||||
TemplateData: &map[string]interface{}{},
|
TemplateData: &map[string]interface{}{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestContext := makeRequestContext()
|
||||||
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
const MAX_USER_RECORDS = 100
|
|
||||||
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
|
||||||
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
name := r.PathValue("name")
|
name := r.PathValue("name")
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package template
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -7,11 +7,9 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
|
func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
|
||||||
templatePath := context.Args.TemplatePath
|
templatePath := context.Args.TemplatePath
|
||||||
basePath := templatePath + "/base_empty.html"
|
basePath := templatePath + "/base_empty.html"
|
||||||
if showBaseHtml {
|
if showBaseHtml {
|
||||||
|
@ -43,9 +41,9 @@ func renderTemplate(context *types.RequestContext, templateName string, showBase
|
||||||
return buffer, nil
|
return buffer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TemplateContinuation(path string, showBase bool) types.Continuation {
|
func TemplateContinuation(path string, showBase bool) Continuation {
|
||||||
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||||
html, err := renderTemplate(context, path, true)
|
html, err := renderTemplate(context, path, true)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
resp.WriteHeader(404)
|
resp.WriteHeader(404)
|
||||||
|
@ -68,6 +66,7 @@ func TemplateContinuation(path string, showBase bool) types.Continuation {
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp.WriteHeader(200)
|
||||||
resp.Header().Set("Content-Type", "text/html")
|
resp.Header().Set("Content-Type", "text/html")
|
||||||
resp.Write(html.Bytes())
|
resp.Write(html.Bytes())
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
|
@ -1,28 +0,0 @@
|
||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RequestContext struct {
|
|
||||||
DBConn *sql.DB
|
|
||||||
Args *args.Arguments
|
|
||||||
|
|
||||||
Id string
|
|
||||||
Start time.Time
|
|
||||||
|
|
||||||
TemplateData *map[string]interface{}
|
|
||||||
User *database.User
|
|
||||||
}
|
|
||||||
|
|
||||||
type FormError struct {
|
|
||||||
Errors []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
|
||||||
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
|
|
@ -23,6 +23,7 @@ type Arguments struct {
|
||||||
OauthUserInfoURI string
|
OauthUserInfoURI string
|
||||||
|
|
||||||
Dns bool
|
Dns bool
|
||||||
|
DnsRecursion []string
|
||||||
DnsPort int
|
DnsPort int
|
||||||
|
|
||||||
CloudflareToken string
|
CloudflareToken string
|
||||||
|
@ -44,6 +45,7 @@ func GetArgs() (*Arguments, error) {
|
||||||
server := flag.Bool("server", false, "Run the server")
|
server := flag.Bool("server", false, "Run the server")
|
||||||
|
|
||||||
dns := flag.Bool("dns", false, "Run DNS resolver")
|
dns := flag.Bool("dns", false, "Run DNS resolver")
|
||||||
|
dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
|
||||||
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
|
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
@ -102,6 +104,7 @@ func GetArgs() (*Arguments, error) {
|
||||||
Migrate: *migrate,
|
Migrate: *migrate,
|
||||||
Scheduler: *scheduler,
|
Scheduler: *scheduler,
|
||||||
Dns: *dns,
|
Dns: *dns,
|
||||||
|
DnsRecursion: strings.Split(*dnsRecursion, ","),
|
||||||
DnsPort: *dnsPort,
|
DnsPort: *dnsPort,
|
||||||
|
|
||||||
OauthConfig: oauthConfig,
|
OauthConfig: oauthConfig,
|
||||||
|
|
|
@ -9,12 +9,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DomainOwner struct {
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
Domain string `json:"domain"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DNSRecord struct {
|
type DNSRecord struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
@ -63,10 +57,7 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) {
|
||||||
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
|
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
|
||||||
log.Println("saving dns record", record.ID)
|
log.Println("saving dns record", record.ID)
|
||||||
|
|
||||||
if (record.CreatedAt == time.Time{}) {
|
|
||||||
record.CreatedAt = time.Now()
|
record.CreatedAt = time.Now()
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
|
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -146,15 +137,3 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err
|
||||||
|
|
||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) {
|
|
||||||
log.Println("saving domain owner", domainOwner.Domain)
|
|
||||||
|
|
||||||
domainOwner.CreatedAt = time.Now()
|
|
||||||
_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return domainOwner, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -111,18 +111,6 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) {
|
|
||||||
log.Println("saving session", session.ID)
|
|
||||||
|
|
||||||
_, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
|
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
|
||||||
newExpireAt := time.Now().Add(ExpiryDuration)
|
newExpireAt := time.Now().Add(ExpiryDuration)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package hcdns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
@ -9,28 +9,27 @@ import (
|
||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MAX_RECURSION = 15
|
const MAX_RECURSION = 10
|
||||||
|
|
||||||
|
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
||||||
|
if maxDepth == 0 {
|
||||||
|
return nil, fmt.Errorf("too much recursion")
|
||||||
|
}
|
||||||
|
|
||||||
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
|
||||||
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
|
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var answers []dns.RR
|
answers := []dns.RR{}
|
||||||
for _, record := range internalCnames {
|
for _, record := range internalCnames {
|
||||||
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
|
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
answers = append(answers, cname)
|
answers = append(answers, cname)
|
||||||
|
|
||||||
cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
|
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
answers = append(answers, cnameRecursive...)
|
answers = append(answers, cnameRecursive...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,26 +43,36 @@ func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, record := range typeDnsRecords {
|
for _, record := range typeDnsRecords {
|
||||||
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
|
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
answers = append(answers, answer)
|
answers = append(answers, answer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(answers) > 0 {
|
||||||
|
// base case; we found the answer
|
||||||
return answers, nil
|
return answers, nil
|
||||||
}
|
|
||||||
|
|
||||||
func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
|
||||||
if maxDepth == 0 {
|
|
||||||
return nil, fmt.Errorf("too much recursion")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
|
message := new(dns.Msg)
|
||||||
if err != nil {
|
message.SetQuestion(dns.Fqdn(domain), qtype)
|
||||||
|
message.RecursionDesired = true
|
||||||
|
|
||||||
|
client := new(dns.Client)
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
in, _, err := client.Exchange(message, dnsResolvers[i])
|
||||||
|
for err != nil {
|
||||||
|
i += 1
|
||||||
|
if i == len(dnsResolvers) {
|
||||||
|
log.Println(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
in, _, err = client.Exchange(message, dnsResolvers[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
answers = append(answers, in.Answer...)
|
||||||
return answers, nil
|
return answers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,26 +87,21 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
msg.Authoritative = true
|
msg.Authoritative = true
|
||||||
|
|
||||||
for _, question := range r.Question {
|
for _, question := range r.Question {
|
||||||
answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
|
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
msg.SetRcode(r, dns.RcodeServerFailure)
|
continue
|
||||||
w.WriteMsg(msg)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
msg.Answer = append(msg.Answer, answers...)
|
msg.Answer = append(msg.Answer, answers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(msg.Answer) == 0 {
|
|
||||||
msg.SetRcode(r, dns.RcodeNameError)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(msg.Answer)
|
log.Println(msg.Answer)
|
||||||
w.WriteMsg(msg)
|
w.WriteMsg(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
|
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
|
||||||
handler := &DnsHandler{
|
handler := &DnsHandler{
|
||||||
|
DnsResolvers: argv.DnsRecursion,
|
||||||
DbConn: dbConn,
|
DbConn: dbConn,
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf(":%d", argv.DnsPort)
|
addr := fmt.Sprintf(":%d", argv.DnsPort)
|
|
@ -1,254 +0,0 @@
|
||||||
package hcdns_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
func randomPort() int {
|
|
||||||
return rand.Intn(3000) + 5192
|
|
||||||
}
|
|
||||||
|
|
||||||
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
|
|
||||||
randomDb := utils.RandomId()
|
|
||||||
dnsPort := randomPort()
|
|
||||||
|
|
||||||
testDb := database.MakeConn(&randomDb)
|
|
||||||
database.Migrate(testDb)
|
|
||||||
testUser := &database.User{
|
|
||||||
ID: "test",
|
|
||||||
}
|
|
||||||
database.FindOrSaveUser(testDb, testUser)
|
|
||||||
|
|
||||||
waitLock := &sync.Mutex{}
|
|
||||||
server := hcdns.MakeServer(&args.Arguments{
|
|
||||||
DnsPort: dnsPort,
|
|
||||||
}, testDb)
|
|
||||||
server.NotifyStartedFunc = func() {
|
|
||||||
waitLock.Unlock()
|
|
||||||
}
|
|
||||||
waitLock.Lock()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
server.ListenAndServe()
|
|
||||||
}()
|
|
||||||
waitLock.Lock()
|
|
||||||
|
|
||||||
address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
|
|
||||||
return testDb, server, &address, waitLock, func() {
|
|
||||||
server.Shutdown()
|
|
||||||
|
|
||||||
testDb.Close()
|
|
||||||
os.Remove(randomDb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWhenCNAMEIsResolved(t *testing.T) {
|
|
||||||
t.Log("TestWhenCNAMEIsResolved")
|
|
||||||
|
|
||||||
testDb, _, addr, lock, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
records := []*database.DNSRecord{
|
|
||||||
{
|
|
||||||
ID: "0",
|
|
||||||
UserID: "test",
|
|
||||||
Name: "cname.internal.example.com.",
|
|
||||||
Type: "CNAME",
|
|
||||||
Content: "next.internal.example.com.",
|
|
||||||
TTL: 300,
|
|
||||||
Internal: true,
|
|
||||||
}, {
|
|
||||||
ID: "1",
|
|
||||||
UserID: "test",
|
|
||||||
Name: "next.internal.example.com.",
|
|
||||||
Type: "CNAME",
|
|
||||||
Content: "res.example.com.",
|
|
||||||
TTL: 300,
|
|
||||||
Internal: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "2",
|
|
||||||
UserID: "test",
|
|
||||||
Name: "res.example.com.",
|
|
||||||
Type: "A",
|
|
||||||
Content: "1.2.3.2",
|
|
||||||
TTL: 300,
|
|
||||||
Internal: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, record := range records {
|
|
||||||
database.SaveDNSRecord(testDb, record)
|
|
||||||
}
|
|
||||||
|
|
||||||
qtype := dns.TypeA
|
|
||||||
domain := dns.Fqdn("cname.internal.example.com.")
|
|
||||||
client := &dns.Client{}
|
|
||||||
message := &dns.Msg{}
|
|
||||||
message.SetQuestion(domain, qtype)
|
|
||||||
|
|
||||||
in, _, err := client.Exchange(message, *addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(in.Answer) != 3 {
|
|
||||||
t.Fatalf("expected 3 answers, got %d", len(in.Answer))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, record := range records {
|
|
||||||
if in.Answer[i].Header().Name != record.Name {
|
|
||||||
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
|
|
||||||
t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(in.Answer[i].Header().Ttl) != record.TTL {
|
|
||||||
t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !in.Authoritative {
|
|
||||||
t.Fatalf("expected authoritative response")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
|
|
||||||
t.Fatalf("expected final record to be the A record with correct IP")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWhenNoRecordNxDomain(t *testing.T) {
|
|
||||||
t.Log("TestWhenNoRecordNxDomain")
|
|
||||||
|
|
||||||
_, _, addr, lock, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
qtype := dns.TypeA
|
|
||||||
domain := dns.Fqdn("nonexistant.example.com.")
|
|
||||||
client := &dns.Client{}
|
|
||||||
message := &dns.Msg{}
|
|
||||||
message.SetQuestion(domain, qtype)
|
|
||||||
|
|
||||||
in, _, err := client.Exchange(message, *addr)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(in.Answer) != 0 {
|
|
||||||
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Rcode != dns.RcodeNameError {
|
|
||||||
t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWhenUnresolvingCNAME(t *testing.T) {
|
|
||||||
t.Log("TestWhenUnresolvingCNAME")
|
|
||||||
|
|
||||||
testDb, _, addr, lock, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
cname := &database.DNSRecord{
|
|
||||||
ID: "1",
|
|
||||||
UserID: "test",
|
|
||||||
Name: "cname.internal.example.com.",
|
|
||||||
Type: "CNAME",
|
|
||||||
Content: "nonexistant.example.com.",
|
|
||||||
TTL: 300,
|
|
||||||
Internal: true,
|
|
||||||
}
|
|
||||||
database.SaveDNSRecord(testDb, cname)
|
|
||||||
|
|
||||||
qtype := dns.TypeA
|
|
||||||
domain := dns.Fqdn(cname.Name)
|
|
||||||
client := &dns.Client{}
|
|
||||||
message := &dns.Msg{}
|
|
||||||
message.SetQuestion(domain, qtype)
|
|
||||||
|
|
||||||
in, _, err := client.Exchange(message, *addr)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(in.Answer) != 1 {
|
|
||||||
t.Fatalf("expected 1 answer, got %d", len(in.Answer))
|
|
||||||
}
|
|
||||||
|
|
||||||
if !in.Authoritative {
|
|
||||||
t.Fatalf("expected authoritative response")
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[0].Header().Name != cname.Name {
|
|
||||||
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
|
|
||||||
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Answer[0].(*dns.CNAME).Target != cname.Content {
|
|
||||||
t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Rcode == dns.RcodeNameError {
|
|
||||||
t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
|
|
||||||
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
|
|
||||||
|
|
||||||
testDb, _, addr, lock, cleanup := setup()
|
|
||||||
defer cleanup()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
cname := &database.DNSRecord{
|
|
||||||
ID: "1",
|
|
||||||
UserID: "test",
|
|
||||||
Name: "cname.internal.example.com.",
|
|
||||||
Type: "CNAME",
|
|
||||||
Content: "cname.internal.example.com.",
|
|
||||||
TTL: 300,
|
|
||||||
Internal: true,
|
|
||||||
}
|
|
||||||
database.SaveDNSRecord(testDb, cname)
|
|
||||||
|
|
||||||
qtype := dns.TypeA
|
|
||||||
domain := dns.Fqdn(cname.Name)
|
|
||||||
client := &dns.Client{}
|
|
||||||
message := &dns.Msg{}
|
|
||||||
message.SetQuestion(domain, qtype)
|
|
||||||
|
|
||||||
in, _, err := client.Exchange(message, *addr)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(in.Answer) > 0 {
|
|
||||||
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Rcode != dns.RcodeServerFailure {
|
|
||||||
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
|
|
||||||
}
|
|
||||||
}
|
|
4
main.go
4
main.go
|
@ -6,7 +6,7 @@ import (
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
)
|
)
|
||||||
|
@ -52,7 +52,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if argv.Dns {
|
if argv.Dns {
|
||||||
server := hcdns.MakeServer(argv, dbConn)
|
server := dns.MakeServer(argv, dbConn)
|
||||||
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
|
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
|
||||||
go func() {
|
go func() {
|
||||||
err = server.ListenAndServe()
|
err = server.ListenAndServe()
|
||||||
|
|
|
@ -15,22 +15,6 @@
|
||||||
padding: 0;
|
padding: 0;
|
||||||
color: var(--text-color);
|
color: var(--text-color);
|
||||||
font-family: "ComicSans", sans-serif;
|
font-family: "ComicSans", sans-serif;
|
||||||
|
|
||||||
cursor: url("/static/img/cursor-1.png"), auto;
|
|
||||||
-webkit-animation: cursor 400ms infinite;
|
|
||||||
animation: cursor 400ms infinite;
|
|
||||||
}
|
|
||||||
|
|
||||||
@-webkit-keyframes cursor {
|
|
||||||
0% {cursor: url("/static/img/cursor-2.png"), auto;}
|
|
||||||
50% {cursor: url("/static/img/cursor-1.png"), auto;}
|
|
||||||
100% {cursor: url("/static/img/cursor-2.png"), auto;}
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes cursor {
|
|
||||||
0% {cursor: url("/static/img/cursor-2.png"), auto;}
|
|
||||||
50% {cursor: url("/static/img/cursor-1.png"), auto;}
|
|
||||||
100% {cursor: url("/static/img/cursor-2.png"), auto;}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
body {
|
body {
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 570 B |
Binary file not shown.
Before Width: | Height: | Size: 563 B |
Loading…
Reference in New Issue