package api import ( "crypto/sha256" "database/sql" "encoding/base64" "encoding/json" "fmt" "io" "log" "net/http" "strings" "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "golang.org/x/oauth2" ) func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { verifier := utils.RandomId() + utils.RandomId() sha2 := sha256.New() io.WriteString(sha2, verifier) codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) state := utils.RandomId() url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) http.SetCookie(resp, &http.Cookie{ Name: "verifier", Value: verifier, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, MaxAge: 60, }) http.SetCookie(resp, &http.Cookie{ Name: "state", Value: state, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, MaxAge: 60, }) http.Redirect(resp, req, url, http.StatusFound) return success(context, req, resp) } } func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { state := req.URL.Query().Get("state") code := req.URL.Query().Get("code") if code == "" || state == "" { resp.WriteHeader(http.StatusBadRequest) return failure(context, req, resp) } if !verifyState(req, "state", state) { resp.WriteHeader(http.StatusBadRequest) return failure(context, req, resp) } verifierCookie, err := req.Cookie("verifier") if err != nil { resp.WriteHeader(http.StatusBadRequest) return failure(context, req, resp) } reqContext := req.Context() token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } client := context.Args.OauthConfig.Client(reqContext, token) user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } session, err := database.MakeUserSessionFor(context.DBConn, user) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } http.SetCookie(resp, &http.Cookie{ Name: "session", Value: session.ID, Path: "/", SameSite: http.SameSiteLaxMode, Secure: true, }) redirect := "/" redirectCookie, err := req.Cookie("redirect") if err == nil && redirectCookie.Value != "" { redirect = redirectCookie.Value http.SetCookie(resp, &http.Cookie{ Name: "redirect", MaxAge: 0, }) } http.Redirect(resp, req, redirect, http.StatusFound) return success(context, req, resp) } } func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { if bearerToken == "" { return nil, nil } parts := strings.Split(bearerToken, " ") if len(parts) != 2 || parts[0] != "Bearer" { return nil, nil } apiKey, err := database.GetAPIKey(dbConn, parts[1]) if err != nil { return nil, err } if apiKey == nil { return nil, nil } user, err := database.GetUser(dbConn, apiKey.UserID) if err != nil { return nil, err } return user, nil } func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { session, err := database.GetSession(dbConn, sessionId) if err != nil { return nil, err } if session.ExpireAt.Before(time.Now()) { session = nil database.DeleteSession(dbConn, sessionId) return nil, fmt.Errorf("session expired") } user, err := database.GetUser(dbConn, session.UserID) if err != nil { return nil, err } return user, nil } func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { authHeader := req.Header.Get("Authorization") user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) sessionCookie, err := req.Cookie("session") if err == nil { user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) } if userErr != nil || user == nil { log.Println(userErr, user) http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, // reset session cookie in case }) return failure(context, req, resp) } context.User = user return success(context, req, resp) } } func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { http.SetCookie(resp, &http.Cookie{ Name: "redirect", Value: req.URL.Path, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, }) http.Redirect(resp, req, "/login", http.StatusFound) return failure(context, req, resp) } } func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { sessionCookie, err := req.Cookie("session") if err != nil { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) if err != nil { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } return success(context, req, resp) } } func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { sessionCookie, err := req.Cookie("session") if err == nil && sessionCookie.Value != "" { _ = database.DeleteSession(context.DBConn, sessionCookie.Value) } http.Redirect(resp, req, "/", http.StatusFound) http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, }) return success(context, req, resp) } } func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { userResponse, err := client.Get(uri) if err != nil { return nil, err } userStruct, err := createUserFromResponse(userResponse) if err != nil { return nil, err } user, err := database.FindOrSaveUser(dbConn, userStruct) if err != nil { return nil, err } return user, nil } func createUserFromResponse(response *http.Response) (*database.User, error) { defer response.Body.Close() user := &database.User{ CreatedAt: time.Now(), } err := json.NewDecoder(response.Body).Decode(user) if err != nil { log.Println(err) return nil, err } user.Username = strings.ToLower(user.Username) user.Username = strings.Split(user.Username, "@")[0] return user, nil } func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { cookie, err := req.Cookie(stateCookieName) if err != nil || cookie.Value != expectedState { return false } return true }