add api keys and route
This commit is contained in:
parent
243bb8e35b
commit
48f124f272
|
@ -4,7 +4,6 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
@ -19,7 +18,6 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
|
|||
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId)
|
||||
|
||||
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
|
||||
log.Println(reqBody)
|
||||
payload := strings.NewReader(reqBody)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, payload)
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||
)
|
||||
|
||||
const MAX_USER_API_KEYS = 5
|
||||
|
||||
func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["APIKeys"] = apiKeys
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
formErrors := FormError{
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
if len(apiKeys) >= MAX_USER_API_KEYS {
|
||||
formErrors.Errors = append(formErrors.Errors, "max api keys reached")
|
||||
}
|
||||
|
||||
_, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{
|
||||
UserID: context.User.ID,
|
||||
Key: utils.RandomId(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
http.Redirect(resp, req, "/keys", http.StatusFound)
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
key := req.FormValue("key")
|
||||
|
||||
apiKey, err := database.GetAPIKey(context.DBConn, key)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
if (apiKey == nil) || (apiKey.UserID != context.User.ID) {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
err = database.DeleteAPIKey(context.DBConn, key)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
http.Redirect(resp, req, "/keys", http.StatusFound)
|
||||
return success(context, req, resp)
|
||||
}
|
||||
}
|
74
api/auth.go
74
api/auth.go
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
@ -116,32 +117,69 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
|
|||
}
|
||||
}
|
||||
|
||||
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
||||
if bearerToken == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(bearerToken, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
apiKey, err := database.GetAPIKey(dbConn, parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if apiKey == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
user, err := database.GetUser(dbConn, apiKey.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
|
||||
session, err := database.GetSession(dbConn, sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.ExpireAt.Before(time.Now()) {
|
||||
session = nil
|
||||
database.DeleteSession(dbConn, sessionId)
|
||||
return nil, fmt.Errorf("session expired")
|
||||
}
|
||||
|
||||
user, err := database.GetUser(dbConn, session.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
||||
return func(success Continuation, failure Continuation) ContinuationChain {
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
||||
|
||||
sessionCookie, err := req.Cookie("session")
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
if err == nil {
|
||||
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
||||
}
|
||||
|
||||
session, err := database.GetSession(context.DBConn, sessionCookie.Value)
|
||||
if err == nil && session.ExpireAt.Before(time.Now()) {
|
||||
session = nil
|
||||
database.DeleteSession(context.DBConn, sessionCookie.Value)
|
||||
}
|
||||
if err != nil || session == nil {
|
||||
if userErr != nil || user == nil {
|
||||
log.Println(userErr, user)
|
||||
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: "session",
|
||||
MaxAge: 0,
|
||||
MaxAge: 0, // reset session cookie in case
|
||||
})
|
||||
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
user, err := database.GetUser(context.DBConn, session.UserID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
|
|
20
api/dns.go
20
api/dns.go
|
@ -10,6 +10,8 @@ import (
|
|||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||
)
|
||||
|
||||
const MAX_USER_RECORDS = 20
|
||||
|
||||
type FormError struct {
|
||||
Errors []string
|
||||
}
|
||||
|
@ -43,6 +45,9 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res
|
|||
recordContent := req.FormValue("content")
|
||||
ttl := req.FormValue("ttl")
|
||||
ttlNum, err := strconv.Atoi(ttl)
|
||||
if err != nil {
|
||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||
}
|
||||
|
||||
dnsRecord := &database.DNSRecord{
|
||||
UserID: context.User.ID,
|
||||
|
@ -52,8 +57,14 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res
|
|||
TTL: ttlNum,
|
||||
}
|
||||
|
||||
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
formErrors.Errors = append(formErrors.Errors, "invalid ttl")
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
if len(dnsRecords) >= MAX_USER_RECORDS {
|
||||
formErrors.Errors = append(formErrors.Errors, "max records reached")
|
||||
}
|
||||
|
||||
if !userCanFuckWithDNSRecord(context.User, dnsRecord) {
|
||||
|
@ -83,13 +94,6 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res
|
|||
return success(context, req, resp)
|
||||
}
|
||||
|
||||
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
return failure(context, req, resp)
|
||||
}
|
||||
|
||||
(*context.TemplateData)["DNSRecords"] = dnsRecords
|
||||
(*context.TemplateData)["FormError"] = &formErrors
|
||||
(*context.TemplateData)["RecordForm"] = dnsRecord
|
||||
|
|
15
api/serve.go
15
api/serve.go
|
@ -126,6 +126,21 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
|||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(IdContinuation, TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
requestContext := makeRequestContext()
|
||||
name := r.PathValue("name")
|
||||
|
|
|
@ -12,6 +12,12 @@ const (
|
|||
ExpiryDuration = time.Hour * 24
|
||||
)
|
||||
|
||||
type UserApiKey struct {
|
||||
Key string `json:"key"`
|
||||
UserID string `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"sub"`
|
||||
Mail string `json:"email"`
|
||||
|
@ -119,3 +125,60 @@ func DeleteExpiredSessions(dbConn *sql.DB) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListUserAPIKeys(dbConn *sql.DB, userId string) ([]*UserApiKey, error) {
|
||||
rows, err := dbConn.Query(`SELECT key, user_id, created_at FROM api_keys WHERE user_id = ?;`, userId)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var apiKeys []*UserApiKey
|
||||
for rows.Next() {
|
||||
var apiKey UserApiKey
|
||||
err := rows.Scan(&apiKey.Key, &apiKey.UserID, &apiKey.CreatedAt)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiKeys = append(apiKeys, &apiKey)
|
||||
}
|
||||
|
||||
return apiKeys, nil
|
||||
}
|
||||
|
||||
func SaveAPIKey(dbConn *sql.DB, apiKey *UserApiKey) (*UserApiKey, error) {
|
||||
_, err := dbConn.Exec(`INSERT OR REPLACE INTO api_keys (key, user_id) VALUES (?, ?);`, apiKey.Key, apiKey.UserID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiKey.CreatedAt = time.Now()
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
func GetAPIKey(dbConn *sql.DB, key string) (*UserApiKey, error) {
|
||||
row := dbConn.QueryRow(`SELECT key, user_id, created_at FROM api_keys WHERE key = ?;`, key)
|
||||
|
||||
var apiKey UserApiKey
|
||||
err := row.Scan(&apiKey.Key, &apiKey.UserID, &apiKey.CreatedAt)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
func DeleteAPIKey(dbConn *sql.DB, key string) error {
|
||||
_, err := dbConn.Exec(`DELETE FROM api_keys WHERE key = ?;`, key)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
{{ define "content" }}
|
||||
<table>
|
||||
<tr>
|
||||
<th>Key</th>
|
||||
<th>Created At</th>
|
||||
<th>Revoke</th>
|
||||
</tr>
|
||||
{{ if (eq (len .APIKeys) 0) }}
|
||||
<tr>
|
||||
<td colspan="5"><span class="blinky">No API Keys Found</span></td>
|
||||
</tr>
|
||||
{{ end }}
|
||||
{{ range $key := .APIKeys }}
|
||||
<tr>
|
||||
<td>{{ $key.Key }}</td>
|
||||
<td>{{ $key.CreatedAt }}</td>
|
||||
<td>
|
||||
<form method="POST" action="/keys/delete">
|
||||
<input type="hidden" name="key" value="{{ $key.Key }}" />
|
||||
<input type="submit" value="Revoke" />
|
||||
</form>
|
||||
</td>
|
||||
</tr>
|
||||
{{ end }}
|
||||
</table>
|
||||
<br>
|
||||
<form method="POST" action="/keys" class="form">
|
||||
<h2>Add An API Key</h2>
|
||||
<hr>
|
||||
<input type="submit" value="Generate" />
|
||||
</form>
|
||||
|
||||
{{ if .FormError }}
|
||||
{{ if (len .FormError.Errors) }}
|
||||
{{ range $error := .FormError.Errors }}
|
||||
<div class="error">{{ $error }}</div>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
{{ end }}
|
|
@ -35,7 +35,10 @@
|
|||
{{ if .User }}
|
||||
<a href="/dns">dns.</a>
|
||||
<span> | </span>
|
||||
<a href="/keys">api keys.</a>
|
||||
<span> | </span>
|
||||
<a href="/logout">logout, {{ .User.DisplayName }}.</a>
|
||||
|
||||
{{ else }}
|
||||
<a href="/login">login.</a>
|
||||
{{ end }}
|
||||
|
|
|
@ -6,14 +6,11 @@ import (
|
|||
)
|
||||
|
||||
func RandomId() string {
|
||||
uuid := make([]byte, 16)
|
||||
_, err := rand.Read(uuid)
|
||||
id := make([]byte, 16)
|
||||
_, err := rand.Read(id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
uuid[8] = uuid[8]&^0xc0 | 0x80
|
||||
uuid[6] = uuid[6]&^0xf0 | 0x40
|
||||
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])
|
||||
return fmt.Sprintf("%x", id)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue