testing | dont be recursive for external domains | finalize oauth #5
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
|
||||||
|
@ -23,24 +24,6 @@ func IdContinuation(context *types.RequestContext, req *http.Request, resp http.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setup() (*sql.DB, *types.RequestContext, func()) {
|
|
||||||
randomDb := utils.RandomId()
|
|
||||||
|
|
||||||
testDb := database.MakeConn(&randomDb)
|
|
||||||
database.Migrate(testDb)
|
|
||||||
|
|
||||||
context := &types.RequestContext{
|
|
||||||
DBConn: testDb,
|
|
||||||
Args: &args.Arguments{},
|
|
||||||
TemplateData: &(map[string]interface{}{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
return testDb, context, func() {
|
|
||||||
testDb.Close()
|
|
||||||
os.Remove(randomDb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func FakedOauthServer() *httptest.Server {
|
func FakedOauthServer() *httptest.Server {
|
||||||
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/auth" {
|
if r.URL.Path == "/auth" {
|
||||||
|
@ -89,7 +72,7 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Path == "/me" {
|
if r.URL.Path == "/me" {
|
||||||
auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Path == "/logout" {
|
if r.URL.Path == "/logout" {
|
||||||
|
@ -99,6 +82,30 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
|
||||||
return testServer
|
return testServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
|
||||||
|
randomDb := utils.RandomId()
|
||||||
|
|
||||||
|
testDb := database.MakeConn(&randomDb)
|
||||||
|
database.Migrate(testDb)
|
||||||
|
|
||||||
|
context := &types.RequestContext{
|
||||||
|
DBConn: testDb,
|
||||||
|
Args: &args.Arguments{},
|
||||||
|
TemplateData: &(map[string]interface{}{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthServer := FakedOauthServer()
|
||||||
|
testServer := MockUserEndpointServer(context)
|
||||||
|
|
||||||
|
return testDb, context, oauthServer, testServer, func() {
|
||||||
|
oauthServer.Close()
|
||||||
|
testServer.Close()
|
||||||
|
|
||||||
|
testDb.Close()
|
||||||
|
os.Remove(randomDb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
|
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
|
||||||
return &oauth2.Config{
|
return &oauth2.Config{
|
||||||
ClientID: "test",
|
ClientID: "test",
|
||||||
|
@ -146,14 +153,9 @@ func FollowAuthentication(
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
|
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
|
||||||
db, context, cleanup := setup()
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
oauthServer := FakedOauthServer()
|
|
||||||
testServer := MockUserEndpointServer(context)
|
|
||||||
defer oauthServer.Close()
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
user, _ := database.GetUser(db, "test")
|
user, _ := database.GetUser(db, "test")
|
||||||
|
@ -174,14 +176,9 @@ func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
|
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
|
||||||
_, context, cleanup := setup()
|
_, context, oauthServer, testServer, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
oauthServer := FakedOauthServer()
|
|
||||||
testServer := MockUserEndpointServer(context)
|
|
||||||
defer oauthServer.Close()
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/protected-path", nil)
|
req := httptest.NewRequest("GET", "/protected-path", nil)
|
||||||
|
@ -201,14 +198,9 @@ func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOauthSetsUniqueSession(t *testing.T) {
|
func TestOauthSetsUniqueSession(t *testing.T) {
|
||||||
db, context, cleanup := setup()
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
oauthServer := FakedOauthServer()
|
|
||||||
testServer := MockUserEndpointServer(context)
|
|
||||||
defer oauthServer.Close()
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
@ -230,14 +222,9 @@ func TestOauthSetsUniqueSession(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogoutClearsSession(t *testing.T) {
|
func TestLogoutClearsSession(t *testing.T) {
|
||||||
db, context, cleanup := setup()
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
oauthServer := FakedOauthServer()
|
|
||||||
testServer := MockUserEndpointServer(context)
|
|
||||||
defer oauthServer.Close()
|
|
||||||
defer testServer.Close()
|
|
||||||
|
|
||||||
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
cookies := make(map[string]*http.Cookie)
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
@ -270,13 +257,52 @@ func TestLogoutClearsSession(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshUpdatesExpiration(t *testing.T) {
|
func TestRefreshUpdatesExpiration(t *testing.T) {
|
||||||
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
||||||
|
|
||||||
|
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/me", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
updatedSession, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
|
||||||
|
// if session expiration is greater than or equal to updated session expiration
|
||||||
|
if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
|
||||||
|
t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
|
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
|
||||||
|
db, context, oauthServer, testServer, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
|
||||||
|
|
||||||
|
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
session.ExpireAt = time.Now().Add(-time.Hour)
|
||||||
|
database.SaveSession(db, session)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/me", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
}
|
}
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
|
||||||
func TestAPITokensAreEquivalentToSessions(t *testing.T) {
|
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
||||||
|
t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,27 +62,26 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request,
|
||||||
|
|
||||||
func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
key := req.FormValue("key")
|
apiKey := req.FormValue("key")
|
||||||
|
|
||||||
typesKey, err := database.GetAPIKey(context.DBConn, key)
|
key, err := database.GetAPIKey(context.DBConn, apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
if (typesKey == nil) || (typesKey.UserID != context.User.ID) {
|
if (key == nil) || (key.UserID != context.User.ID) {
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = database.DeleteAPIKey(context.DBConn, key)
|
err = database.DeleteAPIKey(context.DBConn, apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Redirect(resp, req, "/keys", http.StatusFound)
|
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,7 +140,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
|
||||||
|
|
||||||
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestContext := makeRequestContext()
|
requestContext := makeRequestContext()
|
||||||
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) {
|
||||||
|
log.Println("saving session", session.ID)
|
||||||
|
|
||||||
|
_, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
|
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
|
||||||
newExpireAt := time.Now().Add(ExpiryDuration)
|
newExpireAt := time.Now().Add(ExpiryDuration)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue