diff --git a/go.mod b/go.mod index f42e4a7..91cdafb 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,8 @@ module github.com/rjnemo/auth go 1.25.1 require github.com/go-chi/chi/v5 v5.2.3 + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + golang.org/x/oauth2 v0.31.0 // indirect +) diff --git a/go.sum b/go.sum index 5bd7be3..412f8d5 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= +golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= diff --git a/internal/config/config.go b/internal/config/config.go index fb4d599..473a587 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,10 +11,13 @@ import ( ) const ( - envListenAddr = "AUTH_LISTEN_ADDR" - envLogMode = "AUTH_LOG_MODE" - envEnvironment = "AUTH_ENV" - envSessionSecret = "AUTH_SESSION_SECRET" + envListenAddr = "AUTH_LISTEN_ADDR" + envLogMode = "AUTH_LOG_MODE" + envEnvironment = "AUTH_ENV" + envSessionSecret = "AUTH_SESSION_SECRET" + envGoogleClientID = "AUTH_GOOGLE_CLIENT_ID" + envGoogleClientSecret = "AUTH_GOOGLE_CLIENT_SECRET" + envGoogleRedirectURL = "AUTH_GOOGLE_REDIRECT_URL" defaultListenAddr = ":8000" defaultEnvironment = "development" @@ -26,6 +29,19 @@ type Config struct { LogMode logging.Mode Environment string SessionSecret []byte + GoogleOAuth GoogleOAuthConfig +} + +// GoogleOAuthConfig holds configuration for Google OAuth2 login. +type GoogleOAuthConfig struct { + ClientID string + ClientSecret string + RedirectURL string +} + +// Enabled reports whether Google OAuth2 is fully configured. +func (g GoogleOAuthConfig) Enabled() bool { + return g.ClientID != "" && g.ClientSecret != "" && g.RedirectURL != "" } // New loads configuration from environment variables, applying defaults and validation. @@ -47,12 +63,34 @@ func New() (*Config, error) { return nil, fmt.Errorf("invalid %s: %w", envSessionSecret, err) } + googleOAuth := GoogleOAuthConfig{ + ClientID: strings.TrimSpace(os.Getenv(envGoogleClientID)), + ClientSecret: strings.TrimSpace(os.Getenv(envGoogleClientSecret)), + RedirectURL: strings.TrimSpace(os.Getenv(envGoogleRedirectURL)), + } + + if partiallyConfigured(googleOAuth) { + return nil, fmt.Errorf("incomplete google oauth configuration: set %s, %s, and %s", envGoogleClientID, envGoogleClientSecret, envGoogleRedirectURL) + } + cfg := &Config{ ListenAddr: listenAddr, LogMode: logMode, Environment: environment, SessionSecret: secret, + GoogleOAuth: googleOAuth, } return cfg, nil } + +func partiallyConfigured(cfg GoogleOAuthConfig) bool { + switch { + case cfg.ClientID == "" && cfg.ClientSecret == "" && cfg.RedirectURL == "": + return false + case cfg.ClientID == "" || cfg.ClientSecret == "" || cfg.RedirectURL == "": + return true + default: + return false + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b2ba771..9305cf0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -73,6 +73,33 @@ func TestNewShortSecretAccepted(t *testing.T) { } } +func TestNewGoogleOAuthPartialConfiguration(t *testing.T) { + t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32))) + t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client") + t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "") + if _, err := New(); err == nil { + t.Fatalf("expected error for partial google oauth config") + } +} + +func TestNewGoogleOAuthConfigured(t *testing.T) { + t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32))) + t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client") + t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "secret") + t.Setenv("AUTH_GOOGLE_REDIRECT_URL", "http://localhost:8000/login/google/callback") + + cfg, err := New() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cfg.GoogleOAuth.Enabled() { + t.Fatalf("expected google oauth to be enabled") + } + if cfg.GoogleOAuth.ClientID != "client" { + t.Fatalf("expected client id to match, got %q", cfg.GoogleOAuth.ClientID) + } +} + func bytesOfLength(n int) []byte { b := make([]byte, n) for i := range b { diff --git a/internal/server/handler_login.go b/internal/server/handler_login.go index 0dea2b0..48cb5c3 100644 --- a/internal/server/handler_login.go +++ b/internal/server/handler_login.go @@ -15,7 +15,7 @@ func (s *Server) loginPageHandler() http.HandlerFunc { http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } - s.render(w, "login.html", newLoginData(state.Email, "", state.CSRFToken)) + s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, "", state.CSRFToken))) } } @@ -35,7 +35,7 @@ func (s *Server) loginHandler() http.HandlerFunc { email, err := auth.NewUserEmail(emailInput) if err != nil { w.WriteHeader(http.StatusBadRequest) - s.render(w, "login.html", newLoginData("", credentialRequiredMsg, state.CSRFToken)) + s.render(w, "login.html", s.applyOAuthOptions(newLoginData("", credentialRequiredMsg, state.CSRFToken))) return } @@ -51,10 +51,10 @@ func (s *Server) loginHandler() http.HandlerFunc { case errors.Is(err, auth.ErrWeakPassword): w.WriteHeader(http.StatusBadRequest) - s.render(w, "login.html", newLoginData(email.String(), weakPasswordMsg, state.CSRFToken)) + s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), weakPasswordMsg, state.CSRFToken))) case errors.Is(err, auth.ErrInvalidInput): w.WriteHeader(http.StatusBadRequest) - s.render(w, "login.html", newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken)) + s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken))) case errors.Is(err, auth.ErrInvalidCredentials): s.renderLoginFailure(w, email, state.CSRFToken) default: @@ -66,5 +66,5 @@ func (s *Server) loginHandler() http.HandlerFunc { func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) { w.WriteHeader(http.StatusUnauthorized) - s.render(w, "login.html", newLoginData(email.String(), invalidCredentialsMsg, token)) + s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), invalidCredentialsMsg, token))) } diff --git a/internal/server/handler_login_google.go b/internal/server/handler_login_google.go new file mode 100644 index 0000000..721c6ac --- /dev/null +++ b/internal/server/handler_login_google.go @@ -0,0 +1,200 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + + "golang.org/x/oauth2" + + "github.com/rjnemo/auth/internal/service/auth" +) + +const ( + googleUserInfoEndpoint = "https://www.googleapis.com/oauth2/v3/userinfo" + googleAuthFailedMsg = "Unable to sign in with Google. Please try again." + googleAuthCanceledMsg = "Google sign-in was cancelled." +) + +type googleUserInfo struct { + ID string `json:"sub"` + Email string `json:"email"` + VerifiedEmail bool `json:"email_verified"` +} + +func (s *Server) googleLoginHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + logger := s.logger.With(slog.String("component", "google_oauth")) + if s.googleOAuth == nil { + http.NotFound(w, r) + return + } + + state := sessionFromContext(r.Context()) + if state.Authenticated { + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) + return + } + + token, err := generateOAuthState() + if err != nil { + logger.Error("generate oauth state failed", slog.Any("error", err)) + http.Error(w, "unexpected error", http.StatusInternalServerError) + return + } + + state.OAuthState = token + if err := s.sessions.Save(w, state); err != nil { + logger.Error("persist oauth state failed", slog.Any("error", err)) + http.Error(w, "unexpected error", http.StatusInternalServerError) + return + } + + redirectURL := s.googleOAuth.AuthCodeURL(token, oauth2.AccessTypeOnline) + http.Redirect(w, r, redirectURL, http.StatusFound) + } +} + +func (s *Server) googleCallbackHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + logger := s.logger.With(slog.String("component", "google_oauth")) + if s.googleOAuth == nil { + http.NotFound(w, r) + return + } + + state := sessionFromContext(r.Context()) + expectedState := state.OAuthState + providedState := r.URL.Query().Get("state") + state.OAuthState = "" + + saveState := func() bool { + if err := s.sessions.Save(w, state); err != nil { + logger.Error("session save failed", slog.Any("error", err)) + http.Error(w, "unexpected error", http.StatusInternalServerError) + return false + } + return true + } + + respondWithLogin := func(status int, message string) { + if status != 0 { + w.WriteHeader(status) + } + s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, message, state.CSRFToken))) + } + + if expectedState == "" || providedState == "" || providedState != expectedState { + if !saveState() { + return + } + http.Error(w, "invalid oauth state", http.StatusBadRequest) + return + } + + if errParam := r.URL.Query().Get("error"); errParam != "" { + logger.Info("google oauth returned error", slog.String("google_error", errParam)) + if !saveState() { + return + } + respondWithLogin(http.StatusBadRequest, googleAuthCanceledMsg) + return + } + + authCode := r.URL.Query().Get("code") + if authCode == "" { + if !saveState() { + return + } + http.Error(w, "missing authorization code", http.StatusBadRequest) + return + } + + token, err := s.googleOAuth.Exchange(r.Context(), authCode) + if err != nil { + logger.Error("oauth code exchange failed", slog.Any("error", err)) + if !saveState() { + return + } + respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg) + return + } + + info, err := s.fetchGoogleUserInfo(r.Context(), token) + if err != nil { + logger.Error("fetch google user info failed", slog.Any("error", err)) + if !saveState() { + return + } + respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg) + return + } + + if !info.VerifiedEmail || info.Email == "" { + logger.Warn("google returned unverified email", slog.Bool("verified", info.VerifiedEmail)) + if !saveState() { + return + } + respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg) + return + } + + email, err := auth.NewUserEmail(info.Email) + if err != nil { + logger.Error("normalize google email failed", slog.Any("error", err)) + if !saveState() { + return + } + respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg) + return + } + + account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle) + if err != nil { + logger.Error("ensure external user failed", slog.Any("error", err)) + if !saveState() { + return + } + http.Error(w, "unexpected error", http.StatusInternalServerError) + return + } + + state.Authenticated = true + state.Email = account.Email.String() + if err := s.sessions.Save(w, state); err != nil { + logger.Error("session save failed", slog.Any("error", err)) + http.Error(w, "unexpected error", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) + } +} + +func (s *Server) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token) (googleUserInfo, error) { + client := s.googleOAuth.Client(ctx, token) + resp, err := client.Get(googleUserInfoEndpoint) + if err != nil { + return googleUserInfo{}, fmt.Errorf("request google userinfo: %w", err) + } + defer func() { + if cerr := resp.Body.Close(); cerr != nil { + s.logger.With(slog.String("component", "google_oauth")).Warn("close google userinfo body failed", slog.Any("error", cerr)) + } + }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return googleUserInfo{}, fmt.Errorf("google userinfo response %d: %s", resp.StatusCode, string(body)) + } + + var info googleUserInfo + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return googleUserInfo{}, fmt.Errorf("decode google userinfo: %w", err) + } + + return info, nil +} diff --git a/internal/server/handler_signup.go b/internal/server/handler_signup.go index 43db116..2d68043 100644 --- a/internal/server/handler_signup.go +++ b/internal/server/handler_signup.go @@ -24,7 +24,7 @@ func (s *Server) signupPageHandler() http.HandlerFunc { return } - s.render(w, "signup.html", newSignupData(state.Email, "", state.CSRFToken)) + s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(state.Email, "", state.CSRFToken))) } } @@ -44,7 +44,7 @@ func (s *Server) signupHandler() http.HandlerFunc { email, err := auth.NewUserEmail(emailValue) if err != nil { w.WriteHeader(http.StatusBadRequest) - s.render(w, "signup.html", newSignupData("", credentialRequiredMsg, state.CSRFToken)) + s.render(w, "signup.html", s.applyOAuthOptions(newSignupData("", credentialRequiredMsg, state.CSRFToken))) return } @@ -59,13 +59,13 @@ func (s *Server) signupHandler() http.HandlerFunc { http.Redirect(w, r, "/dashboard", http.StatusSeeOther) case errors.Is(err, auth.ErrWeakPassword): w.WriteHeader(http.StatusBadRequest) - s.render(w, "signup.html", newSignupData(email.String(), weakPasswordMsg, state.CSRFToken)) + s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), weakPasswordMsg, state.CSRFToken))) case errors.Is(err, auth.ErrInvalidInput): w.WriteHeader(http.StatusBadRequest) - s.render(w, "signup.html", newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken)) + s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken))) case errors.Is(err, auth.ErrEmailExists): w.WriteHeader(http.StatusConflict) - s.render(w, "signup.html", newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken)) + s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken))) default: logger.Error("register failed", slog.Any("error", err)) http.Error(w, "unexpected error", http.StatusInternalServerError) diff --git a/internal/server/routes.go b/internal/server/routes.go index dbce1cf..c5b6a96 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -10,6 +10,8 @@ import ( func (s *Server) registerRoutes(r chi.Router) { r.Get("/", s.loginPageHandler()) r.Post("/login", s.loginHandler()) + r.Get("/login/google", s.googleLoginHandler()) + r.Get("/login/google/callback", s.googleCallbackHandler()) r.Post("/logout", s.logoutHandler()) r.Get("/signup", s.signupPageHandler()) r.Post("/signup", s.signupHandler()) diff --git a/internal/server/server.go b/internal/server/server.go index a270a10..20aab2e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -12,6 +12,8 @@ import ( "github.com/rjnemo/auth/internal/driver/logging" "github.com/rjnemo/auth/internal/service/auth" "github.com/rjnemo/auth/web" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" ) const ( @@ -26,6 +28,7 @@ type Server struct { sessions *SessionStore logger *slog.Logger configuration config.Config + googleOAuth *oauth2.Config } // New constructs a Server with parsed templates and default state. @@ -57,12 +60,27 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) { } logger = logger.With(slog.String("service", "http")) + var googleOAuthConfig *oauth2.Config + if cfg.GoogleOAuth.Enabled() { + googleOAuthConfig = &oauth2.Config{ + ClientID: cfg.GoogleOAuth.ClientID, + ClientSecret: cfg.GoogleOAuth.ClientSecret, + RedirectURL: cfg.GoogleOAuth.RedirectURL, + Scopes: []string{ + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + }, + Endpoint: google.Endpoint, + } + } + return &Server{ templates: tmpl, authService: auth.NewService(store), sessions: sessionStore, logger: logger, configuration: cfg, + googleOAuth: googleOAuthConfig, }, nil } @@ -80,6 +98,7 @@ func seedUser(store auth.UserStore) error { Email: email, PasswordSalt: salt, PasswordHash: hash, + Provider: auth.ProviderPassword, CreatedAt: time.Now().UTC(), }) } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 194ea52..8bedde6 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -32,6 +32,30 @@ func newTestServer(t *testing.T) *Server { return srv } +func newGoogleTestServer(t *testing.T) *Server { + t.Helper() + + cfg := config.Config{ + ListenAddr: ":0", + LogMode: logging.ModeText, + Environment: "test", + SessionSecret: bytes.Repeat([]byte("g"), 32), + GoogleOAuth: config.GoogleOAuthConfig{ + ClientID: "client", + ClientSecret: "secret", + RedirectURL: "http://localhost/login/google/callback", + }, + } + + logger := logging.New(io.Discard, logging.ModeText, nil) + + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("new google server: %v", err) + } + return srv +} + func attachSession(req *http.Request, state SessionState) *http.Request { return req.WithContext(withSession(req.Context(), state)) } @@ -55,6 +79,26 @@ func TestLoginPageHandler(t *testing.T) { } } +func TestLoginPageHandlerIncludesGoogleLinkWhenConfigured(t *testing.T) { + t.Parallel() + + srv := newGoogleTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req = attachSession(req, SessionState{CSRFToken: "token"}) + rr := httptest.NewRecorder() + + srv.loginPageHandler()(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + body := rr.Body.String() + if !strings.Contains(body, "id=\"google_login_form\"") { + t.Fatalf("expected google login form in page, got %q", body) + } +} + func TestLoginHandlerSuccess(t *testing.T) { t.Parallel() @@ -115,6 +159,112 @@ func TestLoginHandlerInvalidCredentials(t *testing.T) { } } +func TestGoogleLoginHandlerDisabled(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/login/google", nil) + rr := httptest.NewRecorder() + + srv.googleLoginHandler()(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("expected 404 when google oauth disabled, got %d", rr.Code) + } +} + +func TestGoogleLoginHandlerRedirects(t *testing.T) { + t.Parallel() + + srv := newGoogleTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/login/google", nil) + req = attachSession(req, SessionState{CSRFToken: "csrf"}) + rr := httptest.NewRecorder() + + srv.googleLoginHandler()(rr, req) + + res := rr.Result() + if res.StatusCode != http.StatusFound { + t.Fatalf("expected 302 redirect, got %d", res.StatusCode) + } + location := res.Header.Get("Location") + if location == "" { + t.Fatal("expected redirect location header") + } + if !strings.Contains(location, "accounts.google.com") { + t.Fatalf("expected google authorization URL, got %q", location) + } + + parsed, err := url.Parse(location) + if err != nil { + t.Fatalf("parse redirect url: %v", err) + } + stateParam := parsed.Query().Get("state") + if stateParam == "" { + t.Fatal("expected state parameter in redirect") + } + + var sessionCookie *http.Cookie + for _, c := range res.Cookies() { + if c.Name == sessionCookieName { + sessionCookie = c + break + } + } + if sessionCookie == nil { + t.Fatal("expected session cookie to be set") + } + + savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret) + if err != nil { + t.Fatalf("decode session: %v", err) + } + if savedState.OAuthState == "" { + t.Fatal("expected oauth state stored in session") + } + if savedState.OAuthState != stateParam { + t.Fatalf("expected oauth state %q to match redirect param %q", savedState.OAuthState, stateParam) + } +} + +func TestGoogleCallbackHandlerStateMismatch(t *testing.T) { + t.Parallel() + + srv := newGoogleTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/login/google/callback?state=other&code=ignored", nil) + req = attachSession(req, SessionState{OAuthState: "expected", CSRFToken: "csrf"}) + rr := httptest.NewRecorder() + + srv.googleCallbackHandler()(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for state mismatch, got %d", rr.Code) + } + + res := rr.Result() + var sessionCookie *http.Cookie + for _, c := range res.Cookies() { + if c.Name == sessionCookieName { + sessionCookie = c + break + } + } + if sessionCookie == nil { + t.Fatal("expected session cookie to be set") + } + + savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret) + if err != nil { + t.Fatalf("decode session: %v", err) + } + if savedState.OAuthState != "" { + t.Fatal("expected oauth state to be cleared after mismatch") + } +} + func TestSignupHandlerSuccess(t *testing.T) { t.Parallel() diff --git a/internal/server/session.go b/internal/server/session.go index a2e8796..acbaa76 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -13,6 +13,7 @@ const ( sessionLifetime = 12 * time.Hour sessionSecretMinLength = 32 csrfTokenByteLength int = 32 + oauthStateByteLength int = 32 ) // SessionStore persists session data using secure HTTP cookies. @@ -33,9 +34,10 @@ func NewSessionStore(secret []byte) (*SessionStore, error) { // SessionState holds per-request session data after loading. type SessionState struct { - Authenticated bool - Email string - CSRFToken string + Authenticated bool `json:"authenticated"` + Email string `json:"email"` + CSRFToken string `json:"csrf_token"` + OAuthState string `json:"oauth_state"` } // Load extracts session data from the request cookies. @@ -99,3 +101,11 @@ func ensureCSRFToken(state SessionState) (SessionState, error) { state.CSRFToken = base64.RawURLEncoding.EncodeToString(token) return state, nil } + +func generateOAuthState() (string, error) { + buf := make([]byte, oauthStateByteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} diff --git a/internal/server/views.go b/internal/server/views.go index a7ea11d..47747c9 100644 --- a/internal/server/views.go +++ b/internal/server/views.go @@ -17,20 +17,30 @@ func (s *Server) render(w http.ResponseWriter, name string, data any) { // PageData contains fields shared by the templates for now. type PageData struct { - Title string - View string - Email string - Error string - Info string - CSRFToken string - CreatedAt string - CreatedAtISO string + Title string + View string + Email string + Error string + Info string + CSRFToken string + CreatedAt string + CreatedAtISO string + GoogleLoginURL string + GoogleLoginEnabled bool } func newLoginData(email, errMsg, token string) PageData { return PageData{Title: "Sign in · Auth Demo", View: "login", Email: email, Error: errMsg, CSRFToken: token} } +func (s *Server) applyOAuthOptions(data PageData) PageData { + if s.googleOAuth != nil { + data.GoogleLoginEnabled = true + data.GoogleLoginURL = "/login/google" + } + return data +} + func newUnauthorizedData(errMsg, token string) PageData { return PageData{ Title: "Access denied · Auth Demo", diff --git a/internal/service/auth/service.go b/internal/service/auth/service.go index 0c0a88e..5832909 100644 --- a/internal/service/auth/service.go +++ b/internal/service/auth/service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" ) @@ -14,9 +15,17 @@ var ( ErrInvalidCredentials = errors.New("auth: invalid credentials") // ErrEmailExists indicates an account already uses the provided email address. ErrEmailExists = errors.New("auth: email already registered") + // ErrProviderRequired indicates the external provider identifier was missing. + ErrProviderRequired = errors.New("auth: provider required") ) -const userIDByteLength = 16 +const ( + userIDByteLength = 16 + // ProviderPassword identifies accounts managed via email/password. + ProviderPassword = "password" + // ProviderGoogle identifies accounts authenticated via Google OAuth2. + ProviderGoogle = "google" +) // Service exposes authentication business operations to HTTP handlers. type Service struct { @@ -91,6 +100,7 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string Email: email, PasswordSalt: salt, PasswordHash: hash, + Provider: ProviderPassword, CreatedAt: time.Now().UTC(), } @@ -100,3 +110,39 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string return &user, nil } + +// EnsureExternalUser retrieves or provisions an account authenticated by an external provider. +func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) { + if email.IsZero() { + return nil, ErrInvalidInput + } + if strings.TrimSpace(provider) == "" { + return nil, ErrProviderRequired + } + + account, err := s.store.FindByEmail(ctx, email) + switch { + case err == nil: + return account, nil + case !errors.Is(err, ErrUserNotFound): + return nil, err + } + + id, err := generateUserID() + if err != nil { + return nil, fmt.Errorf("generate user id: %w", err) + } + + user := User{ + ID: id, + Email: email, + Provider: provider, + CreatedAt: time.Now().UTC(), + } + + if err := s.store.Create(ctx, user); err != nil { + return nil, err + } + + return &user, nil +} diff --git a/internal/service/auth/service_test.go b/internal/service/auth/service_test.go index a9eb52f..b873b98 100644 --- a/internal/service/auth/service_test.go +++ b/internal/service/auth/service_test.go @@ -18,7 +18,7 @@ func TestServiceAuthenticate(t *testing.T) { if err != nil { t.Fatalf("hash password: %v", err) } - if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash}); err != nil { + if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash, Provider: ProviderPassword}); err != nil { t.Fatalf("seed user: %v", err) } @@ -70,7 +70,7 @@ func TestServiceLookupByEmail(t *testing.T) { service := NewService(store) email := MustUserEmail("lookup@example.com") - if err := store.Create(ctx, User{Email: email}); err != nil { + if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil { t.Fatalf("seed user: %v", err) } @@ -116,7 +116,7 @@ func TestServiceRegister(t *testing.T) { service := NewService(store) email := MustUserEmail("taken@example.com") - if err := store.Create(ctx, User{Email: email}); err != nil { + if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil { t.Fatalf("seed user: %v", err) } @@ -166,3 +166,65 @@ func TestServiceRegister(t *testing.T) { }) } } + +func TestServiceEnsureExternalUser(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := NewMemoryStore() + service := NewService(store) + + googleEmail := MustUserEmail("google@example.com") + if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil { + t.Fatalf("seed external user: %v", err) + } + + tests := map[string]struct { + email UserEmail + provider string + wantErr error + wantNew bool + }{ + "missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput}, + "missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired}, + "existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false}, + "provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true}, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider) + if tc.wantErr != nil { + if !errors.Is(err, tc.wantErr) { + t.Fatalf("expected %v, got %v", tc.wantErr, err) + } + if user != nil { + t.Fatalf("expected nil user, got %#v", user) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user == nil { + t.Fatal("expected user") + } + if user.Email != tc.email { + t.Fatalf("expected email %q, got %q", tc.email, user.Email) + } + if user.Provider != tc.provider { + t.Fatalf("expected provider %q, got %q", tc.provider, user.Provider) + } + persisted, err := store.FindByEmail(ctx, tc.email) + if err != nil { + t.Fatalf("expected user persisted: %v", err) + } + if tc.wantNew && persisted.CreatedAt.IsZero() { + t.Fatal("expected created at timestamp for new user") + } + }) + } +} diff --git a/internal/service/auth/user.go b/internal/service/auth/user.go index 05042b0..f6a6e7b 100644 --- a/internal/service/auth/user.go +++ b/internal/service/auth/user.go @@ -14,6 +14,7 @@ type User struct { Email UserEmail PasswordSalt string PasswordHash string + Provider string CreatedAt time.Time } diff --git a/web/templates/login.html b/web/templates/login.html index d0d0b07..6853487 100644 --- a/web/templates/login.html +++ b/web/templates/login.html @@ -52,8 +52,14 @@
+ {{if .GoogleLoginEnabled}}
or
- + {{end}}
+ {{if .GoogleLoginEnabled}} + + {{end}} diff --git a/web/templates/signup.html b/web/templates/signup.html index 963f9a7..c5d7280 100644 --- a/web/templates/signup.html +++ b/web/templates/signup.html @@ -67,8 +67,14 @@
+ {{if .GoogleLoginEnabled}}
or
- + {{end}}
+ {{if .GoogleLoginEnabled}} + + {{end}}