From 48f124f2723617f0b4eb570ff987b1abd8d41bbe Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Thu, 28 Mar 2024 10:53:30 -0600 Subject: [PATCH] add api keys and route --- adapters/cloudflare/cloudflare.go | 2 - api/api_keys.go | 84 +++++++++++++++++++++++++++++++ api/auth.go | 74 ++++++++++++++++++++------- api/dns.go | 20 +++++--- api/serve.go | 15 ++++++ database/users.go | 63 +++++++++++++++++++++++ templates/api_keys.html | 40 +++++++++++++++ templates/base.html | 3 ++ utils/RandomId.go | 9 ++-- 9 files changed, 276 insertions(+), 34 deletions(-) create mode 100644 api/api_keys.go create mode 100644 templates/api_keys.html diff --git a/adapters/cloudflare/cloudflare.go b/adapters/cloudflare/cloudflare.go index bfcbea6..40b04a5 100644 --- a/adapters/cloudflare/cloudflare.go +++ b/adapters/cloudflare/cloudflare.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "strings" @@ -19,7 +18,6 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) - log.Println(reqBody) payload := strings.NewReader(reqBody) req, _ := http.NewRequest("POST", url, payload) diff --git a/api/api_keys.go b/api/api_keys.go new file mode 100644 index 0000000..17ed6c9 --- /dev/null +++ b/api/api_keys.go @@ -0,0 +1,84 @@ +package api + +import ( + "log" + "net/http" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +const MAX_USER_API_KEYS = 5 + +func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["APIKeys"] = apiKeys + return success(context, req, resp) + } +} + +func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + formErrors := FormError{ + Errors: []string{}, + } + + apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if len(apiKeys) >= MAX_USER_API_KEYS { + formErrors.Errors = append(formErrors.Errors, "max api keys reached") + } + + _, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{ + UserID: context.User.ID, + Key: utils.RandomId(), + }) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/keys", http.StatusFound) + return success(context, req, resp) + } +} + +func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + key := req.FormValue("key") + + apiKey, err := database.GetAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if (apiKey == nil) || (apiKey.UserID != context.User.ID) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + err = database.DeleteAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/keys", http.StatusFound) + return success(context, req, resp) + } +} diff --git a/api/auth.go b/api/auth.go index 4733971..dcddf5a 100644 --- a/api/auth.go +++ b/api/auth.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "fmt" "io" "log" "net/http" @@ -116,32 +117,69 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp } } +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + apiKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if apiKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, apiKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + sessionCookie, err := req.Cookie("session") - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) + if err == nil { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) } - session, err := database.GetSession(context.DBConn, sessionCookie.Value) - if err == nil && session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(context.DBConn, sessionCookie.Value) - } - if err != nil || session == nil { + if userErr != nil || user == nil { + log.Println(userErr, user) + http.SetCookie(resp, &http.Cookie{ Name: "session", - MaxAge: 0, + MaxAge: 0, // reset session cookie in case }) - - return failure(context, req, resp) - } - - user, err := database.GetUser(context.DBConn, session.UserID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } diff --git a/api/dns.go b/api/dns.go index 0822fbc..5123acc 100644 --- a/api/dns.go +++ b/api/dns.go @@ -10,6 +10,8 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" ) +const MAX_USER_RECORDS = 20 + type FormError struct { Errors []string } @@ -43,6 +45,9 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res recordContent := req.FormValue("content") ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) + if err != nil { + formErrors.Errors = append(formErrors.Errors, "invalid ttl") + } dnsRecord := &database.DNSRecord{ UserID: context.User.ID, @@ -52,8 +57,14 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res TTL: ttlNum, } + dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) if err != nil { - formErrors.Errors = append(formErrors.Errors, "invalid ttl") + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if len(dnsRecords) >= MAX_USER_RECORDS { + formErrors.Errors = append(formErrors.Errors, "max records reached") } if !userCanFuckWithDNSRecord(context.User, dnsRecord) { @@ -83,13 +94,6 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res return success(context, req, resp) } - dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - (*context.TemplateData)["DNSRecords"] = dnsRecords (*context.TemplateData)["FormError"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord diff --git a/api/serve.go b/api/serve.go index 09e2072..d16ea99 100644 --- a/api/serve.go +++ b/api/serve.go @@ -126,6 +126,21 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) + mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(IdContinuation, TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() name := r.PathValue("name") diff --git a/database/users.go b/database/users.go index d2b4f20..f9e4436 100644 --- a/database/users.go +++ b/database/users.go @@ -12,6 +12,12 @@ const ( ExpiryDuration = time.Hour * 24 ) +type UserApiKey struct { + Key string `json:"key"` + UserID string `json:"user_id"` + CreatedAt time.Time `json:"created_at"` +} + type User struct { ID string `json:"sub"` Mail string `json:"email"` @@ -119,3 +125,60 @@ func DeleteExpiredSessions(dbConn *sql.DB) error { } return nil } + +func ListUserAPIKeys(dbConn *sql.DB, userId string) ([]*UserApiKey, error) { + rows, err := dbConn.Query(`SELECT key, user_id, created_at FROM api_keys WHERE user_id = ?;`, userId) + if err != nil { + log.Println(err) + return nil, err + } + defer rows.Close() + + var apiKeys []*UserApiKey + for rows.Next() { + var apiKey UserApiKey + err := rows.Scan(&apiKey.Key, &apiKey.UserID, &apiKey.CreatedAt) + if err != nil { + log.Println(err) + return nil, err + } + + apiKeys = append(apiKeys, &apiKey) + } + + return apiKeys, nil +} + +func SaveAPIKey(dbConn *sql.DB, apiKey *UserApiKey) (*UserApiKey, error) { + _, err := dbConn.Exec(`INSERT OR REPLACE INTO api_keys (key, user_id) VALUES (?, ?);`, apiKey.Key, apiKey.UserID) + if err != nil { + log.Println(err) + return nil, err + } + + apiKey.CreatedAt = time.Now() + return apiKey, nil +} + +func GetAPIKey(dbConn *sql.DB, key string) (*UserApiKey, error) { + row := dbConn.QueryRow(`SELECT key, user_id, created_at FROM api_keys WHERE key = ?;`, key) + + var apiKey UserApiKey + err := row.Scan(&apiKey.Key, &apiKey.UserID, &apiKey.CreatedAt) + if err != nil { + log.Println(err) + return nil, err + } + + return &apiKey, nil +} + +func DeleteAPIKey(dbConn *sql.DB, key string) error { + _, err := dbConn.Exec(`DELETE FROM api_keys WHERE key = ?;`, key) + if err != nil { + log.Println(err) + return err + } + + return nil +} diff --git a/templates/api_keys.html b/templates/api_keys.html new file mode 100644 index 0000000..93eebd5 --- /dev/null +++ b/templates/api_keys.html @@ -0,0 +1,40 @@ +{{ define "content" }} + + + + + + + {{ if (eq (len .APIKeys) 0) }} + + + + {{ end }} + {{ range $key := .APIKeys }} + + + + + + {{ end }} +
KeyCreated AtRevoke
No API Keys Found
{{ $key.Key }}{{ $key.CreatedAt }} +
+ + +
+
+
+
+

Add An API Key

+
+ +
+ + {{ if .FormError }} + {{ if (len .FormError.Errors) }} + {{ range $error := .FormError.Errors }} +
{{ $error }}
+ {{ end }} + {{ end }} + {{ end }} +{{ end }} diff --git a/templates/base.html b/templates/base.html index 79e0d12..d0f97c7 100644 --- a/templates/base.html +++ b/templates/base.html @@ -35,7 +35,10 @@ {{ if .User }} dns. | + api keys. + | logout, {{ .User.DisplayName }}. + {{ else }} login. {{ end }} diff --git a/utils/RandomId.go b/utils/RandomId.go index 09f089d..1b03ec8 100644 --- a/utils/RandomId.go +++ b/utils/RandomId.go @@ -6,14 +6,11 @@ import ( ) func RandomId() string { - uuid := make([]byte, 16) - _, err := rand.Read(uuid) + id := make([]byte, 16) + _, err := rand.Read(id) if err != nil { panic(err) } - uuid[8] = uuid[8]&^0xc0 | 0x80 - uuid[6] = uuid[6]&^0xf0 | 0x40 - - return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]) + return fmt.Sprintf("%x", id) }