package auth_test import ( "database/sql" "log" "net/http" "net/http/httptest" "os" "strings" "testing" "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "golang.org/x/oauth2" ) func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { return success(context, req, resp) } } func FakedOauthServer() *httptest.Server { oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/auth" { code := utils.RandomId() state := r.URL.Query().Get("state") redirectPath := r.URL.Query().Get("redirect_uri") redirectPath += "?code=" + code + "&state=" + state http.Redirect(w, r, redirectPath, http.StatusFound) } if r.URL.Path == "/token" { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) } if r.URL.Path == "/user" { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) } })) return oauthServer } func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { resp.Write([]byte(context.User.Username)) return success(context, req, resp) } } func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/protected-path" { auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) } if r.URL.Path == "/login" { log.Println("login") auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) } if r.URL.Path == "/callback" { log.Println("callback") auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) } if r.URL.Path == "/me" { auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation) } if r.URL.Path == "/logout" { auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) } })) return testServer } func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) { randomDb := utils.RandomId() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) context := &types.RequestContext{ DBConn: testDb, Args: &args.Arguments{}, TemplateData: &(map[string]interface{}{}), } oauthServer := FakedOauthServer() testServer := MockUserEndpointServer(context) return testDb, context, oauthServer, testServer, func() { oauthServer.Close() testServer.Close() testDb.Close() os.Remove(randomDb) } } func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { return &oauth2.Config{ ClientID: "test", ClientSecret: "test", Scopes: []string{"test"}, Endpoint: oauth2.Endpoint{ AuthURL: oauthServerURL + "/auth", TokenURL: oauthServerURL + "/token", }, RedirectURL: testServerURL + "/callback", }, oauthServerURL + "/user" } func FollowAuthentication( oauthServer *httptest.Server, testServer *httptest.Server, cookies map[string]*http.Cookie, location string, ) (map[string]*http.Cookie, string) { resp := httptest.NewRecorder() resp.Code = 0 for resp.Code == 0 || resp.Code == http.StatusFound { req := httptest.NewRequest("GET", location, nil) resp = httptest.NewRecorder() for _, cookie := range cookies { req.AddCookie(cookie) } if strings.HasPrefix(location, oauthServer.URL) { oauthServer.Config.Handler.ServeHTTP(resp, req) } else { testServer.Config.Handler.ServeHTTP(resp, req) } for _, cookie := range resp.Result().Cookies() { cookies[cookie.Name] = cookie } if resp.Code == http.StatusFound { location = resp.Header().Get("Location") } } return cookies, location } func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { db, context, oauthServer, testServer, cleanup := setup() defer cleanup() context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) user, _ := database.GetUser(db, "test") if user != nil { t.Errorf("expected no user, got user") } cookies := make(map[string]*http.Cookie) cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") user, _ = database.GetUser(db, "test") if user == nil { t.Errorf("expected a user to be created, could not find user") } if user.Username != "test" { t.Errorf("expected username to be test, got %s", user.Username) } } func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { _, context, oauthServer, testServer, cleanup := setup() defer cleanup() context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) req := httptest.NewRequest("GET", "/protected-path", nil) resp := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) location := resp.Header().Get("Location") if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) } cookies := make(map[string]*http.Cookie) cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") if !(strings.HasSuffix(location, "/protected-page")) { t.Errorf("expected to redirect back to /protected-page after login, got %s", location) } } func TestOauthSetsUniqueSession(t *testing.T) { db, context, oauthServer, testServer, cleanup := setup() defer cleanup() context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") cookiesAgain := make(map[string]*http.Cookie) cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") sessionOne := cookies["session"].Value sessionTwo := cookiesAgain["session"].Value if sessionOne == sessionTwo { t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) } session, _ := database.GetSession(db, sessionOne) if session.UserID != "test" { t.Errorf("expected session to be associated with user test, got %s", session.UserID) } } func TestLogoutClearsSession(t *testing.T) { db, context, oauthServer, testServer, cleanup := setup() defer cleanup() context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") req := httptest.NewRequest("GET", "/logout", nil) for _, cookie := range cookies { req.AddCookie(cookie) } resp := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) for _, cookie := range resp.Result().Cookies() { cookies[cookie.Name] = cookie } req = httptest.NewRequest("GET", "/me", nil) for _, cookie := range cookies { req.AddCookie(cookie) } resp = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) } session, _ := database.GetSession(db, cookies["session"].Value) if session != nil { t.Errorf("expected session to be deleted, got session") } } func TestRefreshUpdatesExpiration(t *testing.T) { db, context, oauthServer, testServer, cleanup := setup() defer cleanup() context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") session, _ := database.GetSession(db, cookies["session"].Value) req := httptest.NewRequest("GET", "/me", nil) for _, cookie := range cookies { req.AddCookie(cookie) } resp := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) updatedSession, _ := database.GetSession(db, cookies["session"].Value) if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) } } func TestVerifySessionEnsuresNonExpired(t *testing.T) { db, context, oauthServer, testServer, cleanup := setup() defer cleanup() context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") session, _ := database.GetSession(db, cookies["session"].Value) session.ExpireAt = time.Now().Add(-time.Hour) database.SaveSession(db, session) req := httptest.NewRequest("GET", "/me", nil) for _, cookie := range cookies { req.AddCookie(cookie) } resp := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location")) } }