testing | dont be recursive for external domains | finalize oauth #5
115
api/auth/auth.go
115
api/auth/auth.go
|
@ -35,7 +35,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 60,
|
MaxAge: 200,
|
||||||
})
|
})
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "state",
|
Name: "state",
|
||||||
|
@ -43,7 +43,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 60,
|
MaxAge: 200,
|
||||||
})
|
})
|
||||||
|
|
||||||
http.Redirect(resp, req, url, http.StatusFound)
|
http.Redirect(resp, req, url, http.StatusFound)
|
||||||
|
@ -102,6 +102,16 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
})
|
})
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "verifier",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: 0,
|
||||||
|
})
|
||||||
|
http.SetCookie(resp, &http.Cookie{
|
||||||
|
Name: "state",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: 0,
|
||||||
|
})
|
||||||
|
|
||||||
redirect := "/"
|
redirect := "/"
|
||||||
redirectCookie, err := req.Cookie("redirect")
|
redirectCookie, err := req.Cookie("redirect")
|
||||||
|
@ -110,6 +120,7 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "redirect",
|
Name: "redirect",
|
||||||
MaxAge: 0,
|
MaxAge: 0,
|
||||||
|
Value: "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,52 +129,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
|
||||||
if bearerToken == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(bearerToken, " ")
|
|
||||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
typesKey, err := database.GetAPIKey(dbConn, parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if typesKey == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := database.GetUser(dbConn, typesKey.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
|
|
||||||
session, err := database.GetSession(dbConn, sessionId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if session.ExpireAt.Before(time.Now()) {
|
|
||||||
session = nil
|
|
||||||
database.DeleteSession(dbConn, sessionId)
|
|
||||||
return nil, fmt.Errorf("session expired")
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := database.GetUser(dbConn, session.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
authHeader := req.Header.Get("Authorization")
|
authHeader := req.Header.Get("Authorization")
|
||||||
|
@ -179,6 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request,
|
||||||
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "session",
|
Name: "session",
|
||||||
|
Value: "",
|
||||||
MaxAge: 0, // reset session cookie in case
|
MaxAge: 0, // reset session cookie in case
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -210,13 +176,11 @@ func RefreshSessionContinuation(context *types.RequestContext, req *http.Request
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
sessionCookie, err := req.Cookie("session")
|
sessionCookie, err := req.Cookie("session")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,6 +199,7 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "session",
|
Name: "session",
|
||||||
MaxAge: 0,
|
MaxAge: 0,
|
||||||
|
Value: "",
|
||||||
})
|
})
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
|
@ -246,7 +211,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userStruct, err := createUserFromResponse(userResponse)
|
userStruct, err := createUserFromOauthResponse(userResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -259,7 +224,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserFromResponse(response *http.Response) (*database.User, error) {
|
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
|
||||||
user := &database.User{
|
user := &database.User{
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
@ -286,3 +251,49 @@ func verifyState(req *http.Request, stateCookieName string, expectedState string
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
|
||||||
|
if bearerToken == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(bearerToken, " ")
|
||||||
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := database.GetAPIKey(dbConn, parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if key == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := database.GetUser(dbConn, key.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
|
||||||
|
session, err := database.GetSession(dbConn, sessionId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.ExpireAt.Before(time.Now()) {
|
||||||
|
session = nil
|
||||||
|
database.DeleteSession(dbConn, sessionId)
|
||||||
|
return nil, fmt.Errorf("session expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := database.GetUser(dbConn, session.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
|
@ -2,14 +2,24 @@ package auth_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func setup() (*sql.DB, *types.RequestContext, func()) {
|
func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
randomDb := utils.RandomId()
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
@ -28,9 +38,61 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
func TestLoginSendsYouToRedirect(t *testing.T) {
|
||||||
todo: test types key creation
|
db, context, cleanup := setup()
|
||||||
+ api key attached to user
|
defer cleanup()
|
||||||
+ user session is unique
|
|
||||||
+ goLogin goes to page in cookie
|
user := &database.User{
|
||||||
*/
|
ID: "test",
|
||||||
|
Username: "test",
|
||||||
|
}
|
||||||
|
database.FindOrSaveUser(db, user)
|
||||||
|
|
||||||
|
session, _ := database.MakeUserSessionFor(db, user)
|
||||||
|
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
protectedPath := testServer.URL + "/protected-path"
|
||||||
|
req := httptest.NewRequest("GET", protectedPath, nil)
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
location := resp.Header().Get("Location")
|
||||||
|
if resp.Code != http.StatusFound && location != "/login" {
|
||||||
|
t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: "session",
|
||||||
|
Value: session.ID,
|
||||||
|
MaxAge: 60,
|
||||||
|
})
|
||||||
|
resp = httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthFormatsUsername(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionIsUnique(t *testing.T) {}
|
||||||
|
|
||||||
|
func TestLogoutClearsCookie(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshUpdatesExpiration(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPITokensAreEquivalentToSessions(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue