add some auth test cases
	
		
			
	
		
	
	
		
			
				
	
				continuous-integration/drone/pr Build is passing
				
					Details
				
			
		
	
				
					
				
			
				
	
				continuous-integration/drone/pr Build is passing
				
					Details
				
			
		
	This commit is contained in:
		
							parent
							
								
									94984aa4b0
								
							
						
					
					
						commit
						ae640a253e
					
				|  | @ -74,7 +74,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req | |||
| 		reqContext := req.Context() | ||||
| 		token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			resp.WriteHeader(http.StatusInternalServerError) | ||||
| 			return failure(context, req, resp) | ||||
| 		} | ||||
|  | @ -195,12 +194,13 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h | |||
| 			_ = database.DeleteSession(context.DBConn, sessionCookie.Value) | ||||
| 		} | ||||
| 
 | ||||
| 		http.Redirect(resp, req, "/", http.StatusFound) | ||||
| 		http.SetCookie(resp, &http.Cookie{ | ||||
| 			Name:   "session", | ||||
| 			MaxAge: 0, | ||||
| 			Value:  "", | ||||
| 		}) | ||||
| 		http.Redirect(resp, req, "/", http.StatusFound) | ||||
| 
 | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
|  | @ -225,10 +225,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us | |||
| } | ||||
| 
 | ||||
| func createUserFromOauthResponse(response *http.Response) (*database.User, error) { | ||||
| 	user := &database.User{ | ||||
| 		CreatedAt: time.Now(), | ||||
| 	} | ||||
| 
 | ||||
| 	user := &database.User{} | ||||
| 	err := json.NewDecoder(response.Body).Decode(user) | ||||
| 	defer response.Body.Close() | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,9 +2,11 @@ package auth_test | |||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" | ||||
|  | @ -12,6 +14,7 @@ import ( | |||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database" | ||||
| 	"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
| func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
|  | @ -38,51 +41,232 @@ func setup() (*sql.DB, *types.RequestContext, func()) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestLoginSendsYouToRedirect(t *testing.T) { | ||||
| func FakedOauthServer() *httptest.Server { | ||||
| 	oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if r.URL.Path == "/auth" { | ||||
| 			code := utils.RandomId() | ||||
| 
 | ||||
| 			state := r.URL.Query().Get("state") | ||||
| 			redirectPath := r.URL.Query().Get("redirect_uri") | ||||
| 			redirectPath += "?code=" + code + "&state=" + state | ||||
| 
 | ||||
| 			http.Redirect(w, r, redirectPath, http.StatusFound) | ||||
| 		} | ||||
| 		if r.URL.Path == "/token" { | ||||
| 			w.Header().Set("Content-Type", "application/json") | ||||
| 			w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) | ||||
| 		} | ||||
| 		if r.URL.Path == "/user" { | ||||
| 			w.Header().Set("Content-Type", "application/json") | ||||
| 			w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) | ||||
| 		} | ||||
| 	})) | ||||
| 
 | ||||
| 	return oauthServer | ||||
| } | ||||
| 
 | ||||
| func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { | ||||
| 	return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { | ||||
| 		resp.Write([]byte(context.User.Username)) | ||||
| 		return success(context, req, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if r.URL.Path == "/protected-path" { | ||||
| 			auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/login" { | ||||
| 			log.Println("login") | ||||
| 			auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/callback" { | ||||
| 			log.Println("callback") | ||||
| 			auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/me" { | ||||
| 			auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 
 | ||||
| 		if r.URL.Path == "/logout" { | ||||
| 			auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) | ||||
| 		} | ||||
| 	})) | ||||
| 	return testServer | ||||
| } | ||||
| 
 | ||||
| func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { | ||||
| 	return &oauth2.Config{ | ||||
| 		ClientID:     "test", | ||||
| 		ClientSecret: "test", | ||||
| 		Scopes:       []string{"test"}, | ||||
| 		Endpoint: oauth2.Endpoint{ | ||||
| 			AuthURL:  oauthServerURL + "/auth", | ||||
| 			TokenURL: oauthServerURL + "/token", | ||||
| 		}, | ||||
| 		RedirectURL: testServerURL + "/callback", | ||||
| 	}, oauthServerURL + "/user" | ||||
| } | ||||
| 
 | ||||
