finish auth tests
	
		
			
	
		
	
	
		
			
				
	
				continuous-integration/drone/pr Build is passing
				
					Details
				
			
		
	
				
					
				
			
				
	
				continuous-integration/drone/pr Build is passing
				
					Details
				
			
		
	This commit is contained in:
		
							parent
							
								
									ae640a253e
								
							
						
					
					
						commit
						5177735b83
					
				|  | @ -8,6 +8,7 @@ import ( | |||
| 	"os" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" | ||||
|  | @ -23,24 +24,6 @@ func IdContinuation(context *types.RequestContext, req *http.Request, resp http. | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func setup() (*sql.DB, *types.RequestContext, func()) { | ||||
| 	randomDb := utils.RandomId() | ||||
| 
 | ||||
| 	testDb := database.MakeConn(&randomDb) | ||||
| 	database.Migrate(testDb) | ||||
| 
 | ||||
| 	context := &types.RequestContext{ | ||||
| 		DBConn:       testDb, | ||||
| 		Args:         &args.Arguments{}, | ||||
| 		TemplateData: &(map[string]interface{}{}), | ||||
| 	} | ||||
| 
 | ||||
| 	return testDb, context, func() { | ||||
| 		testDb.Close() | ||||
| 		os.Remove(randomDb) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func FakedOauthServer() *httptest.Server { | ||||
| 	oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if r.URL.Path == "/auth" { | ||||
|  | @ -89,7 +72,7 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { | |||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/me" { | ||||
| 			auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) | ||||
| 			auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/logout" { | ||||
|  | @ -99,6 +82,30 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { | |||
| 	return testServer | ||||
| } | ||||
| 
 | ||||
| func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) { | ||||
| 	randomDb := utils.RandomId() | ||||
| 
 | ||||
| 	testDb := database.MakeConn(&randomDb) | ||||
| 	database.Migrate(testDb) | ||||
| 
 | ||||
| 	context := &types.RequestContext{ | ||||
| 		DBConn:       testDb, | ||||
| 		Args:         &args.Arguments{}, | ||||
| 		TemplateData: &(map[string]interface{}{}), | ||||
| 	} | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 
 | ||||
| 	return testDb, context, oauthServer, testServer, func() { | ||||
| 		oauthServer.Close() | ||||
| 		testServer.Close() | ||||
| 
 | ||||
| 		testDb.Close() | ||||
| 		os.Remove(randomDb) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { | ||||
| 	return &oauth2.Config{ | ||||
| 		ClientID:     "test", | ||||
|  | @ -146,14 +153,9 @@ func FollowAuthentication( | |||
| } | ||||
| 
 | ||||
| func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	user, _ := database.GetUser(db, "test") | ||||
|  | @ -174,14 +176,9 @@ func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { | ||||
| 	_, context, cleanup := setup() | ||||
| 	_, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/protected-path", nil) | ||||
|  | @ -201,14 +198,9 @@ func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestOauthSetsUniqueSession(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
|  | @ -230,14 +222,9 @@ func TestOauthSetsUniqueSession(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestLogoutClearsSession(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
|  | @ -270,13 +257,52 @@ func TestLogoutClearsSession(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestRefreshUpdatesExpiration(t *testing.T) { | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/me", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 
 | ||||
| 	updatedSession, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 
 | ||||
| 	// if session expiration is greater than or equal to updated session expiration
 | ||||
| 	if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { | ||||
| 		t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestVerifySessionEnsuresNonExpired(t *testing.T) { | ||||
| 	db, context, oauthServer, testServer, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestAPITokensAreEquivalentToSessions(t *testing.T) { | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 	session.ExpireAt = time.Now().Add(-time.Hour) | ||||
| 	database.SaveSession(db, session) | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/me", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 
 | ||||
| 	if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { | ||||
| 		t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location")) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -62,27 +62,26 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, | |||
| 
 | ||||
| func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		key := req.FormValue("key") | ||||
| 		apiKey := req.FormValue("key") | ||||
| 
 | ||||
| 		typesKey, err := database.GetAPIKey(context.DBConn, key) | ||||
| 		key, err := database.GetAPIKey(context.DBConn, apiKey) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 		if (typesKey == nil) || (typesKey.UserID != context.User.ID) { | ||||
| 		if (key == nil) || (key.UserID != context.User.ID) { | ||||
| 			resp.WriteHeader(http.StatusUnauthorized) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		err = database.DeleteAPIKey(context.DBConn, key) | ||||
| 		err = database.DeleteAPIKey(context.DBConn, apiKey) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
| 
 | ||||
| 		http.Redirect(resp, req, "/keys", http.StatusFound) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -140,7 +140,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { | |||
| 
 | ||||
| 	mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		requestContext := makeRequestContext() | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 		LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) | ||||
| 	}) | ||||
| 
 | ||||
| 	mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { | ||||
|  |  | |||
|  | @ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) { | ||||
| 	log.Println("saving session", session.ID) | ||||
| 
 | ||||
| 	_, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt) | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) { | ||||
| 	newExpireAt := time.Now().Add(ExpiryDuration) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue