feat: add sql-backed user store

This commit is contained in:
Ruidy 2025-09-22 17:55:57 +02:00
parent 29fb3054a5
commit 4ccdaa85b4
No known key found for this signature in database
GPG key ID: 705C24D202990805
13 changed files with 474 additions and 97 deletions

View file

@ -1,14 +1,17 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rjnemo/auth/internal/config" "github.com/rjnemo/auth/internal/config"
"github.com/rjnemo/auth/internal/driver/logging" "github.com/rjnemo/auth/internal/driver/logging"
"github.com/rjnemo/auth/internal/server" "github.com/rjnemo/auth/internal/server"
"github.com/rjnemo/auth/internal/service/auth"
) )
func main() { func main() {
@ -30,7 +33,17 @@ func main() {
} }
func run(cfg *config.Config, logger *slog.Logger) error { func run(cfg *config.Config, logger *slog.Logger) error {
srv, err := server.New(*cfg, logger) ctx := context.Background()
pool, err := pgxpool.New(ctx, cfg.DatabaseURL)
if err != nil {
return fmt.Errorf("connect database: %w", err)
}
defer pool.Close()
store := auth.NewSQLStore(pool)
service := auth.NewService(store)
srv, err := server.New(*cfg, service, logger)
if err != nil { if err != nil {
return fmt.Errorf("initialise server: %w", err) return fmt.Errorf("initialise server: %w", err)
} }

2
go.mod
View file

@ -13,7 +13,9 @@ require (
cloud.google.com/go/compute/metadata v0.8.4 // indirect cloud.google.com/go/compute/metadata v0.8.4 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/crypto v0.37.0 // indirect golang.org/x/crypto v0.37.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.36.0 // indirect golang.org/x/sys v0.36.0 // indirect
golang.org/x/text v0.24.0 // indirect golang.org/x/text v0.24.0 // indirect
) )

View file

@ -1,14 +1,14 @@
-- name: CreateUser :one -- name: CreateUser :one
INSERT INTO users (email, display_name) INSERT INTO users (id, email)
VALUES ($1, $2) VALUES ($1, $2)
RETURNING id, email, display_name, created_at; RETURNING id, email, created_at;
-- name: GetUserByID :one -- name: GetUserByID :one
SELECT id, email, display_name, created_at SELECT id, email, created_at
FROM users FROM users
WHERE id = $1; WHERE id = $1;
-- name: GetUserByEmail :one -- name: GetUserByEmail :one
SELECT id, email, display_name, created_at SELECT id, email, created_at
FROM users FROM users
WHERE email = $1; WHERE email = $1;

View file

@ -13,37 +13,31 @@ import (
) )
const createUser = `-- name: CreateUser :one const createUser = `-- name: CreateUser :one
INSERT INTO users (email, display_name) INSERT INTO users (id, email)
VALUES ($1, $2) VALUES ($1, $2)
RETURNING id, email, display_name, created_at RETURNING id, email, created_at
` `
type CreateUserParams struct { type CreateUserParams struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"` Email string `json:"email"`
DisplayName pgtype.Text `json:"display_name"`
} }
type CreateUserRow struct { type CreateUserRow struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Email string `json:"email"` Email string `json:"email"`
DisplayName pgtype.Text `json:"display_name"`
CreatedAt pgtype.Timestamptz `json:"created_at"` CreatedAt pgtype.Timestamptz `json:"created_at"`
} }
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (CreateUserRow, error) { func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (CreateUserRow, error) {
row := q.db.QueryRow(ctx, createUser, arg.Email, arg.DisplayName) row := q.db.QueryRow(ctx, createUser, arg.ID, arg.Email)
var i CreateUserRow var i CreateUserRow
err := row.Scan( err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
&i.ID,
&i.Email,
&i.DisplayName,
&i.CreatedAt,
)
return i, err return i, err
} }
const getUserByEmail = `-- name: GetUserByEmail :one const getUserByEmail = `-- name: GetUserByEmail :one
SELECT id, email, display_name, created_at SELECT id, email, created_at
FROM users FROM users
WHERE email = $1 WHERE email = $1
` `
@ -51,24 +45,18 @@ WHERE email = $1
type GetUserByEmailRow struct { type GetUserByEmailRow struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Email string `json:"email"` Email string `json:"email"`
DisplayName pgtype.Text `json:"display_name"`
CreatedAt pgtype.Timestamptz `json:"created_at"` CreatedAt pgtype.Timestamptz `json:"created_at"`
} }
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (GetUserByEmailRow, error) { func (q *Queries) GetUserByEmail(ctx context.Context, email string) (GetUserByEmailRow, error) {
row := q.db.QueryRow(ctx, getUserByEmail, email) row := q.db.QueryRow(ctx, getUserByEmail, email)
var i GetUserByEmailRow var i GetUserByEmailRow
err := row.Scan( err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
&i.ID,
&i.Email,
&i.DisplayName,
&i.CreatedAt,
)
return i, err return i, err
} }
const getUserByID = `-- name: GetUserByID :one const getUserByID = `-- name: GetUserByID :one
SELECT id, email, display_name, created_at SELECT id, email, created_at
FROM users FROM users
WHERE id = $1 WHERE id = $1
` `
@ -76,18 +64,12 @@ WHERE id = $1
type GetUserByIDRow struct { type GetUserByIDRow struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Email string `json:"email"` Email string `json:"email"`
DisplayName pgtype.Text `json:"display_name"`
CreatedAt pgtype.Timestamptz `json:"created_at"` CreatedAt pgtype.Timestamptz `json:"created_at"`
} }
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (GetUserByIDRow, error) { func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (GetUserByIDRow, error) {
row := q.db.QueryRow(ctx, getUserByID, id) row := q.db.QueryRow(ctx, getUserByID, id)
var i GetUserByIDRow var i GetUserByIDRow
err := row.Scan( err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
&i.ID,
&i.Email,
&i.DisplayName,
&i.CreatedAt,
)
return i, err return i, err
} }

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -152,7 +153,16 @@ func (s *Server) googleCallbackHandler() http.HandlerFunc {
return return
} }
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle) if strings.TrimSpace(info.ID) == "" {
logger.Warn("google returned empty subject")
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle, info.ID, info.VerifiedEmail)
if err != nil { if err != nil {
logger.Error("ensure external user failed", slog.Any("error", err)) logger.Error("ensure external user failed", slog.Any("error", err))
if !saveState() { if !saveState() {

View file

@ -2,11 +2,11 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"html/template" "html/template"
"io" "io"
"log/slog" "log/slog"
"time"
"github.com/rjnemo/auth/internal/config" "github.com/rjnemo/auth/internal/config"
"github.com/rjnemo/auth/internal/driver/logging" "github.com/rjnemo/auth/internal/driver/logging"
@ -31,8 +31,16 @@ type Server struct {
googleOAuth *oauth2.Config googleOAuth *oauth2.Config
} }
// New constructs a Server with parsed templates and default state. // New constructs a Server with parsed templates and default state using the provided service.
func New(cfg config.Config, logger *slog.Logger) (*Server, error) { func New(cfg config.Config, authService *auth.Service, logger *slog.Logger) (*Server, error) {
if authService == nil {
return nil, fmt.Errorf("auth service must be provided")
}
if err := seedUser(context.Background(), authService); err != nil {
return nil, fmt.Errorf("seed user: %w", err)
}
tmpl, err := template.ParseFS( tmpl, err := template.ParseFS(
web.Templates, web.Templates,
"templates/auth_base.html", "templates/auth_base.html",
@ -45,11 +53,6 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
return nil, fmt.Errorf("parse templates: %w", err) return nil, fmt.Errorf("parse templates: %w", err)
} }
store := auth.NewMemoryStore()
if err := seedUser(store); err != nil {
return nil, fmt.Errorf("seed user: %w", err)
}
sessionStore, err := NewSessionStore(cfg.SessionSecret) sessionStore, err := NewSessionStore(cfg.SessionSecret)
if err != nil { if err != nil {
return nil, fmt.Errorf("session store: %w", err) return nil, fmt.Errorf("session store: %w", err)
@ -76,7 +79,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
return &Server{ return &Server{
templates: tmpl, templates: tmpl,
authService: auth.NewService(store), authService: authService,
sessions: sessionStore, sessions: sessionStore,
logger: logger, logger: logger,
configuration: cfg, configuration: cfg,
@ -84,21 +87,13 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
}, nil }, nil
} }
func seedUser(store auth.UserStore) error { func seedUser(ctx context.Context, service *auth.Service) error {
salt, hash, err := auth.HashPassword(seedPassword) email := auth.MustUserEmail(seedEmail)
if err != nil { if _, err := service.Register(ctx, email, seedPassword); err != nil {
if errors.Is(err, auth.ErrEmailExists) {
return nil
}
return err return err
} }
return nil
email := auth.MustUserEmail(seedEmail)
ctx := context.Background()
return store.Create(ctx, auth.User{
ID: "seed-user",
Email: email,
PasswordSalt: salt,
PasswordHash: hash,
Provider: auth.ProviderPassword,
CreatedAt: time.Now().UTC(),
})
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/rjnemo/auth/internal/config" "github.com/rjnemo/auth/internal/config"
"github.com/rjnemo/auth/internal/driver/logging" "github.com/rjnemo/auth/internal/driver/logging"
"github.com/rjnemo/auth/internal/service/auth"
) )
func newTestServer(t *testing.T) *Server { func newTestServer(t *testing.T) *Server {
@ -26,7 +27,9 @@ func newTestServer(t *testing.T) *Server {
logger := logging.New(io.Discard, logging.ModeText, nil) logger := logging.New(io.Discard, logging.ModeText, nil)
srv, err := New(cfg, logger) store := auth.NewMemoryStore()
service := auth.NewService(store)
srv, err := New(cfg, service, logger)
if err != nil { if err != nil {
t.Fatalf("new server: %v", err) t.Fatalf("new server: %v", err)
} }
@ -51,7 +54,9 @@ func newGoogleTestServer(t *testing.T) *Server {
logger := logging.New(io.Discard, logging.ModeText, nil) logger := logging.New(io.Discard, logging.ModeText, nil)
srv, err := New(cfg, logger) store := auth.NewMemoryStore()
service := auth.NewService(store)
srv, err := New(cfg, service, logger)
if err != nil { if err != nil {
t.Fatalf("new google server: %v", err) t.Fatalf("new google server: %v", err)
} }

View file

@ -20,7 +20,6 @@ var (
) )
const ( const (
userIDByteLength = 16
// ProviderPassword identifies accounts managed via email/password. // ProviderPassword identifies accounts managed via email/password.
ProviderPassword = "password" ProviderPassword = "password"
// ProviderGoogle identifies accounts authenticated via Google OAuth2. // ProviderGoogle identifies accounts authenticated via Google OAuth2.
@ -112,13 +111,16 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
} }
// EnsureExternalUser retrieves or provisions an account authenticated by an external provider. // EnsureExternalUser retrieves or provisions an account authenticated by an external provider.
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) { func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider, subject string, verified bool) (*User, error) {
if email.IsZero() { if email.IsZero() {
return nil, ErrInvalidInput return nil, ErrInvalidInput
} }
if strings.TrimSpace(provider) == "" { if strings.TrimSpace(provider) == "" {
return nil, ErrProviderRequired return nil, ErrProviderRequired
} }
if strings.TrimSpace(subject) == "" {
return nil, ErrSubjectRequired
}
account, err := s.store.FindByEmail(ctx, email) account, err := s.store.FindByEmail(ctx, email)
switch { switch {
@ -137,6 +139,8 @@ func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provi
ID: id, ID: id,
Email: email, Email: email,
Provider: provider, Provider: provider,
OAuthSubject: subject,
OAuthEmailVerified: verified,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
} }

View file

@ -175,20 +175,23 @@ func TestServiceEnsureExternalUser(t *testing.T) {
service := NewService(store) service := NewService(store)
googleEmail := MustUserEmail("google@example.com") googleEmail := MustUserEmail("google@example.com")
if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil { if err := store.Create(ctx, User{ID: "existing-google", Email: googleEmail, Provider: ProviderGoogle, OAuthSubject: "existing-sub"}); err != nil {
t.Fatalf("seed external user: %v", err) t.Fatalf("seed external user: %v", err)
} }
tests := map[string]struct { tests := map[string]struct {
email UserEmail email UserEmail
provider string provider string
subject string
verified bool
wantErr error wantErr error
wantNew bool wantNew bool
}{ }{
"missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput}, "missing email": {email: UserEmail(""), provider: ProviderGoogle, subject: "sub", wantErr: ErrInvalidInput},
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired}, "missing provider": {email: MustUserEmail("new@example.com"), provider: "", subject: "sub", wantErr: ErrProviderRequired},
"existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false}, "missing subject": {email: MustUserEmail("new@example.com"), provider: ProviderGoogle, subject: "", wantErr: ErrSubjectRequired},
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true}, "existing": {email: googleEmail, provider: ProviderGoogle, subject: "existing-sub", wantErr: nil, wantNew: false},
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, subject: "new-sub", verified: true, wantErr: nil, wantNew: true},
} }
for name, tc := range tests { for name, tc := range tests {
@ -196,7 +199,7 @@ func TestServiceEnsureExternalUser(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider) user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider, tc.subject, tc.verified)
if tc.wantErr != nil { if tc.wantErr != nil {
if !errors.Is(err, tc.wantErr) { if !errors.Is(err, tc.wantErr) {
t.Fatalf("expected %v, got %v", tc.wantErr, err) t.Fatalf("expected %v, got %v", tc.wantErr, err)
@ -222,6 +225,9 @@ func TestServiceEnsureExternalUser(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("expected user persisted: %v", err) t.Fatalf("expected user persisted: %v", err)
} }
if persisted.OAuthSubject != "" && persisted.OAuthSubject != tc.subject {
t.Fatalf("expected oauth subject %q, got %q", tc.subject, persisted.OAuthSubject)
}
if tc.wantNew && persisted.CreatedAt.IsZero() { if tc.wantNew && persisted.CreatedAt.IsZero() {
t.Fatal("expected created at timestamp for new user") t.Fatal("expected created at timestamp for new user")
} }

View file

@ -9,6 +9,7 @@ import (
var ( var (
ErrUserNotFound = errors.New("auth: user not found") ErrUserNotFound = errors.New("auth: user not found")
ErrEmailRequired = errors.New("auth: email required") ErrEmailRequired = errors.New("auth: email required")
ErrSubjectRequired = errors.New("auth: oauth subject required")
) )
// UserStore defines persistence expectations for user lookups. // UserStore defines persistence expectations for user lookups.

View file

@ -0,0 +1,176 @@
package auth
import (
"context"
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rjnemo/auth/internal/driver/db"
)
const passwordAlgorithm = "sha256"
// SQLStore persists users in PostgreSQL via generated sqlc queries.
type SQLStore struct {
pool *pgxpool.Pool
queries *db.Queries
}
// NewSQLStore builds a SQL-backed user store.
func NewSQLStore(pool *pgxpool.Pool) *SQLStore {
return &SQLStore{
pool: pool,
queries: db.New(pool),
}
}
// FindByEmail returns the stored user aggregate by canonical email address.
func (s *SQLStore) FindByEmail(ctx context.Context, email UserEmail) (*User, error) {
row, err := s.queries.GetUserByEmail(ctx, email.String())
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("lookup user: %w", err)
}
normalizedEmail, err := NewUserEmail(row.Email)
if err != nil {
return nil, fmt.Errorf("normalize email: %w", err)
}
user := &User{
ID: row.ID.String(),
Email: normalizedEmail,
CreatedAt: timestamptzValue(row.CreatedAt),
}
if pw, err := s.queries.GetUserPassword(ctx, row.ID); err == nil {
user.PasswordSalt = base64.StdEncoding.EncodeToString(pw.PasswordSalt)
user.PasswordHash = base64.StdEncoding.EncodeToString(pw.PasswordHash)
user.Provider = ProviderPassword
} else if !errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("load password: %w", err)
}
oauthAccounts, err := s.queries.ListUserOAuthAccountsByUserID(ctx, row.ID)
if err != nil {
return nil, fmt.Errorf("load oauth accounts: %w", err)
}
if len(oauthAccounts) > 0 {
acct := oauthAccounts[0]
if user.Provider == "" {
user.Provider = acct.Provider
}
user.OAuthSubject = acct.Subject
user.OAuthEmailVerified = acct.EmailVerified
}
if user.Provider == "" {
user.Provider = ProviderPassword
}
return user, nil
}
// Create writes a new user aggregate to persistent storage.
func (s *SQLStore) Create(ctx context.Context, user User) error {
if user.Email.IsZero() {
return ErrEmailRequired
}
id, err := uuid.Parse(user.ID)
if err != nil {
return fmt.Errorf("parse user id: %w", err)
}
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer func() {
if err != nil {
_ = tx.Rollback(ctx)
}
}()
qtx := s.queries.WithTx(tx)
if _, err = qtx.CreateUser(ctx, db.CreateUserParams{ID: id, Email: user.Email.String()}); err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return ErrEmailExists
}
return fmt.Errorf("insert user: %w", err)
}
switch user.Provider {
case ProviderPassword:
if user.PasswordHash == "" || user.PasswordSalt == "" {
return fmt.Errorf("password credentials required")
}
hashBytes, err := base64.StdEncoding.DecodeString(user.PasswordHash)
if err != nil {
return fmt.Errorf("decode password hash: %w", err)
}
saltBytes, err := base64.StdEncoding.DecodeString(user.PasswordSalt)
if err != nil {
return fmt.Errorf("decode password salt: %w", err)
}
if err := qtx.CreateUserPassword(ctx, db.CreateUserPasswordParams{
UserID: id,
PasswordHash: hashBytes,
PasswordSalt: saltBytes,
Algorithm: passwordAlgorithm,
}); err != nil {
return fmt.Errorf("insert password: %w", err)
}
default:
if user.OAuthSubject == "" {
return ErrSubjectRequired
}
var emailValue pgtype.Text
if !user.Email.IsZero() {
emailValue = pgtype.Text{String: user.Email.String(), Valid: true}
}
if _, err := qtx.CreateUserOAuthAccount(ctx, db.CreateUserOAuthAccountParams{
UserID: id,
Provider: user.Provider,
Subject: user.OAuthSubject,
Email: emailValue,
EmailVerified: user.OAuthEmailVerified,
Profile: nil,
}); err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return ErrEmailExists
}
return fmt.Errorf("insert oauth account: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
func timestamptzValue(ts pgtype.Timestamptz) time.Time {
if !ts.Valid {
return time.Time{}
}
return ts.Time
}

View file

@ -0,0 +1,186 @@
package auth
import (
"context"
"os"
"strings"
"testing"
"github.com/jackc/pgx/v5/pgxpool"
)
const (
schemaUpSQL = `
CREATE EXTENSION IF NOT EXISTS pgcrypto;
CREATE EXTENSION IF NOT EXISTS citext;
CREATE TABLE users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
email CITEXT NOT NULL UNIQUE,
display_name TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE user_passwords (
user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
password_hash BYTEA NOT NULL,
password_salt BYTEA NOT NULL,
algorithm TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE user_oauth_accounts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider TEXT NOT NULL,
subject TEXT NOT NULL,
email TEXT,
email_verified BOOLEAN NOT NULL DEFAULT false,
profile JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE UNIQUE INDEX user_oauth_accounts_provider_subject_idx
ON user_oauth_accounts (provider, subject);
CREATE INDEX user_oauth_accounts_user_id_idx
ON user_oauth_accounts (user_id);
CREATE TABLE login_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID REFERENCES users(id),
provider TEXT,
success BOOLEAN NOT NULL,
ip INET,
user_agent TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX login_events_user_id_idx ON login_events (user_id);
CREATE INDEX login_events_created_at_idx ON login_events (created_at);
`
schemaDownSQL = `
DROP TABLE IF EXISTS login_events;
DROP TABLE IF EXISTS user_oauth_accounts;
DROP TABLE IF EXISTS user_passwords;
DROP TABLE IF EXISTS users;
DROP EXTENSION IF EXISTS citext;
DROP EXTENSION IF EXISTS pgcrypto;
`
)
func TestSQLStoreIntegration(t *testing.T) {
dsn := os.Getenv("AUTH_DATABASE_URL")
if strings.TrimSpace(dsn) == "" {
t.Skip("AUTH_DATABASE_URL not set")
}
ctx := context.Background()
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatalf("connect database: %v", err)
}
t.Cleanup(func() { pool.Close() })
resetDatabase(t, ctx, pool)
t.Run("register and authenticate", func(t *testing.T) {
resetDatabase(t, ctx, pool)
store := NewSQLStore(pool)
service := NewService(store)
email := MustUserEmail("sql-user@example.com")
user, err := service.Register(ctx, email, "Password123")
if err != nil {
t.Fatalf("register user: %v", err)
}
if user.ID == "" {
t.Fatal("expected user id")
}
if user.Provider != ProviderPassword {
t.Fatalf("expected provider %q, got %q", ProviderPassword, user.Provider)
}
authenticated, err := service.Authenticate(ctx, email, "Password123")
if err != nil {
t.Fatalf("authenticate user: %v", err)
}
if authenticated.ID != user.ID {
t.Fatalf("expected matching user id, got %q", authenticated.ID)
}
if authenticated.PasswordHash == "" || authenticated.PasswordSalt == "" {
t.Fatal("expected persisted password credentials")
}
})
t.Run("ensure external user", func(t *testing.T) {
resetDatabase(t, ctx, pool)
store := NewSQLStore(pool)
service := NewService(store)
email := MustUserEmail("sql-google@example.com")
subject := "google-subject-123"
account, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true)
if err != nil {
t.Fatalf("ensure external user: %v", err)
}
if account.Provider != ProviderGoogle {
t.Fatalf("expected provider %q, got %q", ProviderGoogle, account.Provider)
}
if account.OAuthSubject != subject {
t.Fatalf("expected oauth subject %q, got %q", subject, account.OAuthSubject)
}
again, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true)
if err != nil {
t.Fatalf("ensure existing external user: %v", err)
}
if again.ID != account.ID {
t.Fatalf("expected same user id, got %q vs %q", again.ID, account.ID)
}
})
}
func resetDatabase(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
t.Helper()
execStatements := func(stmts []string) {
for _, stmt := range stmts {
if strings.TrimSpace(stmt) == "" {
continue
}
if _, execErr := pool.Exec(ctx, stmt); execErr != nil {
t.Fatalf("exec statement %q: %v", stmt, execErr)
}
}
}
execStatements(splitSQLStatements(schemaDownSQL))
execStatements(splitSQLStatements(schemaUpSQL))
}
func splitSQLStatements(section string) []string {
section = strings.TrimSpace(section)
if section == "" {
return nil
}
parts := strings.Split(section, ";")
statements := make([]string, 0, len(parts))
for _, part := range parts {
stmt := strings.TrimSpace(part)
if stmt == "" {
continue
}
statements = append(statements, stmt+";")
}
return statements
}

View file

@ -1,11 +1,11 @@
package auth package auth
import ( import (
"crypto/rand"
"encoding/base64"
"errors" "errors"
"strings" "strings"
"time" "time"
"github.com/google/uuid"
) )
// User represents authenticated account details. // User represents authenticated account details.
@ -15,6 +15,8 @@ type User struct {
PasswordSalt string PasswordSalt string
PasswordHash string PasswordHash string
Provider string Provider string
OAuthSubject string
OAuthEmailVerified bool
CreatedAt time.Time CreatedAt time.Time
} }
@ -49,11 +51,6 @@ func (e UserEmail) IsZero() bool {
return e == "" return e == ""
} }
// TODO: could be UUID. return a dedicated type
func generateUserID() (string, error) { func generateUserID() (string, error) {
buf := make([]byte, userIDByteLength) return uuid.NewString(), nil
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
} }