| func FollowAuthentication( | ||||
| 	oauthServer *httptest.Server, | ||||
| 	testServer *httptest.Server, | ||||
| 	cookies map[string]*http.Cookie, | ||||
| 	location string, | ||||
| ) (map[string]*http.Cookie, string) { | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	resp.Code = 0 | ||||
| 
 | ||||
| 	for resp.Code == 0 || resp.Code == http.StatusFound { | ||||
| 		req := httptest.NewRequest("GET", location, nil) | ||||
| 		resp = httptest.NewRecorder() | ||||
| 
 | ||||
| 		for _, cookie := range cookies { | ||||
| 			req.AddCookie(cookie) | ||||
| 		} | ||||
| 		if strings.HasPrefix(location, oauthServer.URL) { | ||||
| 			oauthServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 		} else { | ||||
| 			testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 		} | ||||
| 		for _, cookie := range resp.Result().Cookies() { | ||||
| 			cookies[cookie.Name] = cookie | ||||
| 		} | ||||
| 
 | ||||
| 		if resp.Code == http.StatusFound { | ||||
| 			location = resp.Header().Get("Location") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return cookies, location | ||||
| } | ||||
| 
 | ||||
| func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	user := &database.User{ | ||||
| 		ID:       "test", | ||||
| 		Username: "test", | ||||
| 	} | ||||
| 	database.FindOrSaveUser(db, user) | ||||
| 
 | ||||
| 	session, _ := database.MakeUserSessionFor(db, user) | ||||
| 
 | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) | ||||
| 	})) | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	protectedPath := testServer.URL + "/protected-path" | ||||
| 	req := httptest.NewRequest("GET", protectedPath, nil) | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	location := resp.Header().Get("Location") | ||||
| 	if resp.Code != http.StatusFound && location != "/login" { | ||||
| 		t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) | ||||
| 	user, _ := database.GetUser(db, "test") | ||||
| 	if user != nil { | ||||
| 		t.Errorf("expected no user, got user") | ||||
| 	} | ||||
| 
 | ||||
| 	req.AddCookie(&http.Cookie{ | ||||
| 		Name:   "session", | ||||
| 		Value:  session.ID, | ||||
| 		MaxAge: 60, | ||||
| 	}) | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") | ||||
| 
 | ||||
| 	user, _ = database.GetUser(db, "test") | ||||
| 	if user == nil { | ||||
| 		t.Errorf("expected a user to be created, could not find user") | ||||
| 	} | ||||
| 	if user.Username != "test" { | ||||
| 		t.Errorf("expected username to be test, got %s", user.Username) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { | ||||
| 	_, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/protected-path", nil) | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	location := resp.Header().Get("Location") | ||||
| 	if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { | ||||
| 		t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) | ||||
| 	} | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") | ||||
| 
 | ||||
| 	if !(strings.HasSuffix(location, "/protected-page")) { | ||||
| 		t.Errorf("expected to redirect back to /protected-page after login, got %s", location) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOauthSetsUniqueSession(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") | ||||
| 
 | ||||
| 	cookiesAgain := make(map[string]*http.Cookie) | ||||
| 	cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") | ||||
| 
 | ||||
| 	sessionOne := cookies["session"].Value | ||||
| 	sessionTwo := cookiesAgain["session"].Value | ||||
| 	if sessionOne == sessionTwo { | ||||
| 		t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) | ||||
| 	} | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, sessionOne) | ||||
| 	if session.UserID != "test" { | ||||
| 		t.Errorf("expected session to be associated with user test, got %s", session.UserID) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestLogoutClearsSession(t *testing.T) { | ||||
| 	db, context, cleanup := setup() | ||||
| 	defer cleanup() | ||||
| 
 | ||||
| 	oauthServer := FakedOauthServer() | ||||
| 	testServer := MockUserEndpointServer(context) | ||||
| 	defer oauthServer.Close() | ||||
| 	defer testServer.Close() | ||||
| 
 | ||||
| 	context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) | ||||
| 
 | ||||
| 	cookies := make(map[string]*http.Cookie) | ||||
| 	cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") | ||||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/logout", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	for _, cookie := range resp.Result().Cookies() { | ||||
| 		cookies[cookie.Name] = cookie | ||||
| 	} | ||||
| 
 | ||||
| 	req = httptest.NewRequest("GET", "/me", nil) | ||||
| 	for _, cookie := range cookies { | ||||
| 		req.AddCookie(cookie) | ||||
| 	} | ||||
| 	resp = httptest.NewRecorder() | ||||
| 	testServer.Config.Handler.ServeHTTP(resp, req) | ||||
| 	if resp.Code != http.StatusOK { | ||||
| } | ||||
| 
 | ||||
| func TestOauthFormatsUsername(t *testing.T) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestSessionIsUnique(t *testing.T) {} | ||||
| 
 | ||||
| func TestLogoutClearsCookie(t *testing.T) { | ||||
| 	if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { | ||||
| 		t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) | ||||
| 	} | ||||
| 
 | ||||
| 	session, _ := database.GetSession(db, cookies["session"].Value) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected session to be deleted, got session") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRefreshUpdatesExpiration(t *testing.T) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue