hatecomputers.club/api/auth/auth_test.go

309 lines
9.9 KiB
Go

package auth_test
import (
"database/sql"
"log"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
"golang.org/x/oauth2"
)
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
return success(context, req, resp)
}
}
func FakedOauthServer() *httptest.Server {
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/auth" {
code := utils.RandomId()
state := r.URL.Query().Get("state")
redirectPath := r.URL.Query().Get("redirect_uri")
redirectPath += "?code=" + code + "&state=" + state
http.Redirect(w, r, redirectPath, http.StatusFound)
}
if r.URL.Path == "/token" {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
}
if r.URL.Path == "/user" {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
}
}))
return oauthServer
}
func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
resp.Write([]byte(context.User.Username))
return success(context, req, resp)
}
}
func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/protected-path" {
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/login" {
log.Println("login")
auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/callback" {
log.Println("callback")
auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/me" {
auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/logout" {
auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
}
}))
return testServer
}
func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
randomDb := utils.RandomId()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
context := &types.RequestContext{
DBConn: testDb,
Args: &args.Arguments{},
TemplateData: &(map[string]interface{}{}),
}
oauthServer := FakedOauthServer()
testServer := MockUserEndpointServer(context)
return testDb, context, oauthServer, testServer, func() {
oauthServer.Close()
testServer.Close()
testDb.Close()
os.Remove(randomDb)
}
}
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
return &oauth2.Config{
ClientID: "test",
ClientSecret: "test",
Scopes: []string{"test"},
Endpoint: oauth2.Endpoint{
AuthURL: oauthServerURL + "/auth",
TokenURL: oauthServerURL + "/token",
},
RedirectURL: testServerURL + "/callback",
}, oauthServerURL + "/user"
}
func FollowAuthentication(
oauthServer *httptest.Server,
testServer *httptest.Server,
cookies map[string]*http.Cookie,
location string,
) (map[string]*http.Cookie, string) {
resp := httptest.NewRecorder()
resp.Code = 0
for resp.Code == 0 || resp.Code == http.StatusFound {
req := httptest.NewRequest("GET", location, nil)
resp = httptest.NewRecorder()
for _, cookie := range cookies {
req.AddCookie(cookie)
}
if strings.HasPrefix(location, oauthServer.URL) {
oauthServer.Config.Handler.ServeHTTP(resp, req)
} else {
testServer.Config.Handler.ServeHTTP(resp, req)
}
for _, cookie := range resp.Result().Cookies() {
cookies[cookie.Name] = cookie
}
if resp.Code == http.StatusFound {
location = resp.Header().Get("Location")
}
}
return cookies, location
}
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
user, _ := database.GetUser(db, "test")
if user != nil {
t.Errorf("expected no user, got user")
}
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
user, _ = database.GetUser(db, "test")
if user == nil {
t.Errorf("expected a user to be created, could not find user")
}
if user.Username != "test" {
t.Errorf("expected username to be test, got %s", user.Username)
}
}
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
_, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
req := httptest.NewRequest("GET", "/protected-path", nil)
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
location := resp.Header().Get("Location")
if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
}
cookies := make(map[string]*http.Cookie)
cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
if !(strings.HasSuffix(location, "/protected-page")) {
t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
}
}
func TestOauthSetsUniqueSession(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
cookiesAgain := make(map[string]*http.Cookie)
cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
sessionOne := cookies["session"].Value
sessionTwo := cookiesAgain["session"].Value
if sessionOne == sessionTwo {
t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
}
session, _ := database.GetSession(db, sessionOne)
if session.UserID != "test" {
t.Errorf("expected session to be associated with user test, got %s", session.UserID)
}
}
func TestLogoutClearsSession(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
req := httptest.NewRequest("GET", "/logout", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
for _, cookie := range resp.Result().Cookies() {
cookies[cookie.Name] = cookie
}
req = httptest.NewRequest("GET", "/me", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
}
session, _ := database.GetSession(db, cookies["session"].Value)
if session != nil {
t.Errorf("expected session to be deleted, got session")
}
}
func TestRefreshUpdatesExpiration(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
session, _ := database.GetSession(db, cookies["session"].Value)
req := httptest.NewRequest("GET", "/me", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
updatedSession, _ := database.GetSession(db, cookies["session"].Value)
// if session expiration is greater than or equal to updated session expiration
if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
}
}
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
session, _ := database.GetSession(db, cookies["session"].Value)
session.ExpireAt = time.Now().Add(-time.Hour)
database.SaveSession(db, session)
req := httptest.NewRequest("GET", "/me", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
}
}