testing | dont be recursive for external domains | finalize oauth #5
|
@ -2,3 +2,4 @@
|
||||||
hatecomputers.club
|
hatecomputers.club
|
||||||
Dockerfile
|
Dockerfile
|
||||||
*.db
|
*.db
|
||||||
|
.drone.yml
|
||||||
|
|
29
.drone.yml
29
.drone.yml
|
@ -1,9 +1,30 @@
|
||||||
---
|
---
|
||||||
kind: pipeline
|
kind: pipeline
|
||||||
type: docker
|
type: docker
|
||||||
name: build, publish docker image, deploy
|
name: build
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
- name: run tests
|
||||||
|
image: golang
|
||||||
|
commands:
|
||||||
|
- go build
|
||||||
|
- go test -p 1 -v ./...
|
||||||
|
|
||||||
|
trigger:
|
||||||
|
event:
|
||||||
|
- pull_request
|
||||||
|
|
||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
type: docker
|
||||||
|
name: deploy
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: run tests
|
||||||
|
image: golang
|
||||||
|
commands:
|
||||||
|
- go build
|
||||||
|
- go test -p 1 -v ./...
|
||||||
- name: docker
|
- name: docker
|
||||||
image: plugins/docker
|
image: plugins/docker
|
||||||
settings:
|
settings:
|
||||||
|
@ -13,9 +34,6 @@ steps:
|
||||||
from_secret: gitea_packpub_password
|
from_secret: gitea_packpub_password
|
||||||
registry: git.hatecomputers.club
|
registry: git.hatecomputers.club
|
||||||
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
|
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
|
||||||
tags:
|
|
||||||
- latest
|
|
||||||
- main
|
|
||||||
- name: ssh
|
- name: ssh
|
||||||
image: appleboy/drone-ssh
|
image: appleboy/drone-ssh
|
||||||
settings:
|
settings:
|
||||||
|
@ -27,6 +45,9 @@ steps:
|
||||||
command_timeout: 2m
|
command_timeout: 2m
|
||||||
script:
|
script:
|
||||||
- systemctl restart docker-compose@hatecomputers-club
|
- systemctl restart docker-compose@hatecomputers-club
|
||||||
|
|
||||||
trigger:
|
trigger:
|
||||||
branch:
|
branch:
|
||||||
- main
|
- main
|
||||||
|
event:
|
||||||
|
- push
|
||||||
|
|
|
@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"]
|
CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]
|
||||||
|
|
|
@ -14,15 +14,20 @@ type CloudflareDNSResponse struct {
|
||||||
Result database.DNSRecord `json:"result"`
|
Result database.DNSRecord `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) {
|
type CloudflareExternalDNSAdapter struct {
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId)
|
ZoneId string
|
||||||
|
APIToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||||
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
||||||
payload := strings.NewReader(reqBody)
|
payload := strings.NewReader(reqBody)
|
||||||
|
|
||||||
req, _ := http.NewRequest("POST", url, payload)
|
req, _ := http.NewRequest("POST", url, payload)
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+apiToken)
|
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
|
||||||
return result.ID, nil
|
return result.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteDNSRecord(zoneId string, apiToken string, id string) error {
|
func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
|
||||||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id)
|
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
|
||||||
|
|
||||||
req, _ := http.NewRequest("DELETE", url, nil)
|
req, _ := http.NewRequest("DELETE", url, nil)
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+apiToken)
|
req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
package external_dns
|
||||||
|
|
||||||
|
import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
|
||||||
|
type ExternalDNSAdapter interface {
|
||||||
|
CreateDNSRecord(record *database.DNSRecord) (string, error)
|
||||||
|
DeleteDNSRecord(id string) error
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
@ -12,13 +12,14 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
verifier := utils.RandomId() + utils.RandomId()
|
verifier := utils.RandomId() + utils.RandomId()
|
||||||
|
|
||||||
sha2 := sha256.New()
|
sha2 := sha256.New()
|
||||||
|
@ -34,7 +35,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 60,
|
MaxAge: 200,
|
||||||
})
|
})
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "state",
|
Name: "state",
|
||||||
|
@ -42,7 +43,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 60,
|
MaxAge: 200,
|
||||||
})
|
})
|
||||||
|
|
||||||
http.Redirect(resp, req, url, http.StatusFound)
|
http.Redirect(resp, req, url, http.StatusFound)
|
||||||
|
@ -50,8 +51,8 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
state := req.URL.Query().Get("state")
|
state := req.URL.Query().Get("state")
|
||||||
code := req.URL.Query().Get("code")
|
code := req.URL.Query().Get("code")
|
||||||
|
|
||||||
|
@ -73,7 +74,6 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
||||||
reqContext := req.Context()
|
reqContext := req.Context()
|
||||||
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
@ -101,6 +101,16 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
})
|
})
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "verifier",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: 0,
|
||||||
|
})
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "state",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: 0,
|
||||||
|
})
|
||||||
|
|
||||||
redirect := "/"
|
redirect := "/"
|
||||||
redirectCookie, err := req.Cookie("redirect")
|
redirectCookie, err := req.Cookie("redirect")
|
||||||
|
@ -109,6 +119,7 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "redirect",
|
Name: "redirect",
|
||||||
MaxAge: 0,
|
MaxAge: 0,
|
||||||
|
Value: "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,6 +128,127 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
authHeader := req.Header.Get("Authorization")
|
||||||
|
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
||||||
|
|
||||||
|
sessionCookie, err := req.Cookie("session")
|
||||||
|
if err == nil && sessionCookie.Value != "" {
|
||||||
|
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userErr != nil || user == nil {
|
||||||
|
log.Println(userErr, user)
|
||||||
|
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "session",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
context.User = nil
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
context.User = user
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "redirect",
|
||||||
|
Value: req.URL.Path,
|
||||||
|
Path: "/",
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(resp, req, "/login", http.StatusFound)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
sessionCookie, err := req.Cookie("session")
|
||||||
|
if err != nil {
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
sessionCookie, err := req.Cookie("session")
|
||||||
|
if err == nil && sessionCookie.Value != "" {
|
||||||
|
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "session",
|
||||||
|
MaxAge: 0,
|
||||||
|
Value: "",
|
||||||
|
})
|
||||||
|
http.Redirect(resp, req, "/", http.StatusFound)
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
||||||
|
userResponse, err := client.Get(uri)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
userStruct, err := createUserFromOauthResponse(userResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := database.FindOrSaveUser(dbConn, userStruct)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
|
||||||
|
user := &database.User{}
|
||||||
|
err := json.NewDecoder(response.Body).Decode(user)
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Username = strings.ToLower(user.Username)
|
||||||
|
user.Username = strings.Split(user.Username, "@")[0]
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
||||||
|
cookie, err := req.Cookie(stateCookieName)
|
||||||
|
if err != nil || cookie.Value != expectedState {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
||||||
if bearerToken == "" {
|
if bearerToken == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -127,15 +259,15 @@ func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User,
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKey, err := database.GetAPIKey(dbConn, parts[1])
|
key, err := database.GetAPIKey(dbConn, parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if apiKey == nil {
|
if key == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := database.GetUser(dbConn, apiKey.UserID)
|
user, err := database.GetUser(dbConn, key.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -162,124 +294,3 @@ func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
authHeader := req.Header.Get("Authorization")
|
|
||||||
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
|
||||||
|
|
||||||
sessionCookie, err := req.Cookie("session")
|
|
||||||
if err == nil && sessionCookie.Value != "" {
|
|
||||||
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
if userErr != nil || user == nil {
|
|
||||||
log.Println(userErr, user)
|
|
||||||
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "session",
|
|
||||||
MaxAge: 0, // reset session cookie in case
|
|
||||||
})
|
|
||||||
|
|
||||||
context.User = nil
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
context.User = user
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "redirect",
|
|
||||||
Value: req.URL.Path,
|
|
||||||
Path: "/",
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
})
|
|
||||||
|
|
||||||
http.Redirect(resp, req, "/login", http.StatusFound)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
sessionCookie, err := req.Cookie("session")
|
|
||||||
if err != nil {
|
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
|
||||||
if err != nil {
|
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
sessionCookie, err := req.Cookie("session")
|
|
||||||
if err == nil && sessionCookie.Value != "" {
|
|
||||||
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Redirect(resp, req, "/", http.StatusFound)
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
|
||||||
Name: "session",
|
|
||||||
MaxAge: 0,
|
|
||||||
})
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
|
||||||
userResponse, err := client.Get(uri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userStruct, err := createUserFromResponse(userResponse)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := database.FindOrSaveUser(dbConn, userStruct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createUserFromResponse(response *http.Response) (*database.User, error) {
|
|
||||||
defer response.Body.Close()
|
|
||||||
user := &database.User{
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
err := json.NewDecoder(response.Body).Decode(user)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.Username = strings.ToLower(user.Username)
|
|
||||||
user.Username = strings.Split(user.Username, "@")[0]
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
|
||||||
cookie, err := req.Cookie(stateCookieName)
|
|
||||||
if err != nil || cookie.Value != expectedState {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
|
@ -0,0 +1,307 @@
|
||||||
|
package auth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FakedOauthServer() *httptest.Server {
|
||||||
|
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/auth" {
|
||||||
|
code := utils.RandomId()
|
||||||
|
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
redirectPath := r.URL.Query().Get("redirect_uri")
|
||||||
|
redirectPath += "?code=" + code + "&state=" + state
|
||||||
|
|
||||||
|
http.Redirect(w, r, redirectPath, http.StatusFound)
|
||||||
|
}
|
||||||
|
if r.URL.Path == "/token" {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
|
||||||
|
}
|
||||||
|
if r.URL.Path == "/user" {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
return oauthServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
resp.Write([]byte(context.User.Username))
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/protected-path" {
|
||||||
|
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/login" {
|
||||||
|
log.Println("login")
|
||||||
|
auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/callback" {
|
||||||
|
log.Println("callback")
|
||||||
|
auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/me" {
|
||||||
|
auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/logout" {
|
||||||
|
auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
return testServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
context := &types.RequestContext{
|
||||||
|
DBConn: testDb,
|
||||||
|
Args: &args.Arguments{},
|
||||||
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthServer := FakedOauthServer()
|
||||||
|
testServer := MockUserEndpointServer(context)
|
||||||
|
|
||||||
|
return testDb, context, oauthServer, testServer, func() {
|
||||||
|
oauthServer.Close()
|
||||||
|
testServer.Close()
|
||||||
|
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
|
||||||
|
return &oauth2.Config{
|
||||||
|
ClientID: "test",
|
||||||
|
ClientSecret: "test",
|
||||||
|
Scopes: []string{"test"},
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: oauthServerURL + "/auth",
|
||||||
|
TokenURL: oauthServerURL + "/token",
|
||||||
|
},
|
||||||
|
RedirectURL: testServerURL + "/callback",
|
||||||
|
}, oauthServerURL + "/user"
|
||||||
|
}
|
||||||
|
|
||||||
|
func FollowAuthentication(
|
||||||
|
oauthServer *httptest.Server,
|
||||||
|
testServer *httptest.Server,
|
||||||
|
cookies map[string]*http.Cookie,
|
||||||
|
location string,
|
||||||
|
) (map[string]*http.Cookie, string) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
resp.Code = 0
|
||||||
|
|
||||||
|
for resp.Code == 0 || resp.Code == http.StatusFound {
|
||||||
|
req := httptest.NewRequest("GET", location, nil)
|
||||||
|
resp = httptest.NewRecorder()
|
||||||
|
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(location, oauthServer.URL) {
|
||||||
|
oauthServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
} else {
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
}
|
||||||
|
for _, cookie := range resp.Result().Cookies() {
|
||||||
|
cookies[cookie.Name] = cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Code == http.StatusFound {
|
||||||
|
location = resp.Header().Get("Location")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cookies, location
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
|
||||||
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
user, _ := database.GetUser(db, "test")
|
||||||
|
if user != nil {
|
||||||
|
t.Errorf("expected no user, got user")
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||||
|
|
||||||
|
user, _ = database.GetUser(db, "test")
|
||||||
|
if user == nil {
|
||||||
|
t.Errorf("expected a user to be created, could not find user")
|
||||||
|
}
|
||||||
|
if user.Username != "test" {
|
||||||
|
t.Errorf("expected username to be test, got %s", user.Username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
|
||||||
|
_, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/protected-path", nil)
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
location := resp.Header().Get("Location")
|
||||||
|
if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
|
||||||
|
t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
|
||||||
|
|
||||||
|
if !(strings.HasSuffix(location, "/protected-page")) {
|
||||||
|
t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthSetsUniqueSession(t *testing.T) {
|
||||||
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||||
|
|
||||||
|
cookiesAgain := make(map[string]*http.Cookie)
|
||||||
|
cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
|
||||||
|
|
||||||
|
sessionOne := cookies["session"].Value
|
||||||
|
sessionTwo := cookiesAgain["session"].Value
|
||||||
|
if sessionOne == sessionTwo {
|
||||||
|
t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := database.GetSession(db, sessionOne)
|
||||||
|
if session.UserID != "test" {
|
||||||
|
t.Errorf("expected session to be associated with user test, got %s", session.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogoutClearsSession(t *testing.T) {
|
||||||
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/logout", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
for _, cookie := range resp.Result().Cookies() {
|
||||||
|
cookies[cookie.Name] = cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "/me", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
resp = httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
||||||
|
t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
if session != nil {
|
||||||
|
t.Errorf("expected session to be deleted, got session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshUpdatesExpiration(t *testing.T) {
|
||||||
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
||||||
|
|
||||||
|
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/me", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
updatedSession, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
|
||||||
|
if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
|
||||||
|
t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
|
||||||
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
||||||
|
|
||||||
|
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
session.ExpireAt = time.Now().Add(-time.Hour)
|
||||||
|
database.SaveSession(db, session)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/me", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
||||||
|
t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||||
|
}
|
||||||
|
}
|
179
api/dns.go
179
api/dns.go
|
@ -1,179 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
const MAX_USER_RECORDS = 65
|
|
||||||
|
|
||||||
type FormError struct {
|
|
||||||
Errors []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool {
|
|
||||||
ownedByUser := (user.ID == record.UserID)
|
|
||||||
if !ownedByUser {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !record.Internal {
|
|
||||||
userOwnedDomains := []string{
|
|
||||||
fmt.Sprintf("%s", user.Username),
|
|
||||||
fmt.Sprintf("%s.endpoints", user.Username),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range userOwnedDomains {
|
|
||||||
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
|
||||||
if domain == record.Name || isInSubDomain {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
userIsOwnerOfDomain := owner == user.ID
|
|
||||||
return ownedByUser && userIsOwnerOfDomain
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
(*context.TemplateData)["DNSRecords"] = dnsRecords
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
formErrors := FormError{
|
|
||||||
Errors: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
internal := req.FormValue("internal") == "on"
|
|
||||||
name := req.FormValue("name")
|
|
||||||
if internal && !strings.HasSuffix(name, ".") {
|
|
||||||
name += "."
|
|
||||||
}
|
|
||||||
|
|
||||||
recordType := req.FormValue("type")
|
|
||||||
recordType = strings.ToUpper(recordType)
|
|
||||||
|
|
||||||
recordContent := req.FormValue("content")
|
|
||||||
ttl := req.FormValue("ttl")
|
|
||||||
ttlNum, err := strconv.Atoi(ttl)
|
|
||||||
if err != nil {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
if dnsRecordCount >= MAX_USER_RECORDS {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecord := &database.DNSRecord{
|
|
||||||
UserID: context.User.ID,
|
|
||||||
Name: name,
|
|
||||||
Type: recordType,
|
|
||||||
Content: recordContent,
|
|
||||||
TTL: ttlNum,
|
|
||||||
Internal: internal,
|
|
||||||
}
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
if dnsRecord.Internal {
|
|
||||||
dnsRecord.ID = utils.RandomId()
|
|
||||||
} else {
|
|
||||||
cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsRecord.ID = cloudflareRecordId
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(formErrors.Errors) == 0 {
|
|
||||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
(*context.TemplateData)["FormError"] = &formErrors
|
|
||||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
|
||||||
|
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
recordId := req.FormValue("id")
|
|
||||||
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
|
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !record.Internal {
|
|
||||||
err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
|
||||||
if err != nil {
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Redirect(resp, req, "/dns", http.StatusFound)
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,174 @@
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
|
||||||
|
ownedByUser := (user.ID == record.UserID)
|
||||||
|
if !ownedByUser {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !record.Internal {
|
||||||
|
for _, format := range ownedInternalDomainFormats {
|
||||||
|
domain := fmt.Sprintf(format, user.Username)
|
||||||
|
|
||||||
|
isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
|
||||||
|
if domain == record.Name || isInSubDomain {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
userIsOwnerOfDomain := owner == user.ID
|
||||||
|
return ownedByUser && userIsOwnerOfDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
(*context.TemplateData)["DNSRecords"] = dnsRecords
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
formErrors := types.FormError{
|
||||||
|
Errors: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
|
||||||
|
name := req.FormValue("name")
|
||||||
|
if internal && !strings.HasSuffix(name, ".") {
|
||||||
|
name += "."
|
||||||
|
}
|
||||||
|
|
||||||
|
recordType := req.FormValue("type")
|
||||||
|
recordType = strings.ToUpper(recordType)
|
||||||
|
|
||||||
|
recordContent := req.FormValue("content")
|
||||||
|
ttl := req.FormValue("ttl")
|
||||||
|
ttlNum, err := strconv.Atoi(ttl)
|
||||||
|
if err != nil {
|
||||||
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
if dnsRecordCount >= maxUserRecords {
|
||||||
|
resp.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRecord := &database.DNSRecord{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Name: name,
|
||||||
|
Type: recordType,
|
||||||
|
Content: recordContent,
|
||||||
|
TTL: ttlNum,
|
||||||
|
Internal: internal,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
|
||||||
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
if dnsRecord.Internal {
|
||||||
|
dnsRecord.ID = utils.RandomId()
|
||||||
|
} else {
|
||||||
|
dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "error saving record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
(*context.TemplateData)["FormError"] = &formErrors
|
||||||
|
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
recordId := req.FormValue("id")
|
||||||
|
record, err := database.GetDNSRecord(context.DBConn, recordId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(record.UserID == context.User.ID) {
|
||||||
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !record.Internal {
|
||||||
|
err = dnsAdapter.DeleteDNSRecord(recordId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = database.DeleteDNSRecord(context.DBConn, recordId)
|
||||||
|
if err != nil {
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,442 @@
|
||||||
|
package dns_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MAX_USER_RECORDS = 10
|
||||||
|
|
||||||
|
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
||||||
|
|
||||||
|
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
user := &database.User{
|
||||||
|
ID: "test",
|
||||||
|
Username: "test",
|
||||||
|
Mail: "test@test.com",
|
||||||
|
DisplayName: "test",
|
||||||
|
}
|
||||||
|
database.FindOrSaveUser(testDb, user)
|
||||||
|
|
||||||
|
context := &types.RequestContext{
|
||||||
|
DBConn: testDb,
|
||||||
|
Args: &args.Arguments{},
|
||||||
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
User: user,
|
||||||
|
}
|
||||||
|
|
||||||
|
return testDb, context, func() {
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SignallingExternalDnsAdapter struct {
|
||||||
|
AddChannel chan *database.DNSRecord
|
||||||
|
RmChannel chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
|
||||||
|
id := utils.RandomId()
|
||||||
|
go func() { adapter.AddChannel <- record }()
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
|
||||||
|
go func() { adapter.RmChannel <- id }()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
domainOwner := &database.DomainOwner{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Domain: "test.domain.",
|
||||||
|
}
|
||||||
|
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
||||||
|
|
||||||
|
records, err := database.GetUserDNSRecords(db, context.User.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(records) > 0 {
|
||||||
|
t.Errorf("expected no records, got records")
|
||||||
|
}
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
validOwner.Form = map[string][]string{
|
||||||
|
"internal": {"on"},
|
||||||
|
"name": {"new.test.domain."},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"test.domain."},
|
||||||
|
}
|
||||||
|
|
||||||
|
validOwnerRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
||||||
|
if validOwnerRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
validOwnerNonInternalRecorder := httptest.NewRecorder()
|
||||||
|
validOwner.Form["internal"] = []string{"off"}
|
||||||
|
testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
|
||||||
|
if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidOwnerRecorder := httptest.NewRecorder()
|
||||||
|
invalidOwner := validOwner
|
||||||
|
invalidOwner.Form["internal"] = []string{"on"}
|
||||||
|
invalidOwner.Form["name"] = []string{"new.invalid.domain."}
|
||||||
|
testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
|
||||||
|
if invalidOwnerRecorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
|
||||||
|
for _, format := range fmts {
|
||||||
|
name := fmt.Sprintf(format, context.User.Username)
|
||||||
|
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"internal": {"off"},
|
||||||
|
"name": {name},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"test.domain."},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
||||||
|
if responseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
|
||||||
|
if len(namedRecords) == 0 {
|
||||||
|
t.Errorf("saved record not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatExternalDnsSaves(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
|
||||||
|
name := "test." + context.User.Username
|
||||||
|
externalRequest.Form = map[string][]string{
|
||||||
|
"internal": {"off"},
|
||||||
|
"name": {name},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"test.domain."},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
||||||
|
if responseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case res := <-addChannel:
|
||||||
|
if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
|
||||||
|
t.Errorf("received the wrong external record")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("timed out in waiting for external addition")
|
||||||
|
}
|
||||||
|
|
||||||
|
domainOwner := &database.DomainOwner{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Domain: "test.domain.",
|
||||||
|
}
|
||||||
|
domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
|
||||||
|
internalRequest := externalRequest
|
||||||
|
internalRequest.Form["internal"] = []string{"on"}
|
||||||
|
internalRequest.Form["name"] = []string{"test.domain."}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
|
||||||
|
if responseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case _ = <-addChannel:
|
||||||
|
t.Errorf("expected nothing in the add channel")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatUserMustOwnRecordToRemove(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
rmChannel := make(chan string)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
RmChannel: rmChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
|
||||||
|
_, err := database.FindOrSaveUser(db, nonOwnerUser)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &database.DNSRecord{
|
||||||
|
ID: "1",
|
||||||
|
Internal: false,
|
||||||
|
Name: "test",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "asdf",
|
||||||
|
TTL: 1000,
|
||||||
|
UserID: nonOwnerUser.ID,
|
||||||
|
}
|
||||||
|
_, err = database.SaveDNSRecord(db, record)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonOwnerRecorder := httptest.NewRecorder()
|
||||||
|
nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
nonOwner.Form = map[string][]string{
|
||||||
|
"id": {record.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
|
||||||
|
if nonOwnerRecorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
record.UserID = context.User.ID
|
||||||
|
record.ID = "2"
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
|
||||||
|
owner := nonOwner
|
||||||
|
owner.Form["id"] = []string{"2"}
|
||||||
|
ownerRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
|
||||||
|
if ownerRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", ownerRecorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThatExternalDnsRemoves(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
record := &database.DNSRecord{
|
||||||
|
ID: "1",
|
||||||
|
Internal: false,
|
||||||
|
Name: "test",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "asdf",
|
||||||
|
TTL: 1000,
|
||||||
|
UserID: context.User.ID,
|
||||||
|
}
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
|
||||||
|
rmChannel := make(chan string)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
RmChannel: rmChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
externalResponseRecorder := httptest.NewRecorder()
|
||||||
|
deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
|
||||||
|
deleteRequest.Form = map[string][]string{
|
||||||
|
"id": {record.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
|
||||||
|
if externalResponseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case res := <-rmChannel:
|
||||||
|
if res != record.ID {
|
||||||
|
t.Errorf("received the wrong external record")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("timed out in waiting for external addition")
|
||||||
|
}
|
||||||
|
|
||||||
|
record.Internal = true
|
||||||
|
record.Name = "test.domain."
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
domainOwner := &database.DomainOwner{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Domain: "test.domain.",
|
||||||
|
}
|
||||||
|
database.SaveDomainOwner(db, domainOwner)
|
||||||
|
|
||||||
|
internalResponseRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
|
||||||
|
if internalResponseRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case _ = <-rmChannel:
|
||||||
|
t.Errorf("expected nothing in the rmchannel")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordCountCannotExceed(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
record := &database.DNSRecord{
|
||||||
|
Internal: false,
|
||||||
|
Name: context.User.Username,
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "asdf",
|
||||||
|
TTL: 1000,
|
||||||
|
UserID: context.User.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= MAX_USER_RECORDS; i++ {
|
||||||
|
record.ID = strconv.Itoa(i)
|
||||||
|
record.Name = record.ID + "." + record.Name
|
||||||
|
database.SaveDNSRecord(db, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"internal": {"off"},
|
||||||
|
"name": {record.Name},
|
||||||
|
"type": {record.Type},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {record.Content},
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(recorder, req)
|
||||||
|
if recorder.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected too many requests code return, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
domainOwner := &database.DomainOwner{
|
||||||
|
UserID: context.User.ID,
|
||||||
|
Domain: "test.internal.",
|
||||||
|
}
|
||||||
|
database.SaveDomainOwner(db, domainOwner)
|
||||||
|
|
||||||
|
addChannel := make(chan *database.DNSRecord)
|
||||||
|
signallingDnsAdapter := &SignallingExternalDnsAdapter{
|
||||||
|
AddChannel: addChannel,
|
||||||
|
}
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
validOwner := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
validOwner.Form = map[string][]string{
|
||||||
|
"internal": {"on"},
|
||||||
|
"name": {"test.internal"},
|
||||||
|
"type": {"CNAME"},
|
||||||
|
"ttl": {"43000"},
|
||||||
|
"content": {"asdf.internal"},
|
||||||
|
}
|
||||||
|
|
||||||
|
validOwnerRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
|
||||||
|
if validOwnerRecorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
|
||||||
|
recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
|
||||||
|
|
||||||
|
if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
|
||||||
|
t.Errorf("expected dot appended")
|
||||||
|
}
|
||||||
|
}
|
141
api/guestbook.go
141
api/guestbook.go
|
@ -1,141 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type HcaptchaArgs struct {
|
|
||||||
SiteKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
|
|
||||||
errors := []string{}
|
|
||||||
|
|
||||||
if entry.Name == "" {
|
|
||||||
errors = append(errors, "name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
if entry.Message == "" {
|
|
||||||
errors = append(errors, "message is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
messageLength := len(entry.Message)
|
|
||||||
if messageLength > 500 {
|
|
||||||
errors = append(errors, "message cannot be longer than 500 characters")
|
|
||||||
}
|
|
||||||
|
|
||||||
newLines := strings.Count(entry.Message, "\n")
|
|
||||||
if newLines > 10 {
|
|
||||||
errors = append(errors, "message cannot contain more than 10 new lines")
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors
|
|
||||||
}
|
|
||||||
|
|
||||||
func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
name := req.FormValue("name")
|
|
||||||
message := req.FormValue("message")
|
|
||||||
hCaptchaResponse := req.FormValue("h-captcha-response")
|
|
||||||
|
|
||||||
formErrors := FormError{
|
|
||||||
Errors: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
if hCaptchaResponse == "" {
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "hCaptcha is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &database.GuestbookEntry{
|
|
||||||
ID: utils.RandomId(),
|
|
||||||
Name: name,
|
|
||||||
Message: message,
|
|
||||||
}
|
|
||||||
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
|
||||||
|
|
||||||
err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
|
|
||||||
formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed")
|
|
||||||
}
|
|
||||||
if len(formErrors.Errors) > 0 {
|
|
||||||
(*context.TemplateData)["FormError"] = formErrors
|
|
||||||
(*context.TemplateData)["EntryForm"] = entry
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = database.SaveGuestbookEntry(context.DBConn, entry)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
entries, err := database.GetGuestbookEntries(context.DBConn)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return failure(context, req, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
(*context.TemplateData)["GuestbookEntries"] = entries
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
||||||
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
|
||||||
SiteKey: context.Args.HcaptchaSiteKey,
|
|
||||||
}
|
|
||||||
log.Println(context.Args.HcaptchaSiteKey)
|
|
||||||
return success(context, req, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyHCaptcha(secret, response string) error {
|
|
||||||
verifyURL := "https://hcaptcha.com/siteverify"
|
|
||||||
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", verifyURL, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse := struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
}{}
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !jsonResponse.Success {
|
|
||||||
return fmt.Errorf("hcaptcha verification failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
package guestbook
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
|
||||||
|
errors := []string{}
|
||||||
|
|
||||||
|
if entry.Name == "" {
|
||||||
|
errors = append(errors, "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.Message == "" {
|
||||||
|
errors = append(errors, "message is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
messageLength := len(entry.Message)
|
||||||
|
if messageLength > 500 {
|
||||||
|
errors = append(errors, "message cannot be longer than 500 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
newLines := strings.Count(entry.Message, "\n")
|
||||||
|
if newLines > 10 {
|
||||||
|
errors = append(errors, "message cannot contain more than 10 new lines")
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
name := req.FormValue("name")
|
||||||
|
message := req.FormValue("message")
|
||||||
|
|
||||||
|
formErrors := types.FormError{
|
||||||
|
Errors: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &database.GuestbookEntry{
|
||||||
|
ID: utils.RandomId(),
|
||||||
|
Name: name,
|
||||||
|
Message: message,
|
||||||
|
}
|
||||||
|
formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
|
||||||
|
|
||||||
|
if len(formErrors.Errors) == 0 {
|
||||||
|
_, err := database.SaveGuestbookEntry(context.DBConn, entry)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
formErrors.Errors = append(formErrors.Errors, "failed to save entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formErrors.Errors) > 0 {
|
||||||
|
(*context.TemplateData)["FormError"] = formErrors
|
||||||
|
(*context.TemplateData)["EntryForm"] = entry
|
||||||
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
|
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
entries, err := database.GetGuestbookEntries(context.DBConn)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
(*context.TemplateData)["GuestbookEntries"] = entries
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,136 @@
|
||||||
|
package guestbook_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
context := &types.RequestContext{
|
||||||
|
DBConn: testDb,
|
||||||
|
Args: &args.Arguments{},
|
||||||
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return testDb, context, func() {
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidGuestbookPutsInDatabase(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
entries, err := database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(entries) > 0 {
|
||||||
|
t.Errorf("expected no entries, got entries")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", ts.URL, nil)
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"name": {"test"},
|
||||||
|
"message": {"test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ts.Config.Handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status code 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err = database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Errorf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
if entries[0].Name != req.FormValue("name") {
|
||||||
|
t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
entries, err := database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(entries) > 0 {
|
||||||
|
t.Errorf("expected no entries, got entries")
|
||||||
|
}
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n"
|
||||||
|
invalidRequests := []struct {
|
||||||
|
name string
|
||||||
|
message string
|
||||||
|
}{
|
||||||
|
{"", "test"},
|
||||||
|
{"test", ""},
|
||||||
|
{"", ""},
|
||||||
|
{"test", reallyLongStringThatWouldTakeTooMuchSpace},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, form := range invalidRequests {
|
||||||
|
req := httptest.NewRequest("POST", testServer.URL, nil)
|
||||||
|
req.Form = map[string][]string{
|
||||||
|
"name": {form.name},
|
||||||
|
"message": {form.message},
|
||||||
|
}
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(responseRecorder, req)
|
||||||
|
|
||||||
|
if responseRecorder.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status code 400, got %d", responseRecorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err = database.GetGuestbookEntries(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Errorf("expected 0 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
package hcaptcha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HcaptchaArgs struct {
|
||||||
|
SiteKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyCaptcha(secret, response string) error {
|
||||||
|
verifyURL := "https://hcaptcha.com/siteverify"
|
||||||
|
body := strings.NewReader("secret=" + secret + "&response=" + response)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", verifyURL, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse := struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}{}
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !jsonResponse.Success {
|
||||||
|
return fmt.Errorf("hcaptcha verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
|
||||||
|
SiteKey: context.Args.HcaptchaSiteKey,
|
||||||
|
}
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
hCaptchaResponse := req.FormValue("h-captcha-response")
|
||||||
|
secretKey := context.Args.HcaptchaSecret
|
||||||
|
|
||||||
|
err := verifyCaptcha(secretKey, hCaptchaResponse)
|
||||||
|
if err != nil {
|
||||||
|
(*context.TemplateData)["FormError"] = types.FormError{
|
||||||
|
Errors: []string{"hCaptcha verification failed"},
|
||||||
|
}
|
||||||
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
|
|
||||||
|
return failure(context, req, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,32 +1,33 @@
|
||||||
package api
|
package keys
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MAX_USER_API_KEYS = 5
|
const MAX_USER_API_KEYS = 5
|
||||||
|
|
||||||
func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
(*context.TemplateData)["APIKeys"] = apiKeys
|
(*context.TemplateData)["APIKeys"] = typesKeys
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
formErrors := FormError{
|
formErrors := types.FormError{
|
||||||
Errors: []string{},
|
Errors: []string{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
|
||||||
}
|
}
|
||||||
|
|
||||||
if numKeys >= MAX_USER_API_KEYS {
|
if numKeys >= MAX_USER_API_KEYS {
|
||||||
formErrors.Errors = append(formErrors.Errors, "max api keys reached")
|
formErrors.Errors = append(formErrors.Errors, "max types keys reached")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(formErrors.Errors) > 0 {
|
if len(formErrors.Errors) > 0 {
|
||||||
|
@ -59,29 +60,28 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
key := req.FormValue("key")
|
apiKey := req.FormValue("key")
|
||||||
|
|
||||||
apiKey, err := database.GetAPIKey(context.DBConn, key)
|
key, err := database.GetAPIKey(context.DBConn, apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
if (apiKey == nil) || (apiKey.UserID != context.User.ID) {
|
if (key == nil) || (key.UserID != context.User.ID) {
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = database.DeleteAPIKey(context.DBConn, key)
|
err = database.DeleteAPIKey(context.DBConn, apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Redirect(resp, req, "/keys", http.StatusFound)
|
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
88
api/serve.go
88
api/serve.go
|
@ -7,27 +7,20 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestContext struct {
|
func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
DBConn *sql.DB
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
Args *args.Arguments
|
|
||||||
|
|
||||||
Id string
|
|
||||||
Start time.Time
|
|
||||||
|
|
||||||
TemplateData *map[string]interface{}
|
|
||||||
User *database.User
|
|
||||||
}
|
|
||||||
|
|
||||||
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
|
||||||
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
|
||||||
|
|
||||||
func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
||||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
|
||||||
context.Start = time.Now()
|
context.Start = time.Now()
|
||||||
context.Id = utils.RandomId()
|
context.Id = utils.RandomId()
|
||||||
|
|
||||||
|
@ -36,8 +29,8 @@ func LogRequestContinuation(context *RequestContext, req *http.Request, resp htt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
end := time.Now()
|
end := time.Now()
|
||||||
|
|
||||||
log.Println(context.Id, "took", end.Sub(context.Start))
|
log.Println(context.Id, "took", end.Sub(context.Start))
|
||||||
|
@ -46,22 +39,22 @@ func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
resp.WriteHeader(200)
|
resp.WriteHeader(200)
|
||||||
resp.Write([]byte("healthy"))
|
resp.Write([]byte("healthy"))
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(_success Continuation, failure Continuation) ContinuationChain {
|
return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, _failure Continuation) ContinuationChain {
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,89 +73,90 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
fileServer := http.FileServer(http.Dir(argv.StaticPath))
|
fileServer := http.FileServer(http.Dir(argv.StaticPath))
|
||||||
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
|
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
|
||||||
|
|
||||||
makeRequestContext := func() *RequestContext {
|
cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
|
||||||
return &RequestContext{
|
APIToken: argv.CloudflareToken,
|
||||||
|
ZoneId: argv.CloudflareZone,
|
||||||
|
}
|
||||||
|
|
||||||
|
makeRequestContext := func() *types.RequestContext {
|
||||||
|
return &types.RequestContext{
|
||||||
DBConn: dbConn,
|
DBConn: dbConn,
|
||||||
Args: argv,
|
Args: argv,
|
||||||
|
|
||||||
TemplateData: &map[string]interface{}{},
|
TemplateData: &map[string]interface{}{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
|
||||||
|
|
||||||
mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
requestContext := makeRequestContext()
|
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const MAX_USER_RECORDS = 100
|
||||||
|
var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
|
||||||
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
name := r.PathValue("name")
|
name := r.PathValue("name")
|
||||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package template
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -7,9 +7,11 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
|
func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
|
||||||
templatePath := context.Args.TemplatePath
|
templatePath := context.Args.TemplatePath
|
||||||
basePath := templatePath + "/base_empty.html"
|
basePath := templatePath + "/base_empty.html"
|
||||||
if showBaseHtml {
|
if showBaseHtml {
|
||||||
|
@ -41,9 +43,9 @@ func renderTemplate(context *RequestContext, templateName string, showBaseHtml b
|
||||||
return buffer, nil
|
return buffer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TemplateContinuation(path string, showBase bool) Continuation {
|
func TemplateContinuation(path string, showBase bool) types.Continuation {
|
||||||
return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
html, err := renderTemplate(context, path, true)
|
html, err := renderTemplate(context, path, true)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
resp.WriteHeader(404)
|
resp.WriteHeader(404)
|
||||||
|
@ -66,7 +68,6 @@ func TemplateContinuation(path string, showBase bool) Continuation {
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.WriteHeader(200)
|
|
||||||
resp.Header().Set("Content-Type", "text/html")
|
resp.Header().Set("Content-Type", "text/html")
|
||||||
resp.Write(html.Bytes())
|
resp.Write(html.Bytes())
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
|
@ -0,0 +1,28 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestContext struct {
|
||||||
|
DBConn *sql.DB
|
||||||
|
Args *args.Arguments
|
||||||
|
|
||||||
|
Id string
|
||||||
|
Start time.Time
|
||||||
|
|
||||||
|
TemplateData *map[string]interface{}
|
||||||
|
User *database.User
|
||||||
|
}
|
||||||
|
|
||||||
|
type FormError struct {
|
||||||
|
Errors []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
|
||||||
|
type ContinuationChain func(Continuation, Continuation) ContinuationChain
|
|
@ -23,7 +23,6 @@ type Arguments struct {
|
||||||
OauthUserInfoURI string
|
OauthUserInfoURI string
|
||||||
|
|
||||||
Dns bool
|
Dns bool
|
||||||
DnsRecursion []string
|
|
||||||
DnsPort int
|
DnsPort int
|
||||||
|
|
||||||
CloudflareToken string
|
CloudflareToken string
|
||||||
|
@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
|
||||||
server := flag.Bool("server", false, "Run the server")
|
server := flag.Bool("server", false, "Run the server")
|
||||||
|
|
||||||
dns := flag.Bool("dns", false, "Run DNS resolver")
|
dns := flag.Bool("dns", false, "Run DNS resolver")
|
||||||
dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
|
|
||||||
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
|
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
|
||||||
Migrate: *migrate,
|
Migrate: *migrate,
|
||||||
Scheduler: *scheduler,
|
Scheduler: *scheduler,
|
||||||
Dns: *dns,
|
Dns: *dns,
|
||||||
DnsRecursion: strings.Split(*dnsRecursion, ","),
|
|
||||||
DnsPort: *dnsPort,
|
DnsPort: *dnsPort,
|
||||||
|
|
||||||
OauthConfig: oauthConfig,
|
OauthConfig: oauthConfig,
|
||||||
|
|
|
@ -9,6 +9,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type DomainOwner struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
type DNSRecord struct {
|
type DNSRecord struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
@ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) {
|
||||||
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
|
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
|
||||||
log.Println("saving dns record", record.ID)
|
log.Println("saving dns record", record.ID)
|
||||||
|
|
||||||
|
if (record.CreatedAt == time.Time{}) {
|
||||||
record.CreatedAt = time.Now()
|
record.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
|
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err
|
||||||
|
|
||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) {
|
||||||
|
log.Println("saving domain owner", domainOwner.Domain)
|
||||||
|
|
||||||
|
domainOwner.CreatedAt = time.Now()
|
||||||
|
_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return domainOwner, nil
|
||||||
|
}
|
||||||
|
|
|
@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) {
|
||||||
|
log.Println("saving session", session.ID)
|
||||||
|
|
||||||
|
_, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
|
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
|
||||||
newExpireAt := time.Now().Add(ExpiryDuration)
|
newExpireAt := time.Now().Add(ExpiryDuration)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package dns
|
package hcdns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
@ -9,27 +9,28 @@ import (
|
||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MAX_RECURSION = 10
|
const MAX_RECURSION = 15
|
||||||
|
|
||||||
func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
|
||||||
if maxDepth == 0 {
|
|
||||||
return nil, fmt.Errorf("too much recursion")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
||||||
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
|
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
answers := []dns.RR{}
|
var answers []dns.RR
|
||||||
for _, record := range internalCnames {
|
for _, record := range internalCnames {
|
||||||
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
|
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
answers = append(answers, cname)
|
answers = append(answers, cname)
|
||||||
|
|
||||||
cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
|
cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
answers = append(answers, cnameRecursive...)
|
answers = append(answers, cnameRecursive...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,36 +44,26 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, record := range typeDnsRecords {
|
for _, record := range typeDnsRecords {
|
||||||
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
|
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
answers = append(answers, answer)
|
answers = append(answers, answer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(answers) > 0 {
|
|
||||||
// base case; we found the answer
|
|
||||||
return answers, nil
|
return answers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
|
||||||
|
if maxDepth == 0 {
|
||||||
|
return nil, fmt.Errorf("too much recursion")
|
||||||
}
|
}
|
||||||
|
|
||||||
message := new(dns.Msg)
|
answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
|
||||||
message.SetQuestion(dns.Fqdn(domain), qtype)
|
if err != nil {
|
||||||
message.RecursionDesired = true
|
|
||||||
|
|
||||||
client := new(dns.Client)
|
|
||||||
|
|
||||||
i := 0
|
|
||||||
in, _, err := client.Exchange(message, dnsResolvers[i])
|
|
||||||
for err != nil {
|
|
||||||
i += 1
|
|
||||||
if i == len(dnsResolvers) {
|
|
||||||
log.Println(err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
in, _, err = client.Exchange(message, dnsResolvers[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
answers = append(answers, in.Answer...)
|
|
||||||
return answers, nil
|
return answers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,21 +78,26 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
msg.Authoritative = true
|
msg.Authoritative = true
|
||||||
|
|
||||||
for _, question := range r.Question {
|
for _, question := range r.Question {
|
||||||
answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
|
answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
continue
|
msg.SetRcode(r, dns.RcodeServerFailure)
|
||||||
|
w.WriteMsg(msg)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
msg.Answer = append(msg.Answer, answers...)
|
msg.Answer = append(msg.Answer, answers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(msg.Answer) == 0 {
|
||||||
|
msg.SetRcode(r, dns.RcodeNameError)
|
||||||
|
}
|
||||||
|
|
||||||
log.Println(msg.Answer)
|
log.Println(msg.Answer)
|
||||||
w.WriteMsg(msg)
|
w.WriteMsg(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
|
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
|
||||||
handler := &DnsHandler{
|
handler := &DnsHandler{
|
||||||
DnsResolvers: argv.DnsRecursion,
|
|
||||||
DbConn: dbConn,
|
DbConn: dbConn,
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf(":%d", argv.DnsPort)
|
addr := fmt.Sprintf(":%d", argv.DnsPort)
|
|
@ -0,0 +1,254 @@
|
||||||
|
package hcdns_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func randomPort() int {
|
||||||
|
return rand.Intn(3000) + 5192
|
||||||
|
}
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
dnsPort := randomPort()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
testUser := &database.User{
|
||||||
|
ID: "test",
|
||||||
|
}
|
||||||
|
database.FindOrSaveUser(testDb, testUser)
|
||||||
|
|
||||||
|
waitLock := &sync.Mutex{}
|
||||||
|
server := hcdns.MakeServer(&args.Arguments{
|
||||||
|
DnsPort: dnsPort,
|
||||||
|
}, testDb)
|
||||||
|
server.NotifyStartedFunc = func() {
|
||||||
|
waitLock.Unlock()
|
||||||
|
}
|
||||||
|
waitLock.Lock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
server.ListenAndServe()
|
||||||
|
}()
|
||||||
|
waitLock.Lock()
|
||||||
|
|
||||||
|
address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
|
||||||
|
return testDb, server, &address, waitLock, func() {
|
||||||
|
server.Shutdown()
|
||||||
|
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhenCNAMEIsResolved(t *testing.T) {
|
||||||
|
t.Log("TestWhenCNAMEIsResolved")
|
||||||
|
|
||||||
|
testDb, _, addr, lock, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
records := []*database.DNSRecord{
|
||||||
|
{
|
||||||
|
ID: "0",
|
||||||
|
UserID: "test",
|
||||||
|
Name: "cname.internal.example.com.",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "next.internal.example.com.",
|
||||||
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
}, {
|
||||||
|
ID: "1",
|
||||||
|
UserID: "test",
|
||||||
|
Name: "next.internal.example.com.",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "res.example.com.",
|
||||||
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
UserID: "test",
|
||||||
|
Name: "res.example.com.",
|
||||||
|
Type: "A",
|
||||||
|
Content: "1.2.3.2",
|
||||||
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, record := range records {
|
||||||
|
database.SaveDNSRecord(testDb, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
qtype := dns.TypeA
|
||||||
|
domain := dns.Fqdn("cname.internal.example.com.")
|
||||||
|
client := &dns.Client{}
|
||||||
|
message := &dns.Msg{}
|
||||||
|
message.SetQuestion(domain, qtype)
|
||||||
|
|
||||||
|
in, _, err := client.Exchange(message, *addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in.Answer) != 3 {
|
||||||
|
t.Fatalf("expected 3 answers, got %d", len(in.Answer))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, record := range records {
|
||||||
|
if in.Answer[i].Header().Name != record.Name {
|
||||||
|
t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
|
||||||
|
t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(in.Answer[i].Header().Ttl) != record.TTL {
|
||||||
|
t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !in.Authoritative {
|
||||||
|
t.Fatalf("expected authoritative response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
|
||||||
|
t.Fatalf("expected final record to be the A record with correct IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhenNoRecordNxDomain(t *testing.T) {
|
||||||
|
t.Log("TestWhenNoRecordNxDomain")
|
||||||
|
|
||||||
|
_, _, addr, lock, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
qtype := dns.TypeA
|
||||||
|
domain := dns.Fqdn("nonexistant.example.com.")
|
||||||
|
client := &dns.Client{}
|
||||||
|
message := &dns.Msg{}
|
||||||
|
message.SetQuestion(domain, qtype)
|
||||||
|
|
||||||
|
in, _, err := client.Exchange(message, *addr)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in.Answer) != 0 {
|
||||||
|
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Rcode != dns.RcodeNameError {
|
||||||
|
t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhenUnresolvingCNAME(t *testing.T) {
|
||||||
|
t.Log("TestWhenUnresolvingCNAME")
|
||||||
|
|
||||||
|
testDb, _, addr, lock, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
cname := &database.DNSRecord{
|
||||||
|
ID: "1",
|
||||||
|
UserID: "test",
|
||||||
|
Name: "cname.internal.example.com.",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "nonexistant.example.com.",
|
||||||
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
}
|
||||||
|
database.SaveDNSRecord(testDb, cname)
|
||||||
|
|
||||||
|
qtype := dns.TypeA
|
||||||
|
domain := dns.Fqdn(cname.Name)
|
||||||
|
client := &dns.Client{}
|
||||||
|
message := &dns.Msg{}
|
||||||
|
message.SetQuestion(domain, qtype)
|
||||||
|
|
||||||
|
in, _, err := client.Exchange(message, *addr)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in.Answer) != 1 {
|
||||||
|
t.Fatalf("expected 1 answer, got %d", len(in.Answer))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !in.Authoritative {
|
||||||
|
t.Fatalf("expected authoritative response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Answer[0].Header().Name != cname.Name {
|
||||||
|
t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
|
||||||
|
t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Answer[0].(*dns.CNAME).Target != cname.Content {
|
||||||
|
t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Rcode == dns.RcodeNameError {
|
||||||
|
t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
|
||||||
|
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
|
||||||
|
|
||||||
|
testDb, _, addr, lock, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
cname := &database.DNSRecord{
|
||||||
|
ID: "1",
|
||||||
|
UserID: "test",
|
||||||
|
Name: "cname.internal.example.com.",
|
||||||
|
Type: "CNAME",
|
||||||
|
Content: "cname.internal.example.com.",
|
||||||
|
TTL: 300,
|
||||||
|
Internal: true,
|
||||||
|
}
|
||||||
|
database.SaveDNSRecord(testDb, cname)
|
||||||
|
|
||||||
|
qtype := dns.TypeA
|
||||||
|
domain := dns.Fqdn(cname.Name)
|
||||||
|
client := &dns.Client{}
|
||||||
|
message := &dns.Msg{}
|
||||||
|
message.SetQuestion(domain, qtype)
|
||||||
|
|
||||||
|
in, _, err := client.Exchange(message, *addr)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in.Answer) > 0 {
|
||||||
|
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Rcode != dns.RcodeServerFailure {
|
||||||
|
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
|
||||||
|
}
|
||||||
|
}
|
4
main.go
4
main.go
|
@ -6,7 +6,7 @@ import (
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
)
|
)
|
||||||
|
@ -52,7 +52,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if argv.Dns {
|
if argv.Dns {
|
||||||
server := dns.MakeServer(argv, dbConn)
|
server := hcdns.MakeServer(argv, dbConn)
|
||||||
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
|
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
|
||||||
go func() {
|
go func() {
|
||||||
err = server.ListenAndServe()
|
err = server.ListenAndServe()
|
||||||
|
|
|
@ -15,6 +15,22 @@
|
||||||
padding: 0;
|
padding: 0;
|
||||||
color: var(--text-color);
|
color: var(--text-color);
|
||||||
font-family: "ComicSans", sans-serif;
|
font-family: "ComicSans", sans-serif;
|
||||||
|
|
||||||
|
cursor: url("/static/img/cursor-1.png"), auto;
|
||||||
|
-webkit-animation: cursor 400ms infinite;
|
||||||
|
animation: cursor 400ms infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@-webkit-keyframes cursor {
|
||||||
|
0% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||||
|
50% {cursor: url("/static/img/cursor-1.png"), auto;}
|
||||||
|
100% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes cursor {
|
||||||
|
0% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||||
|
50% {cursor: url("/static/img/cursor-1.png"), auto;}
|
||||||
|
100% {cursor: url("/static/img/cursor-2.png"), auto;}
|
||||||
}
|
}
|
||||||
|
|
||||||
body {
|
body {
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 570 B |
Binary file not shown.
After Width: | Height: | Size: 563 B |
Loading…
Reference in New Issue