feat: add google oauth login

This commit is contained in:
Ruidy 2025-09-20 18:28:04 +02:00
parent a20953cfb4
commit 52f479206c
No known key found for this signature in database
GPG key ID: 705C24D202990805
17 changed files with 625 additions and 31 deletions

5
go.mod
View file

@ -3,3 +3,8 @@ module github.com/rjnemo/auth
go 1.25.1 go 1.25.1
require github.com/go-chi/chi/v5 v5.2.3 require github.com/go-chi/chi/v5 v5.2.3
require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
golang.org/x/oauth2 v0.31.0 // indirect
)

4
go.sum
View file

@ -1,2 +1,6 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=

View file

@ -15,6 +15,9 @@ const (
envLogMode = "AUTH_LOG_MODE" envLogMode = "AUTH_LOG_MODE"
envEnvironment = "AUTH_ENV" envEnvironment = "AUTH_ENV"
envSessionSecret = "AUTH_SESSION_SECRET" envSessionSecret = "AUTH_SESSION_SECRET"
envGoogleClientID = "AUTH_GOOGLE_CLIENT_ID"
envGoogleClientSecret = "AUTH_GOOGLE_CLIENT_SECRET"
envGoogleRedirectURL = "AUTH_GOOGLE_REDIRECT_URL"
defaultListenAddr = ":8000" defaultListenAddr = ":8000"
defaultEnvironment = "development" defaultEnvironment = "development"
@ -26,6 +29,19 @@ type Config struct {
LogMode logging.Mode LogMode logging.Mode
Environment string Environment string
SessionSecret []byte SessionSecret []byte
GoogleOAuth GoogleOAuthConfig
}
// GoogleOAuthConfig holds configuration for Google OAuth2 login.
type GoogleOAuthConfig struct {
ClientID string
ClientSecret string
RedirectURL string
}
// Enabled reports whether Google OAuth2 is fully configured.
func (g GoogleOAuthConfig) Enabled() bool {
return g.ClientID != "" && g.ClientSecret != "" && g.RedirectURL != ""
} }
// New loads configuration from environment variables, applying defaults and validation. // New loads configuration from environment variables, applying defaults and validation.
@ -47,12 +63,34 @@ func New() (*Config, error) {
return nil, fmt.Errorf("invalid %s: %w", envSessionSecret, err) return nil, fmt.Errorf("invalid %s: %w", envSessionSecret, err)
} }
googleOAuth := GoogleOAuthConfig{
ClientID: strings.TrimSpace(os.Getenv(envGoogleClientID)),
ClientSecret: strings.TrimSpace(os.Getenv(envGoogleClientSecret)),
RedirectURL: strings.TrimSpace(os.Getenv(envGoogleRedirectURL)),
}
if partiallyConfigured(googleOAuth) {
return nil, fmt.Errorf("incomplete google oauth configuration: set %s, %s, and %s", envGoogleClientID, envGoogleClientSecret, envGoogleRedirectURL)
}
cfg := &Config{ cfg := &Config{
ListenAddr: listenAddr, ListenAddr: listenAddr,
LogMode: logMode, LogMode: logMode,
Environment: environment, Environment: environment,
SessionSecret: secret, SessionSecret: secret,
GoogleOAuth: googleOAuth,
} }
return cfg, nil return cfg, nil
} }
func partiallyConfigured(cfg GoogleOAuthConfig) bool {
switch {
case cfg.ClientID == "" && cfg.ClientSecret == "" && cfg.RedirectURL == "":
return false
case cfg.ClientID == "" || cfg.ClientSecret == "" || cfg.RedirectURL == "":
return true
default:
return false
}
}

View file

@ -73,6 +73,33 @@ func TestNewShortSecretAccepted(t *testing.T) {
} }
} }
func TestNewGoogleOAuthPartialConfiguration(t *testing.T) {
t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32)))
t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client")
t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "")
if _, err := New(); err == nil {
t.Fatalf("expected error for partial google oauth config")
}
}
func TestNewGoogleOAuthConfigured(t *testing.T) {
t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32)))
t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client")
t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "secret")
t.Setenv("AUTH_GOOGLE_REDIRECT_URL", "http://localhost:8000/login/google/callback")
cfg, err := New()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !cfg.GoogleOAuth.Enabled() {
t.Fatalf("expected google oauth to be enabled")
}
if cfg.GoogleOAuth.ClientID != "client" {
t.Fatalf("expected client id to match, got %q", cfg.GoogleOAuth.ClientID)
}
}
func bytesOfLength(n int) []byte { func bytesOfLength(n int) []byte {
b := make([]byte, n) b := make([]byte, n)
for i := range b { for i := range b {

View file

@ -15,7 +15,7 @@ func (s *Server) loginPageHandler() http.HandlerFunc {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther) http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
return return
} }
s.render(w, "login.html", newLoginData(state.Email, "", state.CSRFToken)) s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, "", state.CSRFToken)))
} }
} }
@ -35,7 +35,7 @@ func (s *Server) loginHandler() http.HandlerFunc {
email, err := auth.NewUserEmail(emailInput) email, err := auth.NewUserEmail(emailInput)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "login.html", newLoginData("", credentialRequiredMsg, state.CSRFToken)) s.render(w, "login.html", s.applyOAuthOptions(newLoginData("", credentialRequiredMsg, state.CSRFToken)))
return return
} }
@ -51,10 +51,10 @@ func (s *Server) loginHandler() http.HandlerFunc {
case errors.Is(err, auth.ErrWeakPassword): case errors.Is(err, auth.ErrWeakPassword):
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "login.html", newLoginData(email.String(), weakPasswordMsg, state.CSRFToken)) s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), weakPasswordMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrInvalidInput): case errors.Is(err, auth.ErrInvalidInput):
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "login.html", newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken)) s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrInvalidCredentials): case errors.Is(err, auth.ErrInvalidCredentials):
s.renderLoginFailure(w, email, state.CSRFToken) s.renderLoginFailure(w, email, state.CSRFToken)
default: default:
@ -66,5 +66,5 @@ func (s *Server) loginHandler() http.HandlerFunc {
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) { func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
s.render(w, "login.html", newLoginData(email.String(), invalidCredentialsMsg, token)) s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), invalidCredentialsMsg, token)))
} }

View file

@ -0,0 +1,200 @@
package server
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"golang.org/x/oauth2"
"github.com/rjnemo/auth/internal/service/auth"
)
const (
googleUserInfoEndpoint = "https://www.googleapis.com/oauth2/v3/userinfo"
googleAuthFailedMsg = "Unable to sign in with Google. Please try again."
googleAuthCanceledMsg = "Google sign-in was cancelled."
)
type googleUserInfo struct {
ID string `json:"sub"`
Email string `json:"email"`
VerifiedEmail bool `json:"email_verified"`
}
func (s *Server) googleLoginHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger := s.logger.With(slog.String("component", "google_oauth"))
if s.googleOAuth == nil {
http.NotFound(w, r)
return
}
state := sessionFromContext(r.Context())
if state.Authenticated {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
return
}
token, err := generateOAuthState()
if err != nil {
logger.Error("generate oauth state failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
state.OAuthState = token
if err := s.sessions.Save(w, state); err != nil {
logger.Error("persist oauth state failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
redirectURL := s.googleOAuth.AuthCodeURL(token, oauth2.AccessTypeOnline)
http.Redirect(w, r, redirectURL, http.StatusFound)
}
}
func (s *Server) googleCallbackHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger := s.logger.With(slog.String("component", "google_oauth"))
if s.googleOAuth == nil {
http.NotFound(w, r)
return
}
state := sessionFromContext(r.Context())
expectedState := state.OAuthState
providedState := r.URL.Query().Get("state")
state.OAuthState = ""
saveState := func() bool {
if err := s.sessions.Save(w, state); err != nil {
logger.Error("session save failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return false
}
return true
}
respondWithLogin := func(status int, message string) {
if status != 0 {
w.WriteHeader(status)
}
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, message, state.CSRFToken)))
}
if expectedState == "" || providedState == "" || providedState != expectedState {
if !saveState() {
return
}
http.Error(w, "invalid oauth state", http.StatusBadRequest)
return
}
if errParam := r.URL.Query().Get("error"); errParam != "" {
logger.Info("google oauth returned error", slog.String("google_error", errParam))
if !saveState() {
return
}
respondWithLogin(http.StatusBadRequest, googleAuthCanceledMsg)
return
}
authCode := r.URL.Query().Get("code")
if authCode == "" {
if !saveState() {
return
}
http.Error(w, "missing authorization code", http.StatusBadRequest)
return
}
token, err := s.googleOAuth.Exchange(r.Context(), authCode)
if err != nil {
logger.Error("oauth code exchange failed", slog.Any("error", err))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
info, err := s.fetchGoogleUserInfo(r.Context(), token)
if err != nil {
logger.Error("fetch google user info failed", slog.Any("error", err))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
if !info.VerifiedEmail || info.Email == "" {
logger.Warn("google returned unverified email", slog.Bool("verified", info.VerifiedEmail))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
email, err := auth.NewUserEmail(info.Email)
if err != nil {
logger.Error("normalize google email failed", slog.Any("error", err))
if !saveState() {
return
}
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
return
}
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle)
if err != nil {
logger.Error("ensure external user failed", slog.Any("error", err))
if !saveState() {
return
}
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
state.Authenticated = true
state.Email = account.Email.String()
if err := s.sessions.Save(w, state); err != nil {
logger.Error("session save failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
}
}
func (s *Server) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token) (googleUserInfo, error) {
client := s.googleOAuth.Client(ctx, token)
resp, err := client.Get(googleUserInfoEndpoint)
if err != nil {
return googleUserInfo{}, fmt.Errorf("request google userinfo: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
s.logger.With(slog.String("component", "google_oauth")).Warn("close google userinfo body failed", slog.Any("error", cerr))
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return googleUserInfo{}, fmt.Errorf("google userinfo response %d: %s", resp.StatusCode, string(body))
}
var info googleUserInfo
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
return googleUserInfo{}, fmt.Errorf("decode google userinfo: %w", err)
}
return info, nil
}

View file

@ -24,7 +24,7 @@ func (s *Server) signupPageHandler() http.HandlerFunc {
return return
} }
s.render(w, "signup.html", newSignupData(state.Email, "", state.CSRFToken)) s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(state.Email, "", state.CSRFToken)))
} }
} }
@ -44,7 +44,7 @@ func (s *Server) signupHandler() http.HandlerFunc {
email, err := auth.NewUserEmail(emailValue) email, err := auth.NewUserEmail(emailValue)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "signup.html", newSignupData("", credentialRequiredMsg, state.CSRFToken)) s.render(w, "signup.html", s.applyOAuthOptions(newSignupData("", credentialRequiredMsg, state.CSRFToken)))
return return
} }
@ -59,13 +59,13 @@ func (s *Server) signupHandler() http.HandlerFunc {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther) http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
case errors.Is(err, auth.ErrWeakPassword): case errors.Is(err, auth.ErrWeakPassword):
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "signup.html", newSignupData(email.String(), weakPasswordMsg, state.CSRFToken)) s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), weakPasswordMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrInvalidInput): case errors.Is(err, auth.ErrInvalidInput):
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "signup.html", newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken)) s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken)))
case errors.Is(err, auth.ErrEmailExists): case errors.Is(err, auth.ErrEmailExists):
w.WriteHeader(http.StatusConflict) w.WriteHeader(http.StatusConflict)
s.render(w, "signup.html", newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken)) s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken)))
default: default:
logger.Error("register failed", slog.Any("error", err)) logger.Error("register failed", slog.Any("error", err))
http.Error(w, "unexpected error", http.StatusInternalServerError) http.Error(w, "unexpected error", http.StatusInternalServerError)

View file

@ -10,6 +10,8 @@ import (
func (s *Server) registerRoutes(r chi.Router) { func (s *Server) registerRoutes(r chi.Router) {
r.Get("/", s.loginPageHandler()) r.Get("/", s.loginPageHandler())
r.Post("/login", s.loginHandler()) r.Post("/login", s.loginHandler())
r.Get("/login/google", s.googleLoginHandler())
r.Get("/login/google/callback", s.googleCallbackHandler())
r.Post("/logout", s.logoutHandler()) r.Post("/logout", s.logoutHandler())
r.Get("/signup", s.signupPageHandler()) r.Get("/signup", s.signupPageHandler())
r.Post("/signup", s.signupHandler()) r.Post("/signup", s.signupHandler())

View file

@ -12,6 +12,8 @@ import (
"github.com/rjnemo/auth/internal/driver/logging" "github.com/rjnemo/auth/internal/driver/logging"
"github.com/rjnemo/auth/internal/service/auth" "github.com/rjnemo/auth/internal/service/auth"
"github.com/rjnemo/auth/web" "github.com/rjnemo/auth/web"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
) )
const ( const (
@ -26,6 +28,7 @@ type Server struct {
sessions *SessionStore sessions *SessionStore
logger *slog.Logger logger *slog.Logger
configuration config.Config configuration config.Config
googleOAuth *oauth2.Config
} }
// New constructs a Server with parsed templates and default state. // New constructs a Server with parsed templates and default state.
@ -57,12 +60,27 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
} }
logger = logger.With(slog.String("service", "http")) logger = logger.With(slog.String("service", "http"))
var googleOAuthConfig *oauth2.Config
if cfg.GoogleOAuth.Enabled() {
googleOAuthConfig = &oauth2.Config{
ClientID: cfg.GoogleOAuth.ClientID,
ClientSecret: cfg.GoogleOAuth.ClientSecret,
RedirectURL: cfg.GoogleOAuth.RedirectURL,
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
}
return &Server{ return &Server{
templates: tmpl, templates: tmpl,
authService: auth.NewService(store), authService: auth.NewService(store),
sessions: sessionStore, sessions: sessionStore,
logger: logger, logger: logger,
configuration: cfg, configuration: cfg,
googleOAuth: googleOAuthConfig,
}, nil }, nil
} }
@ -80,6 +98,7 @@ func seedUser(store auth.UserStore) error {
Email: email, Email: email,
PasswordSalt: salt, PasswordSalt: salt,
PasswordHash: hash, PasswordHash: hash,
Provider: auth.ProviderPassword,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
}) })
} }

View file

