testing | dont be recursive for external domains | finalize oauth #5
|
@ -74,7 +74,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req
|
||||||
reqContext := req.Context()
|
reqContext := req.Context()
|
||||||
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
return failure(context, req, resp)
|
return failure(context, req, resp)
|
||||||
}
|
}
|
||||||
|
@ -195,12 +194,13 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h
|
||||||
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Redirect(resp, req, "/", http.StatusFound)
|
|
||||||
http.SetCookie(resp, &http.Cookie{
|
http.SetCookie(resp, &http.Cookie{
|
||||||
Name: "session",
|
Name: "session",
|
||||||
MaxAge: 0,
|
MaxAge: 0,
|
||||||
Value: "",
|
Value: "",
|
||||||
})
|
})
|
||||||
|
http.Redirect(resp, req, "/", http.StatusFound)
|
||||||
|
|
||||||
return success(context, req, resp)
|
return success(context, req, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -225,10 +225,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
|
func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
|
||||||
user := &database.User{
|
user := &database.User{}
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := json.NewDecoder(response.Body).Decode(user)
|
err := json.NewDecoder(response.Body).Decode(user)
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,11 @@ package auth_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
|
||||||
|
@ -12,6 +14,7 @@ import (
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
|
||||||
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
@ -38,51 +41,232 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoginSendsYouToRedirect(t *testing.T) {
|
func FakedOauthServer() *httptest.Server {
|
||||||
|
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/auth" {
|
||||||
|
code := utils.RandomId()
|
||||||
|
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
redirectPath := r.URL.Query().Get("redirect_uri")
|
||||||
|
redirectPath += "?code=" + code + "&state=" + state
|
||||||
|
|
||||||
|
http.Redirect(w, r, redirectPath, http.StatusFound)
|
||||||
|
}
|
||||||
|
if r.URL.Path == "/token" {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
|
||||||
|
}
|
||||||
|
if r.URL.Path == "/user" {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
return oauthServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
|
||||||
|
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
|
||||||
|
resp.Write([]byte(context.User.Username))
|
||||||
|
return success(context, req, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/protected-path" {
|
||||||
|
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/login" {
|
||||||
|
log.Println("login")
|
||||||
|
auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/callback" {
|
||||||
|
log.Println("callback")
|
||||||
|
auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/me" {
|
||||||
|
auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/logout" {
|
||||||
|
auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
return testServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
|
||||||
|
return &oauth2.Config{
|
||||||
|
ClientID: "test",
|
||||||
|
ClientSecret: "test",
|
||||||
|
Scopes: []string{"test"},
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: oauthServerURL + "/auth",
|
||||||
|
TokenURL: oauthServerURL + "/token",
|
||||||
|
},
|
||||||
|
RedirectURL: testServerURL + "/callback",
|
||||||
|
}, oauthServerURL + "/user"
|
||||||
|
}
|
||||||
|
|
||||||
|
func FollowAuthentication(
|
||||||
|
oauthServer *httptest.Server,
|
||||||
|
testServer *httptest.Server,
|
||||||
|
cookies map[string]*http.Cookie,
|
||||||
|
location string,
|
||||||
|
) (map[string]*http.Cookie, string) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
resp.Code = 0
|
||||||
|
|
||||||
|
for resp.Code == 0 || resp.Code == http.StatusFound {
|
||||||
|
req := httptest.NewRequest("GET", location, nil)
|
||||||
|
resp = httptest.NewRecorder()
|
||||||
|
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(location, oauthServer.URL) {
|
||||||
|
oauthServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
} else {
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
}
|
||||||
|
for _, cookie := range resp.Result().Cookies() {
|
||||||
|
cookies[cookie.Name] = cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Code == http.StatusFound {
|
||||||
|
location = resp.Header().Get("Location")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cookies, location
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
|
||||||
db, context, cleanup := setup()
|
db, context, cleanup := setup()
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
user := &database.User{
|
oauthServer := FakedOauthServer()
|
||||||
ID: "test",
|
testServer := MockUserEndpointServer(context)
|
||||||
Username: "test",
|
defer oauthServer.Close()
|
||||||
}
|
|
||||||
database.FindOrSaveUser(db, user)
|
|
||||||
|
|
||||||
session, _ := database.MakeUserSessionFor(db, user)
|
|
||||||
|
|
||||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
|
|
||||||
}))
|
|
||||||
defer testServer.Close()
|
defer testServer.Close()
|
||||||
|
|
||||||
protectedPath := testServer.URL + "/protected-path"
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
req := httptest.NewRequest("GET", protectedPath, nil)
|
|
||||||
|
user, _ := database.GetUser(db, "test")
|
||||||
|
if user != nil {
|
||||||
|
t.Errorf("expected no user, got user")
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||||
|
|
||||||
|
user, _ = database.GetUser(db, "test")
|
||||||
|
if user == nil {
|
||||||
|
t.Errorf("expected a user to be created, could not find user")
|
||||||
|
}
|
||||||
|
if user.Username != "test" {
|
||||||
|
t.Errorf("expected username to be test, got %s", user.Username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
|
||||||
|
_, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
oauthServer := FakedOauthServer()
|
||||||
|
testServer := MockUserEndpointServer(context)
|
||||||
|
defer oauthServer.Close()
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/protected-path", nil)
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
|
||||||
location := resp.Header().Get("Location")
|
location := resp.Header().Get("Location")
|
||||||
if resp.Code != http.StatusFound && location != "/login" {
|
if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
|
||||||
t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location)
|
t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||||
}
|
}
|
||||||
|
|
||||||
req.AddCookie(&http.Cookie{
|
cookies := make(map[string]*http.Cookie)
|
||||||
Name: "session",
|
cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
|
||||||
Value: session.ID,
|
|
||||||
MaxAge: 60,
|
if !(strings.HasSuffix(location, "/protected-page")) {
|
||||||
})
|
t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthSetsUniqueSession(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
oauthServer := FakedOauthServer()
|
||||||
|
testServer := MockUserEndpointServer(context)
|
||||||
|
defer oauthServer.Close()
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||||
|
|
||||||
|
cookiesAgain := make(map[string]*http.Cookie)
|
||||||
|
cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
|
||||||
|
|
||||||
|
sessionOne := cookies["session"].Value
|
||||||
|
sessionTwo := cookiesAgain["session"].Value
|
||||||
|
if sessionOne == sessionTwo {
|
||||||
|
t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := database.GetSession(db, sessionOne)
|
||||||
|
if session.UserID != "test" {
|
||||||
|
t.Errorf("expected session to be associated with user test, got %s", session.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogoutClearsSession(t *testing.T) {
|
||||||
|
db, context, cleanup := setup()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
oauthServer := FakedOauthServer()
|
||||||
|
testServer := MockUserEndpointServer(context)
|
||||||
|
defer oauthServer.Close()
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
|
||||||
|
|
||||||
|
cookies := make(map[string]*http.Cookie)
|
||||||
|
cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/logout", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
|
for _, cookie := range resp.Result().Cookies() {
|
||||||
|
cookies[cookie.Name] = cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "/me", nil)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
resp = httptest.NewRecorder()
|
resp = httptest.NewRecorder()
|
||||||
testServer.Config.Handler.ServeHTTP(resp, req)
|
testServer.Config.Handler.ServeHTTP(resp, req)
|
||||||
if resp.Code != http.StatusOK {
|
if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
|
||||||
|
t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOauthFormatsUsername(t *testing.T) {
|
session, _ := database.GetSession(db, cookies["session"].Value)
|
||||||
|
if session != nil {
|
||||||
|
t.Errorf("expected session to be deleted, got session")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionIsUnique(t *testing.T) {}
|
|
||||||
|
|
||||||
func TestLogoutClearsCookie(t *testing.T) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshUpdatesExpiration(t *testing.T) {
|
func TestRefreshUpdatesExpiration(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue