From 4ccdaa85b491bb7bb8cfff2c7aced556a06e67ae Mon Sep 17 00:00:00 2001 From: Ruidy Date: Mon, 22 Sep 2025 17:55:57 +0200 Subject: [PATCH] feat: add sql-backed user store --- cmd/server/main.go | 15 +- go.mod | 2 + internal/driver/db/queries/users.sql | 8 +- internal/driver/db/users.sql.go | 56 +++---- internal/server/handler_login_google.go | 12 +- internal/server/server.go | 43 +++--- internal/server/server_test.go | 9 +- internal/service/auth/service.go | 16 +- internal/service/auth/service_test.go | 18 ++- internal/service/auth/store.go | 5 +- internal/service/auth/store_sql.go | 176 ++++++++++++++++++++++ internal/service/auth/store_sql_test.go | 186 ++++++++++++++++++++++++ internal/service/auth/user.go | 25 ++-- 13 files changed, 474 insertions(+), 97 deletions(-) create mode 100644 internal/service/auth/store_sql.go create mode 100644 internal/service/auth/store_sql_test.go diff --git a/cmd/server/main.go b/cmd/server/main.go index f912773..8332c1e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,14 +1,17 @@ package main import ( + "context" "fmt" "log/slog" "net/http" "os" + "github.com/jackc/pgx/v5/pgxpool" "github.com/rjnemo/auth/internal/config" "github.com/rjnemo/auth/internal/driver/logging" "github.com/rjnemo/auth/internal/server" + "github.com/rjnemo/auth/internal/service/auth" ) func main() { @@ -30,7 +33,17 @@ func main() { } func run(cfg *config.Config, logger *slog.Logger) error { - srv, err := server.New(*cfg, logger) + ctx := context.Background() + pool, err := pgxpool.New(ctx, cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("connect database: %w", err) + } + defer pool.Close() + + store := auth.NewSQLStore(pool) + service := auth.NewService(store) + + srv, err := server.New(*cfg, service, logger) if err != nil { return fmt.Errorf("initialise server: %w", err) } diff --git a/go.mod b/go.mod index 1cbbd39..fc37dbe 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,9 @@ require ( cloud.google.com/go/compute/metadata v0.8.4 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect golang.org/x/crypto v0.37.0 // indirect + golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.24.0 // indirect ) diff --git a/internal/driver/db/queries/users.sql b/internal/driver/db/queries/users.sql index 8214f5c..d411360 100644 --- a/internal/driver/db/queries/users.sql +++ b/internal/driver/db/queries/users.sql @@ -1,14 +1,14 @@ -- name: CreateUser :one -INSERT INTO users (email, display_name) +INSERT INTO users (id, email) VALUES ($1, $2) -RETURNING id, email, display_name, created_at; +RETURNING id, email, created_at; -- name: GetUserByID :one -SELECT id, email, display_name, created_at +SELECT id, email, created_at FROM users WHERE id = $1; -- name: GetUserByEmail :one -SELECT id, email, display_name, created_at +SELECT id, email, created_at FROM users WHERE email = $1; diff --git a/internal/driver/db/users.sql.go b/internal/driver/db/users.sql.go index e4804e3..c7e6d7b 100644 --- a/internal/driver/db/users.sql.go +++ b/internal/driver/db/users.sql.go @@ -13,81 +13,63 @@ import ( ) const createUser = `-- name: CreateUser :one -INSERT INTO users (email, display_name) +INSERT INTO users (id, email) VALUES ($1, $2) -RETURNING id, email, display_name, created_at +RETURNING id, email, created_at ` type CreateUserParams struct { - Email string `json:"email"` - DisplayName pgtype.Text `json:"display_name"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` } type CreateUserRow struct { - ID uuid.UUID `json:"id"` - Email string `json:"email"` - DisplayName pgtype.Text `json:"display_name"` - CreatedAt pgtype.Timestamptz `json:"created_at"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` + CreatedAt pgtype.Timestamptz `json:"created_at"` } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (CreateUserRow, error) { - row := q.db.QueryRow(ctx, createUser, arg.Email, arg.DisplayName) + row := q.db.QueryRow(ctx, createUser, arg.ID, arg.Email) var i CreateUserRow - err := row.Scan( - &i.ID, - &i.Email, - &i.DisplayName, - &i.CreatedAt, - ) + err := row.Scan(&i.ID, &i.Email, &i.CreatedAt) return i, err } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT id, email, display_name, created_at +SELECT id, email, created_at FROM users WHERE email = $1 ` type GetUserByEmailRow struct { - ID uuid.UUID `json:"id"` - Email string `json:"email"` - DisplayName pgtype.Text `json:"display_name"` - CreatedAt pgtype.Timestamptz `json:"created_at"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` + CreatedAt pgtype.Timestamptz `json:"created_at"` } func (q *Queries) GetUserByEmail(ctx context.Context, email string) (GetUserByEmailRow, error) { row := q.db.QueryRow(ctx, getUserByEmail, email) var i GetUserByEmailRow - err := row.Scan( - &i.ID, - &i.Email, - &i.DisplayName, - &i.CreatedAt, - ) + err := row.Scan(&i.ID, &i.Email, &i.CreatedAt) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, email, display_name, created_at +SELECT id, email, created_at FROM users WHERE id = $1 ` type GetUserByIDRow struct { - ID uuid.UUID `json:"id"` - Email string `json:"email"` - DisplayName pgtype.Text `json:"display_name"` - CreatedAt pgtype.Timestamptz `json:"created_at"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` + CreatedAt pgtype.Timestamptz `json:"created_at"` } func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (GetUserByIDRow, error) { row := q.db.QueryRow(ctx, getUserByID, id) var i GetUserByIDRow - err := row.Scan( - &i.ID, - &i.Email, - &i.DisplayName, - &i.CreatedAt, - ) + err := row.Scan(&i.ID, &i.Email, &i.CreatedAt) return i, err } diff --git a/internal/server/handler_login_google.go b/internal/server/handler_login_google.go index 721c6ac..c21e4e2 100644 --- a/internal/server/handler_login_google.go +++ b/internal/server/handler_login_google.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/http" + "strings" "golang.org/x/oauth2" @@ -152,7 +153,16 @@ func (s *Server) googleCallbackHandler() http.HandlerFunc { return } - account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle) + if strings.TrimSpace(info.ID) == "" { + logger.Warn("google returned empty subject") + if !saveState() { + return + } + respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg) + return + } + + account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle, info.ID, info.VerifiedEmail) if err != nil { logger.Error("ensure external user failed", slog.Any("error", err)) if !saveState() { diff --git a/internal/server/server.go b/internal/server/server.go index 20aab2e..3756bd4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,11 +2,11 @@ package server import ( "context" + "errors" "fmt" "html/template" "io" "log/slog" - "time" "github.com/rjnemo/auth/internal/config" "github.com/rjnemo/auth/internal/driver/logging" @@ -31,8 +31,16 @@ type Server struct { googleOAuth *oauth2.Config } -// New constructs a Server with parsed templates and default state. -func New(cfg config.Config, logger *slog.Logger) (*Server, error) { +// New constructs a Server with parsed templates and default state using the provided service. +func New(cfg config.Config, authService *auth.Service, logger *slog.Logger) (*Server, error) { + if authService == nil { + return nil, fmt.Errorf("auth service must be provided") + } + + if err := seedUser(context.Background(), authService); err != nil { + return nil, fmt.Errorf("seed user: %w", err) + } + tmpl, err := template.ParseFS( web.Templates, "templates/auth_base.html", @@ -45,11 +53,6 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) { return nil, fmt.Errorf("parse templates: %w", err) } - store := auth.NewMemoryStore() - if err := seedUser(store); err != nil { - return nil, fmt.Errorf("seed user: %w", err) - } - sessionStore, err := NewSessionStore(cfg.SessionSecret) if err != nil { return nil, fmt.Errorf("session store: %w", err) @@ -76,7 +79,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) { return &Server{ templates: tmpl, - authService: auth.NewService(store), + authService: authService, sessions: sessionStore, logger: logger, configuration: cfg, @@ -84,21 +87,13 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) { }, nil } -func seedUser(store auth.UserStore) error { - salt, hash, err := auth.HashPassword(seedPassword) - if err != nil { +func seedUser(ctx context.Context, service *auth.Service) error { + email := auth.MustUserEmail(seedEmail) + if _, err := service.Register(ctx, email, seedPassword); err != nil { + if errors.Is(err, auth.ErrEmailExists) { + return nil + } return err } - - email := auth.MustUserEmail(seedEmail) - - ctx := context.Background() - return store.Create(ctx, auth.User{ - ID: "seed-user", - Email: email, - PasswordSalt: salt, - PasswordHash: hash, - Provider: auth.ProviderPassword, - CreatedAt: time.Now().UTC(), - }) + return nil } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ecd9d84..f1f8fdf 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -11,6 +11,7 @@ import ( "github.com/rjnemo/auth/internal/config" "github.com/rjnemo/auth/internal/driver/logging" + "github.com/rjnemo/auth/internal/service/auth" ) func newTestServer(t *testing.T) *Server { @@ -26,7 +27,9 @@ func newTestServer(t *testing.T) *Server { logger := logging.New(io.Discard, logging.ModeText, nil) - srv, err := New(cfg, logger) + store := auth.NewMemoryStore() + service := auth.NewService(store) + srv, err := New(cfg, service, logger) if err != nil { t.Fatalf("new server: %v", err) } @@ -51,7 +54,9 @@ func newGoogleTestServer(t *testing.T) *Server { logger := logging.New(io.Discard, logging.ModeText, nil) - srv, err := New(cfg, logger) + store := auth.NewMemoryStore() + service := auth.NewService(store) + srv, err := New(cfg, service, logger) if err != nil { t.Fatalf("new google server: %v", err) } diff --git a/internal/service/auth/service.go b/internal/service/auth/service.go index 5832909..592de60 100644 --- a/internal/service/auth/service.go +++ b/internal/service/auth/service.go @@ -20,7 +20,6 @@ var ( ) const ( - userIDByteLength = 16 // ProviderPassword identifies accounts managed via email/password. ProviderPassword = "password" // ProviderGoogle identifies accounts authenticated via Google OAuth2. @@ -112,13 +111,16 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string } // EnsureExternalUser retrieves or provisions an account authenticated by an external provider. -func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) { +func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider, subject string, verified bool) (*User, error) { if email.IsZero() { return nil, ErrInvalidInput } if strings.TrimSpace(provider) == "" { return nil, ErrProviderRequired } + if strings.TrimSpace(subject) == "" { + return nil, ErrSubjectRequired + } account, err := s.store.FindByEmail(ctx, email) switch { @@ -134,10 +136,12 @@ func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provi } user := User{ - ID: id, - Email: email, - Provider: provider, - CreatedAt: time.Now().UTC(), + ID: id, + Email: email, + Provider: provider, + OAuthSubject: subject, + OAuthEmailVerified: verified, + CreatedAt: time.Now().UTC(), } if err := s.store.Create(ctx, user); err != nil { diff --git a/internal/service/auth/service_test.go b/internal/service/auth/service_test.go index b873b98..8e42d6d 100644 --- a/internal/service/auth/service_test.go +++ b/internal/service/auth/service_test.go @@ -175,20 +175,23 @@ func TestServiceEnsureExternalUser(t *testing.T) { service := NewService(store) googleEmail := MustUserEmail("google@example.com") - if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil { + if err := store.Create(ctx, User{ID: "existing-google", Email: googleEmail, Provider: ProviderGoogle, OAuthSubject: "existing-sub"}); err != nil { t.Fatalf("seed external user: %v", err) } tests := map[string]struct { email UserEmail provider string + subject string + verified bool wantErr error wantNew bool }{ - "missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput}, - "missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired}, - "existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false}, - "provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true}, + "missing email": {email: UserEmail(""), provider: ProviderGoogle, subject: "sub", wantErr: ErrInvalidInput}, + "missing provider": {email: MustUserEmail("new@example.com"), provider: "", subject: "sub", wantErr: ErrProviderRequired}, + "missing subject": {email: MustUserEmail("new@example.com"), provider: ProviderGoogle, subject: "", wantErr: ErrSubjectRequired}, + "existing": {email: googleEmail, provider: ProviderGoogle, subject: "existing-sub", wantErr: nil, wantNew: false}, + "provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, subject: "new-sub", verified: true, wantErr: nil, wantNew: true}, } for name, tc := range tests { @@ -196,7 +199,7 @@ func TestServiceEnsureExternalUser(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider) + user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider, tc.subject, tc.verified) if tc.wantErr != nil { if !errors.Is(err, tc.wantErr) { t.Fatalf("expected %v, got %v", tc.wantErr, err) @@ -222,6 +225,9 @@ func TestServiceEnsureExternalUser(t *testing.T) { if err != nil { t.Fatalf("expected user persisted: %v", err) } + if persisted.OAuthSubject != "" && persisted.OAuthSubject != tc.subject { + t.Fatalf("expected oauth subject %q, got %q", tc.subject, persisted.OAuthSubject) + } if tc.wantNew && persisted.CreatedAt.IsZero() { t.Fatal("expected created at timestamp for new user") } diff --git a/internal/service/auth/store.go b/internal/service/auth/store.go index 3f26d4e..81772a7 100644 --- a/internal/service/auth/store.go +++ b/internal/service/auth/store.go @@ -7,8 +7,9 @@ import ( // ErrUserNotFound signals no user exists for the provided lookup criteria. var ( - ErrUserNotFound = errors.New("auth: user not found") - ErrEmailRequired = errors.New("auth: email required") + ErrUserNotFound = errors.New("auth: user not found") + ErrEmailRequired = errors.New("auth: email required") + ErrSubjectRequired = errors.New("auth: oauth subject required") ) // UserStore defines persistence expectations for user lookups. diff --git a/internal/service/auth/store_sql.go b/internal/service/auth/store_sql.go new file mode 100644 index 0000000..da4c8b8 --- /dev/null +++ b/internal/service/auth/store_sql.go @@ -0,0 +1,176 @@ +package auth + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/rjnemo/auth/internal/driver/db" +) + +const passwordAlgorithm = "sha256" + +// SQLStore persists users in PostgreSQL via generated sqlc queries. +type SQLStore struct { + pool *pgxpool.Pool + queries *db.Queries +} + +// NewSQLStore builds a SQL-backed user store. +func NewSQLStore(pool *pgxpool.Pool) *SQLStore { + return &SQLStore{ + pool: pool, + queries: db.New(pool), + } +} + +// FindByEmail returns the stored user aggregate by canonical email address. +func (s *SQLStore) FindByEmail(ctx context.Context, email UserEmail) (*User, error) { + row, err := s.queries.GetUserByEmail(ctx, email.String()) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("lookup user: %w", err) + } + + normalizedEmail, err := NewUserEmail(row.Email) + if err != nil { + return nil, fmt.Errorf("normalize email: %w", err) + } + + user := &User{ + ID: row.ID.String(), + Email: normalizedEmail, + CreatedAt: timestamptzValue(row.CreatedAt), + } + + if pw, err := s.queries.GetUserPassword(ctx, row.ID); err == nil { + user.PasswordSalt = base64.StdEncoding.EncodeToString(pw.PasswordSalt) + user.PasswordHash = base64.StdEncoding.EncodeToString(pw.PasswordHash) + user.Provider = ProviderPassword + } else if !errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("load password: %w", err) + } + + oauthAccounts, err := s.queries.ListUserOAuthAccountsByUserID(ctx, row.ID) + if err != nil { + return nil, fmt.Errorf("load oauth accounts: %w", err) + } + + if len(oauthAccounts) > 0 { + acct := oauthAccounts[0] + if user.Provider == "" { + user.Provider = acct.Provider + } + user.OAuthSubject = acct.Subject + user.OAuthEmailVerified = acct.EmailVerified + } + + if user.Provider == "" { + user.Provider = ProviderPassword + } + + return user, nil +} + +// Create writes a new user aggregate to persistent storage. +func (s *SQLStore) Create(ctx context.Context, user User) error { + if user.Email.IsZero() { + return ErrEmailRequired + } + + id, err := uuid.Parse(user.ID) + if err != nil { + return fmt.Errorf("parse user id: %w", err) + } + + tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer func() { + if err != nil { + _ = tx.Rollback(ctx) + } + }() + + qtx := s.queries.WithTx(tx) + + if _, err = qtx.CreateUser(ctx, db.CreateUserParams{ID: id, Email: user.Email.String()}); err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + return ErrEmailExists + } + return fmt.Errorf("insert user: %w", err) + } + + switch user.Provider { + case ProviderPassword: + if user.PasswordHash == "" || user.PasswordSalt == "" { + return fmt.Errorf("password credentials required") + } + hashBytes, err := base64.StdEncoding.DecodeString(user.PasswordHash) + if err != nil { + return fmt.Errorf("decode password hash: %w", err) + } + saltBytes, err := base64.StdEncoding.DecodeString(user.PasswordSalt) + if err != nil { + return fmt.Errorf("decode password salt: %w", err) + } + + if err := qtx.CreateUserPassword(ctx, db.CreateUserPasswordParams{ + UserID: id, + PasswordHash: hashBytes, + PasswordSalt: saltBytes, + Algorithm: passwordAlgorithm, + }); err != nil { + return fmt.Errorf("insert password: %w", err) + } + default: + if user.OAuthSubject == "" { + return ErrSubjectRequired + } + + var emailValue pgtype.Text + if !user.Email.IsZero() { + emailValue = pgtype.Text{String: user.Email.String(), Valid: true} + } + + if _, err := qtx.CreateUserOAuthAccount(ctx, db.CreateUserOAuthAccountParams{ + UserID: id, + Provider: user.Provider, + Subject: user.OAuthSubject, + Email: emailValue, + EmailVerified: user.OAuthEmailVerified, + Profile: nil, + }); err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + return ErrEmailExists + } + return fmt.Errorf("insert oauth account: %w", err) + } + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + + return nil +} + +func timestamptzValue(ts pgtype.Timestamptz) time.Time { + if !ts.Valid { + return time.Time{} + } + return ts.Time +} diff --git a/internal/service/auth/store_sql_test.go b/internal/service/auth/store_sql_test.go new file mode 100644 index 0000000..445c2d9 --- /dev/null +++ b/internal/service/auth/store_sql_test.go @@ -0,0 +1,186 @@ +package auth + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" +) + +const ( + schemaUpSQL = ` +CREATE EXTENSION IF NOT EXISTS pgcrypto; +CREATE EXTENSION IF NOT EXISTS citext; + +CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + email CITEXT NOT NULL UNIQUE, + display_name TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE user_passwords ( + user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + password_hash BYTEA NOT NULL, + password_salt BYTEA NOT NULL, + algorithm TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE user_oauth_accounts ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider TEXT NOT NULL, + subject TEXT NOT NULL, + email TEXT, + email_verified BOOLEAN NOT NULL DEFAULT false, + profile JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE UNIQUE INDEX user_oauth_accounts_provider_subject_idx + ON user_oauth_accounts (provider, subject); + +CREATE INDEX user_oauth_accounts_user_id_idx + ON user_oauth_accounts (user_id); + +CREATE TABLE login_events ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID REFERENCES users(id), + provider TEXT, + success BOOLEAN NOT NULL, + ip INET, + user_agent TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX login_events_user_id_idx ON login_events (user_id); +CREATE INDEX login_events_created_at_idx ON login_events (created_at); +` + + schemaDownSQL = ` +DROP TABLE IF EXISTS login_events; +DROP TABLE IF EXISTS user_oauth_accounts; +DROP TABLE IF EXISTS user_passwords; +DROP TABLE IF EXISTS users; +DROP EXTENSION IF EXISTS citext; +DROP EXTENSION IF EXISTS pgcrypto; +` +) + +func TestSQLStoreIntegration(t *testing.T) { + dsn := os.Getenv("AUTH_DATABASE_URL") + if strings.TrimSpace(dsn) == "" { + t.Skip("AUTH_DATABASE_URL not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Fatalf("connect database: %v", err) + } + t.Cleanup(func() { pool.Close() }) + + resetDatabase(t, ctx, pool) + + t.Run("register and authenticate", func(t *testing.T) { + resetDatabase(t, ctx, pool) + + store := NewSQLStore(pool) + service := NewService(store) + + email := MustUserEmail("sql-user@example.com") + + user, err := service.Register(ctx, email, "Password123") + if err != nil { + t.Fatalf("register user: %v", err) + } + if user.ID == "" { + t.Fatal("expected user id") + } + if user.Provider != ProviderPassword { + t.Fatalf("expected provider %q, got %q", ProviderPassword, user.Provider) + } + + authenticated, err := service.Authenticate(ctx, email, "Password123") + if err != nil { + t.Fatalf("authenticate user: %v", err) + } + if authenticated.ID != user.ID { + t.Fatalf("expected matching user id, got %q", authenticated.ID) + } + if authenticated.PasswordHash == "" || authenticated.PasswordSalt == "" { + t.Fatal("expected persisted password credentials") + } + }) + + t.Run("ensure external user", func(t *testing.T) { + resetDatabase(t, ctx, pool) + + store := NewSQLStore(pool) + service := NewService(store) + + email := MustUserEmail("sql-google@example.com") + subject := "google-subject-123" + + account, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true) + if err != nil { + t.Fatalf("ensure external user: %v", err) + } + if account.Provider != ProviderGoogle { + t.Fatalf("expected provider %q, got %q", ProviderGoogle, account.Provider) + } + if account.OAuthSubject != subject { + t.Fatalf("expected oauth subject %q, got %q", subject, account.OAuthSubject) + } + + again, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true) + if err != nil { + t.Fatalf("ensure existing external user: %v", err) + } + if again.ID != account.ID { + t.Fatalf("expected same user id, got %q vs %q", again.ID, account.ID) + } + }) +} + +func resetDatabase(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + t.Helper() + + execStatements := func(stmts []string) { + for _, stmt := range stmts { + if strings.TrimSpace(stmt) == "" { + continue + } + if _, execErr := pool.Exec(ctx, stmt); execErr != nil { + t.Fatalf("exec statement %q: %v", stmt, execErr) + } + } + } + + execStatements(splitSQLStatements(schemaDownSQL)) + execStatements(splitSQLStatements(schemaUpSQL)) +} + +func splitSQLStatements(section string) []string { + section = strings.TrimSpace(section) + if section == "" { + return nil + } + + parts := strings.Split(section, ";") + statements := make([]string, 0, len(parts)) + for _, part := range parts { + stmt := strings.TrimSpace(part) + if stmt == "" { + continue + } + statements = append(statements, stmt+";") + } + + return statements +} diff --git a/internal/service/auth/user.go b/internal/service/auth/user.go index f6a6e7b..222b48e 100644 --- a/internal/service/auth/user.go +++ b/internal/service/auth/user.go @@ -1,21 +1,23 @@ package auth import ( - "crypto/rand" - "encoding/base64" "errors" "strings" "time" + + "github.com/google/uuid" ) // User represents authenticated account details. type User struct { - ID string - Email UserEmail - PasswordSalt string - PasswordHash string - Provider string - CreatedAt time.Time + ID string + Email UserEmail + PasswordSalt string + PasswordHash string + Provider string + OAuthSubject string + OAuthEmailVerified bool + CreatedAt time.Time } // UserEmail represents a canonical email string. @@ -49,11 +51,6 @@ func (e UserEmail) IsZero() bool { return e == "" } -// TODO: could be UUID. return a dedicated type func generateUserID() (string, error) { - buf := make([]byte, userIDByteLength) - if _, err := rand.Read(buf); err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(buf), nil + return uuid.NewString(), nil }