@ -32,6 +32,30 @@ func newTestServer(t *testing.T) *Server {
return srv return srv
} }
func newGoogleTestServer(t *testing.T) *Server {
t.Helper()
cfg := config.Config{
ListenAddr: ":0",
LogMode: logging.ModeText,
Environment: "test",
SessionSecret: bytes.Repeat([]byte("g"), 32),
GoogleOAuth: config.GoogleOAuthConfig{
ClientID: "client",
ClientSecret: "secret",
RedirectURL: "http://localhost/login/google/callback",
},
}
logger := logging.New(io.Discard, logging.ModeText, nil)
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("new google server: %v", err)
}
return srv
}
func attachSession(req *http.Request, state SessionState) *http.Request { func attachSession(req *http.Request, state SessionState) *http.Request {
return req.WithContext(withSession(req.Context(), state)) return req.WithContext(withSession(req.Context(), state))
} }
@ -55,6 +79,26 @@ func TestLoginPageHandler(t *testing.T) {
} }
} }
func TestLoginPageHandlerIncludesGoogleLinkWhenConfigured(t *testing.T) {
t.Parallel()
srv := newGoogleTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = attachSession(req, SessionState{CSRFToken: "token"})
rr := httptest.NewRecorder()
srv.loginPageHandler()(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
body := rr.Body.String()
if !strings.Contains(body, "id=\"google_login_form\"") {
t.Fatalf("expected google login form in page, got %q", body)
}
}
func TestLoginHandlerSuccess(t *testing.T) { func TestLoginHandlerSuccess(t *testing.T) {
t.Parallel() t.Parallel()
@ -115,6 +159,112 @@ func TestLoginHandlerInvalidCredentials(t *testing.T) {
} }
} }
func TestGoogleLoginHandlerDisabled(t *testing.T) {
t.Parallel()
srv := newTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/login/google", nil)
rr := httptest.NewRecorder()
srv.googleLoginHandler()(rr, req)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected 404 when google oauth disabled, got %d", rr.Code)
}
}
func TestGoogleLoginHandlerRedirects(t *testing.T) {
t.Parallel()
srv := newGoogleTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/login/google", nil)
req = attachSession(req, SessionState{CSRFToken: "csrf"})
rr := httptest.NewRecorder()
srv.googleLoginHandler()(rr, req)
res := rr.Result()
if res.StatusCode != http.StatusFound {
t.Fatalf("expected 302 redirect, got %d", res.StatusCode)
}
location := res.Header.Get("Location")
if location == "" {
t.Fatal("expected redirect location header")
}
if !strings.Contains(location, "accounts.google.com") {
t.Fatalf("expected google authorization URL, got %q", location)
}
parsed, err := url.Parse(location)
if err != nil {
t.Fatalf("parse redirect url: %v", err)
}
stateParam := parsed.Query().Get("state")
if stateParam == "" {
t.Fatal("expected state parameter in redirect")
}
var sessionCookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == sessionCookieName {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("expected session cookie to be set")
}
savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret)
if err != nil {
t.Fatalf("decode session: %v", err)
}
if savedState.OAuthState == "" {
t.Fatal("expected oauth state stored in session")
}
if savedState.OAuthState != stateParam {
t.Fatalf("expected oauth state %q to match redirect param %q", savedState.OAuthState, stateParam)
}
}
func TestGoogleCallbackHandlerStateMismatch(t *testing.T) {
t.Parallel()
srv := newGoogleTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/login/google/callback?state=other&code=ignored", nil)
req = attachSession(req, SessionState{OAuthState: "expected", CSRFToken: "csrf"})
rr := httptest.NewRecorder()
srv.googleCallbackHandler()(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for state mismatch, got %d", rr.Code)
}
res := rr.Result()
var sessionCookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == sessionCookieName {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("expected session cookie to be set")
}
savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret)
if err != nil {
t.Fatalf("decode session: %v", err)
}
if savedState.OAuthState != "" {
t.Fatal("expected oauth state to be cleared after mismatch")
}
}
func TestSignupHandlerSuccess(t *testing.T) { func TestSignupHandlerSuccess(t *testing.T) {
t.Parallel() t.Parallel()

View file

@ -13,6 +13,7 @@ const (
sessionLifetime = 12 * time.Hour sessionLifetime = 12 * time.Hour
sessionSecretMinLength = 32 sessionSecretMinLength = 32
csrfTokenByteLength int = 32 csrfTokenByteLength int = 32
oauthStateByteLength int = 32
) )
// SessionStore persists session data using secure HTTP cookies. // SessionStore persists session data using secure HTTP cookies.
@ -33,9 +34,10 @@ func NewSessionStore(secret []byte) (*SessionStore, error) {
// SessionState holds per-request session data after loading. // SessionState holds per-request session data after loading.
type SessionState struct { type SessionState struct {
Authenticated bool Authenticated bool `json:"authenticated"`
Email string Email string `json:"email"`
CSRFToken string CSRFToken string `json:"csrf_token"`
OAuthState string `json:"oauth_state"`
} }
// Load extracts session data from the request cookies. // Load extracts session data from the request cookies.
@ -99,3 +101,11 @@ func ensureCSRFToken(state SessionState) (SessionState, error) {
state.CSRFToken = base64.RawURLEncoding.EncodeToString(token) state.CSRFToken = base64.RawURLEncoding.EncodeToString(token)
return state, nil return state, nil
} }
func generateOAuthState() (string, error) {
buf := make([]byte, oauthStateByteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}

View file

@ -25,12 +25,22 @@ type PageData struct {
CSRFToken string CSRFToken string
CreatedAt string CreatedAt string
CreatedAtISO string CreatedAtISO string
GoogleLoginURL string
GoogleLoginEnabled bool
} }
func newLoginData(email, errMsg, token string) PageData { func newLoginData(email, errMsg, token string) PageData {
return PageData{Title: "Sign in · Auth Demo", View: "login", Email: email, Error: errMsg, CSRFToken: token} return PageData{Title: "Sign in · Auth Demo", View: "login", Email: email, Error: errMsg, CSRFToken: token}
} }
func (s *Server) applyOAuthOptions(data PageData) PageData {
if s.googleOAuth != nil {
data.GoogleLoginEnabled = true
data.GoogleLoginURL = "/login/google"
}
return data
}
func newUnauthorizedData(errMsg, token string) PageData { func newUnauthorizedData(errMsg, token string) PageData {
return PageData{ return PageData{
Title: "Access denied · Auth Demo", Title: "Access denied · Auth Demo",

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strings"
"time" "time"
) )
@ -14,9 +15,17 @@ var (
ErrInvalidCredentials = errors.New("auth: invalid credentials") ErrInvalidCredentials = errors.New("auth: invalid credentials")
// ErrEmailExists indicates an account already uses the provided email address. // ErrEmailExists indicates an account already uses the provided email address.
ErrEmailExists = errors.New("auth: email already registered") ErrEmailExists = errors.New("auth: email already registered")
// ErrProviderRequired indicates the external provider identifier was missing.
ErrProviderRequired = errors.New("auth: provider required")
) )
const userIDByteLength = 16 const (
userIDByteLength = 16
// ProviderPassword identifies accounts managed via email/password.
ProviderPassword = "password"
// ProviderGoogle identifies accounts authenticated via Google OAuth2.
ProviderGoogle = "google"
)
// Service exposes authentication business operations to HTTP handlers. // Service exposes authentication business operations to HTTP handlers.
type Service struct { type Service struct {
@ -91,6 +100,43 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
Email: email, Email: email,
PasswordSalt: salt, PasswordSalt: salt,
PasswordHash: hash, PasswordHash: hash,
Provider: ProviderPassword,
CreatedAt: time.Now().UTC(),
}
if err := s.store.Create(ctx, user); err != nil {
return nil, err
}
return &user, nil
}
// EnsureExternalUser retrieves or provisions an account authenticated by an external provider.
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) {
if email.IsZero() {
return nil, ErrInvalidInput
}
if strings.TrimSpace(provider) == "" {
return nil, ErrProviderRequired
}
account, err := s.store.FindByEmail(ctx, email)
switch {
case err == nil:
return account, nil
case !errors.Is(err, ErrUserNotFound):
return nil, err
}
id, err := generateUserID()
if err != nil {
return nil, fmt.Errorf("generate user id: %w", err)
}
user := User{
ID: id,
Email: email,
Provider: provider,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
} }

View file

@ -18,7 +18,7 @@ func TestServiceAuthenticate(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("hash password: %v", err) t.Fatalf("hash password: %v", err)
} }
if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash}); err != nil { if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash, Provider: ProviderPassword}); err != nil {
t.Fatalf("seed user: %v", err) t.Fatalf("seed user: %v", err)
} }
@ -70,7 +70,7 @@ func TestServiceLookupByEmail(t *testing.T) {
service := NewService(store) service := NewService(store)
email := MustUserEmail("lookup@example.com") email := MustUserEmail("lookup@example.com")
if err := store.Create(ctx, User{Email: email}); err != nil { if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil {
t.Fatalf("seed user: %v", err) t.Fatalf("seed user: %v", err)
} }
@ -116,7 +116,7 @@ func TestServiceRegister(t *testing.T) {
service := NewService(store) service := NewService(store)
email := MustUserEmail("taken@example.com") email := MustUserEmail("taken@example.com")
if err := store.Create(ctx, User{Email: email}); err != nil { if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil {
t.Fatalf("seed user: %v", err) t.Fatalf("seed user: %v", err)
} }
@ -166,3 +166,65 @@ func TestServiceRegister(t *testing.T) {
}) })
} }
} }
func TestServiceEnsureExternalUser(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := NewMemoryStore()
service := NewService(store)
googleEmail := MustUserEmail("google@example.com")
if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil {
t.Fatalf("seed external user: %v", err)
}
tests := map[string]struct {
email UserEmail
provider string
wantErr error
wantNew bool
}{
"missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput},
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired},
"existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false},
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider)
if tc.wantErr != nil {
if !errors.Is(err, tc.wantErr) {
t.Fatalf("expected %v, got %v", tc.wantErr, err)
}
if user != nil {
t.Fatalf("expected nil user, got %#v", user)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user == nil {
t.Fatal("expected user")
}
if user.Email != tc.email {
t.Fatalf("expected email %q, got %q", tc.email, user.Email)
}
if user.Provider != tc.provider {
t.Fatalf("expected provider %q, got %q", tc.provider, user.Provider)
}
persisted, err := store.FindByEmail(ctx, tc.email)
if err != nil {
t.Fatalf("expected user persisted: %v", err)
}
if tc.wantNew && persisted.CreatedAt.IsZero() {
t.Fatal("expected created at timestamp for new user")
}
})
}
}

View file

@ -14,6 +14,7 @@ type User struct {
Email UserEmail Email UserEmail
PasswordSalt string PasswordSalt string
PasswordHash string PasswordHash string
Provider string
CreatedAt time.Time CreatedAt time.Time
} }

View file

@ -52,8 +52,14 @@
</div> </div>
<div class="auth-actions"> <div class="auth-actions">
<button type="submit" class="primary">Log in</button> <button type="submit" class="primary">Log in</button>
{{if .GoogleLoginEnabled}}
<div class="auth-divider">or</div> <div class="auth-divider">or</div>
<button type="button" class="secondary outline auth-google"> <button
type="submit"
class="secondary outline auth-google"
form="google_login_form"
formnovalidate
>
<svg <svg
width="18" width="18"
height="18" height="18"
@ -81,8 +87,12 @@
</svg> </svg>
Continue with Google Continue with Google
</button> </button>
{{end}}
</div> </div>
</form> </form>
{{if .GoogleLoginEnabled}}
<form id="google_login_form" action="{{.GoogleLoginURL}}" method="get" hidden></form>
{{end}}
<p class="auth-footer"> <p class="auth-footer">
Don't have an account? <a href="/signup">Sign up</a> Don't have an account? <a href="/signup">Sign up</a>
</p> </p>

View file

@ -67,8 +67,14 @@
</label> </label>
<div class="auth-actions"> <div class="auth-actions">
<button type="submit" class="primary">Create account</button> <button type="submit" class="primary">Create account</button>
{{if .GoogleLoginEnabled}}
<div class="auth-divider">or</div> <div class="auth-divider">or</div>
<button type="button" class="secondary outline auth-google"> <button
type="submit"
class="secondary outline auth-google"
form="google_login_form"
formnovalidate
>
<svg <svg
width="18" width="18"
height="18" height="18"
@ -96,8 +102,12 @@
</svg> </svg>
Sign up with Google Sign up with Google
</button> </button>
{{end}}
</div> </div>
</form> </form>
{{if .GoogleLoginEnabled}}
<form id="google_login_form" action="{{.GoogleLoginURL}}" method="get" hidden></form>
{{end}}
<p class="auth-footer"> <p class="auth-footer">
Have an account? <a href="/">Log in</a> Have an account? <a href="/">Log in</a>
</p> </p>