mirror of
https://github.com/rjNemo/auth
synced 2026-06-06 00:16:40 +00:00
feat: add sql-backed user store
This commit is contained in:
parent
29fb3054a5
commit
4ccdaa85b4
13 changed files with 474 additions and 97 deletions
|
|
@ -1,14 +1,17 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rjnemo/auth/internal/config"
|
||||
"github.com/rjnemo/auth/internal/driver/logging"
|
||||
"github.com/rjnemo/auth/internal/server"
|
||||
"github.com/rjnemo/auth/internal/service/auth"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
@ -30,7 +33,17 @@ func main() {
|
|||
}
|
||||
|
||||
func run(cfg *config.Config, logger *slog.Logger) error {
|
||||
srv, err := server.New(*cfg, logger)
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect database: %w", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
store := auth.NewSQLStore(pool)
|
||||
service := auth.NewService(store)
|
||||
|
||||
srv, err := server.New(*cfg, service, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialise server: %w", err)
|
||||
}
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -13,7 +13,9 @@ require (
|
|||
cloud.google.com/go/compute/metadata v0.8.4 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/sync v0.13.0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
-- name: CreateUser :one
|
||||
INSERT INTO users (email, display_name)
|
||||
INSERT INTO users (id, email)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, email, display_name, created_at;
|
||||
RETURNING id, email, created_at;
|
||||
|
||||
-- name: GetUserByID :one
|
||||
SELECT id, email, display_name, created_at
|
||||
SELECT id, email, created_at
|
||||
FROM users
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: GetUserByEmail :one
|
||||
SELECT id, email, display_name, created_at
|
||||
SELECT id, email, created_at
|
||||
FROM users
|
||||
WHERE email = $1;
|
||||
|
|
|
|||
|
|
@ -13,81 +13,63 @@ import (
|
|||
)
|
||||
|
||||
const createUser = `-- name: CreateUser :one
|
||||
INSERT INTO users (email, display_name)
|
||||
INSERT INTO users (id, email)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, email, display_name, created_at
|
||||
RETURNING id, email, created_at
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
Email string `json:"email"`
|
||||
DisplayName pgtype.Text `json:"display_name"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type CreateUserRow struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName pgtype.Text `json:"display_name"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (CreateUserRow, error) {
|
||||
row := q.db.QueryRow(ctx, createUser, arg.Email, arg.DisplayName)
|
||||
row := q.db.QueryRow(ctx, createUser, arg.ID, arg.Email)
|
||||
var i CreateUserRow
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.DisplayName,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||
SELECT id, email, display_name, created_at
|
||||
SELECT id, email, created_at
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
`
|
||||
|
||||
type GetUserByEmailRow struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName pgtype.Text `json:"display_name"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (GetUserByEmailRow, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByEmail, email)
|
||||
var i GetUserByEmailRow
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.DisplayName,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, email, display_name, created_at
|
||||
SELECT id, email, created_at
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type GetUserByIDRow struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName pgtype.Text `json:"display_name"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (GetUserByIDRow, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||
var i GetUserByIDRow
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.DisplayName,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
|
||||
return i, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
|
|
@ -152,7 +153,16 @@ func (s *Server) googleCallbackHandler() http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle)
|
||||
if strings.TrimSpace(info.ID) == "" {
|
||||
logger.Warn("google returned empty subject")
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle, info.ID, info.VerifiedEmail)
|
||||
if err != nil {
|
||||
logger.Error("ensure external user failed", slog.Any("error", err))
|
||||
if !saveState() {
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/rjnemo/auth/internal/config"
|
||||
"github.com/rjnemo/auth/internal/driver/logging"
|
||||
|
|
@ -31,8 +31,16 @@ type Server struct {
|
|||
googleOAuth *oauth2.Config
|
||||
}
|
||||
|
||||
// New constructs a Server with parsed templates and default state.
|
||||
func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
||||
// New constructs a Server with parsed templates and default state using the provided service.
|
||||
func New(cfg config.Config, authService *auth.Service, logger *slog.Logger) (*Server, error) {
|
||||
if authService == nil {
|
||||
return nil, fmt.Errorf("auth service must be provided")
|
||||
}
|
||||
|
||||
if err := seedUser(context.Background(), authService); err != nil {
|
||||
return nil, fmt.Errorf("seed user: %w", err)
|
||||
}
|
||||
|
||||
tmpl, err := template.ParseFS(
|
||||
web.Templates,
|
||||
"templates/auth_base.html",
|
||||
|
|
@ -45,11 +53,6 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
|||
return nil, fmt.Errorf("parse templates: %w", err)
|
||||
}
|
||||
|
||||
store := auth.NewMemoryStore()
|
||||
if err := seedUser(store); err != nil {
|
||||
return nil, fmt.Errorf("seed user: %w", err)
|
||||
}
|
||||
|
||||
sessionStore, err := NewSessionStore(cfg.SessionSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session store: %w", err)
|
||||
|
|
@ -76,7 +79,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
|||
|
||||
return &Server{
|
||||
templates: tmpl,
|
||||
authService: auth.NewService(store),
|
||||
authService: authService,
|
||||
sessions: sessionStore,
|
||||
logger: logger,
|
||||
configuration: cfg,
|
||||
|
|
@ -84,21 +87,13 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func seedUser(store auth.UserStore) error {
|
||||
salt, hash, err := auth.HashPassword(seedPassword)
|
||||
if err != nil {
|
||||
func seedUser(ctx context.Context, service *auth.Service) error {
|
||||
email := auth.MustUserEmail(seedEmail)
|
||||
if _, err := service.Register(ctx, email, seedPassword); err != nil {
|
||||
if errors.Is(err, auth.ErrEmailExists) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
email := auth.MustUserEmail(seedEmail)
|
||||
|
||||
ctx := context.Background()
|
||||
return store.Create(ctx, auth.User{
|
||||
ID: "seed-user",
|
||||
Email: email,
|
||||
PasswordSalt: salt,
|
||||
PasswordHash: hash,
|
||||
Provider: auth.ProviderPassword,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/rjnemo/auth/internal/config"
|
||||
"github.com/rjnemo/auth/internal/driver/logging"
|
||||
"github.com/rjnemo/auth/internal/service/auth"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
|
|
@ -26,7 +27,9 @@ func newTestServer(t *testing.T) *Server {
|
|||
|
||||
logger := logging.New(io.Discard, logging.ModeText, nil)
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
store := auth.NewMemoryStore()
|
||||
service := auth.NewService(store)
|
||||
srv, err := New(cfg, service, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("new server: %v", err)
|
||||
}
|
||||
|
|
@ -51,7 +54,9 @@ func newGoogleTestServer(t *testing.T) *Server {
|
|||
|
||||
logger := logging.New(io.Discard, logging.ModeText, nil)
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
store := auth.NewMemoryStore()
|
||||
service := auth.NewService(store)
|
||||
srv, err := New(cfg, service, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("new google server: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ var (
|
|||
)
|
||||
|
||||
const (
|
||||
userIDByteLength = 16
|
||||
// ProviderPassword identifies accounts managed via email/password.
|
||||
ProviderPassword = "password"
|
||||
// ProviderGoogle identifies accounts authenticated via Google OAuth2.
|
||||
|
|
@ -112,13 +111,16 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
|
|||
}
|
||||
|
||||
// EnsureExternalUser retrieves or provisions an account authenticated by an external provider.
|
||||
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) {
|
||||
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider, subject string, verified bool) (*User, error) {
|
||||
if email.IsZero() {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
if strings.TrimSpace(provider) == "" {
|
||||
return nil, ErrProviderRequired
|
||||
}
|
||||
if strings.TrimSpace(subject) == "" {
|
||||
return nil, ErrSubjectRequired
|
||||
}
|
||||
|
||||
account, err := s.store.FindByEmail(ctx, email)
|
||||
switch {
|
||||
|
|
@ -134,10 +136,12 @@ func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provi
|
|||
}
|
||||
|
||||
user := User{
|
||||
ID: id,
|
||||
Email: email,
|
||||
Provider: provider,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ID: id,
|
||||
Email: email,
|
||||
Provider: provider,
|
||||
OAuthSubject: subject,
|
||||
OAuthEmailVerified: verified,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.store.Create(ctx, user); err != nil {
|
||||
|
|
|
|||
|
|
@ -175,20 +175,23 @@ func TestServiceEnsureExternalUser(t *testing.T) {
|
|||
service := NewService(store)
|
||||
|
||||
googleEmail := MustUserEmail("google@example.com")
|
||||
if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil {
|
||||
if err := store.Create(ctx, User{ID: "existing-google", Email: googleEmail, Provider: ProviderGoogle, OAuthSubject: "existing-sub"}); err != nil {
|
||||
t.Fatalf("seed external user: %v", err)
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
email UserEmail
|
||||
provider string
|
||||
subject string
|
||||
verified bool
|
||||
wantErr error
|
||||
wantNew bool
|
||||
}{
|
||||
"missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput},
|
||||
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired},
|
||||
"existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false},
|
||||
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true},
|
||||
"missing email": {email: UserEmail(""), provider: ProviderGoogle, subject: "sub", wantErr: ErrInvalidInput},
|
||||
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", subject: "sub", wantErr: ErrProviderRequired},
|
||||
"missing subject": {email: MustUserEmail("new@example.com"), provider: ProviderGoogle, subject: "", wantErr: ErrSubjectRequired},
|
||||
"existing": {email: googleEmail, provider: ProviderGoogle, subject: "existing-sub", wantErr: nil, wantNew: false},
|
||||
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, subject: "new-sub", verified: true, wantErr: nil, wantNew: true},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
|
|
@ -196,7 +199,7 @@ func TestServiceEnsureExternalUser(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider)
|
||||
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider, tc.subject, tc.verified)
|
||||
if tc.wantErr != nil {
|
||||
if !errors.Is(err, tc.wantErr) {
|
||||
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||
|
|
@ -222,6 +225,9 @@ func TestServiceEnsureExternalUser(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("expected user persisted: %v", err)
|
||||
}
|
||||
if persisted.OAuthSubject != "" && persisted.OAuthSubject != tc.subject {
|
||||
t.Fatalf("expected oauth subject %q, got %q", tc.subject, persisted.OAuthSubject)
|
||||
}
|
||||
if tc.wantNew && persisted.CreatedAt.IsZero() {
|
||||
t.Fatal("expected created at timestamp for new user")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,9 @@ import (
|
|||
|
||||
// ErrUserNotFound signals no user exists for the provided lookup criteria.
|
||||
var (
|
||||
ErrUserNotFound = errors.New("auth: user not found")
|
||||
ErrEmailRequired = errors.New("auth: email required")
|
||||
ErrUserNotFound = errors.New("auth: user not found")
|
||||
ErrEmailRequired = errors.New("auth: email required")
|
||||
ErrSubjectRequired = errors.New("auth: oauth subject required")
|
||||
)
|
||||
|
||||
// UserStore defines persistence expectations for user lookups.
|
||||
|
|
|
|||
176
internal/service/auth/store_sql.go
Normal file
176
internal/service/auth/store_sql.go
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/rjnemo/auth/internal/driver/db"
|
||||
)
|
||||
|
||||
const passwordAlgorithm = "sha256"
|
||||
|
||||
// SQLStore persists users in PostgreSQL via generated sqlc queries.
|
||||
type SQLStore struct {
|
||||
pool *pgxpool.Pool
|
||||
queries *db.Queries
|
||||
}
|
||||
|
||||
// NewSQLStore builds a SQL-backed user store.
|
||||
func NewSQLStore(pool *pgxpool.Pool) *SQLStore {
|
||||
return &SQLStore{
|
||||
pool: pool,
|
||||
queries: db.New(pool),
|
||||
}
|
||||
}
|
||||
|
||||
// FindByEmail returns the stored user aggregate by canonical email address.
|
||||
func (s *SQLStore) FindByEmail(ctx context.Context, email UserEmail) (*User, error) {
|
||||
row, err := s.queries.GetUserByEmail(ctx, email.String())
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
normalizedEmail, err := NewUserEmail(row.Email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("normalize email: %w", err)
|
||||
}
|
||||
|
||||
user := &User{
|
||||
ID: row.ID.String(),
|
||||
Email: normalizedEmail,
|
||||
CreatedAt: timestamptzValue(row.CreatedAt),
|
||||
}
|
||||
|
||||
if pw, err := s.queries.GetUserPassword(ctx, row.ID); err == nil {
|
||||
user.PasswordSalt = base64.StdEncoding.EncodeToString(pw.PasswordSalt)
|
||||
user.PasswordHash = base64.StdEncoding.EncodeToString(pw.PasswordHash)
|
||||
user.Provider = ProviderPassword
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("load password: %w", err)
|
||||
}
|
||||
|
||||
oauthAccounts, err := s.queries.ListUserOAuthAccountsByUserID(ctx, row.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load oauth accounts: %w", err)
|
||||
}
|
||||
|
||||
if len(oauthAccounts) > 0 {
|
||||
acct := oauthAccounts[0]
|
||||
if user.Provider == "" {
|
||||
user.Provider = acct.Provider
|
||||
}
|
||||
user.OAuthSubject = acct.Subject
|
||||
user.OAuthEmailVerified = acct.EmailVerified
|
||||
}
|
||||
|
||||
if user.Provider == "" {
|
||||
user.Provider = ProviderPassword
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Create writes a new user aggregate to persistent storage.
|
||||
func (s *SQLStore) Create(ctx context.Context, user User) error {
|
||||
if user.Email.IsZero() {
|
||||
return ErrEmailRequired
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(user.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse user id: %w", err)
|
||||
}
|
||||
|
||||
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
qtx := s.queries.WithTx(tx)
|
||||
|
||||
if _, err = qtx.CreateUser(ctx, db.CreateUserParams{ID: id, Email: user.Email.String()}); err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return fmt.Errorf("insert user: %w", err)
|
||||
}
|
||||
|
||||
switch user.Provider {
|
||||
case ProviderPassword:
|
||||
if user.PasswordHash == "" || user.PasswordSalt == "" {
|
||||
return fmt.Errorf("password credentials required")
|
||||
}
|
||||
hashBytes, err := base64.StdEncoding.DecodeString(user.PasswordHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode password hash: %w", err)
|
||||
}
|
||||
saltBytes, err := base64.StdEncoding.DecodeString(user.PasswordSalt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode password salt: %w", err)
|
||||
}
|
||||
|
||||
if err := qtx.CreateUserPassword(ctx, db.CreateUserPasswordParams{
|
||||
UserID: id,
|
||||
PasswordHash: hashBytes,
|
||||
PasswordSalt: saltBytes,
|
||||
Algorithm: passwordAlgorithm,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("insert password: %w", err)
|
||||
}
|
||||
default:
|
||||
if user.OAuthSubject == "" {
|
||||
return ErrSubjectRequired
|
||||
}
|
||||
|
||||
var emailValue pgtype.Text
|
||||
if !user.Email.IsZero() {
|
||||
emailValue = pgtype.Text{String: user.Email.String(), Valid: true}
|
||||
}
|
||||
|
||||
if _, err := qtx.CreateUserOAuthAccount(ctx, db.CreateUserOAuthAccountParams{
|
||||
UserID: id,
|
||||
Provider: user.Provider,
|
||||
Subject: user.OAuthSubject,
|
||||
Email: emailValue,
|
||||
EmailVerified: user.OAuthEmailVerified,
|
||||
Profile: nil,
|
||||
}); err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return fmt.Errorf("insert oauth account: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func timestamptzValue(ts pgtype.Timestamptz) time.Time {
|
||||
if !ts.Valid {
|
||||
return time.Time{}
|
||||
}
|
||||
return ts.Time
|
||||
}
|
||||
186
internal/service/auth/store_sql_test.go
Normal file
186
internal/service/auth/store_sql_test.go
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const (
|
||||
schemaUpSQL = `
|
||||
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||
CREATE EXTENSION IF NOT EXISTS citext;
|
||||
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email CITEXT NOT NULL UNIQUE,
|
||||
display_name TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE user_passwords (
|
||||
user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
|
||||
password_hash BYTEA NOT NULL,
|
||||
password_salt BYTEA NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE user_oauth_accounts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
provider TEXT NOT NULL,
|
||||
subject TEXT NOT NULL,
|
||||
email TEXT,
|
||||
email_verified BOOLEAN NOT NULL DEFAULT false,
|
||||
profile JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX user_oauth_accounts_provider_subject_idx
|
||||
ON user_oauth_accounts (provider, subject);
|
||||
|
||||
CREATE INDEX user_oauth_accounts_user_id_idx
|
||||
ON user_oauth_accounts (user_id);
|
||||
|
||||
CREATE TABLE login_events (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID REFERENCES users(id),
|
||||
provider TEXT,
|
||||
success BOOLEAN NOT NULL,
|
||||
ip INET,
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX login_events_user_id_idx ON login_events (user_id);
|
||||
CREATE INDEX login_events_created_at_idx ON login_events (created_at);
|
||||
`
|
||||
|
||||
schemaDownSQL = `
|
||||
DROP TABLE IF EXISTS login_events;
|
||||
DROP TABLE IF EXISTS user_oauth_accounts;
|
||||
DROP TABLE IF EXISTS user_passwords;
|
||||
DROP TABLE IF EXISTS users;
|
||||
DROP EXTENSION IF EXISTS citext;
|
||||
DROP EXTENSION IF EXISTS pgcrypto;
|
||||
`
|
||||
)
|
||||
|
||||
func TestSQLStoreIntegration(t *testing.T) {
|
||||
dsn := os.Getenv("AUTH_DATABASE_URL")
|
||||
if strings.TrimSpace(dsn) == "" {
|
||||
t.Skip("AUTH_DATABASE_URL not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("connect database: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { pool.Close() })
|
||||
|
||||
resetDatabase(t, ctx, pool)
|
||||
|
||||
t.Run("register and authenticate", func(t *testing.T) {
|
||||
resetDatabase(t, ctx, pool)
|
||||
|
||||
store := NewSQLStore(pool)
|
||||
service := NewService(store)
|
||||
|
||||
email := MustUserEmail("sql-user@example.com")
|
||||
|
||||
user, err := service.Register(ctx, email, "Password123")
|
||||
if err != nil {
|
||||
t.Fatalf("register user: %v", err)
|
||||
}
|
||||
if user.ID == "" {
|
||||
t.Fatal("expected user id")
|
||||
}
|
||||
if user.Provider != ProviderPassword {
|
||||
t.Fatalf("expected provider %q, got %q", ProviderPassword, user.Provider)
|
||||
}
|
||||
|
||||
authenticated, err := service.Authenticate(ctx, email, "Password123")
|
||||
if err != nil {
|
||||
t.Fatalf("authenticate user: %v", err)
|
||||
}
|
||||
if authenticated.ID != user.ID {
|
||||
t.Fatalf("expected matching user id, got %q", authenticated.ID)
|
||||
}
|
||||
if authenticated.PasswordHash == "" || authenticated.PasswordSalt == "" {
|
||||
t.Fatal("expected persisted password credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ensure external user", func(t *testing.T) {
|
||||
resetDatabase(t, ctx, pool)
|
||||
|
||||
store := NewSQLStore(pool)
|
||||
service := NewService(store)
|
||||
|
||||
email := MustUserEmail("sql-google@example.com")
|
||||
subject := "google-subject-123"
|
||||
|
||||
account, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true)
|
||||
if err != nil {
|
||||
t.Fatalf("ensure external user: %v", err)
|
||||
}
|
||||
if account.Provider != ProviderGoogle {
|
||||
t.Fatalf("expected provider %q, got %q", ProviderGoogle, account.Provider)
|
||||
}
|
||||
if account.OAuthSubject != subject {
|
||||
t.Fatalf("expected oauth subject %q, got %q", subject, account.OAuthSubject)
|
||||
}
|
||||
|
||||
again, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true)
|
||||
if err != nil {
|
||||
t.Fatalf("ensure existing external user: %v", err)
|
||||
}
|
||||
if again.ID != account.ID {
|
||||
t.Fatalf("expected same user id, got %q vs %q", again.ID, account.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func resetDatabase(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
t.Helper()
|
||||
|
||||
execStatements := func(stmts []string) {
|
||||
for _, stmt := range stmts {
|
||||
if strings.TrimSpace(stmt) == "" {
|
||||
continue
|
||||
}
|
||||
if _, execErr := pool.Exec(ctx, stmt); execErr != nil {
|
||||
t.Fatalf("exec statement %q: %v", stmt, execErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
execStatements(splitSQLStatements(schemaDownSQL))
|
||||
execStatements(splitSQLStatements(schemaUpSQL))
|
||||
}
|
||||
|
||||
func splitSQLStatements(section string) []string {
|
||||
section = strings.TrimSpace(section)
|
||||
if section == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(section, ";")
|
||||
statements := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
stmt := strings.TrimSpace(part)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
statements = append(statements, stmt+";")
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
|
@ -1,21 +1,23 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// User represents authenticated account details.
|
||||
type User struct {
|
||||
ID string
|
||||
Email UserEmail
|
||||
PasswordSalt string
|
||||
PasswordHash string
|
||||
Provider string
|
||||
CreatedAt time.Time
|
||||
ID string
|
||||
Email UserEmail
|
||||
PasswordSalt string
|
||||
PasswordHash string
|
||||
Provider string
|
||||
OAuthSubject string
|
||||
OAuthEmailVerified bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// UserEmail represents a canonical email string.
|
||||
|
|
@ -49,11 +51,6 @@ func (e UserEmail) IsZero() bool {
|
|||
return e == ""
|
||||
}
|
||||
|
||||
// TODO: could be UUID. return a dedicated type
|
||||
func generateUserID() (string, error) {
|
||||
buf := make([]byte, userIDByteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
return uuid.NewString(), nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue