diff --git a/api/auth/auth.go b/api/auth/auth.go index dc348b2..3c633cd 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -35,7 +35,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.SetCookie(resp, &http.Cookie{ Name: "state", @@ -43,7 +43,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.Redirect(resp, req, url, http.StatusFound) @@ -102,6 +102,16 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req SameSite: http.SameSiteLaxMode, Secure: true, }) + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: "", + MaxAge: 0, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: "", + MaxAge: 0, + }) redirect := "/" redirectCookie, err := req.Cookie("redirect") @@ -110,6 +120,7 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req http.SetCookie(resp, &http.Cookie{ Name: "redirect", MaxAge: 0, + Value: "", }) } @@ -118,52 +129,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req } } -func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { - if bearerToken == "" { - return nil, nil - } - - parts := strings.Split(bearerToken, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - return nil, nil - } - - typesKey, err := database.GetAPIKey(dbConn, parts[1]) - if err != nil { - return nil, err - } - if typesKey == nil { - return nil, nil - } - - user, err := database.GetUser(dbConn, typesKey.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { - session, err := database.GetSession(dbConn, sessionId) - if err != nil { - return nil, err - } - - if session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(dbConn, sessionId) - return nil, fmt.Errorf("session expired") - } - - user, err := database.GetUser(dbConn, session.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { authHeader := req.Header.Get("Authorization") @@ -179,6 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, http.SetCookie(resp, &http.Cookie{ Name: "session", + Value: "", MaxAge: 0, // reset session cookie in case }) @@ -210,13 +176,11 @@ func RefreshSessionContinuation(context *types.RequestContext, req *http.Request return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { sessionCookie, err := req.Cookie("session") if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -235,6 +199,7 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, + Value: "", }) return success(context, req, resp) } @@ -246,7 +211,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return nil, err } - userStruct, err := createUserFromResponse(userResponse) + userStruct, err := createUserFromOauthResponse(userResponse) if err != nil { return nil, err } @@ -259,7 +224,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return user, nil } -func createUserFromResponse(response *http.Response) (*database.User, error) { +func createUserFromOauthResponse(response *http.Response) (*database.User, error) { user := &database.User{ CreatedAt: time.Now(), } @@ -286,3 +251,49 @@ func verifyState(req *http.Request, stateCookieName string, expectedState string return true } + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + key, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if key == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, key.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index a6c2a45..caaedf1 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,14 +2,24 @@ package auth_test import ( "database/sql" + "net/http" + "net/http/httptest" "os" + "testing" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + func setup() (*sql.DB, *types.RequestContext, func()) { randomDb := utils.RandomId() @@ -28,9 +38,61 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -/* -todo: test types key creation -+ api key attached to user -+ user session is unique -+ goLogin goes to page in cookie -*/ +func TestLoginSendsYouToRedirect(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + user := &database.User{ + ID: "test", + Username: "test", + } + database.FindOrSaveUser(db, user) + + session, _ := database.MakeUserSessionFor(db, user) + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + protectedPath := testServer.URL + "/protected-path" + req := httptest.NewRequest("GET", protectedPath, nil) + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + location := resp.Header().Get("Location") + if resp.Code != http.StatusFound && location != "/login" { + t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + } + + req.AddCookie(&http.Cookie{ + Name: "session", + Value: session.ID, + MaxAge: 60, + }) + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { +} + +func TestOauthFormatsUsername(t *testing.T) { + +} + +func TestSessionIsUnique(t *testing.T) {} + +func TestLogoutClearsCookie(t *testing.T) { + +} + +func TestRefreshUpdatesExpiration(t *testing.T) { + +} + +func TestVerifySessionEnsuresNonExpired(t *testing.T) { + +} + +func TestAPITokensAreEquivalentToSessions(t *testing.T) { + +}