package database import ( "database/sql" "log" "time" _ "github.com/mattn/go-sqlite3" ) const ( ExpiryDuration = time.Hour * 24 ) type UserApiKey struct { Key string `json:"key"` UserID string `json:"user_id"` CreatedAt time.Time `json:"created_at"` } type User struct { ID string `json:"sub"` Mail string `json:"email"` Username string `json:"preferred_username"` DisplayName string `json:"name"` CreatedAt time.Time `json:"created_at"` } type UserSession struct { ID string `json:"id"` UserID string `json:"user_id"` ExpireAt time.Time `json:"expire_at"` } func GetUser(dbConn *sql.DB, id string) (*User, error) { log.Println("getting user", id) row := dbConn.QueryRow(`SELECT id, mail, username, display_name, created_at FROM users WHERE id = ?;`, id) var user User err := row.Scan(&user.ID, &user.Mail, &user.Username, &user.DisplayName, &user.CreatedAt) if err != nil { log.Println(err) return nil, err } return &user, nil } func FindOrSaveUser(dbConn *sql.DB, user *User) (*User, error) { log.Println("finding or saving user", user.ID) _, err := dbConn.Exec(`INSERT INTO users (id, mail, username, display_name) VALUES (?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET mail = excluded.mail, username = excluded.username, display_name = excluded.display_name;`, user.ID, user.Mail, user.Username, user.DisplayName) if err != nil { return nil, err } return user, nil } func MakeUserSessionFor(dbConn *sql.DB, user *User) (*UserSession, error) { log.Println("making session for user", user.ID) expireAt := time.Now().Add(time.Hour * 12) _, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, user.ID, user.ID, time.Now().Add(ExpiryDuration)) if err != nil { log.Println(err) return nil, err } return &UserSession{ ID: user.ID, UserID: user.ID, ExpireAt: expireAt, }, nil } func GetSession(dbConn *sql.DB, sessionId string) (*UserSession, error) { log.Println("getting session", sessionId) row := dbConn.QueryRow(`SELECT id, user_id, expire_at FROM user_sessions WHERE id = ?;`, sessionId) var id, userId string var expireAt time.Time err := row.Scan(&id, &userId, &expireAt) if err != nil { log.Println(err) return nil, err } return &UserSession{ ID: id, UserID: userId, ExpireAt: expireAt, }, nil } func DeleteSession(dbConn *sql.DB, sessionId string) error { log.Println("deleting session", sessionId) _, err := dbConn.Exec(`DELETE FROM user_sessions WHERE id = ?;`, sessionId) if err != nil { log.Println(err) return err } return nil } func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) { newExpireAt := time.Now().Add(ExpiryDuration) _, err := dbConn.Exec(`UPDATE user_sessions SET expire_at = ? WHERE id = ?;`, newExpireAt, sessionId) if err != nil { log.Println(err) return nil, err } session, err := GetSession(dbConn, sessionId) if err != nil { log.Println(err) return nil, err } return session, nil } func DeleteExpiredSessions(dbConn *sql.DB) error { _, err := dbConn.Exec(`DELETE FROM user_sessions WHERE expire_at < ?;`, time.Now()) if err != nil { log.Println(err) return err } return nil } func CountUserAPIKeys(dbConn *sql.DB, userId string) (int, error) { log.Println("counting api keys for user", userId) row := dbConn.QueryRow(`SELECT COUNT(*) FROM api_keys WHERE user_id = ?;`, userId) var count int err := row.Scan(&count) if err != nil { log.Println(err) return 0, err } return count, nil } func ListUserAPIKeys(dbConn *sql.DB, userId string) ([]*UserApiKey, error) { log.Println("listing api keys for user", userId) rows, err := dbConn.Query(`SELECT key, user_id, created_at FROM api_keys WHERE user_id = ?;`, userId) if err != nil { log.Println(err) return nil, err } defer rows.Close() var apiKeys []*UserApiKey for rows.Next() { var apiKey UserApiKey err := rows.Scan(&apiKey.Key, &apiKey.UserID, &apiKey.CreatedAt) if err != nil { log.Println(err) return nil, err } apiKeys = append(apiKeys, &apiKey) } return apiKeys, nil } func SaveAPIKey(dbConn *sql.DB, apiKey *UserApiKey) (*UserApiKey, error) { log.Println("saving api key", apiKey.Key) _, err := dbConn.Exec(`INSERT OR REPLACE INTO api_keys (key, user_id) VALUES (?, ?);`, apiKey.Key, apiKey.UserID) if err != nil { log.Println(err) return nil, err } apiKey.CreatedAt = time.Now() return apiKey, nil } func GetAPIKey(dbConn *sql.DB, key string) (*UserApiKey, error) { log.Println("getting api key", key) row := dbConn.QueryRow(`SELECT key, user_id, created_at FROM api_keys WHERE key = ?;`, key) var apiKey UserApiKey err := row.Scan(&apiKey.Key, &apiKey.UserID, &apiKey.CreatedAt) if err != nil { log.Println(err) return nil, err } return &apiKey, nil } func DeleteAPIKey(dbConn *sql.DB, key string) error { log.Println("deleting api key", key) _, err := dbConn.Exec(`DELETE FROM api_keys WHERE key = ?;`, key) if err != nil { log.Println(err) return err } return nil }