feat: add google oauth login

This commit is contained in:
Ruidy 2025-09-20 18:28:04 +02:00
parent a20953cfb4
commit 52f479206c
No known key found for this signature in database
GPG key ID: 705C24D202990805
17 changed files with 625 additions and 31 deletions

5
go.mod
View file

@ -3,3 +3,8 @@ module github.com/rjnemo/auth
go 1.25.1
require github.com/go-chi/chi/v5 v5.2.3
require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
golang.org/x/oauth2 v0.31.0 // indirect
)

4
go.sum
View file

@ -1,2 +1,6 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=

View file

@ -11,10 +11,13 @@ import (
)
const (
envListenAddr = "AUTH_LISTEN_ADDR"
envLogMode = "AUTH_LOG_MODE"
envEnvironment = "AUTH_ENV"
envSessionSecret = "AUTH_SESSION_SECRET"
envListenAddr = "AUTH_LISTEN_ADDR"
envLogMode = "AUTH_LOG_MODE"
envEnvironment = "AUTH_ENV"
envSessionSecret = "AUTH_SESSION_SECRET"
envGoogleClientID = "AUTH_GOOGLE_CLIENT_ID"
envGoogleClientSecret = "AUTH_GOOGLE_CLIENT_SECRET"
envGoogleRedirectURL = "AUTH_GOOGLE_REDIRECT_URL"
defaultListenAddr = ":8000"
defaultEnvironment = "development"
@ -26,6 +29,19 @@ type Config struct {
LogMode logging.Mode
Environment string
SessionSecret []byte
GoogleOAuth GoogleOAuthConfig
}
// GoogleOAuthConfig holds configuration for Google OAuth2 login.
type GoogleOAuthConfig struct {
ClientID string
ClientSecret string
RedirectURL string
}
// Enabled reports whether Google OAuth2 is fully configured.
func (g GoogleOAuthConfig) Enabled() bool {
return g.ClientID != "" && g.ClientSecret != "" && g.RedirectURL != ""
}
// New loads configuration from environment variables, applying defaults and validation.
@ -47,12 +63,34 @@ func New() (*Config, error) {
return nil, fmt.Errorf("invalid %s: %w", envSessionSecret, err)
}
googleOAuth := GoogleOAuthConfig{
ClientID: strings.TrimSpace(os.Getenv(envGoogleClientID)),
ClientSecret: strings.TrimSpace(os.Getenv(envGoogleClientSecret)),
RedirectURL: strings.TrimSpace(os.Getenv(envGoogleRedirectURL)),
}
if partiallyConfigured(googleOAuth) {
return nil, fmt.Errorf("incomplete google oauth configuration: set %s, %s, and %s", envGoogleClientID, envGoogleClientSecret, envGoogleRedirectURL)
}
cfg := &Config{
ListenAddr: listenAddr,
LogMode: logMode,
Environment: environment,
SessionSecret: secret,
GoogleOAuth: googleOAuth,
}
return cfg, nil
}
func partiallyConfigured(cfg GoogleOAuthConfig) bool {
switch {
case cfg.ClientID == "" && cfg.ClientSecret == "" && cfg.RedirectURL == "":
return false
case cfg.ClientID == "" || cfg.ClientSecret == "" || cfg.RedirectURL == "":
return true
default:
return false
}
}

View file

@ -73,6 +73,33 @@ func TestNewShortSecretAccepted(t *testing.T) {
}
}
func TestNewGoogleOAuthPartialConfiguration(t *testing.T) {
t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32)))
t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client")
t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "")
if _, err := New(); err == nil {
t.Fatalf("expected error for partial google oauth config")
}
}
func TestNewGoogleOAuthConfigured(t *testing.T) {
t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32)))
t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client")
t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "secret")
t.Setenv("AUTH_GOOGLE_REDIRECT_URL", "http://localhost:8000/login/google/callback")
cfg, err := New()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !cfg.GoogleOAuth.Enabled() {
t.Fatalf("expected google oauth to be enabled")
}
if cfg.GoogleOAuth.ClientID != "client" {
t.Fatalf("expected client id to match, got %q", cfg.GoogleOAuth.ClientID)
}
}
func bytesOfLength(n int) []byte {
b := make([]byte, n)
for i := range b {

View file

@ -15,7 +15,7 @@ func (s *Server) loginPageHandler() http.HandlerFunc {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
return
}
s.render(w, "login.html", newLoginData(state.Email, "", state.CSRFToken))
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, "", state.CSRFToken)))
}
}
@ -35,7 +35,7 @@ func (s *Server) loginHandler() http.HandlerFunc {
email, err := auth.NewUserEmail(emailInput)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
s.render(w, "login.html", newLoginData("", credentialRequiredMsg, state.CSRFToken))
s.render(w, "login.html", s.applyOAuthOptions(newLoginData("", credentialRequiredMsg, state.CSRFToken)))
return
}
@ -51,10 +51,10 @@ func (s *Server) loginHandler() http.HandlerFunc {
case errors.Is(err, auth.ErrWeakPassword):
w.WriteHeader(http.StatusBadRequest)
s.render(w, "login.html", newLoginData(email.String(), weakPasswordMsg, state.CSRFToken))
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), weakPasswordMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrInvalidInput):
w.WriteHeader(http.StatusBadRequest)
s.render(w, "login.html", newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken))
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrInvalidCredentials):
s.renderLoginFailure(w, email, state.CSRFToken)
default:
@ -66,5 +66,5 @@ func (s *Server) loginHandler() http.HandlerFunc {
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) {
w.WriteHeader(http.StatusUnauthorized)
s.render(w, "login.html", newLoginData(email.String(), invalidCredentialsMsg, token))
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), invalidCredentialsMsg, token)))
}

View file

