246 lines
6.6 KiB
Go
246 lines
6.6 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
verifier := utils.RandomId() + utils.RandomId()
|
|
|
|
sha2 := sha256.New()
|
|
io.WriteString(sha2, verifier)
|
|
codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
|
|
|
|
state := utils.RandomId()
|
|
url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge))
|
|
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "verifier",
|
|
Value: verifier,
|
|
Path: "/",
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: 60,
|
|
})
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "state",
|
|
Value: state,
|
|
Path: "/",
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: 60,
|
|
})
|
|
|
|
http.Redirect(resp, req, url, http.StatusFound)
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
state := req.URL.Query().Get("state")
|
|
code := req.URL.Query().Get("code")
|
|
|
|
if code == "" || state == "" {
|
|
resp.WriteHeader(http.StatusBadRequest)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
if !verifyState(req, "state", state) {
|
|
resp.WriteHeader(http.StatusBadRequest)
|
|
return failure(context, req, resp)
|
|
}
|
|
verifierCookie, err := req.Cookie("verifier")
|
|
if err != nil {
|
|
resp.WriteHeader(http.StatusBadRequest)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
reqContext := req.Context()
|
|
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
|
if err != nil {
|
|
log.Println(err)
|
|
resp.WriteHeader(http.StatusInternalServerError)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
client := context.Args.OauthConfig.Client(reqContext, token)
|
|
user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI)
|
|
if err != nil {
|
|
log.Println(err)
|
|
resp.WriteHeader(http.StatusInternalServerError)
|
|
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
session, err := database.MakeUserSessionFor(context.DBConn, user)
|
|
if err != nil {
|
|
log.Println(err)
|
|
resp.WriteHeader(http.StatusInternalServerError)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "session",
|
|
Value: session.ID,
|
|
Path: "/",
|
|
SameSite: http.SameSiteLaxMode,
|
|
Secure: true,
|
|
})
|
|
|
|
redirect := "/"
|
|
redirectCookie, err := req.Cookie("redirect")
|
|
if err == nil && redirectCookie.Value != "" {
|
|
redirect = redirectCookie.Value
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "redirect",
|
|
MaxAge: 0,
|
|
})
|
|
}
|
|
|
|
http.Redirect(resp, req, redirect, http.StatusFound)
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
sessionCookie, err := req.Cookie("session")
|
|
if err != nil {
|
|
resp.WriteHeader(http.StatusUnauthorized)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
session, err := database.GetSession(context.DBConn, sessionCookie.Value)
|
|
if err == nil && session.ExpireAt.Before(time.Now()) {
|
|
session = nil
|
|
database.DeleteSession(context.DBConn, sessionCookie.Value)
|
|
}
|
|
if err != nil || session == nil {
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "session",
|
|
MaxAge: 0,
|
|
})
|
|
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
user, err := database.GetUser(context.DBConn, session.UserID)
|
|
if err != nil {
|
|
log.Println(err)
|
|
resp.WriteHeader(http.StatusUnauthorized)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
context.User = user
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "redirect",
|
|
Value: req.URL.Path,
|
|
Path: "/",
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
http.Redirect(resp, req, "/login", http.StatusFound)
|
|
return failure(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
sessionCookie, err := req.Cookie("session")
|
|
if err != nil {
|
|
resp.WriteHeader(http.StatusUnauthorized)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
|
if err != nil {
|
|
resp.WriteHeader(http.StatusUnauthorized)
|
|
return failure(context, req, resp)
|
|
}
|
|
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
|
|
return func(success Continuation, failure Continuation) ContinuationChain {
|
|
sessionCookie, err := req.Cookie("session")
|
|
if err == nil && sessionCookie.Value != "" {
|
|
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
|
}
|
|
|
|
http.Redirect(resp, req, "/", http.StatusFound)
|
|
http.SetCookie(resp, &http.Cookie{
|
|
Name: "session",
|
|
MaxAge: 0,
|
|
})
|
|
return success(context, req, resp)
|
|
}
|
|
}
|
|
|
|
func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
|
|
userResponse, err := client.Get(uri)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userStruct, err := createUserFromResponse(userResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user, err := database.FindOrSaveUser(dbConn, userStruct)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func createUserFromResponse(response *http.Response) (*database.User, error) {
|
|
defer response.Body.Close()
|
|
user := &database.User{
|
|
CreatedAt: time.Now(),
|
|
}
|
|
err := json.NewDecoder(response.Body).Decode(user)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return nil, err
|
|
}
|
|
|
|
user.Username = strings.ToLower(user.Username)
|
|
user.Username = strings.Split(user.Username, "@")[0]
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
|
|
cookie, err := req.Cookie(stateCookieName)
|
|
if err != nil || cookie.Value != expectedState {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|