diff --git a/api/auth/auth.go b/api/auth/auth.go index 3c633cd..becce24 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -74,7 +74,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req reqContext := req.Context() token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) if err != nil { - log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } @@ -195,12 +194,13 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h _ = database.DeleteSession(context.DBConn, sessionCookie.Value) } - http.Redirect(resp, req, "/", http.StatusFound) http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, Value: "", }) + http.Redirect(resp, req, "/", http.StatusFound) + return success(context, req, resp) } } @@ -225,10 +225,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us } func createUserFromOauthResponse(response *http.Response) (*database.User, error) { - user := &database.User{ - CreatedAt: time.Now(), - } - + user := &database.User{} err := json.NewDecoder(response.Body).Decode(user) defer response.Body.Close() diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index caaedf1..1e54099 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,9 +2,11 @@ package auth_test import ( "database/sql" + "log" "net/http" "net/http/httptest" "os" + "strings" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" @@ -12,6 +14,7 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" ) func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { @@ -38,51 +41,232 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -func TestLoginSendsYouToRedirect(t *testing.T) { +func FakedOauthServer() *httptest.Server { + oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/auth" { + code := utils.RandomId() + + state := r.URL.Query().Get("state") + redirectPath := r.URL.Query().Get("redirect_uri") + redirectPath += "?code=" + code + "&state=" + state + + http.Redirect(w, r, redirectPath, http.StatusFound) + } + if r.URL.Path == "/token" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) + } + if r.URL.Path == "/user" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) + } + })) + + return oauthServer +} + +func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + resp.Write([]byte(context.User.Username)) + return success(context, req, resp) + } +} + +func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/protected-path" { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/login" { + log.Println("login") + auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/callback" { + log.Println("callback") + auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/me" { + auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/logout" { + auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) + } + })) + return testServer +} + +func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { + return &oauth2.Config{ + ClientID: "test", + ClientSecret: "test", + Scopes: []string{"test"}, + Endpoint: oauth2.Endpoint{ + AuthURL: oauthServerURL + "/auth", + TokenURL: oauthServerURL + "/token", + }, + RedirectURL: testServerURL + "/callback", + }, oauthServerURL + "/user" +} + +func FollowAuthentication( + oauthServer *httptest.Server, + testServer *httptest.Server, + cookies map[string]*http.Cookie, + location string, +) (map[string]*http.Cookie, string) { + resp := httptest.NewRecorder() + resp.Code = 0 + + for resp.Code == 0 || resp.Code == http.StatusFound { + req := httptest.NewRequest("GET", location, nil) + resp = httptest.NewRecorder() + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + if strings.HasPrefix(location, oauthServer.URL) { + oauthServer.Config.Handler.ServeHTTP(resp, req) + } else { + testServer.Config.Handler.ServeHTTP(resp, req) + } + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + if resp.Code == http.StatusFound { + location = resp.Header().Get("Location") + } + } + + return cookies, location +} + +func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { db, context, cleanup := setup() defer cleanup() - user := &database.User{ - ID: "test", - Username: "test", - } - database.FindOrSaveUser(db, user) - - session, _ := database.MakeUserSessionFor(db, user) - - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) - })) + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() defer testServer.Close() - protectedPath := testServer.URL + "/protected-path" - req := httptest.NewRequest("GET", protectedPath, nil) - resp := httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(resp, req) + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) - location := resp.Header().Get("Location") - if resp.Code != http.StatusFound && location != "/login" { - t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + user, _ := database.GetUser(db, "test") + if user != nil { + t.Errorf("expected no user, got user") } - req.AddCookie(&http.Cookie{ - Name: "session", - Value: session.ID, - MaxAge: 60, - }) + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + user, _ = database.GetUser(db, "test") + if user == nil { + t.Errorf("expected a user to be created, could not find user") + } + if user.Username != "test" { + t.Errorf("expected username to be test, got %s", user.Username) + } +} + +func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { + _, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + req := httptest.NewRequest("GET", "/protected-path", nil) + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + location := resp.Header().Get("Location") + if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { + t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) + } + + cookies := make(map[string]*http.Cookie) + cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") + + if !(strings.HasSuffix(location, "/protected-page")) { + t.Errorf("expected to redirect back to /protected-page after login, got %s", location) + } +} + +func TestOauthSetsUniqueSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + cookiesAgain := make(map[string]*http.Cookie) + cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") + + sessionOne := cookies["session"].Value + sessionTwo := cookiesAgain["session"].Value + if sessionOne == sessionTwo { + t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) + } + + session, _ := database.GetSession(db, sessionOne) + if session.UserID != "test" { + t.Errorf("expected session to be associated with user test, got %s", session.UserID) + } +} + +func TestLogoutClearsSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + req := httptest.NewRequest("GET", "/logout", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + req = httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } resp = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) - if resp.Code != http.StatusOK { -} - -func TestOauthFormatsUsername(t *testing.T) { - -} - -func TestSessionIsUnique(t *testing.T) {} - -func TestLogoutClearsCookie(t *testing.T) { + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) + } + session, _ := database.GetSession(db, cookies["session"].Value) + if session != nil { + t.Errorf("expected session to be deleted, got session") + } } func TestRefreshUpdatesExpiration(t *testing.T) {