mirror of
https://github.com/rjNemo/auth
synced 2026-06-06 00:16:40 +00:00
feat: add google oauth login
This commit is contained in:
parent
a20953cfb4
commit
52f479206c
17 changed files with 625 additions and 31 deletions
5
go.mod
5
go.mod
|
|
@ -3,3 +3,8 @@ module github.com/rjnemo/auth
|
|||
go 1.25.1
|
||||
|
||||
require github.com/go-chi/chi/v5 v5.2.3
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
golang.org/x/oauth2 v0.31.0 // indirect
|
||||
)
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -1,2 +1,6 @@
|
|||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
|
||||
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
|
|
|
|||
|
|
@ -11,10 +11,13 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
envListenAddr = "AUTH_LISTEN_ADDR"
|
||||
envLogMode = "AUTH_LOG_MODE"
|
||||
envEnvironment = "AUTH_ENV"
|
||||
envSessionSecret = "AUTH_SESSION_SECRET"
|
||||
envListenAddr = "AUTH_LISTEN_ADDR"
|
||||
envLogMode = "AUTH_LOG_MODE"
|
||||
envEnvironment = "AUTH_ENV"
|
||||
envSessionSecret = "AUTH_SESSION_SECRET"
|
||||
envGoogleClientID = "AUTH_GOOGLE_CLIENT_ID"
|
||||
envGoogleClientSecret = "AUTH_GOOGLE_CLIENT_SECRET"
|
||||
envGoogleRedirectURL = "AUTH_GOOGLE_REDIRECT_URL"
|
||||
|
||||
defaultListenAddr = ":8000"
|
||||
defaultEnvironment = "development"
|
||||
|
|
@ -26,6 +29,19 @@ type Config struct {
|
|||
LogMode logging.Mode
|
||||
Environment string
|
||||
SessionSecret []byte
|
||||
GoogleOAuth GoogleOAuthConfig
|
||||
}
|
||||
|
||||
// GoogleOAuthConfig holds configuration for Google OAuth2 login.
|
||||
type GoogleOAuthConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
RedirectURL string
|
||||
}
|
||||
|
||||
// Enabled reports whether Google OAuth2 is fully configured.
|
||||
func (g GoogleOAuthConfig) Enabled() bool {
|
||||
return g.ClientID != "" && g.ClientSecret != "" && g.RedirectURL != ""
|
||||
}
|
||||
|
||||
// New loads configuration from environment variables, applying defaults and validation.
|
||||
|
|
@ -47,12 +63,34 @@ func New() (*Config, error) {
|
|||
return nil, fmt.Errorf("invalid %s: %w", envSessionSecret, err)
|
||||
}
|
||||
|
||||
googleOAuth := GoogleOAuthConfig{
|
||||
ClientID: strings.TrimSpace(os.Getenv(envGoogleClientID)),
|
||||
ClientSecret: strings.TrimSpace(os.Getenv(envGoogleClientSecret)),
|
||||
RedirectURL: strings.TrimSpace(os.Getenv(envGoogleRedirectURL)),
|
||||
}
|
||||
|
||||
if partiallyConfigured(googleOAuth) {
|
||||
return nil, fmt.Errorf("incomplete google oauth configuration: set %s, %s, and %s", envGoogleClientID, envGoogleClientSecret, envGoogleRedirectURL)
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
ListenAddr: listenAddr,
|
||||
LogMode: logMode,
|
||||
Environment: environment,
|
||||
SessionSecret: secret,
|
||||
GoogleOAuth: googleOAuth,
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func partiallyConfigured(cfg GoogleOAuthConfig) bool {
|
||||
switch {
|
||||
case cfg.ClientID == "" && cfg.ClientSecret == "" && cfg.RedirectURL == "":
|
||||
return false
|
||||
case cfg.ClientID == "" || cfg.ClientSecret == "" || cfg.RedirectURL == "":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,6 +73,33 @@ func TestNewShortSecretAccepted(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewGoogleOAuthPartialConfiguration(t *testing.T) {
|
||||
t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32)))
|
||||
t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client")
|
||||
t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "")
|
||||
if _, err := New(); err == nil {
|
||||
t.Fatalf("expected error for partial google oauth config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGoogleOAuthConfigured(t *testing.T) {
|
||||
t.Setenv("AUTH_SESSION_SECRET", base64.StdEncoding.EncodeToString(bytesOfLength(32)))
|
||||
t.Setenv("AUTH_GOOGLE_CLIENT_ID", "client")
|
||||
t.Setenv("AUTH_GOOGLE_CLIENT_SECRET", "secret")
|
||||
t.Setenv("AUTH_GOOGLE_REDIRECT_URL", "http://localhost:8000/login/google/callback")
|
||||
|
||||
cfg, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !cfg.GoogleOAuth.Enabled() {
|
||||
t.Fatalf("expected google oauth to be enabled")
|
||||
}
|
||||
if cfg.GoogleOAuth.ClientID != "client" {
|
||||
t.Fatalf("expected client id to match, got %q", cfg.GoogleOAuth.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func bytesOfLength(n int) []byte {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ func (s *Server) loginPageHandler() http.HandlerFunc {
|
|||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
s.render(w, "login.html", newLoginData(state.Email, "", state.CSRFToken))
|
||||
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, "", state.CSRFToken)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ func (s *Server) loginHandler() http.HandlerFunc {
|
|||
email, err := auth.NewUserEmail(emailInput)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "login.html", newLoginData("", credentialRequiredMsg, state.CSRFToken))
|
||||
s.render(w, "login.html", s.applyOAuthOptions(newLoginData("", credentialRequiredMsg, state.CSRFToken)))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -51,10 +51,10 @@ func (s *Server) loginHandler() http.HandlerFunc {
|
|||
|
||||
case errors.Is(err, auth.ErrWeakPassword):
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "login.html", newLoginData(email.String(), weakPasswordMsg, state.CSRFToken))
|
||||
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), weakPasswordMsg, state.CSRFToken)))
|
||||
case errors.Is(err, auth.ErrInvalidInput):
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "login.html", newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken))
|
||||
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), credentialRequiredMsg, state.CSRFToken)))
|
||||
case errors.Is(err, auth.ErrInvalidCredentials):
|
||||
s.renderLoginFailure(w, email, state.CSRFToken)
|
||||
default:
|
||||
|
|
@ -66,5 +66,5 @@ func (s *Server) loginHandler() http.HandlerFunc {
|
|||
|
||||
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
s.render(w, "login.html", newLoginData(email.String(), invalidCredentialsMsg, token))
|
||||
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(email.String(), invalidCredentialsMsg, token)))
|
||||
}
|
||||
|
|
|
|||
200
internal/server/handler_login_google.go
Normal file
200
internal/server/handler_login_google.go
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/rjnemo/auth/internal/service/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
googleUserInfoEndpoint = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
googleAuthFailedMsg = "Unable to sign in with Google. Please try again."
|
||||
googleAuthCanceledMsg = "Google sign-in was cancelled."
|
||||
)
|
||||
|
||||
type googleUserInfo struct {
|
||||
ID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
VerifiedEmail bool `json:"email_verified"`
|
||||
}
|
||||
|
||||
func (s *Server) googleLoginHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
logger := s.logger.With(slog.String("component", "google_oauth"))
|
||||
if s.googleOAuth == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
state := sessionFromContext(r.Context())
|
||||
if state.Authenticated {
|
||||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := generateOAuthState()
|
||||
if err != nil {
|
||||
logger.Error("generate oauth state failed", slog.Any("error", err))
|
||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
state.OAuthState = token
|
||||
if err := s.sessions.Save(w, state); err != nil {
|
||||
logger.Error("persist oauth state failed", slog.Any("error", err))
|
||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURL := s.googleOAuth.AuthCodeURL(token, oauth2.AccessTypeOnline)
|
||||
http.Redirect(w, r, redirectURL, http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) googleCallbackHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
logger := s.logger.With(slog.String("component", "google_oauth"))
|
||||
if s.googleOAuth == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
state := sessionFromContext(r.Context())
|
||||
expectedState := state.OAuthState
|
||||
providedState := r.URL.Query().Get("state")
|
||||
state.OAuthState = ""
|
||||
|
||||
saveState := func() bool {
|
||||
if err := s.sessions.Save(w, state); err != nil {
|
||||
logger.Error("session save failed", slog.Any("error", err))
|
||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
respondWithLogin := func(status int, message string) {
|
||||
if status != 0 {
|
||||
w.WriteHeader(status)
|
||||
}
|
||||
s.render(w, "login.html", s.applyOAuthOptions(newLoginData(state.Email, message, state.CSRFToken)))
|
||||
}
|
||||
|
||||
if expectedState == "" || providedState == "" || providedState != expectedState {
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
http.Error(w, "invalid oauth state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
||||
logger.Info("google oauth returned error", slog.String("google_error", errParam))
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
respondWithLogin(http.StatusBadRequest, googleAuthCanceledMsg)
|
||||
return
|
||||
}
|
||||
|
||||
authCode := r.URL.Query().Get("code")
|
||||
if authCode == "" {
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
http.Error(w, "missing authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := s.googleOAuth.Exchange(r.Context(), authCode)
|
||||
if err != nil {
|
||||
logger.Error("oauth code exchange failed", slog.Any("error", err))
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.fetchGoogleUserInfo(r.Context(), token)
|
||||
if err != nil {
|
||||
logger.Error("fetch google user info failed", slog.Any("error", err))
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
|
||||
return
|
||||
}
|
||||
|
||||
if !info.VerifiedEmail || info.Email == "" {
|
||||
logger.Warn("google returned unverified email", slog.Bool("verified", info.VerifiedEmail))
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
|
||||
return
|
||||
}
|
||||
|
||||
email, err := auth.NewUserEmail(info.Email)
|
||||
if err != nil {
|
||||
logger.Error("normalize google email failed", slog.Any("error", err))
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
respondWithLogin(http.StatusUnauthorized, googleAuthFailedMsg)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := s.authService.EnsureExternalUser(r.Context(), email, auth.ProviderGoogle)
|
||||
if err != nil {
|
||||
logger.Error("ensure external user failed", slog.Any("error", err))
|
||||
if !saveState() {
|
||||
return
|
||||
}
|
||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
state.Authenticated = true
|
||||
state.Email = account.Email.String()
|
||||
if err := s.sessions.Save(w, state); err != nil {
|
||||
logger.Error("session save failed", slog.Any("error", err))
|
||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token) (googleUserInfo, error) {
|
||||
client := s.googleOAuth.Client(ctx, token)
|
||||
resp, err := client.Get(googleUserInfoEndpoint)
|
||||
if err != nil {
|
||||
return googleUserInfo{}, fmt.Errorf("request google userinfo: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := resp.Body.Close(); cerr != nil {
|
||||
s.logger.With(slog.String("component", "google_oauth")).Warn("close google userinfo body failed", slog.Any("error", cerr))
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return googleUserInfo{}, fmt.Errorf("google userinfo response %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var info googleUserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
|
||||
return googleUserInfo{}, fmt.Errorf("decode google userinfo: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
|
@ -24,7 +24,7 @@ func (s *Server) signupPageHandler() http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
s.render(w, "signup.html", newSignupData(state.Email, "", state.CSRFToken))
|
||||
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(state.Email, "", state.CSRFToken)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ func (s *Server) signupHandler() http.HandlerFunc {
|
|||
email, err := auth.NewUserEmail(emailValue)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "signup.html", newSignupData("", credentialRequiredMsg, state.CSRFToken))
|
||||
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData("", credentialRequiredMsg, state.CSRFToken)))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -59,13 +59,13 @@ func (s *Server) signupHandler() http.HandlerFunc {
|
|||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||||
case errors.Is(err, auth.ErrWeakPassword):
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "signup.html", newSignupData(email.String(), weakPasswordMsg, state.CSRFToken))
|
||||
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), weakPasswordMsg, state.CSRFToken)))
|
||||
case errors.Is(err, auth.ErrInvalidInput):
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "signup.html", newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken))
|
||||
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), credentialRequiredMsg, state.CSRFToken)))
|
||||
case errors.Is(err, auth.ErrEmailExists):
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
s.render(w, "signup.html", newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken))
|
||||
s.render(w, "signup.html", s.applyOAuthOptions(newSignupData(email.String(), duplicateEmailMsg, state.CSRFToken)))
|
||||
default:
|
||||
logger.Error("register failed", slog.Any("error", err))
|
||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import (
|
|||
func (s *Server) registerRoutes(r chi.Router) {
|
||||
r.Get("/", s.loginPageHandler())
|
||||
r.Post("/login", s.loginHandler())
|
||||
r.Get("/login/google", s.googleLoginHandler())
|
||||
r.Get("/login/google/callback", s.googleCallbackHandler())
|
||||
r.Post("/logout", s.logoutHandler())
|
||||
r.Get("/signup", s.signupPageHandler())
|
||||
r.Post("/signup", s.signupHandler())
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import (
|
|||
"github.com/rjnemo/auth/internal/driver/logging"
|
||||
"github.com/rjnemo/auth/internal/service/auth"
|
||||
"github.com/rjnemo/auth/web"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -26,6 +28,7 @@ type Server struct {
|
|||
sessions *SessionStore
|
||||
logger *slog.Logger
|
||||
configuration config.Config
|
||||
googleOAuth *oauth2.Config
|
||||
}
|
||||
|
||||
// New constructs a Server with parsed templates and default state.
|
||||
|
|
@ -57,12 +60,27 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
|||
}
|
||||
logger = logger.With(slog.String("service", "http"))
|
||||
|
||||
var googleOAuthConfig *oauth2.Config
|
||||
if cfg.GoogleOAuth.Enabled() {
|
||||
googleOAuthConfig = &oauth2.Config{
|
||||
ClientID: cfg.GoogleOAuth.ClientID,
|
||||
ClientSecret: cfg.GoogleOAuth.ClientSecret,
|
||||
RedirectURL: cfg.GoogleOAuth.RedirectURL,
|
||||
Scopes: []string{
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
templates: tmpl,
|
||||
authService: auth.NewService(store),
|
||||
sessions: sessionStore,
|
||||
logger: logger,
|
||||
configuration: cfg,
|
||||
googleOAuth: googleOAuthConfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -80,6 +98,7 @@ func seedUser(store auth.UserStore) error {
|
|||
Email: email,
|
||||
PasswordSalt: salt,
|
||||
PasswordHash: hash,
|
||||
Provider: auth.ProviderPassword,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,30 @@ func newTestServer(t *testing.T) *Server {
|
|||
return srv
|
||||
}
|
||||
|
||||
func newGoogleTestServer(t *testing.T) *Server {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.Config{
|
||||
ListenAddr: ":0",
|
||||
LogMode: logging.ModeText,
|
||||
Environment: "test",
|
||||
SessionSecret: bytes.Repeat([]byte("g"), 32),
|
||||
GoogleOAuth: config.GoogleOAuthConfig{
|
||||
ClientID: "client",
|
||||
ClientSecret: "secret",
|
||||
RedirectURL: "http://localhost/login/google/callback",
|
||||
},
|
||||
}
|
||||
|
||||
logger := logging.New(io.Discard, logging.ModeText, nil)
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("new google server: %v", err)
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func attachSession(req *http.Request, state SessionState) *http.Request {
|
||||
return req.WithContext(withSession(req.Context(), state))
|
||||
}
|
||||
|
|
@ -55,6 +79,26 @@ func TestLoginPageHandler(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestLoginPageHandlerIncludesGoogleLinkWhenConfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newGoogleTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req = attachSession(req, SessionState{CSRFToken: "token"})
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
srv.loginPageHandler()(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
body := rr.Body.String()
|
||||
if !strings.Contains(body, "id=\"google_login_form\"") {
|
||||
t.Fatalf("expected google login form in page, got %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
@ -115,6 +159,112 @@ func TestLoginHandlerInvalidCredentials(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGoogleLoginHandlerDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/google", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
srv.googleLoginHandler()(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 when google oauth disabled, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleLoginHandlerRedirects(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newGoogleTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/google", nil)
|
||||
req = attachSession(req, SessionState{CSRFToken: "csrf"})
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
srv.googleLoginHandler()(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
if res.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected 302 redirect, got %d", res.StatusCode)
|
||||
}
|
||||
location := res.Header.Get("Location")
|
||||
if location == "" {
|
||||
t.Fatal("expected redirect location header")
|
||||
}
|
||||
if !strings.Contains(location, "accounts.google.com") {
|
||||
t.Fatalf("expected google authorization URL, got %q", location)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(location)
|
||||
if err != nil {
|
||||
t.Fatalf("parse redirect url: %v", err)
|
||||
}
|
||||
stateParam := parsed.Query().Get("state")
|
||||
if stateParam == "" {
|
||||
t.Fatal("expected state parameter in redirect")
|
||||
}
|
||||
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range res.Cookies() {
|
||||
if c.Name == sessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if sessionCookie == nil {
|
||||
t.Fatal("expected session cookie to be set")
|
||||
}
|
||||
|
||||
savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("decode session: %v", err)
|
||||
}
|
||||
if savedState.OAuthState == "" {
|
||||
t.Fatal("expected oauth state stored in session")
|
||||
}
|
||||
if savedState.OAuthState != stateParam {
|
||||
t.Fatalf("expected oauth state %q to match redirect param %q", savedState.OAuthState, stateParam)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleCallbackHandlerStateMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newGoogleTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/google/callback?state=other&code=ignored", nil)
|
||||
req = attachSession(req, SessionState{OAuthState: "expected", CSRFToken: "csrf"})
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
srv.googleCallbackHandler()(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for state mismatch, got %d", rr.Code)
|
||||
}
|
||||
|
||||
res := rr.Result()
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range res.Cookies() {
|
||||
if c.Name == sessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if sessionCookie == nil {
|
||||
t.Fatal("expected session cookie to be set")
|
||||
}
|
||||
|
||||
savedState, err := decodeSession(sessionCookie.Value, srv.configuration.SessionSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("decode session: %v", err)
|
||||
}
|
||||
if savedState.OAuthState != "" {
|
||||
t.Fatal("expected oauth state to be cleared after mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignupHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ const (
|
|||
sessionLifetime = 12 * time.Hour
|
||||
sessionSecretMinLength = 32
|
||||
csrfTokenByteLength int = 32
|
||||
oauthStateByteLength int = 32
|
||||
)
|
||||
|
||||
// SessionStore persists session data using secure HTTP cookies.
|
||||
|
|
@ -33,9 +34,10 @@ func NewSessionStore(secret []byte) (*SessionStore, error) {
|
|||
|
||||
// SessionState holds per-request session data after loading.
|
||||
type SessionState struct {
|
||||
Authenticated bool
|
||||
Email string
|
||||
CSRFToken string
|
||||
Authenticated bool `json:"authenticated"`
|
||||
Email string `json:"email"`
|
||||
CSRFToken string `json:"csrf_token"`
|
||||
OAuthState string `json:"oauth_state"`
|
||||
}
|
||||
|
||||
// Load extracts session data from the request cookies.
|
||||
|
|
@ -99,3 +101,11 @@ func ensureCSRFToken(state SessionState) (SessionState, error) {
|
|||
state.CSRFToken = base64.RawURLEncoding.EncodeToString(token)
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func generateOAuthState() (string, error) {
|
||||
buf := make([]byte, oauthStateByteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,20 +17,30 @@ func (s *Server) render(w http.ResponseWriter, name string, data any) {
|
|||
|
||||
// PageData contains fields shared by the templates for now.
|
||||
type PageData struct {
|
||||
Title string
|
||||
View string
|
||||
Email string
|
||||
Error string
|
||||
Info string
|
||||
CSRFToken string
|
||||
CreatedAt string
|
||||
CreatedAtISO string
|
||||
Title string
|
||||
View string
|
||||
Email string
|
||||
Error string
|
||||
Info string
|
||||
CSRFToken string
|
||||
CreatedAt string
|
||||
CreatedAtISO string
|
||||
GoogleLoginURL string
|
||||
GoogleLoginEnabled bool
|
||||
}
|
||||
|
||||
func newLoginData(email, errMsg, token string) PageData {
|
||||
return PageData{Title: "Sign in · Auth Demo", View: "login", Email: email, Error: errMsg, CSRFToken: token}
|
||||
}
|
||||
|
||||
func (s *Server) applyOAuthOptions(data PageData) PageData {
|
||||
if s.googleOAuth != nil {
|
||||
data.GoogleLoginEnabled = true
|
||||
data.GoogleLoginURL = "/login/google"
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func newUnauthorizedData(errMsg, token string) PageData {
|
||||
return PageData{
|
||||
Title: "Access denied · Auth Demo",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -14,9 +15,17 @@ var (
|
|||
ErrInvalidCredentials = errors.New("auth: invalid credentials")
|
||||
// ErrEmailExists indicates an account already uses the provided email address.
|
||||
ErrEmailExists = errors.New("auth: email already registered")
|
||||
// ErrProviderRequired indicates the external provider identifier was missing.
|
||||
ErrProviderRequired = errors.New("auth: provider required")
|
||||
)
|
||||
|
||||
const userIDByteLength = 16
|
||||
const (
|
||||
userIDByteLength = 16
|
||||
// ProviderPassword identifies accounts managed via email/password.
|
||||
ProviderPassword = "password"
|
||||
// ProviderGoogle identifies accounts authenticated via Google OAuth2.
|
||||
ProviderGoogle = "google"
|
||||
)
|
||||
|
||||
// Service exposes authentication business operations to HTTP handlers.
|
||||
type Service struct {
|
||||
|
|
@ -91,6 +100,7 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
|
|||
Email: email,
|
||||
PasswordSalt: salt,
|
||||
PasswordHash: hash,
|
||||
Provider: ProviderPassword,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
|
|
@ -100,3 +110,39 @@ func (s *Service) Register(ctx context.Context, email UserEmail, password string
|
|||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// EnsureExternalUser retrieves or provisions an account authenticated by an external provider.
|
||||
func (s *Service) EnsureExternalUser(ctx context.Context, email UserEmail, provider string) (*User, error) {
|
||||
if email.IsZero() {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
if strings.TrimSpace(provider) == "" {
|
||||
return nil, ErrProviderRequired
|
||||
}
|
||||
|
||||
account, err := s.store.FindByEmail(ctx, email)
|
||||
switch {
|
||||
case err == nil:
|
||||
return account, nil
|
||||
case !errors.Is(err, ErrUserNotFound):
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := generateUserID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate user id: %w", err)
|
||||
}
|
||||
|
||||
user := User{
|
||||
ID: id,
|
||||
Email: email,
|
||||
Provider: provider,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.store.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ func TestServiceAuthenticate(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("hash password: %v", err)
|
||||
}
|
||||
if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash}); err != nil {
|
||||
if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash, Provider: ProviderPassword}); err != nil {
|
||||
t.Fatalf("seed user: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ func TestServiceLookupByEmail(t *testing.T) {
|
|||
service := NewService(store)
|
||||
|
||||
email := MustUserEmail("lookup@example.com")
|
||||
if err := store.Create(ctx, User{Email: email}); err != nil {
|
||||
if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil {
|
||||
t.Fatalf("seed user: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -116,7 +116,7 @@ func TestServiceRegister(t *testing.T) {
|
|||
service := NewService(store)
|
||||
|
||||
email := MustUserEmail("taken@example.com")
|
||||
if err := store.Create(ctx, User{Email: email}); err != nil {
|
||||
if err := store.Create(ctx, User{Email: email, Provider: ProviderPassword}); err != nil {
|
||||
t.Fatalf("seed user: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -166,3 +166,65 @@ func TestServiceRegister(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceEnsureExternalUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
store := NewMemoryStore()
|
||||
service := NewService(store)
|
||||
|
||||
googleEmail := MustUserEmail("google@example.com")
|
||||
if err := store.Create(ctx, User{Email: googleEmail, Provider: ProviderGoogle}); err != nil {
|
||||
t.Fatalf("seed external user: %v", err)
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
email UserEmail
|
||||
provider string
|
||||
wantErr error
|
||||
wantNew bool
|
||||
}{
|
||||
"missing email": {email: UserEmail(""), provider: ProviderGoogle, wantErr: ErrInvalidInput},
|
||||
"missing provider": {email: MustUserEmail("new@example.com"), provider: "", wantErr: ErrProviderRequired},
|
||||
"existing": {email: googleEmail, provider: ProviderGoogle, wantErr: nil, wantNew: false},
|
||||
"provision": {email: MustUserEmail("brandnew@example.com"), provider: ProviderGoogle, wantErr: nil, wantNew: true},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
user, err := service.EnsureExternalUser(ctx, tc.email, tc.provider)
|
||||
if tc.wantErr != nil {
|
||||
if !errors.Is(err, tc.wantErr) {
|
||||
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||
}
|
||||
if user != nil {
|
||||
t.Fatalf("expected nil user, got %#v", user)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
t.Fatal("expected user")
|
||||
}
|
||||
if user.Email != tc.email {
|
||||
t.Fatalf("expected email %q, got %q", tc.email, user.Email)
|
||||
}
|
||||
if user.Provider != tc.provider {
|
||||
t.Fatalf("expected provider %q, got %q", tc.provider, user.Provider)
|
||||
}
|
||||
persisted, err := store.FindByEmail(ctx, tc.email)
|
||||
if err != nil {
|
||||
t.Fatalf("expected user persisted: %v", err)
|
||||
}
|
||||
if tc.wantNew && persisted.CreatedAt.IsZero() {
|
||||
t.Fatal("expected created at timestamp for new user")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ type User struct {
|
|||
Email UserEmail
|
||||
PasswordSalt string
|
||||
PasswordHash string
|
||||
Provider string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,8 +52,14 @@
|
|||
</div>
|
||||
<div class="auth-actions">
|
||||
<button type="submit" class="primary">Log in</button>
|
||||
{{if .GoogleLoginEnabled}}
|
||||
<div class="auth-divider">or</div>
|
||||
<button type="button" class="secondary outline auth-google">
|
||||
<button
|
||||
type="submit"
|
||||
class="secondary outline auth-google"
|
||||
form="google_login_form"
|
||||
formnovalidate
|
||||
>
|
||||
<svg
|
||||
width="18"
|
||||
height="18"
|
||||
|
|
@ -81,8 +87,12 @@
|
|||
</svg>
|
||||
Continue with Google
|
||||
</button>
|
||||
{{end}}
|
||||
</div>
|
||||
</form>
|
||||
{{if .GoogleLoginEnabled}}
|
||||
<form id="google_login_form" action="{{.GoogleLoginURL}}" method="get" hidden></form>
|
||||
{{end}}
|
||||
<p class="auth-footer">
|
||||
Don't have an account? <a href="/signup">Sign up</a>
|
||||
</p>
|
||||
|
|
|
|||
|
|
@ -67,8 +67,14 @@
|
|||
</label>
|
||||
<div class="auth-actions">
|
||||
<button type="submit" class="primary">Create account</button>
|
||||
{{if .GoogleLoginEnabled}}
|
||||
<div class="auth-divider">or</div>
|
||||
<button type="button" class="secondary outline auth-google">
|
||||
<button
|
||||
type="submit"
|
||||
class="secondary outline auth-google"
|
||||
form="google_login_form"
|
||||
formnovalidate
|
||||
>
|
||||
<svg
|
||||
width="18"
|
||||
height="18"
|
||||
|
|
@ -96,8 +102,12 @@
|
|||
</svg>
|
||||
Sign up with Google
|
||||
</button>
|
||||
{{end}}
|
||||
</div>
|
||||
</form>
|
||||
{{if .GoogleLoginEnabled}}
|
||||
<form id="google_login_form" action="{{.GoogleLoginURL}}" method="get" hidden></form>
|
||||
{{end}}
|
||||
<p class="auth-footer">
|
||||
Have an account? <a href="/">Log in</a>
|
||||
</p>
|
||||
|
|
|
|||
Loading…
Reference in a new issue