package api_test import ( "database/sql" "net/http" "net/http/httptest" "os" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) func setup() (*sql.DB, *api.RequestContext, func()) { randomDb := utils.RandomId() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) context := &api.RequestContext{ DBConn: testDb, Args: &args.Arguments{}, TemplateData: &(map[string]interface{}{}), } return testDb, context, func() { testDb.Close() os.Remove(randomDb) } } func TestValidGuestbookPutsInDatabase(t *testing.T) { db, context, cleanup := setup() defer cleanup() entries, err := database.GetGuestbookEntries(db) if err != nil { t.Fatal(err) } if len(entries) > 0 { t.Errorf("expected no entries, got entries") } ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) })) defer ts.Close() req := httptest.NewRequest("POST", ts.URL, nil) req.Form = map[string][]string{ "name": {"test"}, "message": {"test"}, } w := httptest.NewRecorder() ts.Config.Handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("expected status code 200, got %d", w.Code) } entries, err = database.GetGuestbookEntries(db) if err != nil { t.Fatal(err) } if len(entries) != 1 { t.Errorf("expected 1 entry, got %d", len(entries)) } if entries[0].Name != req.FormValue("name") { t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name) } } func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) { db, context, cleanup := setup() defer cleanup() entries, err := database.GetGuestbookEntries(db) if err != nil { t.Fatal(err) } if len(entries) > 0 { t.Errorf("expected no entries, got entries") } testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) })) defer testServer.Close() reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n" invalidRequests := []struct { name string message string }{ {"", "test"}, {"test", ""}, {"", ""}, {"test", reallyLongStringThatWouldTakeTooMuchSpace}, } for _, form := range invalidRequests { req := httptest.NewRequest("POST", testServer.URL, nil) req.Form = map[string][]string{ "name": {form.name}, "message": {form.message}, } responseRecorder := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(responseRecorder, req) if responseRecorder.Code != http.StatusBadRequest { t.Errorf("expected status code 400, got %d", responseRecorder.Code) } } entries, err = database.GetGuestbookEntries(db) if err != nil { t.Fatal(err) } if len(entries) != 0 { t.Errorf("expected 0 entries, got %d", len(entries)) } }