testing | dont be recursive for external domains | finalize oauth #5
			
				
			
		
		
		
	
							
								
								
									
										115
									
								
								api/auth/auth.go
								
								
								
								
							
							
						
						
									
										115
									
								
								api/auth/auth.go
								
								
								
								
							|  | @ -35,7 +35,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, | |||
| 			Path:     "/", | ||||
| 			Secure:   true, | ||||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 			MaxAge:   60, | ||||
| 			MaxAge:   200, | ||||
| 		}) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:     "state", | ||||
|  | @ -43,7 +43,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, | |||
| 			Path:     "/", | ||||
| 			Secure:   true, | ||||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 			MaxAge:   60, | ||||
| 			MaxAge:   200, | ||||
| 		}) | ||||
| 
 | ||||
| 		http.Redirect(resp, req, url, http.StatusFound) | ||||
|  | @ -102,6 +102,16 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req | |||
| 			SameSite: http.SameSiteLaxMode, | ||||
| 			Secure:   true, | ||||
| 		}) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "verifier", | ||||
| 			Value:  "", | ||||
| 			MaxAge: 0, | ||||
| 		}) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "state", | ||||
| 			Value:  "", | ||||
| 			MaxAge: 0, | ||||
| 		}) | ||||
| 
 | ||||
| 		redirect := "/" | ||||
| 		redirectCookie, err := req.Cookie("redirect") | ||||
|  | @ -110,6 +120,7 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req | |||
| 			http.SetCookie(resp, &http.Cookie{ | ||||
| 				Name:   "redirect", | ||||
| 				MaxAge: 0, | ||||
| 				Value:  "", | ||||
| 			}) | ||||
| 		} | ||||
| 
 | ||||
|  | @ -118,52 +129,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { | ||||
| 	if bearerToken == "" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	parts := strings.Split(bearerToken, " ") | ||||
| 	if len(parts) != 2 || parts[0] != "Bearer" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	typesKey, err := database.GetAPIKey(dbConn, parts[1]) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if typesKey == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := database.GetUser(dbConn, typesKey.UserID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { | ||||
| 	session, err := database.GetSession(dbConn, sessionId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if session.ExpireAt.Before(time.Now()) { | ||||
| 		session = nil | ||||
| 		database.DeleteSession(dbConn, sessionId) | ||||
| 		return nil, fmt.Errorf("session expired") | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := database.GetUser(dbConn, session.UserID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		authHeader := req.Header.Get("Authorization") | ||||
|  | @ -179,6 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, | |||
| 
 | ||||
| 			http.SetCookie(resp, &http.Cookie{ | ||||
| 				Name:   "session", | ||||
| 				Value:  "", | ||||
| 				MaxAge: 0, // reset session cookie in case
 | ||||
| 			}) | ||||
| 
 | ||||
|  | @ -210,13 +176,11 @@ func RefreshSessionContinuation(context *types.RequestContext, req *http.Request | |||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		sessionCookie, err := req.Cookie("session") | ||||
| 		if err != nil { | ||||
| 			resp.WriteHeader(http.StatusUnauthorized) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		_, err = database.RefreshSession(context.DBConn, sessionCookie.Value) | ||||
| 		if err != nil { | ||||
| 			resp.WriteHeader(http.StatusUnauthorized) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
|  | @ -235,6 +199,7 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h | |||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "session", | ||||
| 			MaxAge: 0, | ||||
| 			Value:  "", | ||||
| 		}) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
|  | @ -246,7 +211,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	userStruct, err := createUserFromResponse(userResponse) | ||||
| 	userStruct, err := createUserFromOauthResponse(userResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -259,7 +224,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us | |||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func createUserFromResponse(response *http.Response) (*database.User, error) { | ||||
| func createUserFromOauthResponse(response *http.Response) (*database.User, error) { | ||||
| 	user := &database.User{ | ||||
| 		CreatedAt: time.Now(), | ||||
| 	} | ||||
|  | @ -286,3 +251,49 @@ func verifyState(req *http.Request, stateCookieName string, expectedState string | |||
| 
 | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { | ||||
| 	if bearerToken == "" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	parts := strings.Split(bearerToken, " ") | ||||
| 	if len(parts) != 2 || parts[0] != "Bearer" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	key, err := database.GetAPIKey(dbConn, parts[1]) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if key == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := database.GetUser(dbConn, key.UserID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
| 
 | ||||
| func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { | ||||
| 	session, err := database.GetSession(dbConn, sessionId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if session.ExpireAt.Before(time.Now()) { | ||||
| 		session = nil | ||||
| 		database.DeleteSession(dbConn, sessionId) | ||||
| 		return nil, fmt.Errorf("session expired") | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := database.GetUser(dbConn, session.UserID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return user, nil | ||||
| } | ||||
|  |  | |||
|  | @ -2,14 +2,24 @@ package auth_test | |||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| ) | ||||
| 
 | ||||
| func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func setup() (*sql.DB, *types.RequestContext, func()) { | ||||
| 	randomDb := utils.RandomId() | ||||
| 
 | ||||
|  | @ -28,9 +38,61 @@ func setup() (*sql.DB, *types.RequestContext, func()) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| todo: test types key creation | ||||
| + api key attached to user | ||||
| + user session is unique | ||||
| + goLogin goes to page in cookie | ||||
| */ | ||||
| func TestLoginSendsYouToRedirect(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	user := &database.User{ | ||||
| 		ID:       "test", | ||||
| 		Username: "test", | ||||
| 	} | ||||
| 	database.FindOrSaveUser(db, user) | ||||
| 
 | ||||
| 	session, _ := database.MakeUserSessionFor(db, user) | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	protectedPath := testServer.URL + "/protected-path" | ||||
| 	req := httptest.NewRequest("GET", protectedPath, nil) | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 
 | ||||
| 	location := resp.Header().Get("Location") | ||||
| 	if resp.Code != http.StatusFound && location != "/login" { | ||||
| 		t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) | ||||
| 	} | ||||
| 
 | ||||
| 	req.AddCookie(&http.Cookie{ | ||||
| 		Name:   "session", | ||||
| 		Value:  session.ID, | ||||
| 		MaxAge: 60, | ||||
| 	}) | ||||
| 	resp = httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	if resp.Code != http.StatusOK { | ||||
| } | ||||
| 
 | ||||
| func TestOauthFormatsUsername(t *testing.T) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestSessionIsUnique(t *testing.T) {} | ||||
| 
 | ||||
| func TestLogoutClearsCookie(t *testing.T) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestRefreshUpdatesExpiration(t *testing.T) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestVerifySessionEnsuresNonExpired(t *testing.T) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestAPITokensAreEquivalentToSessions(t *testing.T) { | ||||
| 
 | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue