310 lines
8.1 KiB
Go
310 lines
8.1 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
func ListUsersContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
users, err := database.ListUsers(context.DBConn)
|
|
if err != nil {
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
(*context.TemplateData)["Users"] = users
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
verifier := utils.RandomId() + utils.RandomId()
|
|
|
|
sha2 := sha256.New()
|
|
io.WriteString(sha2, verifier)
|
|
codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
|
|
|
|
state := utils.RandomId()
|
|
url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge))
|
|
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "verifier",
|
|
Value: verifier,
|
|
Path: "/",
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: 200,
|
|
})
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "state",
|
|
Value: state,
|
|
Path: "/",
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: 200,
|
|
})
|
|
|
|
http.Redirect(resp, req, url, http.StatusFound)
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
state := req.URL.Query().Get("state")
|
|
code := req.URL.Query().Get("code")
|
|
|
|
if code == "" || state == "" {
|
|
resp.WriteHeader(http.StatusBadRequest)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
if !verifyState(req, "state", state) {
|
|
resp.WriteHeader(http.StatusBadRequest)
|
|
return failure(context, req, resp)
|
|
}
|
|
verifierCookie, err := req.Cookie("verifier")
|
|
if err != nil {
|
|
resp.WriteHeader(http.StatusBadRequest)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
reqContext := req.Context()
|
|
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
|
if err != nil {
|
|
resp.WriteHeader(http.StatusInternalServerError)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
client := context.Args.OauthConfig.Client(reqContext, token)
|
|
user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI)
|
|
if err != nil {
|
|
log.Println(err)
|
|
resp.WriteHeader(http.StatusInternalServerError)
|
|
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
session, err := database.MakeUserSessionFor(context.DBConn, user)
|
|
if err != nil {
|
|
log.Println(err)
|
|
resp.WriteHeader(http.StatusInternalServerError)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "session",
|
|
Value: session.ID,
|
|
Path: "/",
|
|
SameSite: http.SameSiteLaxMode,
|
|
Secure: true,
|
|
})
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "verifier",
|
|
Value: "",
|
|
MaxAge: 0,
|
|
})
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "state",
|
|
Value: "",
|
|
MaxAge: 0,
|
|
})
|
|
|
|
redirect := "/"
|
|
redirectCookie, err := req.Cookie("redirect")
|
|
if err == nil && redirectCookie.Value != "" {
|
|
redirect = redirectCookie.Value
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "redirect",
|
|
MaxAge: 0,
|
|
Value: "",
|
|
})
|
|
}
|
|
|
|
http.Redirect(resp, req, redirect, http.StatusFound)
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
authHeader := req.Header.Get("Authorization")
|
|
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
|
|
|
|
sessionCookie, err := req.Cookie("session")
|
|
if err == nil && sessionCookie.Value != "" {
|
|
user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
|
|
}
|
|
|
|
if userErr != nil || user == nil {
|
|
log.Println(userErr, user)
|
|
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "session",
|
|
Value: "",
|
|
MaxAge: 0,
|
|
})
|
|
|
|
context.User = nil
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
context.User = user
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
log.Println("GoLoginContinuation")
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "redirect",
|
|
Value: req.URL.Path,
|
|
Path: "/",
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
http.Redirect(resp, req, "/login", http.StatusFound)
|
|
return failure(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
sessionCookie, err := req.Cookie("session")
|
|
if err != nil {
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
|
if err != nil {
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
|
sessionCookie, err := req.Cookie("session")
|
|
if err == nil && sessionCookie.Value != "" {
|
|
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
|
}
|
|
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "session",
|
|
MaxAge: 0,
|
|
Value: "",
|
|
})
|
|
http.Redirect(resp, req, "/", http.StatusFound)
|
|
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
|
userResponse, err := client.Get(uri)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userStruct, err := createUserFromOauthResponse(userResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user, err := database.FindOrSaveBaseUser(dbConn, userStruct)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
|
|
user := &database.User{}
|
|
err := json.NewDecoder(response.Body).Decode(user)
|
|
defer response.Body.Close()
|
|
|
|
if err != nil {
|
|
log.Println(err)
|
|
return nil, err
|
|
}
|
|
|
|
user.Username = strings.ToLower(user.Username)
|
|
user.Username = strings.Split(user.Username, "@")[0]
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
|
cookie, err := req.Cookie(stateCookieName)
|
|
if err != nil || cookie.Value != expectedState {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
|
if bearerToken == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
parts := strings.Split(bearerToken, " ")
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
return nil, nil
|
|
}
|
|
|
|
key, err := database.GetAPIKey(dbConn, parts[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if key == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
user, err := database.GetUser(dbConn, key.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
|
|
session, err := database.GetSession(dbConn, sessionId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if session.ExpireAt.Before(time.Now()) {
|
|
session = nil
|
|
database.DeleteSession(dbConn, sessionId)
|
|
return nil, fmt.Errorf("session expired")
|
|
}
|
|
|
|
user, err := database.GetUser(dbConn, session.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|