@ -0,0 +1,200 @@
package server
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"golang.org/x/oauth2"
"github.com/rjnemo/auth/internal/service/auth"
)
const (
googleUserInfoEndpoint = "https://www.googleapis.com/oauth2/v3/userinfo"
googleAuthFailedMsg = "Unable to sign in with Google. Please try again."
googleAuthCanceledMsg = "Google sign-in was cancelled."
)
type googleUserInfo struct {
ID string `json:"sub"`
Email string `json:"email"`
VerifiedEmail bool `json:"email_verified"`
}
func (s *Server) googleLoginHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger := s.logger.With(slog.String("component", "google_oauth"))
if s.googleOAuth == nil {
http.NotFound(w, r)
return
}
state := sessionFromContext(r.Context())
if state.Authenticated {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
return
}
token, err := generateOAuthState()
if err != nil {
logger.Error("generate oauth state failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
state.OAuthState = token
if err := s.sessions.Save(w, state); err != nil {
logger.Error("persist oauth state failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
redirectURL := s.googleOAuth.AuthCodeURL(token, oauth2.AccessTypeOnline)
http.Redirect(w, r, redirectURL, http.StatusFound)
}
}
func (s *Server) googleCallbackHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger := s.logger.With(slog.String("component", "google_oauth"))
if s.googleOAuth == nil {
http.NotFound(w, r)
return
}
state := sessionFromContext(r.Context())
expectedState := state.OAuthState
providedState := r.URL.Query().Get("state")
state.OAuthState = ""
saveState := func() bool {
if err := s.sessions.Save(w, state); err != nil {
logger.Error("session save failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return false
}
return true
}
respondWithLogin := func(status int, message string) {
if status != 0 {
w.WriteHeader(status)
}
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, message, state.CSRFToken)))
}
if expectedState == "" || providedState == "" || providedState != expectedState {
if !saveState() {
return
}
http.Error(w, "invalid oauth state", http.StatusBadRequest)
return
}
if errParam := r.URL.Query().Get("error"); errParam != "" {
logger.Info("google oauth returned error", slog.String("google_error", errParam))
if !saveState() {
return
}
respondWithLogin(http.StatusBadRequest, googleAuthCanceledMsg)
return
}
authCode := r.URL.Query().Get("code")
if authCode == "" {
if !saveState() {
return
}
http.Error(w, "missing authorization code", http.StatusBadRequest)
return
}
token, err := s.googleOAuth.Exchange(r.Context(), authCode)
if err != nil {
logger.Error("oauth code exchange failed", slog.Any("error", err))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
info, err := s.fetchGoogleUserInfo(r.Context(), token)
if err != nil {
logger.Error("fetch google user info failed", slog.Any("error", err))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
if !info.VerifiedEmail || info.Email == "" {
logger.Warn("google returned unverified email", slog.Bool("verified", info.VerifiedEmail))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
email, err := auth.NewUserEmail(info.Email)
if err != nil {
logger.Error("normalize google email failed", slog.Any("error", err))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle)
if err != nil {
logger.Error("ensure external user failed", slog.Any("error", err))
if !saveState() {
return
}
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
state.Authenticated = true
state.Email = account.Email.String()
if err := s.sessions.Save(w, state); err != nil {
logger.Error("session save failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
}
}
func (s *Server) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token) (googleUserInfo, error) {
client := s.googleOAuth.Client(ctx, token)
resp, err := client.Get(googleUserInfoEndpoint)
if err != nil {
return googleUserInfo{}, fmt.Errorf("request google userinfo: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
s.logger.With(slog.String("component", "google_oauth")).Warn("close google userinfo body failed", slog.Any("error", cerr))
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return googleUserInfo{}, fmt.Errorf("google userinfo response %d: %s", resp.StatusCode, string(body))
}
var info googleUserInfo
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
return googleUserInfo{}, fmt.Errorf("decode google userinfo: %w", err)
}
return info, nil
}

View file

@ -24,7 +24,7 @@ func (s *Server) signupPageHandler() http.HandlerFunc {
return
}
s.render(w, "signup.html", newSignupData(state.Email, "", state.CSRFToken))
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(state.Email, "", state.CSRFToken)))
}
}
@ -44,7 +44,7 @@ func (s *Server) signupHandler() http.HandlerFunc {
email, err := auth.NewUserEmail(emailValue)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
s.render(w, "signup.html", newSignupData("", credentialRequiredMsg, state.CSRFToken))
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData("", credentialRequiredMsg, state.CSRFToken)))
return
}
@ -59,13 +59,13 @@ func (s *Server) signupHandler() http.HandlerFunc {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
case errors.Is(err, auth.ErrWeakPassword):
w.WriteHeader(http.StatusBadRequest)
s.render(w, "signup.html", newSignupData(email.String(), weakPasswordMsg, state.CSRFToken))
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), weakPasswordMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrInvalidInput):
w.WriteHeader(http.StatusBadRequest)
s.render(w, "signup.html", newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken))
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrEmailExists):
w.WriteHeader(http.StatusConflict)
s.render(w, "signup.html", newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken))
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken)))
default:
logger.Error("register failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)

View file

@ -10,6 +10,8 @@ import (
func (s *Server) registerRoutes(r chi.Router) {
r.Get("/", s.loginPageHandler())
r.Post("/login", s.loginHandler())
r.Get("/login/google", s.googleLoginHandler())
r.Get("/login/google/callback", s.googleCallbackHandler())
r.Post("/logout", s.logoutHandler())
r.Get("/signup", s.signupPageHandler())
r.Post("/signup", s.signupHandler())

View file

@ -12,6 +12,8 @@ import (
"github.com/rjnemo/auth/internal/driver/logging"
"github.com/rjnemo/auth/internal/service/auth"
"github.com/rjnemo/auth/web"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const (
@ -26,6 +28,7 @@ type Server struct {
sessions *SessionStore
logger *slog.Logger
configuration config.Config
googleOAuth *oauth2.Config
}
// New constructs a Server with parsed templates and default state.
@ -57,12 +60,27 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
}
logger = logger.With(slog.String("service", "http"))
var googleOAuthConfig *oauth2.Config
if cfg.GoogleOAuth.Enabled() {
googleOAuthConfig = &oauth2.Config{
ClientID: cfg.GoogleOAuth.ClientID,
ClientSecret: cfg.GoogleOAuth.ClientSecret,
RedirectURL: cfg.GoogleOAuth.RedirectURL,
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
}
return &Server{
templates: tmpl,
authService: auth.NewService(store),
sessions: sessionStore,
logger: logger,
configuration: cfg,
googleOAuth: googleOAuthConfig,
}, nil
}
@ -80,6 +98,7 @@ func seedUser(store auth.UserStore) error {
Email: email,
PasswordSalt: salt,
PasswordHash: hash,
Provider: auth.ProviderPassword,
CreatedAt: time.Now().UTC(),
})
}

View file

@ -32,6 +32,30 @@ func newTestServer(t *testing.T) *Server {
return srv
}
func newGoogleTestServer(t *testing.T) *Server {
t.Helper()
cfg := config.Config{
ListenAddr: ":0",
LogMode: logging.ModeText,
Environment: "test",
SessionSecret: bytes.Repeat([]byte("g"), 32),
GoogleOAuth: config.GoogleOAuthConfig{
ClientID: "client",
ClientSecret: "secret",
RedirectURL: "http://localhost/login/google/callback",
},
}
logger := logging.New(io.Discard, logging.ModeText, nil)
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("new google server: %v", err)
}
return srv
}
func attachSession(req *http.Request, state SessionState) *http.Request {
return req.WithContext(withSession(req.Context(), state))
}
@ -55,6 +79,26 @@ func TestLoginPageHandler(t *testing.T) {
}
}
func TestLoginPageHandlerIncludesGoogleLinkWhenConfigured(t *testing.T) {
t.Parallel()
srv := newGoogleTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = attachSession(req, SessionState{CSRFToken: "token"})
rr := httptest.NewRecorder()
srv.loginPageHandler()(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
body := rr.Body.String()
if !strings.Contains(body, "id=\"google_login_form\"") {
t.Fatalf("expected google login form in page, got %q", body)
}
}
func TestLoginHandlerSuccess(t *testing.T) {
t.Parallel()
@ -115,6 +159,112 @@ func TestLoginHandlerInvalidCredentials(t *testing.T) {
}
}
func TestGoogleLoginHandlerDisabled(t *testing.T) {
t.Parallel()
srv := newTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/login/google", nil)
rr := httptest.NewRecorder()
srv.googleLoginHandler()(rr, req)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected 404 when google oauth disabled, got %d", rr.Code)
}
}
func TestGoogleLoginHandlerRedirects(t *testing.T) {
t.Parallel()
srv := newGoogleTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/login/google", nil)
req = attachSession(req, SessionState{CSRFToken: "csrf"})
rr := httptest.NewRecorder()
srv.googleLoginHandler()(rr, req)
res := rr.Result()
if res.StatusCode != http.StatusFound {
t.Fatalf("expected 302 redirect, got %d", res.StatusCode)
}
location := res.Header.Get("Location")
if location == "" {
t.Fatal("expected redirect location header")
}
if !strings.Contains(location, "accounts.google.com") {
t.Fatalf("expected google authorization URL, got %q", location)
}
parsed, err := url.Parse(location)
if err != nil {
t.Fatalf("parse redirect url: %v", err)
}
stateParam := parsed.Query().Get("state")
if stateParam == "" {
t.Fatal("expected state parameter in redirect")
}
var sessionCookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == sessionCookieName {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("expected session cookie to be set")
}
savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret)
if err != nil {
t.Fatalf("decode session: %v", err)
}
if savedState.OAuthState == "" {
t.Fatal("expected oauth state stored in session")
}
if savedState.OAuthState != stateParam {
t.Fatalf("expected oauth state %q to match redirect param %q", savedState.OAuthState, stateParam)
}
}
func TestGoogleCallbackHandlerStateMismatch(t *testing.T) {
t.Parallel()
srv := newGoogleTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/login/google/callback?state=other&code=ignored", nil)
req = attachSession(req, SessionState{OAuthState: "expected", CSRFToken: "csrf"})
rr := httptest.NewRecorder()
srv.googleCallbackHandler()(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for state mismatch, got %d", rr.Code)
}
res := rr.Result()
var sessionCookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == sessionCookieName {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("expected session cookie to be set")
}
savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret)
if err != nil {
t.Fatalf("decode session: %v", err)
}
if savedState.OAuthState != "" {
t.Fatal("expected oauth state to be cleared after mismatch")
}
}
func TestSignupHandlerSuccess(t *testing.T) {
t.Parallel()

View file

@ -13,6 +13,7 @@ const (
sessionLifetime = 12 * time.Hour
sessionSecretMinLength = 32
csrfTokenByteLength int = 32
oauthStateByteLength int = 32
)
// SessionStore persists session data using secure HTTP cookies.
@ -33,9 +34,10 @@ func NewSessionStore(secret []byte) (*SessionStore, error) {
// SessionState holds per-request session data after loading.
type SessionState struct {
Authenticated bool
Email string
CSRFToken string
Authenticated bool `json:"authenticated"`
Email string `json:"email"`
CSRFToken string `json:"csrf_token"`
OAuthState string `json:"oauth_state"`
}
// Load extracts session data from the request cookies.
@ -99,3 +101,11 @@ func ensureCSRFToken(state SessionState) (SessionState, error) {
state.CSRFToken = base64.RawURLEncoding.EncodeToString(token)
return state, nil
}
func generateOAuthState() (string, error) {
buf := make([]byte, oauthStateByteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}

View file

@ -17,20 +17,30 @@ func (s *Server) render(w http.ResponseWriter, name string, data any) {
// PageData contains fields shared by the templates for now.
type PageData struct {
Title string
View string
Email string
Error string
Info string
CSRFToken string
CreatedAt string
CreatedAtISO string
Title string
View string
Email string
Error string
Info string
CSRFToken string
CreatedAt string
CreatedAtISO string
GoogleLoginURL string
GoogleLoginEnabled bool
}
func newLoginData(email, errMsg, token string) PageData {
return PageData{Title: "Sign in · Auth Demo", View: "login", Email: email, Error: errMsg, CSRFToken: token}
}
func (s *Server) applyOAuthOptions(data PageData) PageData {
if s.googleOAuth != nil {
data.GoogleLoginEnabled = true
data.GoogleLoginURL = "/login/google"
}
return data
}
func newUnauthorizedData(errMsg, token string) PageData {
return PageData{
Title: "Access denied · Auth Demo",

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"
)
@ -14,9 +15,17 @@ var (
ErrInvalidCredentials = errors.New("auth: invalid credentials")
// ErrEmailExists indicates an account already uses the provided email address.
ErrEmailExists = errors.New("auth: email already registered")
// ErrProviderRequired indicates the external provider identifier was missing.
ErrProviderRequired = errors.New("auth: provider required")
)
const userIDByteLength = 16
const (
userIDByteLength = 16
// ProviderPassword identifies accounts managed via email/password.
ProviderPassword = "password"
// ProviderGoogle identifies accounts authenticated via Google OAuth2.
ProviderGoogle = "google"
)
// Service exposes authentication business operations to HTTP handlers.
type Service struct {
@ -91,6 +100,7 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
Email: email,
PasswordSalt: salt,
PasswordHash: hash,
Provider: ProviderPassword,
CreatedAt: time.Now().UTC(),
}
@ -100,3 +110,39 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
return &user, nil
}
// EnsureExternalUser retrieves or provisions an account authenticated by an external provider.
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) {
if email.IsZero() {
return nil, ErrInvalidInput
}
if strings.TrimSpace(provider) == "" {
return nil, ErrProviderRequired
}
account, err := s.store.FindByEmail(ctx, email)
switch {
case err == nil:
return account, nil
case !errors.Is(err, ErrUserNotFound):
return nil, err
}
id, err := generateUserID()
if err != nil {
return nil, fmt.Errorf("generate user id: %w", err)
}
user := User{
ID: id,
Email: email,
Provider: provider,
CreatedAt: time.Now().UTC(),
}
if err := s.store.Create(ctx, user); err != nil {
return nil, err
}
return &user, nil
}

View file

@ -18,7 +18,7 @@ func TestServiceAuthenticate(t *testing.T) {
if err != nil {
t.Fatalf("hash password: %v", err)
}
if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash}); err != nil {
if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash, Provider: ProviderPassword}); err != nil {
t.Fatalf("seed user: %v", err)
}
@ -70,7 +70,7 @@ func TestServiceLookupByEmail(t *testing.T) {
service := NewService(store)
email := MustUserEmail("lookup@example.com")
if err := store.Create(ctx, User{Email: email}); err != nil {
if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil {
t.Fatalf("seed user: %v", err)
}
@ -116,7 +116,7 @@ func TestServiceRegister(t *testing.T) {
service := NewService(store)
email := MustUserEmail("taken@example.com")
if err := store.Create(ctx, User{Email: email}); err != nil {
if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil {
t.Fatalf("seed user: %v", err)
}
@ -166,3 +166,65 @@ func TestServiceRegister(t *testing.T) {
})
}
}
func TestServiceEnsureExternalUser(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := NewMemoryStore()
service := NewService(store)
googleEmail := MustUserEmail("google@example.com")
if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil {
t.Fatalf("seed external user: %v", err)
}
tests := map[string]struct {
email UserEmail
provider string
wantErr error
wantNew bool
}{
"missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput},
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired},
"existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false},
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider)
if tc.wantErr != nil {
if !errors.Is(err, tc.wantErr) {
t.Fatalf("expected %v, got %v", tc.wantErr, err)
}
if user != nil {
t.Fatalf("expected nil user, got %#v", user)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user == nil {
t.Fatal("expected user")
}
if user.Email != tc.email {
t.Fatalf("expected email %q, got %q", tc.email, user.Email)
}
if user.Provider != tc.provider {
t.Fatalf("expected provider %q, got %q", tc.provider, user.Provider)
}
persisted, err := store.FindByEmail(ctx, tc.email)
if err != nil {
t.Fatalf("expected user persisted: %v", err)
}
if tc.wantNew && persisted.CreatedAt.IsZero() {
t.Fatal("expected created at timestamp for new user")
}
})
}
}

