mirror of
https://github.com/rjNemo/auth
synced 2026-06-06 00:16:40 +00:00
feat: add sql-backed user store
This commit is contained in:
parent
29fb3054a5
commit
4ccdaa85b4
13 changed files with 474 additions and 97 deletions
|
|
@ -1,14 +1,17 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
"github.com/rjnemo/auth/internal/config"
|
"github.com/rjnemo/auth/internal/config"
|
||||||
"github.com/rjnemo/auth/internal/driver/logging"
|
"github.com/rjnemo/auth/internal/driver/logging"
|
||||||
"github.com/rjnemo/auth/internal/server"
|
"github.com/rjnemo/auth/internal/server"
|
||||||
|
"github.com/rjnemo/auth/internal/service/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
@ -30,7 +33,17 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func run(cfg *config.Config, logger *slog.Logger) error {
|
func run(cfg *config.Config, logger *slog.Logger) error {
|
||||||
srv, err := server.New(*cfg, logger)
|
ctx := context.Background()
|
||||||
|
pool, err := pgxpool.New(ctx, cfg.DatabaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect database: %w", err)
|
||||||
|
}
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
store := auth.NewSQLStore(pool)
|
||||||
|
service := auth.NewService(store)
|
||||||
|
|
||||||
|
srv, err := server.New(*cfg, service, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initialise server: %w", err)
|
return fmt.Errorf("initialise server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -13,7 +13,9 @@ require (
|
||||||
cloud.google.com/go/compute/metadata v0.8.4 // indirect
|
cloud.google.com/go/compute/metadata v0.8.4 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
golang.org/x/crypto v0.37.0 // indirect
|
golang.org/x/crypto v0.37.0 // indirect
|
||||||
|
golang.org/x/sync v0.13.0 // indirect
|
||||||
golang.org/x/sys v0.36.0 // indirect
|
golang.org/x/sys v0.36.0 // indirect
|
||||||
golang.org/x/text v0.24.0 // indirect
|
golang.org/x/text v0.24.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
-- name: CreateUser :one
|
-- name: CreateUser :one
|
||||||
INSERT INTO users (email, display_name)
|
INSERT INTO users (id, email)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
RETURNING id, email, display_name, created_at;
|
RETURNING id, email, created_at;
|
||||||
|
|
||||||
-- name: GetUserByID :one
|
-- name: GetUserByID :one
|
||||||
SELECT id, email, display_name, created_at
|
SELECT id, email, created_at
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1;
|
WHERE id = $1;
|
||||||
|
|
||||||
-- name: GetUserByEmail :one
|
-- name: GetUserByEmail :one
|
||||||
SELECT id, email, display_name, created_at
|
SELECT id, email, created_at
|
||||||
FROM users
|
FROM users
|
||||||
WHERE email = $1;
|
WHERE email = $1;
|
||||||
|
|
|
||||||
|
|
@ -13,37 +13,31 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const createUser = `-- name: CreateUser :one
|
const createUser = `-- name: CreateUser :one
|
||||||
INSERT INTO users (email, display_name)
|
INSERT INTO users (id, email)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
RETURNING id, email, display_name, created_at
|
RETURNING id, email, created_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateUserParams struct {
|
type CreateUserParams struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
DisplayName pgtype.Text `json:"display_name"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateUserRow struct {
|
type CreateUserRow struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
DisplayName pgtype.Text `json:"display_name"`
|
|
||||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (CreateUserRow, error) {
|
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (CreateUserRow, error) {
|
||||||
row := q.db.QueryRow(ctx, createUser, arg.Email, arg.DisplayName)
|
row := q.db.QueryRow(ctx, createUser, arg.ID, arg.Email)
|
||||||
var i CreateUserRow
|
var i CreateUserRow
|
||||||
err := row.Scan(
|
err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
|
||||||
&i.ID,
|
|
||||||
&i.Email,
|
|
||||||
&i.DisplayName,
|
|
||||||
&i.CreatedAt,
|
|
||||||
)
|
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||||
SELECT id, email, display_name, created_at
|
SELECT id, email, created_at
|
||||||
FROM users
|
FROM users
|
||||||
WHERE email = $1
|
WHERE email = $1
|
||||||
`
|
`
|
||||||
|
|
@ -51,24 +45,18 @@ WHERE email = $1
|
||||||
type GetUserByEmailRow struct {
|
type GetUserByEmailRow struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
DisplayName pgtype.Text `json:"display_name"`
|
|
||||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (GetUserByEmailRow, error) {
|
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (GetUserByEmailRow, error) {
|
||||||
row := q.db.QueryRow(ctx, getUserByEmail, email)
|
row := q.db.QueryRow(ctx, getUserByEmail, email)
|
||||||
var i GetUserByEmailRow
|
var i GetUserByEmailRow
|
||||||
err := row.Scan(
|
err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
|
||||||
&i.ID,
|
|
||||||
&i.Email,
|
|
||||||
&i.DisplayName,
|
|
||||||
&i.CreatedAt,
|
|
||||||
)
|
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getUserByID = `-- name: GetUserByID :one
|
const getUserByID = `-- name: GetUserByID :one
|
||||||
SELECT id, email, display_name, created_at
|
SELECT id, email, created_at
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
@ -76,18 +64,12 @@ WHERE id = $1
|
||||||
type GetUserByIDRow struct {
|
type GetUserByIDRow struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
DisplayName pgtype.Text `json:"display_name"`
|
|
||||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (GetUserByIDRow, error) {
|
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (GetUserByIDRow, error) {
|
||||||
row := q.db.QueryRow(ctx, getUserByID, id)
|
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||||
var i GetUserByIDRow
|
var i GetUserByIDRow
|
||||||
err := row.Scan(
|
err := row.Scan(&i.ID, &i.Email, &i.CreatedAt)
|
||||||
&i.ID,
|
|
||||||
&i.Email,
|
|
||||||
&i.DisplayName,
|
|
||||||
&i.CreatedAt,
|
|
||||||
)
|
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
|
@ -152,7 +153,16 @@ func (s *Server) googleCallbackHandler() http.HandlerFunc {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle)
|
if strings.TrimSpace(info.ID) == "" {
|
||||||
|
logger.Warn("google returned empty subject")
|
||||||
|
if !saveState() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle, info.ID, info.VerifiedEmail)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("ensure external user failed", slog.Any("error", err))
|
logger.Error("ensure external user failed", slog.Any("error", err))
|
||||||
if !saveState() {
|
if !saveState() {
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,11 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rjnemo/auth/internal/config"
|
"github.com/rjnemo/auth/internal/config"
|
||||||
"github.com/rjnemo/auth/internal/driver/logging"
|
"github.com/rjnemo/auth/internal/driver/logging"
|
||||||
|
|
@ -31,8 +31,16 @@ type Server struct {
|
||||||
googleOAuth *oauth2.Config
|
googleOAuth *oauth2.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// New constructs a Server with parsed templates and default state.
|
// New constructs a Server with parsed templates and default state using the provided service.
|
||||||
func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
func New(cfg config.Config, authService *auth.Service, logger *slog.Logger) (*Server, error) {
|
||||||
|
if authService == nil {
|
||||||
|
return nil, fmt.Errorf("auth service must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := seedUser(context.Background(), authService); err != nil {
|
||||||
|
return nil, fmt.Errorf("seed user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
tmpl, err := template.ParseFS(
|
tmpl, err := template.ParseFS(
|
||||||
web.Templates,
|
web.Templates,
|
||||||
"templates/auth_base.html",
|
"templates/auth_base.html",
|
||||||
|
|
@ -45,11 +53,6 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
||||||
return nil, fmt.Errorf("parse templates: %w", err)
|
return nil, fmt.Errorf("parse templates: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := auth.NewMemoryStore()
|
|
||||||
if err := seedUser(store); err != nil {
|
|
||||||
return nil, fmt.Errorf("seed user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionStore, err := NewSessionStore(cfg.SessionSecret)
|
sessionStore, err := NewSessionStore(cfg.SessionSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("session store: %w", err)
|
return nil, fmt.Errorf("session store: %w", err)
|
||||||
|
|
@ -76,7 +79,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
templates: tmpl,
|
templates: tmpl,
|
||||||
authService: auth.NewService(store),
|
authService: authService,
|
||||||
sessions: sessionStore,
|
sessions: sessionStore,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
configuration: cfg,
|
configuration: cfg,
|
||||||
|
|
@ -84,21 +87,13 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func seedUser(store auth.UserStore) error {
|
func seedUser(ctx context.Context, service *auth.Service) error {
|
||||||
salt, hash, err := auth.HashPassword(seedPassword)
|
email := auth.MustUserEmail(seedEmail)
|
||||||
if err != nil {
|
if _, err := service.Register(ctx, email, seedPassword); err != nil {
|
||||||
|
if errors.Is(err, auth.ErrEmailExists) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
email := auth.MustUserEmail(seedEmail)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
return store.Create(ctx, auth.User{
|
|
||||||
ID: "seed-user",
|
|
||||||
Email: email,
|
|
||||||
PasswordSalt: salt,
|
|
||||||
PasswordHash: hash,
|
|
||||||
Provider: auth.ProviderPassword,
|
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/rjnemo/auth/internal/config"
|
"github.com/rjnemo/auth/internal/config"
|
||||||
"github.com/rjnemo/auth/internal/driver/logging"
|
"github.com/rjnemo/auth/internal/driver/logging"
|
||||||
|
"github.com/rjnemo/auth/internal/service/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestServer(t *testing.T) *Server {
|
func newTestServer(t *testing.T) *Server {
|
||||||
|
|
@ -26,7 +27,9 @@ func newTestServer(t *testing.T) *Server {
|
||||||
|
|
||||||
logger := logging.New(io.Discard, logging.ModeText, nil)
|
logger := logging.New(io.Discard, logging.ModeText, nil)
|
||||||
|
|
||||||
srv, err := New(cfg, logger)
|
store := auth.NewMemoryStore()
|
||||||
|
service := auth.NewService(store)
|
||||||
|
srv, err := New(cfg, service, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new server: %v", err)
|
t.Fatalf("new server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -51,7 +54,9 @@ func newGoogleTestServer(t *testing.T) *Server {
|
||||||
|
|
||||||
logger := logging.New(io.Discard, logging.ModeText, nil)
|
logger := logging.New(io.Discard, logging.ModeText, nil)
|
||||||
|
|
||||||
srv, err := New(cfg, logger)
|
store := auth.NewMemoryStore()
|
||||||
|
service := auth.NewService(store)
|
||||||
|
srv, err := New(cfg, service, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new google server: %v", err)
|
t.Fatalf("new google server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
userIDByteLength = 16
|
|
||||||
// ProviderPassword identifies accounts managed via email/password.
|
// ProviderPassword identifies accounts managed via email/password.
|
||||||
ProviderPassword = "password"
|
ProviderPassword = "password"
|
||||||
// ProviderGoogle identifies accounts authenticated via Google OAuth2.
|
// ProviderGoogle identifies accounts authenticated via Google OAuth2.
|
||||||
|
|
@ -112,13 +111,16 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnsureExternalUser retrieves or provisions an account authenticated by an external provider.
|
// EnsureExternalUser retrieves or provisions an account authenticated by an external provider.
|
||||||
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) {
|
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider, subject string, verified bool) (*User, error) {
|
||||||
if email.IsZero() {
|
if email.IsZero() {
|
||||||
return nil, ErrInvalidInput
|
return nil, ErrInvalidInput
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(provider) == "" {
|
if strings.TrimSpace(provider) == "" {
|
||||||
return nil, ErrProviderRequired
|
return nil, ErrProviderRequired
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(subject) == "" {
|
||||||
|
return nil, ErrSubjectRequired
|
||||||
|
}
|
||||||
|
|
||||||
account, err := s.store.FindByEmail(ctx, email)
|
account, err := s.store.FindByEmail(ctx, email)
|
||||||
switch {
|
switch {
|
||||||
|
|
@ -137,6 +139,8 @@ func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provi
|
||||||
ID: id,
|
ID: id,
|
||||||
Email: email,
|
Email: email,
|
||||||
Provider: provider,
|
Provider: provider,
|
||||||
|
OAuthSubject: subject,
|
||||||
|
OAuthEmailVerified: verified,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -175,20 +175,23 @@ func TestServiceEnsureExternalUser(t *testing.T) {
|
||||||
service := NewService(store)
|
service := NewService(store)
|
||||||
|
|
||||||
googleEmail := MustUserEmail("google@example.com")
|
googleEmail := MustUserEmail("google@example.com")
|
||||||
if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil {
|
if err := store.Create(ctx, User{ID: "existing-google", Email: googleEmail, Provider: ProviderGoogle, OAuthSubject: "existing-sub"}); err != nil {
|
||||||
t.Fatalf("seed external user: %v", err)
|
t.Fatalf("seed external user: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
email UserEmail
|
email UserEmail
|
||||||
provider string
|
provider string
|
||||||
|
subject string
|
||||||
|
verified bool
|
||||||
wantErr error
|
wantErr error
|
||||||
wantNew bool
|
wantNew bool
|
||||||
}{
|
}{
|
||||||
"missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput},
|
"missing email": {email: UserEmail(""), provider: ProviderGoogle, subject: "sub", wantErr: ErrInvalidInput},
|
||||||
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired},
|
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", subject: "sub", wantErr: ErrProviderRequired},
|
||||||
"existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false},
|
"missing subject": {email: MustUserEmail("new@example.com"), provider: ProviderGoogle, subject: "", wantErr: ErrSubjectRequired},
|
||||||
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true},
|
"existing": {email: googleEmail, provider: ProviderGoogle, subject: "existing-sub", wantErr: nil, wantNew: false},
|
||||||
|
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, subject: "new-sub", verified: true, wantErr: nil, wantNew: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
|
|
@ -196,7 +199,7 @@ func TestServiceEnsureExternalUser(t *testing.T) {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider)
|
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider, tc.subject, tc.verified)
|
||||||
if tc.wantErr != nil {
|
if tc.wantErr != nil {
|
||||||
if !errors.Is(err, tc.wantErr) {
|
if !errors.Is(err, tc.wantErr) {
|
||||||
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||||
|
|
@ -222,6 +225,9 @@ func TestServiceEnsureExternalUser(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected user persisted: %v", err)
|
t.Fatalf("expected user persisted: %v", err)
|
||||||
}
|
}
|
||||||
|
if persisted.OAuthSubject != "" && persisted.OAuthSubject != tc.subject {
|
||||||
|
t.Fatalf("expected oauth subject %q, got %q", tc.subject, persisted.OAuthSubject)
|
||||||
|
}
|
||||||
if tc.wantNew && persisted.CreatedAt.IsZero() {
|
if tc.wantNew && persisted.CreatedAt.IsZero() {
|
||||||
t.Fatal("expected created at timestamp for new user")
|
t.Fatal("expected created at timestamp for new user")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import (
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("auth: user not found")
|
ErrUserNotFound = errors.New("auth: user not found")
|
||||||
ErrEmailRequired = errors.New("auth: email required")
|
ErrEmailRequired = errors.New("auth: email required")
|
||||||
|
ErrSubjectRequired = errors.New("auth: oauth subject required")
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserStore defines persistence expectations for user lookups.
|
// UserStore defines persistence expectations for user lookups.
|
||||||
|
|
|
||||||
176
internal/service/auth/store_sql.go
Normal file
176
internal/service/auth/store_sql.go
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
|
"github.com/rjnemo/auth/internal/driver/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
const passwordAlgorithm = "sha256"
|
||||||
|
|
||||||
|
// SQLStore persists users in PostgreSQL via generated sqlc queries.
|
||||||
|
type SQLStore struct {
|
||||||
|
pool *pgxpool.Pool
|
||||||
|
queries *db.Queries
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSQLStore builds a SQL-backed user store.
|
||||||
|
func NewSQLStore(pool *pgxpool.Pool) *SQLStore {
|
||||||
|
return &SQLStore{
|
||||||
|
pool: pool,
|
||||||
|
queries: db.New(pool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindByEmail returns the stored user aggregate by canonical email address.
|
||||||
|
func (s *SQLStore) FindByEmail(ctx context.Context, email UserEmail) (*User, error) {
|
||||||
|
row, err := s.queries.GetUserByEmail(ctx, email.String())
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("lookup user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedEmail, err := NewUserEmail(row.Email)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("normalize email: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &User{
|
||||||
|
ID: row.ID.String(),
|
||||||
|
Email: normalizedEmail,
|
||||||
|
CreatedAt: timestamptzValue(row.CreatedAt),
|
||||||
|
}
|
||||||
|
|
||||||
|
if pw, err := s.queries.GetUserPassword(ctx, row.ID); err == nil {
|
||||||
|
user.PasswordSalt = base64.StdEncoding.EncodeToString(pw.PasswordSalt)
|
||||||
|
user.PasswordHash = base64.StdEncoding.EncodeToString(pw.PasswordHash)
|
||||||
|
user.Provider = ProviderPassword
|
||||||
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return nil, fmt.Errorf("load password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthAccounts, err := s.queries.ListUserOAuthAccountsByUserID(ctx, row.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load oauth accounts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(oauthAccounts) > 0 {
|
||||||
|
acct := oauthAccounts[0]
|
||||||
|
if user.Provider == "" {
|
||||||
|
user.Provider = acct.Provider
|
||||||
|
}
|
||||||
|
user.OAuthSubject = acct.Subject
|
||||||
|
user.OAuthEmailVerified = acct.EmailVerified
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Provider == "" {
|
||||||
|
user.Provider = ProviderPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writes a new user aggregate to persistent storage.
|
||||||
|
func (s *SQLStore) Create(ctx context.Context, user User) error {
|
||||||
|
if user.Email.IsZero() {
|
||||||
|
return ErrEmailRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := uuid.Parse(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse user id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("begin transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback(ctx)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
qtx := s.queries.WithTx(tx)
|
||||||
|
|
||||||
|
if _, err = qtx.CreateUser(ctx, db.CreateUserParams{ID: id, Email: user.Email.String()}); err != nil {
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||||
|
return ErrEmailExists
|
||||||
|
}
|
||||||
|
return fmt.Errorf("insert user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch user.Provider {
|
||||||
|
case ProviderPassword:
|
||||||
|
if user.PasswordHash == "" || user.PasswordSalt == "" {
|
||||||
|
return fmt.Errorf("password credentials required")
|
||||||
|
}
|
||||||
|
hashBytes, err := base64.StdEncoding.DecodeString(user.PasswordHash)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decode password hash: %w", err)
|
||||||
|
}
|
||||||
|
saltBytes, err := base64.StdEncoding.DecodeString(user.PasswordSalt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decode password salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := qtx.CreateUserPassword(ctx, db.CreateUserPasswordParams{
|
||||||
|
UserID: id,
|
||||||
|
PasswordHash: hashBytes,
|
||||||
|
PasswordSalt: saltBytes,
|
||||||
|
Algorithm: passwordAlgorithm,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("insert password: %w", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if user.OAuthSubject == "" {
|
||||||
|
return ErrSubjectRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
var emailValue pgtype.Text
|
||||||
|
if !user.Email.IsZero() {
|
||||||
|
emailValue = pgtype.Text{String: user.Email.String(), Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := qtx.CreateUserOAuthAccount(ctx, db.CreateUserOAuthAccountParams{
|
||||||
|
UserID: id,
|
||||||
|
Provider: user.Provider,
|
||||||
|
Subject: user.OAuthSubject,
|
||||||
|
Email: emailValue,
|
||||||
|
EmailVerified: user.OAuthEmailVerified,
|
||||||
|
Profile: nil,
|
||||||
|
}); err != nil {
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||||
|
return ErrEmailExists
|
||||||
|
}
|
||||||
|
return fmt.Errorf("insert oauth account: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return fmt.Errorf("commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func timestamptzValue(ts pgtype.Timestamptz) time.Time {
|
||||||
|
if !ts.Valid {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return ts.Time
|
||||||
|
}
|
||||||
186
internal/service/auth/store_sql_test.go
Normal file
186
internal/service/auth/store_sql_test.go
Normal file
|
|
@ -0,0 +1,186 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
schemaUpSQL = `
|
||||||
|
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||||
|
CREATE EXTENSION IF NOT EXISTS citext;
|
||||||
|
|
||||||
|
CREATE TABLE users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
email CITEXT NOT NULL UNIQUE,
|
||||||
|
display_name TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE user_passwords (
|
||||||
|
user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
password_hash BYTEA NOT NULL,
|
||||||
|
password_salt BYTEA NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE user_oauth_accounts (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
subject TEXT NOT NULL,
|
||||||
|
email TEXT,
|
||||||
|
email_verified BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
profile JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX user_oauth_accounts_provider_subject_idx
|
||||||
|
ON user_oauth_accounts (provider, subject);
|
||||||
|
|
||||||
|
CREATE INDEX user_oauth_accounts_user_id_idx
|
||||||
|
ON user_oauth_accounts (user_id);
|
||||||
|
|
||||||
|
CREATE TABLE login_events (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id UUID REFERENCES users(id),
|
||||||
|
provider TEXT,
|
||||||
|
success BOOLEAN NOT NULL,
|
||||||
|
ip INET,
|
||||||
|
user_agent TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX login_events_user_id_idx ON login_events (user_id);
|
||||||
|
CREATE INDEX login_events_created_at_idx ON login_events (created_at);
|
||||||
|
`
|
||||||
|
|
||||||
|
schemaDownSQL = `
|
||||||
|
DROP TABLE IF EXISTS login_events;
|
||||||
|
DROP TABLE IF EXISTS user_oauth_accounts;
|
||||||
|
DROP TABLE IF EXISTS user_passwords;
|
||||||
|
DROP TABLE IF EXISTS users;
|
||||||
|
DROP EXTENSION IF EXISTS citext;
|
||||||
|
DROP EXTENSION IF EXISTS pgcrypto;
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSQLStoreIntegration(t *testing.T) {
|
||||||
|
dsn := os.Getenv("AUTH_DATABASE_URL")
|
||||||
|
if strings.TrimSpace(dsn) == "" {
|
||||||
|
t.Skip("AUTH_DATABASE_URL not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
pool, err := pgxpool.New(ctx, dsn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect database: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { pool.Close() })
|
||||||
|
|
||||||
|
resetDatabase(t, ctx, pool)
|
||||||
|
|
||||||
|
t.Run("register and authenticate", func(t *testing.T) {
|
||||||
|
resetDatabase(t, ctx, pool)
|
||||||
|
|
||||||
|
store := NewSQLStore(pool)
|
||||||
|
service := NewService(store)
|
||||||
|
|
||||||
|
email := MustUserEmail("sql-user@example.com")
|
||||||
|
|
||||||
|
user, err := service.Register(ctx, email, "Password123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("register user: %v", err)
|
||||||
|
}
|
||||||
|
if user.ID == "" {
|
||||||
|
t.Fatal("expected user id")
|
||||||
|
}
|
||||||
|
if user.Provider != ProviderPassword {
|
||||||
|
t.Fatalf("expected provider %q, got %q", ProviderPassword, user.Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticated, err := service.Authenticate(ctx, email, "Password123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("authenticate user: %v", err)
|
||||||
|
}
|
||||||
|
if authenticated.ID != user.ID {
|
||||||
|
t.Fatalf("expected matching user id, got %q", authenticated.ID)
|
||||||
|
}
|
||||||
|
if authenticated.PasswordHash == "" || authenticated.PasswordSalt == "" {
|
||||||
|
t.Fatal("expected persisted password credentials")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ensure external user", func(t *testing.T) {
|
||||||
|
resetDatabase(t, ctx, pool)
|
||||||
|
|
||||||
|
store := NewSQLStore(pool)
|
||||||
|
service := NewService(store)
|
||||||
|
|
||||||
|
email := MustUserEmail("sql-google@example.com")
|
||||||
|
subject := "google-subject-123"
|
||||||
|
|
||||||
|
account, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensure external user: %v", err)
|
||||||
|
}
|
||||||
|
if account.Provider != ProviderGoogle {
|
||||||
|
t.Fatalf("expected provider %q, got %q", ProviderGoogle, account.Provider)
|
||||||
|
}
|
||||||
|
if account.OAuthSubject != subject {
|
||||||
|
t.Fatalf("expected oauth subject %q, got %q", subject, account.OAuthSubject)
|
||||||
|
}
|
||||||
|
|
||||||
|
again, err := service.EnsureExternalUser(ctx, email, ProviderGoogle, subject, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensure existing external user: %v", err)
|
||||||
|
}
|
||||||
|
if again.ID != account.ID {
|
||||||
|
t.Fatalf("expected same user id, got %q vs %q", again.ID, account.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetDatabase(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
execStatements := func(stmts []string) {
|
||||||
|
for _, stmt := range stmts {
|
||||||
|
if strings.TrimSpace(stmt) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, execErr := pool.Exec(ctx, stmt); execErr != nil {
|
||||||
|
t.Fatalf("exec statement %q: %v", stmt, execErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
execStatements(splitSQLStatements(schemaDownSQL))
|
||||||
|
execStatements(splitSQLStatements(schemaUpSQL))
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitSQLStatements(section string) []string {
|
||||||
|
section = strings.TrimSpace(section)
|
||||||
|
if section == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(section, ";")
|
||||||
|
statements := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
stmt := strings.TrimSpace(part)
|
||||||
|
if stmt == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
statements = append(statements, stmt+";")
|
||||||
|
}
|
||||||
|
|
||||||
|
return statements
|
||||||
|
}
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// User represents authenticated account details.
|
// User represents authenticated account details.
|
||||||
|
|
@ -15,6 +15,8 @@ type User struct {
|
||||||
PasswordSalt string
|
PasswordSalt string
|
||||||
PasswordHash string
|
PasswordHash string
|
||||||
Provider string
|
Provider string
|
||||||
|
OAuthSubject string
|
||||||
|
OAuthEmailVerified bool
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -49,11 +51,6 @@ func (e UserEmail) IsZero() bool {
|
||||||
return e == ""
|
return e == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: could be UUID. return a dedicated type
|
|
||||||
func generateUserID() (string, error) {
|
func generateUserID() (string, error) {
|
||||||
buf := make([]byte, userIDByteLength)
|
return uuid.NewString(), nil
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue