Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
	
		
			
	
		
	
	
		
			
				
	
				continuous-integration/drone/push Build is passing
				
					Details
				
			
		
	
				
					
				
			
				
	
				continuous-integration/drone/push Build is passing
				
					Details
				
			
		
	Reviewed-on: #5
This commit is contained in:
		
						commit
						83cc6267fd
					
				|  | @ -2,3 +2,4 @@ | |||
| hatecomputers.club | ||||
| Dockerfile | ||||
| *.db | ||||
| .drone.yml | ||||
|  |  | |||
							
								
								
									
										31
									
								
								.drone.yml
								
								
								
								
							
							
						
						
									
										31
									
								
								.drone.yml
								
								
								
								
							|  | @ -1,9 +1,30 @@ | |||
| --- | ||||
| kind: pipeline | ||||
| type: docker | ||||
| name: build, publish docker image, deploy | ||||
| name: build | ||||
| 
 | ||||
| steps: | ||||
|   - name: run tests | ||||
|     image: golang | ||||
|     commands: | ||||
|       - go build | ||||
|       - go test -p 1 -v ./... | ||||
| 
 | ||||
| trigger: | ||||
|   event: | ||||
|     - pull_request | ||||
| 
 | ||||
| --- | ||||
| kind: pipeline | ||||
| type: docker | ||||
| name: deploy | ||||
| 
 | ||||
| steps: | ||||
|   - name: run tests | ||||
|     image: golang | ||||
|     commands: | ||||
|       - go build | ||||
|       - go test -p 1 -v ./... | ||||
|   - name: docker | ||||
|     image: plugins/docker | ||||
|     settings: | ||||
|  | @ -13,9 +34,6 @@ steps: | |||
|         from_secret: gitea_packpub_password | ||||
|       registry: git.hatecomputers.club | ||||
|       repo: git.hatecomputers.club/hatecomputers/hatecomputers.club | ||||
|       tags: | ||||
|       - latest | ||||
|       - main | ||||
|   - name: ssh | ||||
|     image: appleboy/drone-ssh | ||||
|     settings: | ||||
|  | @ -27,6 +45,9 @@ steps: | |||
|       command_timeout: 2m | ||||
|       script: | ||||
|         - systemctl restart docker-compose@hatecomputers-club | ||||
| 
 | ||||
| trigger: | ||||
|   branch: | ||||
|   - main | ||||
|     - main | ||||
|   event: | ||||
|     - push | ||||
|  |  | |||
|  | @ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers | |||
| 
 | ||||
| EXPOSE 8080 | ||||
| 
 | ||||
| CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"] | ||||
| CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"] | ||||
|  |  | |||
|  | @ -14,15 +14,20 @@ type CloudflareDNSResponse struct { | |||
| 	Result database.DNSRecord `json:"result"` | ||||
| } | ||||
| 
 | ||||
| func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) { | ||||
| 	url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) | ||||
| type CloudflareExternalDNSAdapter struct { | ||||
| 	ZoneId   string | ||||
| 	APIToken string | ||||
| } | ||||
| 
 | ||||
| func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { | ||||
| 	url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId) | ||||
| 
 | ||||
| 	reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) | ||||
| 	payload := strings.NewReader(reqBody) | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("POST", url, payload) | ||||
| 
 | ||||
| 	req.Header.Add("Authorization", "Bearer "+apiToken) | ||||
| 	req.Header.Add("Authorization", "Bearer "+adapter.APIToken) | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 
 | ||||
| 	res, err := http.DefaultClient.Do(req) | ||||
|  | @ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) | |||
| 	return result.ID, nil | ||||
| } | ||||
| 
 | ||||
| func DeleteDNSRecord(zoneId string, apiToken string, id string) error { | ||||
| 	url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id) | ||||
| func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error { | ||||
| 	url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id) | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("DELETE", url, nil) | ||||
| 
 | ||||
| 	req.Header.Add("Authorization", "Bearer "+apiToken) | ||||
| 	req.Header.Add("Authorization", "Bearer "+adapter.APIToken) | ||||
| 
 | ||||
| 	res, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -0,0 +1,8 @@ | |||
| package external_dns | ||||
| 
 | ||||
| import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 
 | ||||
| type ExternalDNSAdapter interface { | ||||
| 	CreateDNSRecord(record *database.DNSRecord) (string, error) | ||||
| 	DeleteDNSRecord(id string) error | ||||
| } | ||||
|  | @ -1,4 +1,4 @@ | |||
| package api | ||||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/sha256" | ||||
|  | @ -12,13 +12,14 @@ import ( | |||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
| func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		verifier := utils.RandomId() + utils.RandomId() | ||||
| 
 | ||||
| 		sha2 := sha256.New() | ||||
|  | @ -34,7 +35,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h | |||
| 			Path:     "/", | ||||
| 			Secure:   true, | ||||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 			MaxAge:   60, | ||||
| 			MaxAge:   200, | ||||
| 		}) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:     "state", | ||||
|  | @ -42,7 +43,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h | |||
| 			Path:     "/", | ||||
| 			Secure:   true, | ||||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 			MaxAge:   60, | ||||
| 			MaxAge:   200, | ||||
| 		}) | ||||
| 
 | ||||
| 		http.Redirect(resp, req, url, http.StatusFound) | ||||
|  | @ -50,8 +51,8 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		state := req.URL.Query().Get("state") | ||||
| 		code := req.URL.Query().Get("code") | ||||
| 
 | ||||
|  | @ -73,7 +74,6 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp | |||
| 		reqContext := req.Context() | ||||
| 		token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
|  | @ -101,6 +101,16 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp | |||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 			Secure:   true, | ||||
| 		}) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "verifier", | ||||
| 			Value:  "", | ||||
| 			MaxAge: 0, | ||||
| 		}) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "state", | ||||
| 			Value:  "", | ||||
| 			MaxAge: 0, | ||||
| 		}) | ||||
| 
 | ||||
| 		redirect := "/" | ||||
| 		redirectCookie, err := req.Cookie("redirect") | ||||
|  | @ -109,6 +119,7 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp | |||
| 			http.SetCookie(resp, &http.Cookie{ | ||||
| 				Name:   "redirect", | ||||
| 				MaxAge: 0, | ||||
| 				Value:  "", | ||||
| 			}) | ||||
| 		} | ||||
| 
 | ||||
|  | @ -117,6 +128,127 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		authHeader := req.Header.Get("Authorization") | ||||
| 		user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) | ||||
| 
 | ||||
| 		sessionCookie, err := req.Cookie("session") | ||||
| 		if err == nil && sessionCookie.Value != "" { | ||||
| 			user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) | ||||
| 		} | ||||
| 
 | ||||
| 		if userErr != nil || user == nil { | ||||
| 			log.Println(userErr, user) | ||||
| 
 | ||||
| 			http.SetCookie(resp, &http.Cookie{ | ||||
| 				Name:   "session", | ||||
| 				Value:  "", | ||||
| 				MaxAge: 0, | ||||
| 			}) | ||||
| 
 | ||||
| 			context.User = nil | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		context.User = user | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:     "redirect", | ||||
| 			Value:    req.URL.Path, | ||||
| 			Path:     "/", | ||||
| 			Secure:   true, | ||||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 		}) | ||||
| 
 | ||||
| 		http.Redirect(resp, req, "/login", http.StatusFound) | ||||
| 		return failure(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		sessionCookie, err := req.Cookie("session") | ||||
| 		if err != nil { | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		_, err = database.RefreshSession(context.DBConn, sessionCookie.Value) | ||||
| 		if err != nil { | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		sessionCookie, err := req.Cookie("session") | ||||
| 		if err == nil && sessionCookie.Value != "" { | ||||
| 			_ = database.DeleteSession(context.DBConn, sessionCookie.Value) | ||||
| 		} | ||||
| 
 | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "session", | ||||
| 			MaxAge: 0, | ||||
| 			Value:  "", | ||||
| 		}) | ||||
| 		http.Redirect(resp, req, "/", http.StatusFound) | ||||
| 
 | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { | ||||
| 	userResponse, err := client.Get(uri) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	userStruct, err := createUserFromOauthResponse(userResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := database.FindOrSaveUser(dbConn, userStruct) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func createUserFromOauthResponse(response *http.Response) (*database.User, error) { | ||||
| 	user := &database.User{} | ||||
| 	err := json.NewDecoder(response.Body).Decode(user) | ||||
| 	defer response.Body.Close() | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	user.Username = strings.ToLower(user.Username) | ||||
| 	user.Username = strings.Split(user.Username, "@")[0] | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { | ||||
| 	cookie, err := req.Cookie(stateCookieName) | ||||
| 	if err != nil || cookie.Value != expectedState { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { | ||||
| 	if bearerToken == "" { | ||||
| 		return nil, nil | ||||
|  | @ -127,15 +259,15 @@ func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, | |||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	apiKey, err := database.GetAPIKey(dbConn, parts[1]) | ||||
| 	key, err := database.GetAPIKey(dbConn, parts[1]) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if apiKey == nil { | ||||
| 	if key == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := database.GetUser(dbConn, apiKey.UserID) | ||||
| 	user, err := database.GetUser(dbConn, key.UserID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -162,124 +294,3 @@ func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error | |||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		authHeader := req.Header.Get("Authorization") | ||||
| 		user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) | ||||
| 
 | ||||
| 		sessionCookie, err := req.Cookie("session") | ||||
| 		if err == nil && sessionCookie.Value != "" { | ||||
| 			user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) | ||||
| 		} | ||||
| 
 | ||||
| 		if userErr != nil || user == nil { | ||||
| 			log.Println(userErr, user) | ||||
| 
 | ||||
| 			http.SetCookie(resp, &http.Cookie{ | ||||
| 				Name:   "session", | ||||
| 				MaxAge: 0, // reset session cookie in case
 | ||||
| 			}) | ||||
| 
 | ||||
| 			context.User = nil | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		context.User = user | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:     "redirect", | ||||
| 			Value:    req.URL.Path, | ||||
| 			Path:     "/", | ||||
| 			Secure:   true, | ||||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 		}) | ||||
| 
 | ||||
| 		http.Redirect(resp, req, "/login", http.StatusFound) | ||||
| 		return failure(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		sessionCookie, err := req.Cookie("session") | ||||
| 		if err != nil { | ||||
| 			resp.WriteHeader(http.StatusUnauthorized) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		_, err = database.RefreshSession(context.DBConn, sessionCookie.Value) | ||||
| 		if err != nil { | ||||
| 			resp.WriteHeader(http.StatusUnauthorized) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		sessionCookie, err := req.Cookie("session") | ||||
| 		if err == nil && sessionCookie.Value != "" { | ||||
| 			_ = database.DeleteSession(context.DBConn, sessionCookie.Value) | ||||
| 		} | ||||
| 
 | ||||
| 		http.Redirect(resp, req, "/", http.StatusFound) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "session", | ||||
| 			MaxAge: 0, | ||||
| 		}) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { | ||||
| 	userResponse, err := client.Get(uri) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	userStruct, err := createUserFromResponse(userResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := database.FindOrSaveUser(dbConn, userStruct) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func createUserFromResponse(response *http.Response) (*database.User, error) { | ||||
| 	defer response.Body.Close() | ||||
| 	user := &database.User{ | ||||
| 		CreatedAt: time.Now(), | ||||
| 	} | ||||
| 	err := json.NewDecoder(response.Body).Decode(user) | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	user.Username = strings.ToLower(user.Username) | ||||
| 	user.Username = strings.Split(user.Username, "@")[0] | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { | ||||
| 	cookie, err := req.Cookie(stateCookieName) | ||||
| 	if err != nil || cookie.Value != expectedState { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return true | ||||
| } | ||||
|  | @ -0,0 +1,307 @@ | |||
| package auth_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
| func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func FakedOauthServer() *httptest.Server { | ||||
| 	oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if r.URL.Path == "/auth" { | ||||
| 			code := utils.RandomId() | ||||
| 
 | ||||
| 			state := r.URL.Query().Get("state") | ||||
| 			redirectPath := r.URL.Query().Get("redirect_uri") | ||||
| 			redirectPath += "?code=" + code + "&state=" + state | ||||
| 
 | ||||
| 			http.Redirect(w, r, redirectPath, http.StatusFound) | ||||
| 		} | ||||
| 		if r.URL.Path == "/token" { | ||||
| 			w.Header().Set("Content-Type", "application/json") | ||||
| 			w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) | ||||
| 		} | ||||
| 		if r.URL.Path == "/user" { | ||||
| 			w.Header().Set("Content-Type", "application/json") | ||||
| 			w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) | ||||
| 		} | ||||
| 	})) | ||||
| 
 | ||||
| 	return oauthServer | ||||
| } | ||||
| 
 | ||||
| func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		resp.Write([]byte(context.User.Username)) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if r.URL.Path == "/protected-path" { | ||||
| 			auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/login" { | ||||
| 			log.Println("login") | ||||
| 			auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/callback" { | ||||
| 			log.Println("callback") | ||||
| 			auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/me" { | ||||
| 			auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/logout" { | ||||
| 			auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 	})) | ||||
| 	return testServer | ||||
| } | ||||
| 
 | ||||
| func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) { | ||||
| 	randomDb := utils.RandomId() | ||||
| 
 | ||||
| 	testDb := database.MakeConn(&randomDb) | ||||
| 	database.Migrate(testDb) | ||||
| 
 | ||||
| 	context := &types.RequestContext{ | ||||
| 		DBConn:       testDb, | ||||
| 		Args:         &args.Arguments{}, | ||||
| 		TemplateData: &(map[string]interface{}{}), | ||||
| 	} | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 
 | ||||
| 	return testDb, context, oauthServer, testServer, func() { | ||||
| 		oauthServer.Close() | ||||
| 		testServer.Close() | ||||
| 
 | ||||
| 		testDb.Close() | ||||
| 		os.Remove(randomDb) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { | ||||
| 	return &oauth2.Config{ | ||||
| 		ClientID:     "test", | ||||
| 		ClientSecret: "test", | ||||
| 		Scopes:       []string{"test"}, | ||||
| 		Endpoint: oauth2.Endpoint{ | ||||
| 			AuthURL:  oauthServerURL + "/auth", | ||||
| 			TokenURL: oauthServerURL + "/token", | ||||
| 		}, | ||||
| 		RedirectURL: testServerURL + "/callback", | ||||
| 	}, oauthServerURL + "/user" | ||||
| } | ||||
| 
 | ||||
| func FollowAuthentication( | ||||
| 	oauthServer *httptest.Server, | ||||
| 	testServer *httptest.Server, | ||||
| 	cookies map[string]*http.Cookie, | ||||
| 	location string, | ||||
| ) (map[string]*http.Cookie, string) { | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	resp.Code = 0 | ||||
| 
 | ||||
| 	for resp.Code == 0 || resp.Code == http.StatusFound { | ||||
| 		req := httptest.NewRequest("GET", location, nil) | ||||
| 		resp = httptest.NewRecorder() | ||||
| 
 | ||||
| 		for _, cookie := range cookies { | ||||
| 			req.AddCookie(cookie) | ||||
| 		} | ||||
| 		if strings.HasPrefix(location, oauthServer.URL) { | ||||
| 			oauthServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 		} else { | ||||
| 			testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 		} | ||||
| 		for _, cookie := range resp.Result().Cookies() { | ||||
| 			cookies[cookie.Name] = cookie | ||||
| 		} | ||||
| 
 | ||||
| 		if resp.Code == http.StatusFound { | ||||
| 			location = resp.Header().Get("Location") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return cookies, location | ||||
| } | ||||
| 
 | ||||
| func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	user, _ := database.GetUser(db, "test") | ||||
| 	if user != nil { | ||||
| 		t.Errorf("expected no user, got user") | ||||
| 	} | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") | ||||
| 
 | ||||
| 	user, _ = database.GetUser(db, "test") | ||||
| 	if user == nil { | ||||
| 		t.Errorf("expected a user to be created, could not find user") | ||||
| 	} | ||||
| 	if user.Username != "test" { | ||||
| 		t.Errorf("expected username to be test, got %s", user.Username) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { | ||||
| 	_, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/protected-path", nil) | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	location := resp.Header().Get("Location") | ||||
| 	if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { | ||||
| 		t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) | ||||
| 	} | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") | ||||
| 
 | ||||
| 	if !(strings.HasSuffix(location, "/protected-page")) { | ||||
| 		t.Errorf("expected to redirect back to /protected-page after login, got %s", location) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOauthSetsUniqueSession(t *testing.T) { | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") | ||||
| 
 | ||||
| 	cookiesAgain := make(map[string]*http.Cookie) | ||||
| 	cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") | ||||
| 
 | ||||
| 	sessionOne := cookies["session"].Value | ||||
| 	sessionTwo := cookiesAgain["session"].Value | ||||
| 	if sessionOne == sessionTwo { | ||||
| 		t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) | ||||
| 	} | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, sessionOne) | ||||
| 	if session.UserID != "test" { | ||||
| 		t.Errorf("expected session to be associated with user test, got %s", session.UserID) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestLogoutClearsSession(t *testing.T) { | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/logout", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	for _, cookie := range resp.Result().Cookies() { | ||||
| 		cookies[cookie.Name] = cookie | ||||
| 	} | ||||
| 
 | ||||
| 	req = httptest.NewRequest("GET", "/me", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp = httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { | ||||
| 		t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) | ||||
| 	} | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected session to be deleted, got session") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRefreshUpdatesExpiration(t *testing.T) { | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/me", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 
 | ||||
| 	updatedSession, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 
 | ||||
| 	if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { | ||||
| 		t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestVerifySessionEnsuresNonExpired(t *testing.T) { | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 	session.ExpireAt = time.Now().Add(-time.Hour) | ||||
| 	database.SaveSession(db, session) | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/me", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 
 | ||||
| 	if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { | ||||
| 		t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location")) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										179
									
								
								api/dns.go
								
								
								
								
							
							
						
						
									
										179
									
								
								api/dns.go
								
								
								
								
							|  | @ -1,179 +0,0 @@ | |||
| package api | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| const MAX_USER_RECORDS = 65 | ||||
| 
 | ||||
| type FormError struct { | ||||
| 	Errors []string | ||||
| } | ||||
| 
 | ||||
| func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool { | ||||
| 	ownedByUser := (user.ID == record.UserID) | ||||
| 	if !ownedByUser { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if !record.Internal { | ||||
| 		userOwnedDomains := []string{ | ||||
| 			fmt.Sprintf("%s", user.Username), | ||||
| 			fmt.Sprintf("%s.endpoints", user.Username), | ||||
| 		} | ||||
| 
 | ||||
| 		for _, domain := range userOwnedDomains { | ||||
| 			isInSubDomain := strings.HasSuffix(record.Name, "."+domain) | ||||
| 			if domain == record.Name || isInSubDomain { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	userIsOwnerOfDomain := owner == user.ID | ||||
| 	return ownedByUser && userIsOwnerOfDomain | ||||
| } | ||||
| 
 | ||||
| func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		(*context.TemplateData)["DNSRecords"] = dnsRecords | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		formErrors := FormError{ | ||||
| 			Errors: []string{}, | ||||
| 		} | ||||
| 
 | ||||
| 		internal := req.FormValue("internal") == "on" | ||||
| 		name := req.FormValue("name") | ||||
| 		if internal && !strings.HasSuffix(name, ".") { | ||||
| 			name += "." | ||||
| 		} | ||||
| 
 | ||||
| 		recordType := req.FormValue("type") | ||||
| 		recordType = strings.ToUpper(recordType) | ||||
| 
 | ||||
| 		recordContent := req.FormValue("content") | ||||
| 		ttl := req.FormValue("ttl") | ||||
| 		ttlNum, err := strconv.Atoi(ttl) | ||||
| 		if err != nil { | ||||
| 			formErrors.Errors = append(formErrors.Errors, "invalid ttl") | ||||
| 		} | ||||
| 
 | ||||
| 		dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 		if dnsRecordCount >= MAX_USER_RECORDS { | ||||
| 			formErrors.Errors = append(formErrors.Errors, "max records reached") | ||||
| 		} | ||||
| 
 | ||||
| 		dnsRecord := &database.DNSRecord{ | ||||
| 			UserID:   context.User.ID, | ||||
| 			Name:     name, | ||||
| 			Type:     recordType, | ||||
| 			Content:  recordContent, | ||||
| 			TTL:      ttlNum, | ||||
| 			Internal: internal, | ||||
| 		} | ||||
| 		if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) { | ||||
| 			formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") | ||||
| 		} | ||||
| 
 | ||||
| 		if len(formErrors.Errors) == 0 { | ||||
| 			if dnsRecord.Internal { | ||||
| 				dnsRecord.ID = utils.RandomId() | ||||
| 			} else { | ||||
| 				cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) | ||||
| 				if err != nil { | ||||
| 					log.Println(err) | ||||
| 					formErrors.Errors = append(formErrors.Errors, err.Error()) | ||||
| 				} | ||||
| 
 | ||||
| 				dnsRecord.ID = cloudflareRecordId | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if len(formErrors.Errors) == 0 { | ||||
| 			_, err := database.SaveDNSRecord(context.DBConn, dnsRecord) | ||||
| 			if err != nil { | ||||
| 				log.Println(err) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "error saving record") | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if len(formErrors.Errors) == 0 { | ||||
| 			http.Redirect(resp, req, "/dns", http.StatusFound) | ||||
| 			return success(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		(*context.TemplateData)["FormError"] = &formErrors | ||||
| 		(*context.TemplateData)["RecordForm"] = dnsRecord | ||||
| 
 | ||||
| 		resp.WriteHeader(http.StatusBadRequest) | ||||
| 		return failure(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		recordId := req.FormValue("id") | ||||
| 		record, err := database.GetDNSRecord(context.DBConn, recordId) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { | ||||
| 			resp.WriteHeader(http.StatusUnauthorized) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		if !record.Internal { | ||||
| 			err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) | ||||
| 			if err != nil { | ||||
| 				log.Println(err) | ||||
| 				resp.WriteHeader(http.StatusInternalServerError) | ||||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		err = database.DeleteDNSRecord(context.DBConn, recordId) | ||||
| 		if err != nil { | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		http.Redirect(resp, req, "/dns", http.StatusFound) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,174 @@ | |||
| package dns | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { | ||||
| 	ownedByUser := (user.ID == record.UserID) | ||||
| 	if !ownedByUser { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if !record.Internal { | ||||
| 		for _, format := range ownedInternalDomainFormats { | ||||
| 			domain := fmt.Sprintf(format, user.Username) | ||||
| 
 | ||||
| 			isInSubDomain := strings.HasSuffix(record.Name, "."+domain) | ||||
| 			if domain == record.Name || isInSubDomain { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	userIsOwnerOfDomain := owner == user.ID | ||||
| 	return ownedByUser && userIsOwnerOfDomain | ||||
| } | ||||
| 
 | ||||
| func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		(*context.TemplateData)["DNSRecords"] = dnsRecords | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 		return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 			formErrors := types.FormError{ | ||||
| 				Errors: []string{}, | ||||
| 			} | ||||
| 
 | ||||
| 			internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" | ||||
| 			name := req.FormValue("name") | ||||
| 			if internal && !strings.HasSuffix(name, ".") { | ||||
| 				name += "." | ||||
| 			} | ||||
| 
 | ||||
| 			recordType := req.FormValue("type") | ||||
| 			recordType = strings.ToUpper(recordType) | ||||
| 
 | ||||
| 			recordContent := req.FormValue("content") | ||||
| 			ttl := req.FormValue("ttl") | ||||
| 			ttlNum, err := strconv.Atoi(ttl) | ||||
| 			if err != nil { | ||||
| 				resp.WriteHeader(http.StatusBadRequest) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "invalid ttl") | ||||
| 			} | ||||
| 
 | ||||
| 			dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) | ||||
| 			if err != nil { | ||||
| 				log.Println(err) | ||||
| 				resp.WriteHeader(http.StatusInternalServerError) | ||||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 			if dnsRecordCount >= maxUserRecords { | ||||
| 				resp.WriteHeader(http.StatusTooManyRequests) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "max records reached") | ||||
| 			} | ||||
| 
 | ||||
| 			dnsRecord := &database.DNSRecord{ | ||||
| 				UserID:   context.User.ID, | ||||
| 				Name:     name, | ||||
| 				Type:     recordType, | ||||
| 				Content:  recordContent, | ||||
| 				TTL:      ttlNum, | ||||
| 				Internal: internal, | ||||
| 			} | ||||
| 
 | ||||
| 			if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { | ||||
| 				resp.WriteHeader(http.StatusUnauthorized) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") | ||||
| 			} | ||||
| 
 | ||||
| 			if len(formErrors.Errors) == 0 { | ||||
| 				if dnsRecord.Internal { | ||||
| 					dnsRecord.ID = utils.RandomId() | ||||
| 				} else { | ||||
| 					dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) | ||||
| 					if err != nil { | ||||
| 						log.Println(err) | ||||
| 						resp.WriteHeader(http.StatusInternalServerError) | ||||
| 						formErrors.Errors = append(formErrors.Errors, err.Error()) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if len(formErrors.Errors) == 0 { | ||||
| 				_, err := database.SaveDNSRecord(context.DBConn, dnsRecord) | ||||
| 				if err != nil { | ||||
| 					log.Println(err) | ||||
| 					formErrors.Errors = append(formErrors.Errors, "error saving record") | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if len(formErrors.Errors) == 0 { | ||||
| 				return success(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			(*context.TemplateData)["FormError"] = &formErrors | ||||
| 			(*context.TemplateData)["RecordForm"] = dnsRecord | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 		return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 			recordId := req.FormValue("id") | ||||
| 			record, err := database.GetDNSRecord(context.DBConn, recordId) | ||||
| 			if err != nil { | ||||
| 				log.Println(err) | ||||
| 				resp.WriteHeader(http.StatusInternalServerError) | ||||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			if !(record.UserID == context.User.ID) { | ||||
| 				resp.WriteHeader(http.StatusUnauthorized) | ||||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			if !record.Internal { | ||||
| 				err = dnsAdapter.DeleteDNSRecord(recordId) | ||||
| 				if err != nil { | ||||
| 					log.Println(err) | ||||
| 					resp.WriteHeader(http.StatusInternalServerError) | ||||
| 					return failure(context, req, resp) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			err = database.DeleteDNSRecord(context.DBConn, recordId) | ||||
| 			if err != nil { | ||||
| 				resp.WriteHeader(http.StatusInternalServerError) | ||||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			return success(context, req, resp) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,442 @@ | |||
| package dns_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| const MAX_USER_RECORDS = 10 | ||||
| 
 | ||||
| var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} | ||||
| 
 | ||||
| func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func setup() (*sql.DB, *types.RequestContext, func()) { | ||||
| 	randomDb := utils.RandomId() | ||||
| 
 | ||||
| 	testDb := database.MakeConn(&randomDb) | ||||
| 	database.Migrate(testDb) | ||||
| 
 | ||||
| 	user := &database.User{ | ||||
| 		ID:          "test", | ||||
| 		Username:    "test", | ||||
| 		Mail:        "test@test.com", | ||||
| 		DisplayName: "test", | ||||
| 	} | ||||
| 	database.FindOrSaveUser(testDb, user) | ||||
| 
 | ||||
| 	context := &types.RequestContext{ | ||||
| 		DBConn:       testDb, | ||||
| 		Args:         &args.Arguments{}, | ||||
| 		TemplateData: &(map[string]interface{}{}), | ||||
| 		User:         user, | ||||
| 	} | ||||
| 
 | ||||
| 	return testDb, context, func() { | ||||
| 		testDb.Close() | ||||
| 		os.Remove(randomDb) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type SignallingExternalDnsAdapter struct { | ||||
| 	AddChannel chan *database.DNSRecord | ||||
| 	RmChannel  chan string | ||||
| } | ||||
| 
 | ||||
| func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { | ||||
| 	id := utils.RandomId() | ||||
| 	go func() { adapter.AddChannel <- record }() | ||||
| 
 | ||||
| 	return id, nil | ||||
| } | ||||
| 
 | ||||
| func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { | ||||
| 	go func() { adapter.RmChannel <- id }() | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func TestThatOwnerCanPutRecordInDomain(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.domain.", | ||||
| 	} | ||||
| 	domainOwner, _ = database.SaveDomainOwner(db, domainOwner) | ||||
| 
 | ||||
| 	records, err := database.GetUserDNSRecords(db, context.User.ID) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if len(records) > 0 { | ||||
| 		t.Errorf("expected no records, got records") | ||||
| 	} | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	validOwner := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	validOwner.Form = map[string][]string{ | ||||
| 		"internal": {"on"}, | ||||
| 		"name":     {"new.test.domain."}, | ||||
| 		"type":     {"CNAME"}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {"test.domain."}, | ||||
| 	} | ||||
| 
 | ||||
| 	validOwnerRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) | ||||
| 	if validOwnerRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	validOwnerNonInternalRecorder := httptest.NewRecorder() | ||||
| 	validOwner.Form["internal"] = []string{"off"} | ||||
| 	testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner) | ||||
| 	if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized { | ||||
| 		t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	invalidOwnerRecorder := httptest.NewRecorder() | ||||
| 	invalidOwner := validOwner | ||||
| 	invalidOwner.Form["internal"] = []string{"on"} | ||||
| 	invalidOwner.Form["name"] = []string{"new.invalid.domain."} | ||||
| 	testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner) | ||||
| 	if invalidOwnerRecorder.Code != http.StatusUnauthorized { | ||||
| 		t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatUserCanAddToPublicEndpoints(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	responseRecorder := httptest.NewRecorder() | ||||
| 	req := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	fmts := USER_OWNED_INTERNAL_FMT_DOMAINS | ||||
| 	for _, format := range fmts { | ||||
| 		name := fmt.Sprintf(format, context.User.Username) | ||||
| 
 | ||||
| 		req.Form = map[string][]string{ | ||||
| 			"internal": {"off"}, | ||||
| 			"name":     {name}, | ||||
| 			"type":     {"CNAME"}, | ||||
| 			"ttl":      {"43000"}, | ||||
| 			"content":  {"test.domain."}, | ||||
| 		} | ||||
| 
 | ||||
| 		testServer.Config.Handler.ServeHTTP(responseRecorder, req) | ||||
| 		if responseRecorder.Code != http.StatusOK { | ||||
| 			t.Errorf("expected valid return, got %d", responseRecorder.Code) | ||||
| 		} | ||||
| 
 | ||||
| 		namedRecords, _ := database.FindDNSRecords(db, name, "CNAME") | ||||
| 		if len(namedRecords) == 0 { | ||||
| 			t.Errorf("saved record not found") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatExternalDnsSaves(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	responseRecorder := httptest.NewRecorder() | ||||
| 	externalRequest := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 
 | ||||
| 	name := "test." + context.User.Username | ||||
| 	externalRequest.Form = map[string][]string{ | ||||
| 		"internal": {"off"}, | ||||
| 		"name":     {name}, | ||||
| 		"type":     {"CNAME"}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {"test.domain."}, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) | ||||
| 	if responseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", responseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case res := <-addChannel: | ||||
| 		if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." { | ||||
| 			t.Errorf("received the wrong external record") | ||||
| 		} | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 		t.Errorf("timed out in waiting for external addition") | ||||
| 	} | ||||
| 
 | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.domain.", | ||||
| 	} | ||||
| 	domainOwner, _ = database.SaveDomainOwner(db, domainOwner) | ||||
| 	internalRequest := externalRequest | ||||
| 	internalRequest.Form["internal"] = []string{"on"} | ||||
| 	internalRequest.Form["name"] = []string{"test.domain."} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) | ||||
| 	if responseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", responseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case _ = <-addChannel: | ||||
| 		t.Errorf("expected nothing in the add channel") | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatUserMustOwnRecordToRemove(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	rmChannel := make(chan string) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		RmChannel: rmChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"} | ||||
| 	_, err := database.FindOrSaveUser(db, nonOwnerUser) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	record := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		Internal: false, | ||||
| 		Name:     "test", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "asdf", | ||||
| 		TTL:      1000, | ||||
| 		UserID:   nonOwnerUser.ID, | ||||
| 	} | ||||
| 	_, err = database.SaveDNSRecord(db, record) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	nonOwnerRecorder := httptest.NewRecorder() | ||||
| 	nonOwner := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	nonOwner.Form = map[string][]string{ | ||||
| 		"id": {record.ID}, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner) | ||||
| 	if nonOwnerRecorder.Code != http.StatusUnauthorized { | ||||
| 		t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	record.UserID = context.User.ID | ||||
| 	record.ID = "2" | ||||
| 	database.SaveDNSRecord(db, record) | ||||
| 
 | ||||
| 	owner := nonOwner | ||||
| 	owner.Form["id"] = []string{"2"} | ||||
| 	ownerRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(ownerRecorder, owner) | ||||
| 	if ownerRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", ownerRecorder.Code) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestThatExternalDnsRemoves(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	record := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		Internal: false, | ||||
| 		Name:     "test", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "asdf", | ||||
| 		TTL:      1000, | ||||
| 		UserID:   context.User.ID, | ||||
| 	} | ||||
| 	database.SaveDNSRecord(db, record) | ||||
| 
 | ||||
| 	rmChannel := make(chan string) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		RmChannel: rmChannel, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	externalResponseRecorder := httptest.NewRecorder() | ||||
| 	deleteRequest := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 
 | ||||
| 	deleteRequest.Form = map[string][]string{ | ||||
| 		"id": {record.ID}, | ||||
| 	} | ||||
| 
 | ||||
| 	testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest) | ||||
| 	if externalResponseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", externalResponseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case res := <-rmChannel: | ||||
| 		if res != record.ID { | ||||
| 			t.Errorf("received the wrong external record") | ||||
| 		} | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 		t.Errorf("timed out in waiting for external addition") | ||||
| 	} | ||||
| 
 | ||||
| 	record.Internal = true | ||||
| 	record.Name = "test.domain." | ||||
| 	database.SaveDNSRecord(db, record) | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.domain.", | ||||
| 	} | ||||
| 	database.SaveDomainOwner(db, domainOwner) | ||||
| 
 | ||||
| 	internalResponseRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest) | ||||
| 	if internalResponseRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", internalResponseRecorder.Code) | ||||
| 	} | ||||
| 	select { | ||||
| 	case _ = <-rmChannel: | ||||
| 		t.Errorf("expected nothing in the rmchannel") | ||||
| 	case <-time.After(100 * time.Millisecond): | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRecordCountCannotExceed(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	record := &database.DNSRecord{ | ||||
| 		Internal: false, | ||||
| 		Name:     context.User.Username, | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "asdf", | ||||
| 		TTL:      1000, | ||||
| 		UserID:   context.User.ID, | ||||
| 	} | ||||
| 
 | ||||
| 	for i := 1; i <= MAX_USER_RECORDS; i++ { | ||||
| 		record.ID = strconv.Itoa(i) | ||||
| 		record.Name = record.ID + "." + record.Name | ||||
| 		database.SaveDNSRecord(db, record) | ||||
| 	} | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	req := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	req.Form = map[string][]string{ | ||||
| 		"internal": {"off"}, | ||||
| 		"name":     {record.Name}, | ||||
| 		"type":     {record.Type}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {record.Content}, | ||||
| 	} | ||||
| 
 | ||||
| 	recorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(recorder, req) | ||||
| 	if recorder.Code != http.StatusTooManyRequests { | ||||
| 		t.Errorf("expected too many requests code return, got %d", recorder.Code) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestInternalRecordAppendsTopLevelDot(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	domainOwner := &database.DomainOwner{ | ||||
| 		UserID: context.User.ID, | ||||
| 		Domain: "test.internal.", | ||||
| 	} | ||||
| 	database.SaveDomainOwner(db, domainOwner) | ||||
| 
 | ||||
| 	addChannel := make(chan *database.DNSRecord) | ||||
| 	signallingDnsAdapter := &SignallingExternalDnsAdapter{ | ||||
| 		AddChannel: addChannel, | ||||
| 	} | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	validOwner := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 	validOwner.Form = map[string][]string{ | ||||
| 		"internal": {"on"}, | ||||
| 		"name":     {"test.internal"}, | ||||
| 		"type":     {"CNAME"}, | ||||
| 		"ttl":      {"43000"}, | ||||
| 		"content":  {"asdf.internal"}, | ||||
| 	} | ||||
| 
 | ||||
| 	validOwnerRecorder := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) | ||||
| 	if validOwnerRecorder.Code != http.StatusOK { | ||||
| 		t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME") | ||||
| 	recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME") | ||||
| 
 | ||||
| 	if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 { | ||||
| 		t.Errorf("expected dot appended") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										141
									
								
								api/guestbook.go
								
								
								
								
							
							
						
						
									
										141
									
								
								api/guestbook.go
								
								
								
								
							|  | @ -1,141 +0,0 @@ | |||
| package api | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| type HcaptchaArgs struct { | ||||
| 	SiteKey string | ||||
| } | ||||
| 
 | ||||
| func validateGuestbookEntry(entry *database.GuestbookEntry) []string { | ||||
| 	errors := []string{} | ||||
| 
 | ||||
| 	if entry.Name == "" { | ||||
| 		errors = append(errors, "name is required") | ||||
| 	} | ||||
| 
 | ||||
| 	if entry.Message == "" { | ||||
| 		errors = append(errors, "message is required") | ||||
| 	} | ||||
| 
 | ||||
| 	messageLength := len(entry.Message) | ||||
| 	if messageLength > 500 { | ||||
| 		errors = append(errors, "message cannot be longer than 500 characters") | ||||
| 	} | ||||
| 
 | ||||
| 	newLines := strings.Count(entry.Message, "\n") | ||||
| 	if newLines > 10 { | ||||
| 		errors = append(errors, "message cannot contain more than 10 new lines") | ||||
| 	} | ||||
| 
 | ||||
| 	return errors | ||||
| } | ||||
| 
 | ||||
| func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		name := req.FormValue("name") | ||||
| 		message := req.FormValue("message") | ||||
| 		hCaptchaResponse := req.FormValue("h-captcha-response") | ||||
| 
 | ||||
| 		formErrors := FormError{ | ||||
| 			Errors: []string{}, | ||||
| 		} | ||||
| 
 | ||||
| 		if hCaptchaResponse == "" { | ||||
| 			formErrors.Errors = append(formErrors.Errors, "hCaptcha is required") | ||||
| 		} | ||||
| 
 | ||||
| 		entry := &database.GuestbookEntry{ | ||||
| 			ID:      utils.RandomId(), | ||||
| 			Name:    name, | ||||
| 			Message: message, | ||||
| 		} | ||||
| 		formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) | ||||
| 
 | ||||
| 		err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 
 | ||||
| 			formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed") | ||||
| 		} | ||||
| 		if len(formErrors.Errors) > 0 { | ||||
| 			(*context.TemplateData)["FormError"] = formErrors | ||||
| 			(*context.TemplateData)["EntryForm"] = entry | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		_, err = database.SaveGuestbookEntry(context.DBConn, entry) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		entries, err := database.GetGuestbookEntries(context.DBConn) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		(*context.TemplateData)["GuestbookEntries"] = entries | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ | ||||
| 			SiteKey: context.Args.HcaptchaSiteKey, | ||||
| 		} | ||||
| 		log.Println(context.Args.HcaptchaSiteKey) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func verifyHCaptcha(secret, response string) error { | ||||
| 	verifyURL := "https://hcaptcha.com/siteverify" | ||||
| 	body := strings.NewReader("secret=" + secret + "&response=" + response) | ||||
| 
 | ||||
| 	req, err := http.NewRequest("POST", verifyURL, body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	jsonResponse := struct { | ||||
| 		Success bool `json:"success"` | ||||
| 	}{} | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&jsonResponse) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if !jsonResponse.Success { | ||||
| 		return fmt.Errorf("hcaptcha verification failed") | ||||
| 	} | ||||
| 
 | ||||
| 	defer resp.Body.Close() | ||||
| 	return nil | ||||
| } | ||||
|  | @ -0,0 +1,85 @@ | |||
| package guestbook | ||||
| 
 | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| func validateGuestbookEntry(entry *database.GuestbookEntry) []string { | ||||
| 	errors := []string{} | ||||
| 
 | ||||
| 	if entry.Name == "" { | ||||
| 		errors = append(errors, "name is required") | ||||
| 	} | ||||
| 
 | ||||
| 	if entry.Message == "" { | ||||
| 		errors = append(errors, "message is required") | ||||
| 	} | ||||
| 
 | ||||
| 	messageLength := len(entry.Message) | ||||
| 	if messageLength > 500 { | ||||
| 		errors = append(errors, "message cannot be longer than 500 characters") | ||||
| 	} | ||||
| 
 | ||||
| 	newLines := strings.Count(entry.Message, "\n") | ||||
| 	if newLines > 10 { | ||||
| 		errors = append(errors, "message cannot contain more than 10 new lines") | ||||
| 	} | ||||
| 
 | ||||
| 	return errors | ||||
| } | ||||
| 
 | ||||
| func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		name := req.FormValue("name") | ||||
| 		message := req.FormValue("message") | ||||
| 
 | ||||
| 		formErrors := types.FormError{ | ||||
| 			Errors: []string{}, | ||||
| 		} | ||||
| 
 | ||||
| 		entry := &database.GuestbookEntry{ | ||||
| 			ID:      utils.RandomId(), | ||||
| 			Name:    name, | ||||
| 			Message: message, | ||||
| 		} | ||||
| 		formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) | ||||
| 
 | ||||
| 		if len(formErrors.Errors) == 0 { | ||||
| 			_, err := database.SaveGuestbookEntry(context.DBConn, entry) | ||||
| 			if err != nil { | ||||
| 				log.Println(err) | ||||
| 				formErrors.Errors = append(formErrors.Errors, "failed to save entry") | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if len(formErrors.Errors) > 0 { | ||||
| 			(*context.TemplateData)["FormError"] = formErrors | ||||
| 			(*context.TemplateData)["EntryForm"] = entry | ||||
| 			resp.WriteHeader(http.StatusBadRequest) | ||||
| 
 | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		entries, err := database.GetGuestbookEntries(context.DBConn) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		(*context.TemplateData)["GuestbookEntries"] = entries | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,136 @@ | |||
| package guestbook_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func setup() (*sql.DB, *types.RequestContext, func()) { | ||||
| 	randomDb := utils.RandomId() | ||||
| 
 | ||||
| 	testDb := database.MakeConn(&randomDb) | ||||
| 	database.Migrate(testDb) | ||||
| 
 | ||||
| 	context := &types.RequestContext{ | ||||
| 		DBConn:       testDb, | ||||
| 		Args:         &args.Arguments{}, | ||||
| 		TemplateData: &(map[string]interface{}{}), | ||||
| 	} | ||||
| 
 | ||||
| 	return testDb, context, func() { | ||||
| 		testDb.Close() | ||||
| 		os.Remove(randomDb) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestValidGuestbookPutsInDatabase(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	entries, err := database.GetGuestbookEntries(db) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if len(entries) > 0 { | ||||
| 		t.Errorf("expected no entries, got entries") | ||||
| 	} | ||||
| 
 | ||||
| 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer ts.Close() | ||||
| 
 | ||||
| 	req := httptest.NewRequest("POST", ts.URL, nil) | ||||
| 	req.Form = map[string][]string{ | ||||
| 		"name":    {"test"}, | ||||
| 		"message": {"test"}, | ||||
| 	} | ||||
| 
 | ||||
| 	w := httptest.NewRecorder() | ||||
| 	ts.Config.Handler.ServeHTTP(w, req) | ||||
| 
 | ||||
| 	if w.Code != http.StatusOK { | ||||
| 		t.Errorf("expected status code 200, got %d", w.Code) | ||||
| 	} | ||||
| 
 | ||||
| 	entries, err = database.GetGuestbookEntries(db) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(entries) != 1 { | ||||
| 		t.Errorf("expected 1 entry, got %d", len(entries)) | ||||
| 	} | ||||
| 
 | ||||
| 	if entries[0].Name != req.FormValue("name") { | ||||
| 		t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	entries, err := database.GetGuestbookEntries(db) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if len(entries) > 0 { | ||||
| 		t.Errorf("expected no entries, got entries") | ||||
| 	} | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n" | ||||
| 	invalidRequests := []struct { | ||||
| 		name    string | ||||
| 		message string | ||||
| 	}{ | ||||
| 		{"", "test"}, | ||||
| 		{"test", ""}, | ||||
| 		{"", ""}, | ||||
| 		{"test", reallyLongStringThatWouldTakeTooMuchSpace}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, form := range invalidRequests { | ||||
| 		req := httptest.NewRequest("POST", testServer.URL, nil) | ||||
| 		req.Form = map[string][]string{ | ||||
| 			"name":    {form.name}, | ||||
| 			"message": {form.message}, | ||||
| 		} | ||||
| 
 | ||||
| 		responseRecorder := httptest.NewRecorder() | ||||
| 		testServer.Config.Handler.ServeHTTP(responseRecorder, req) | ||||
| 
 | ||||
| 		if responseRecorder.Code != http.StatusBadRequest { | ||||
| 			t.Errorf("expected status code 400, got %d", responseRecorder.Code) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	entries, err = database.GetGuestbookEntries(db) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(entries) != 0 { | ||||
| 		t.Errorf("expected 0 entries, got %d", len(entries)) | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,75 @@ | |||
| package hcaptcha | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| ) | ||||
| 
 | ||||
| type HcaptchaArgs struct { | ||||
| 	SiteKey string | ||||
| } | ||||
| 
 | ||||
| func verifyCaptcha(secret, response string) error { | ||||
| 	verifyURL := "https://hcaptcha.com/siteverify" | ||||
| 	body := strings.NewReader("secret=" + secret + "&response=" + response) | ||||
| 
 | ||||
| 	req, err := http.NewRequest("POST", verifyURL, body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	jsonResponse := struct { | ||||
| 		Success bool `json:"success"` | ||||
| 	}{} | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&jsonResponse) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if !jsonResponse.Success { | ||||
| 		return fmt.Errorf("hcaptcha verification failed") | ||||
| 	} | ||||
| 
 | ||||
| 	defer resp.Body.Close() | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		(*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ | ||||
| 			SiteKey: context.Args.HcaptchaSiteKey, | ||||
| 		} | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		hCaptchaResponse := req.FormValue("h-captcha-response") | ||||
| 		secretKey := context.Args.HcaptchaSecret | ||||
| 
 | ||||
| 		err := verifyCaptcha(secretKey, hCaptchaResponse) | ||||
| 		if err != nil { | ||||
| 			(*context.TemplateData)["FormError"] = types.FormError{ | ||||
| 				Errors: []string{"hCaptcha verification failed"}, | ||||
| 			} | ||||
| 			resp.WriteHeader(http.StatusBadRequest) | ||||
| 
 | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
|  | @ -1,32 +1,33 @@ | |||
| package api | ||||
| package keys | ||||
| 
 | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| const MAX_USER_API_KEYS = 5 | ||||
| 
 | ||||
| func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) | ||||
| func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		(*context.TemplateData)["APIKeys"] = apiKeys | ||||
| 		(*context.TemplateData)["APIKeys"] = typesKeys | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		formErrors := FormError{ | ||||
| func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		formErrors := types.FormError{ | ||||
| 			Errors: []string{}, | ||||
| 		} | ||||
| 
 | ||||
|  | @ -38,7 +39,7 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h | |||
| 		} | ||||
| 
 | ||||
| 		if numKeys >= MAX_USER_API_KEYS { | ||||
| 			formErrors.Errors = append(formErrors.Errors, "max api keys reached") | ||||
| 			formErrors.Errors = append(formErrors.Errors, "max types keys reached") | ||||
| 		} | ||||
| 
 | ||||
| 		if len(formErrors.Errors) > 0 { | ||||
|  | @ -59,29 +60,28 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| 		key := req.FormValue("key") | ||||
| func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		apiKey := req.FormValue("key") | ||||
| 
 | ||||
| 		apiKey, err := database.GetAPIKey(context.DBConn, key) | ||||
| 		key, err := database.GetAPIKey(context.DBConn, apiKey) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 		if (apiKey == nil) || (apiKey.UserID != context.User.ID) { | ||||
| 		if (key == nil) || (key.UserID != context.User.ID) { | ||||
| 			resp.WriteHeader(http.StatusUnauthorized) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		err = database.DeleteAPIKey(context.DBConn, key) | ||||
| 		err = database.DeleteAPIKey(context.DBConn, apiKey) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		http.Redirect(resp, req, "/keys", http.StatusFound) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										90
									
								
								api/serve.go
								
								
								
								
							
							
						
						
									
										90
									
								
								api/serve.go
								
								
								
								
							|  | @ -7,27 +7,20 @@ import ( | |||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| type RequestContext struct { | ||||
| 	DBConn *sql.DB | ||||
| 	Args   *args.Arguments | ||||
| 
 | ||||
| 	Id    string | ||||
| 	Start time.Time | ||||
| 
 | ||||
| 	TemplateData *map[string]interface{} | ||||
| 	User         *database.User | ||||
| } | ||||
| 
 | ||||
| type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain | ||||
| type ContinuationChain func(Continuation, Continuation) ContinuationChain | ||||
| 
 | ||||
| func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, _failure Continuation) ContinuationChain { | ||||
| func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		context.Start = time.Now() | ||||
| 		context.Id = utils.RandomId() | ||||
| 
 | ||||
|  | @ -36,8 +29,8 @@ func LogRequestContinuation(context *RequestContext, req *http.Request, resp htt | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, _failure Continuation) ContinuationChain { | ||||
| func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		end := time.Now() | ||||
| 
 | ||||
| 		log.Println(context.Id, "took", end.Sub(context.Start)) | ||||
|  | @ -46,22 +39,22 @@ func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, re | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, _failure Continuation) ContinuationChain { | ||||
| func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		resp.WriteHeader(200) | ||||
| 		resp.Write([]byte("healthy")) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(_success Continuation, failure Continuation) ContinuationChain { | ||||
| func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		return failure(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 	return func(success Continuation, _failure Continuation) ContinuationChain { | ||||
| func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
|  | @ -80,89 +73,90 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { | |||
| 	fileServer := http.FileServer(http.Dir(argv.StaticPath)) | ||||
| 	mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600))) | ||||
| 
 | ||||
| 	makeRequestContext := func() *RequestContext { | ||||
| 		return &RequestContext{ | ||||
| 			DBConn: dbConn, | ||||
| 			Args:   argv, | ||||
| 	cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{ | ||||
| 		APIToken: argv.CloudflareToken, | ||||
| 		ZoneId:   argv.CloudflareZone, | ||||
| 	} | ||||
| 
 | ||||
| 	makeRequestContext := func() *types.RequestContext { | ||||
| 		return &types.RequestContext{ | ||||
| 			DBConn:       dbConn, | ||||
| 			Args:         argv, | ||||
| 			TemplateData: &map[string]interface{}{}, | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) { | ||||
| 	mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	const MAX_USER_RECORDS = 100 | ||||
| 	var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} | ||||
| 	mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		name := r.PathValue("name") | ||||
| 		LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	return &http.Server{ | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| package api | ||||
| package template | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
|  | @ -7,9 +7,11 @@ import ( | |||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| ) | ||||
| 
 | ||||
| func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) { | ||||
| func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) { | ||||
| 	templatePath := context.Args.TemplatePath | ||||
| 	basePath := templatePath + "/base_empty.html" | ||||
| 	if showBaseHtml { | ||||
|  | @ -41,9 +43,9 @@ func renderTemplate(context *RequestContext, templateName string, showBaseHtml b | |||
| 	return buffer, nil | ||||
| } | ||||
| 
 | ||||
| func TemplateContinuation(path string, showBase bool) Continuation { | ||||
| 	return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { | ||||
| 		return func(success Continuation, failure Continuation) ContinuationChain { | ||||
| func TemplateContinuation(path string, showBase bool) types.Continuation { | ||||
| 	return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 		return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 			html, err := renderTemplate(context, path, true) | ||||
| 			if errors.Is(err, os.ErrNotExist) { | ||||
| 				resp.WriteHeader(404) | ||||
|  | @ -66,7 +68,6 @@ func TemplateContinuation(path string, showBase bool) Continuation { | |||
| 				return failure(context, req, resp) | ||||
| 			} | ||||
| 
 | ||||
| 			resp.WriteHeader(200) | ||||
| 			resp.Header().Set("Content-Type", "text/html") | ||||
| 			resp.Write(html.Bytes()) | ||||
| 			return success(context, req, resp) | ||||
|  | @ -0,0 +1,28 @@ | |||
| package types | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| ) | ||||
| 
 | ||||
| type RequestContext struct { | ||||
| 	DBConn *sql.DB | ||||
| 	Args   *args.Arguments | ||||
| 
 | ||||
| 	Id    string | ||||
| 	Start time.Time | ||||
| 
 | ||||
| 	TemplateData *map[string]interface{} | ||||
| 	User         *database.User | ||||
| } | ||||
| 
 | ||||
| type FormError struct { | ||||
| 	Errors []string | ||||
| } | ||||
| 
 | ||||
| type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain | ||||
| type ContinuationChain func(Continuation, Continuation) ContinuationChain | ||||
|  | @ -22,9 +22,8 @@ type Arguments struct { | |||
| 	OauthConfig      *oauth2.Config | ||||
| 	OauthUserInfoURI string | ||||
| 
 | ||||
| 	Dns          bool | ||||
| 	DnsRecursion []string | ||||
| 	DnsPort      int | ||||
| 	Dns     bool | ||||
| 	DnsPort int | ||||
| 
 | ||||
| 	CloudflareToken string | ||||
| 	CloudflareZone  string | ||||
|  | @ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) { | |||
| 	server := flag.Bool("server", false, "Run the server") | ||||
| 
 | ||||
| 	dns := flag.Bool("dns", false, "Run DNS resolver") | ||||
| 	dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers") | ||||
| 	dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver") | ||||
| 
 | ||||
| 	flag.Parse() | ||||
|  | @ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) { | |||
| 		Migrate:         *migrate, | ||||
| 		Scheduler:       *scheduler, | ||||
| 		Dns:             *dns, | ||||
| 		DnsRecursion:    strings.Split(*dnsRecursion, ","), | ||||
| 		DnsPort:         *dnsPort, | ||||
| 
 | ||||
| 		OauthConfig:      oauthConfig, | ||||
|  |  | |||
|  | @ -9,6 +9,12 @@ import ( | |||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type DomainOwner struct { | ||||
| 	UserID    string    `json:"user_id"` | ||||
| 	Domain    string    `json:"domain"` | ||||
| 	CreatedAt time.Time `json:"created_at"` | ||||
| } | ||||
| 
 | ||||
| type DNSRecord struct { | ||||
| 	ID        string    `json:"id"` | ||||
| 	UserID    string    `json:"user_id"` | ||||
|  | @ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) { | |||
| func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) { | ||||
| 	log.Println("saving dns record", record.ID) | ||||
| 
 | ||||
| 	record.CreatedAt = time.Now() | ||||
| 	if (record.CreatedAt == time.Time{}) { | ||||
| 		record.CreatedAt = time.Now() | ||||
| 	} | ||||
| 
 | ||||
| 	_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt) | ||||
| 
 | ||||
| 	if err != nil { | ||||
|  | @ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err | |||
| 
 | ||||
| 	return records, nil | ||||
| } | ||||
| 
 | ||||
| func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) { | ||||
| 	log.Println("saving domain owner", domainOwner.Domain) | ||||
| 
 | ||||
| 	domainOwner.CreatedAt = time.Now() | ||||
| 	_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return domainOwner, nil | ||||
| } | ||||
|  |  | |||
|  | @ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) { | ||||
| 	log.Println("saving session", session.ID) | ||||
| 
 | ||||
| 	_, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt) | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) { | ||||
| 	newExpireAt := time.Now().Add(ExpiryDuration) | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| package dns | ||||
| package hcdns | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
|  | @ -9,27 +9,28 @@ import ( | |||
| 	"log" | ||||
| ) | ||||
| 
 | ||||
| const MAX_RECURSION = 10 | ||||
| 
 | ||||
| func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { | ||||
| 	if maxDepth == 0 { | ||||
| 		return nil, fmt.Errorf("too much recursion") | ||||
| 	} | ||||
| const MAX_RECURSION = 15 | ||||
| 
 | ||||
| func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { | ||||
| 	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	answers := []dns.RR{} | ||||
| 	var answers []dns.RR | ||||
| 	for _, record := range internalCnames { | ||||
| 		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		answers = append(answers, cname) | ||||
| 
 | ||||
| 		cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) | ||||
| 		cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		answers = append(answers, cnameRecursive...) | ||||
| 	} | ||||
| 
 | ||||
|  | @ -43,36 +44,26 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp | |||
| 		return nil, err | ||||
| 	} | ||||
| 	for _, record := range typeDnsRecords { | ||||
| 		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) | ||||
| 		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		answers = append(answers, answer) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(answers) > 0 { | ||||
| 		// base case; we found the answer
 | ||||
| 		return answers, nil | ||||
| 	return answers, nil | ||||
| } | ||||
| 
 | ||||
| func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { | ||||
| 	if maxDepth == 0 { | ||||
| 		return nil, fmt.Errorf("too much recursion") | ||||
| 	} | ||||
| 
 | ||||
| 	message := new(dns.Msg) | ||||
| 	message.SetQuestion(dns.Fqdn(domain), qtype) | ||||
| 	message.RecursionDesired = true | ||||
| 
 | ||||
| 	client := new(dns.Client) | ||||
| 
 | ||||
| 	i := 0 | ||||
| 	in, _, err := client.Exchange(message, dnsResolvers[i]) | ||||
| 	for err != nil { | ||||
| 		i += 1 | ||||
| 		if i == len(dnsResolvers) { | ||||
| 			log.Println(err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		in, _, err = client.Exchange(message, dnsResolvers[i]) | ||||
| 	answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	answers = append(answers, in.Answer...) | ||||
| 	return answers, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -87,22 +78,27 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||
| 	msg.Authoritative = true | ||||
| 
 | ||||
| 	for _, question := range r.Question { | ||||
| 		answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) | ||||
| 		answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) | ||||
| 		if err != nil { | ||||
| 			fmt.Println(err) | ||||
| 			continue | ||||
| 			msg.SetRcode(r, dns.RcodeServerFailure) | ||||
| 			w.WriteMsg(msg) | ||||
| 			return | ||||
| 		} | ||||
| 		msg.Answer = append(msg.Answer, answers...) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(msg.Answer) == 0 { | ||||
| 		msg.SetRcode(r, dns.RcodeNameError) | ||||
| 	} | ||||
| 
 | ||||
| 	log.Println(msg.Answer) | ||||
| 	w.WriteMsg(msg) | ||||
| } | ||||
| 
 | ||||
| func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { | ||||
| 	handler := &DnsHandler{ | ||||
| 		DnsResolvers: argv.DnsRecursion, | ||||
| 		DbConn:       dbConn, | ||||
| 		DbConn: dbConn, | ||||
| 	} | ||||
| 	addr := fmt.Sprintf(":%d", argv.DnsPort) | ||||
| 
 | ||||
|  | @ -0,0 +1,254 @@ | |||
| package hcdns_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"os" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
| 
 | ||||
| func randomPort() int { | ||||
| 	return rand.Intn(3000) + 5192 | ||||
| } | ||||
| 
 | ||||
| func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { | ||||
| 	randomDb := utils.RandomId() | ||||
| 	dnsPort := randomPort() | ||||
| 
 | ||||
| 	testDb := database.MakeConn(&randomDb) | ||||
| 	database.Migrate(testDb) | ||||
| 	testUser := &database.User{ | ||||
| 		ID: "test", | ||||
| 	} | ||||
| 	database.FindOrSaveUser(testDb, testUser) | ||||
| 
 | ||||
| 	waitLock := &sync.Mutex{} | ||||
| 	server := hcdns.MakeServer(&args.Arguments{ | ||||
| 		DnsPort: dnsPort, | ||||
| 	}, testDb) | ||||
| 	server.NotifyStartedFunc = func() { | ||||
| 		waitLock.Unlock() | ||||
| 	} | ||||
| 	waitLock.Lock() | ||||
| 
 | ||||
| 	go func() { | ||||
| 		server.ListenAndServe() | ||||
| 	}() | ||||
| 	waitLock.Lock() | ||||
| 
 | ||||
| 	address := fmt.Sprintf("127.0.0.1:%d", dnsPort) | ||||
| 	return testDb, server, &address, waitLock, func() { | ||||
| 		server.Shutdown() | ||||
| 
 | ||||
| 		testDb.Close() | ||||
| 		os.Remove(randomDb) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWhenCNAMEIsResolved(t *testing.T) { | ||||
| 	t.Log("TestWhenCNAMEIsResolved") | ||||
| 
 | ||||
| 	testDb, _, addr, lock, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 	defer lock.Unlock() | ||||
| 
 | ||||
| 	records := []*database.DNSRecord{ | ||||
| 		{ | ||||
| 			ID:       "0", | ||||
| 			UserID:   "test", | ||||
| 			Name:     "cname.internal.example.com.", | ||||
| 			Type:     "CNAME", | ||||
| 			Content:  "next.internal.example.com.", | ||||
| 			TTL:      300, | ||||
| 			Internal: true, | ||||
| 		}, { | ||||
| 			ID:       "1", | ||||
| 			UserID:   "test", | ||||
| 			Name:     "next.internal.example.com.", | ||||
| 			Type:     "CNAME", | ||||
| 			Content:  "res.example.com.", | ||||
| 			TTL:      300, | ||||
| 			Internal: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			ID:       "2", | ||||
| 			UserID:   "test", | ||||
| 			Name:     "res.example.com.", | ||||
| 			Type:     "A", | ||||
| 			Content:  "1.2.3.2", | ||||
| 			TTL:      300, | ||||
| 			Internal: true, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, record := range records { | ||||
| 		database.SaveDNSRecord(testDb, record) | ||||
| 	} | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn("cname.internal.example.com.") | ||||
| 	client := &dns.Client{} | ||||
| 	message := &dns.Msg{} | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, *addr) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) != 3 { | ||||
| 		t.Fatalf("expected 3 answers, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 
 | ||||
| 	for i, record := range records { | ||||
| 		if in.Answer[i].Header().Name != record.Name { | ||||
| 			t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) | ||||
| 		} | ||||
| 
 | ||||
| 		if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { | ||||
| 			t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) | ||||
| 		} | ||||
| 
 | ||||
| 		if int(in.Answer[i].Header().Ttl) != record.TTL { | ||||
| 			t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) | ||||
| 		} | ||||
| 
 | ||||
| 		if !in.Authoritative { | ||||
| 			t.Fatalf("expected authoritative response") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { | ||||
| 		t.Fatalf("expected final record to be the A record with correct IP") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWhenNoRecordNxDomain(t *testing.T) { | ||||
| 	t.Log("TestWhenNoRecordNxDomain") | ||||
| 
 | ||||
| 	_, _, addr, lock, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 	defer lock.Unlock() | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn("nonexistant.example.com.") | ||||
| 	client := &dns.Client{} | ||||
| 	message := &dns.Msg{} | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, *addr) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) != 0 { | ||||
| 		t.Fatalf("expected 0 answers, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Rcode != dns.RcodeNameError { | ||||
| 		t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWhenUnresolvingCNAME(t *testing.T) { | ||||
| 	t.Log("TestWhenUnresolvingCNAME") | ||||
| 
 | ||||
| 	testDb, _, addr, lock, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 	defer lock.Unlock() | ||||
| 
 | ||||
| 	cname := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		UserID:   "test", | ||||
| 		Name:     "cname.internal.example.com.", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "nonexistant.example.com.", | ||||
| 		TTL:      300, | ||||
| 		Internal: true, | ||||
| 	} | ||||
| 	database.SaveDNSRecord(testDb, cname) | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn(cname.Name) | ||||
| 	client := &dns.Client{} | ||||
| 	message := &dns.Msg{} | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, *addr) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) != 1 { | ||||
| 		t.Fatalf("expected 1 answer, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 
 | ||||
| 	if !in.Authoritative { | ||||
| 		t.Fatalf("expected authoritative response") | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].Header().Name != cname.Name { | ||||
| 		t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].Header().Rrtype != dns.TypeCNAME { | ||||
| 		t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Answer[0].(*dns.CNAME).Target != cname.Content { | ||||
| 		t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Rcode == dns.RcodeNameError { | ||||
| 		t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { | ||||
| 	t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") | ||||
| 
 | ||||
| 	testDb, _, addr, lock, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 	defer lock.Unlock() | ||||
| 
 | ||||
| 	cname := &database.DNSRecord{ | ||||
| 		ID:       "1", | ||||
| 		UserID:   "test", | ||||
| 		Name:     "cname.internal.example.com.", | ||||
| 		Type:     "CNAME", | ||||
| 		Content:  "cname.internal.example.com.", | ||||
| 		TTL:      300, | ||||
| 		Internal: true, | ||||
| 	} | ||||
| 	database.SaveDNSRecord(testDb, cname) | ||||
| 
 | ||||
| 	qtype := dns.TypeA | ||||
| 	domain := dns.Fqdn(cname.Name) | ||||
| 	client := &dns.Client{} | ||||
| 	message := &dns.Msg{} | ||||
| 	message.SetQuestion(domain, qtype) | ||||
| 
 | ||||
| 	in, _, err := client.Exchange(message, *addr) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(in.Answer) > 0 { | ||||
| 		t.Fatalf("expected 0 answers, got %d", len(in.Answer)) | ||||
| 	} | ||||
| 
 | ||||
| 	if in.Rcode != dns.RcodeServerFailure { | ||||
| 		t.Fatalf("expected SERVFAIL, got %d", in.Rcode) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										4
									
								
								main.go
								
								
								
								
							
							
						
						
									
										4
									
								
								main.go
								
								
								
								
							|  | @ -6,7 +6,7 @@ import ( | |||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" | ||||
| 	"github.com/joho/godotenv" | ||||
| ) | ||||
|  | @ -52,7 +52,7 @@ func main() { | |||
| 	} | ||||
| 
 | ||||
| 	if argv.Dns { | ||||
| 		server := dns.MakeServer(argv, dbConn) | ||||
| 		server := hcdns.MakeServer(argv, dbConn) | ||||
| 		log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) | ||||
| 		go func() { | ||||
| 			err = server.ListenAndServe() | ||||
|  |  | |||
|  | @ -15,6 +15,22 @@ | |||
|   padding: 0; | ||||
|   color: var(--text-color); | ||||
|   font-family: "ComicSans", sans-serif; | ||||
| 
 | ||||
|   cursor: url("/static/img/cursor-1.png"), auto; | ||||
|   -webkit-animation: cursor 400ms infinite; | ||||
|   animation: cursor 400ms infinite; | ||||
| } | ||||
| 
 | ||||
| @-webkit-keyframes cursor { | ||||
|   0% {cursor: url("/static/img/cursor-2.png"), auto;} | ||||
|   50% {cursor: url("/static/img/cursor-1.png"), auto;} | ||||
|   100% {cursor: url("/static/img/cursor-2.png"), auto;} | ||||
| } | ||||
| 
 | ||||
| @keyframes cursor { | ||||
|   0% {cursor: url("/static/img/cursor-2.png"), auto;} | ||||
|   50% {cursor: url("/static/img/cursor-1.png"), auto;} | ||||
|   100% {cursor: url("/static/img/cursor-2.png"), auto;} | ||||
| } | ||||
| 
 | ||||
| body { | ||||
|  |  | |||
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 570 B | 
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 563 B | 
		Loading…
	
		Reference in New Issue