View file

@ -14,6 +14,7 @@ type User struct {
Email UserEmail
PasswordSalt string
PasswordHash string
Provider string
CreatedAt time.Time
}

View file

@ -52,8 +52,14 @@
</div>
<div class="auth-actions">
<button type="submit" class="primary">Log in</button>
{{if .GoogleLoginEnabled}}
<div class="auth-divider">or</div>
<button type="button" class="secondary outline auth-google">
<button
type="submit"
class="secondary outline auth-google"
form="google_login_form"
formnovalidate
>
<svg
width="18"
height="18"
@ -81,8 +87,12 @@
</svg>
Continue with Google
</button>
{{end}}
</div>
</form>
{{if .GoogleLoginEnabled}}
<form id="google_login_form" action="{{.GoogleLoginURL}}" method="get" hidden></form>
{{end}}
<p class="auth-footer">
Don't have an account? <a href="/signup">Sign up</a>
</p>

View file

@ -67,8 +67,14 @@
</label>
<div class="auth-actions">
<button type="submit" class="primary">Create account</button>
{{if .GoogleLoginEnabled}}
<div class="auth-divider">or</div>
<button type="button" class="secondary outline auth-google">
<button
type="submit"
class="secondary outline auth-google"
form="google_login_form"
formnovalidate
>
<svg
width="18"
height="18"
@ -96,8 +102,12 @@
</svg>
Sign up with Google
</button>
{{end}}
</div>
</form>
{{if .GoogleLoginEnabled}}
<form id="google_login_form" action="{{.GoogleLoginURL}}" method="get" hidden></form>
{{end}}
<p class="auth-footer">
Have an account? <a href="/">Log in</a>
</p>