hatecomputers.club/api/auth/auth.go

297 lines
7.7 KiB
Go

package auth
import (
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
"golang.org/x/oauth2"
)
func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
verifier := utils.RandomId() + utils.RandomId()
sha2 := sha256.New()
io.WriteString(sha2, verifier)
codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
state := utils.RandomId()
url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge))
http.SetCookie(resp, &http.Cookie{
Name: "verifier",
Value: verifier,
Path: "/",
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: 200,
})
http.SetCookie(resp, &http.Cookie{
Name: "state",
Value: state,
Path: "/",
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: 200,
})
http.Redirect(resp, req, url, http.StatusFound)
return success(context, req, resp)
}
}
func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
state := req.URL.Query().Get("state")
code := req.URL.Query().Get("code")
if code == "" || state == "" {
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
if !verifyState(req, "state", state) {
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
verifierCookie, err := req.Cookie("verifier")
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
return failure(context, req, resp)
}
reqContext := req.Context()
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
if err != nil {
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
client := context.Args.OauthConfig.Client(reqContext, token)
user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
session, err := database.MakeUserSessionFor(context.DBConn, user)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
http.SetCookie(resp, &http.Cookie{
Name: "session",
Value: session.ID,
Path: "/",
SameSite: http.SameSiteLaxMode,
Secure: true,
})
http.SetCookie(resp, &http.Cookie{
Name: "verifier",
Value: "",
MaxAge: 0,
})
http.SetCookie(resp, &http.Cookie{
Name: "state",
Value: "",
MaxAge: 0,
})
redirect := "/"
redirectCookie, err := req.Cookie("redirect")
if err == nil && redirectCookie.Value != "" {
redirect = redirectCookie.Value
http.SetCookie(resp, &http.Cookie{
Name: "redirect",
MaxAge: 0,
Value: "",
})
}
http.Redirect(resp, req, redirect, http.StatusFound)
return success(context, req, resp)
}
}
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
authHeader := req.Header.Get("Authorization")
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
sessionCookie, err := req.Cookie("session")
if err == nil && sessionCookie.Value != "" {
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
}
if userErr != nil || user == nil {
log.Println(userErr, user)
http.SetCookie(resp, &http.Cookie{
Name: "session",
Value: "",
MaxAge: 0,
})
context.User = nil
return failure(context, req, resp)
}
context.User = user
return success(context, req, resp)
}
}
func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
http.SetCookie(resp, &http.Cookie{
Name: "redirect",
Value: req.URL.Path,
Path: "/",
Secure: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(resp, req, "/login", http.StatusFound)
return failure(context, req, resp)
}
}
func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err != nil {
return failure(context, req, resp)
}
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
if err != nil {
return failure(context, req, resp)
}
return success(context, req, resp)
}
}
func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err == nil && sessionCookie.Value != "" {
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
}
http.SetCookie(resp, &http.Cookie{
Name: "session",
MaxAge: 0,
Value: "",
})
http.Redirect(resp, req, "/", http.StatusFound)
return success(context, req, resp)
}
}
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
userResponse, err := client.Get(uri)
if err != nil {
return nil, err
}
userStruct, err := createUserFromOauthResponse(userResponse)
if err != nil {
return nil, err
}
user, err := database.FindOrSaveUser(dbConn, userStruct)
if err != nil {
return nil, err
}
return user, nil
}
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
user := &database.User{}
err := json.NewDecoder(response.Body).Decode(user)
defer response.Body.Close()
if err != nil {
log.Println(err)
return nil, err
}
user.Username = strings.ToLower(user.Username)
user.Username = strings.Split(user.Username, "@")[0]
return user, nil
}
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
cookie, err := req.Cookie(stateCookieName)
if err != nil || cookie.Value != expectedState {
return false
}
return true
}
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
if bearerToken == "" {
return nil, nil
}
parts := strings.Split(bearerToken, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return nil, nil
}
key, err := database.GetAPIKey(dbConn, parts[1])
if err != nil {
return nil, err
}
if key == nil {
return nil, nil
}
user, err := database.GetUser(dbConn, key.UserID)
if err != nil {
return nil, err
}
return user, nil
}
func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
session, err := database.GetSession(dbConn, sessionId)
if err != nil {
return nil, err
}
if session.ExpireAt.Before(time.Now()) {
session = nil
database.DeleteSession(dbConn, sessionId)
return nil, fmt.Errorf("session expired")
}
user, err := database.GetUser(dbConn, session.UserID)
if err != nil {
return nil, err
}
return user, nil
}