From da1bf44d8f6f22cf0b69fac9a06c5e7b59fd6b7f Mon Sep 17 00:00:00 2001 From: Ruidy Date: Sat, 20 Sep 2025 01:08:03 +0200 Subject: [PATCH] feat: secure sessions and csrf --- AGENTS.md | 2 +- internal/server/handler_auth.go | 21 +++-- internal/server/handler_dashboard.go | 4 +- internal/server/handler_public.go | 2 +- internal/server/middleware.go | 56 +++++++++++- internal/server/server.go | 16 +++- internal/server/session.go | 132 +++++++++++++++++++-------- internal/server/session_encoding.go | 53 +++++++++++ internal/server/views.go | 13 +-- web/templates/in.html | 1 + web/templates/index.html | 1 + 11 files changed, 241 insertions(+), 60 deletions(-) create mode 100644 internal/server/session_encoding.go diff --git a/AGENTS.md b/AGENTS.md index ce42681..c6f2c2f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -20,7 +20,7 @@ Implement email/password authentication with secure password hashing, CSRF prote ## Coding Style & Naming Conventions -Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Template IDs and Alpine component names should reflect their role (e.g., `login_form`). +Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Promote named constants/variables over magic numbers or strings. Template IDs and Alpine component names should reflect their role (e.g., `login_form`). ## Testing Guidelines diff --git a/internal/server/handler_auth.go b/internal/server/handler_auth.go index 1d784ef..4be897f 100644 --- a/internal/server/handler_auth.go +++ b/internal/server/handler_auth.go @@ -10,6 +10,8 @@ import ( func (s *Server) loginHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + state := sessionFromContext(r.Context()) + if err := r.ParseForm(); err != nil { http.Error(w, "invalid form submission", http.StatusBadRequest) return @@ -21,20 +23,25 @@ func (s *Server) loginHandler() http.HandlerFunc { email, err := auth.NewUserEmail(emailInput) if err != nil { w.WriteHeader(http.StatusBadRequest) - s.render(w, "index.html", newIndexData("", "Email and password are required.")) + s.render(w, "index.html", newIndexData("", "Email and password are required.", state.CSRFToken)) return } account, err := s.authService.Authenticate(r.Context(), email, password) switch { case err == nil: - s.sessions.SetAuthenticated(account.Email.String()) + state.Authenticated = true + state.Email = account.Email.String() + if err := s.sessions.Save(w, state); err != nil { + log.Printf("session: save failed: %v", err) + } http.Redirect(w, r, "/in", http.StatusSeeOther) + case errors.Is(err, auth.ErrInvalidInput): w.WriteHeader(http.StatusBadRequest) - s.render(w, "index.html", newIndexData(email.String(), "Email and password are required.")) + s.render(w, "index.html", newIndexData(email.String(), "Email and password are required.", state.CSRFToken)) case errors.Is(err, auth.ErrInvalidCredentials): - s.renderLoginFailure(w, email) + s.renderLoginFailure(w, email, state.CSRFToken) default: log.Printf("auth: authenticate failed: %v", err) http.Error(w, "unexpected error", http.StatusInternalServerError) @@ -44,12 +51,12 @@ func (s *Server) loginHandler() http.HandlerFunc { func (s *Server) logoutHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - s.sessions.Clear() + s.sessions.Clear(w) http.Redirect(w, r, "/", http.StatusSeeOther) } } -func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail) { +func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) { w.WriteHeader(http.StatusUnauthorized) - s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials.")) + s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials.", token)) } diff --git a/internal/server/handler_dashboard.go b/internal/server/handler_dashboard.go index 167944b..02ee0bc 100644 --- a/internal/server/handler_dashboard.go +++ b/internal/server/handler_dashboard.go @@ -8,10 +8,10 @@ func (s *Server) dashboardHandler() http.HandlerFunc { if !state.Authenticated { w.WriteHeader(http.StatusUnauthorized) - s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue.")) + s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue.", state.CSRFToken)) return } - s.render(w, "in.html", PageData{Email: state.Email}) + s.render(w, "in.html", PageData{Email: state.Email, CSRFToken: state.CSRFToken}) } } diff --git a/internal/server/handler_public.go b/internal/server/handler_public.go index 7794f07..0c0fe51 100644 --- a/internal/server/handler_public.go +++ b/internal/server/handler_public.go @@ -5,6 +5,6 @@ import "net/http" func (s *Server) indexHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { state := sessionFromContext(r.Context()) - s.render(w, "index.html", newIndexData(state.Email, "")) + s.render(w, "index.html", newIndexData(state.Email, "", state.CSRFToken)) } } diff --git a/internal/server/middleware.go b/internal/server/middleware.go index d68b90b..ccfe88f 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -2,17 +2,61 @@ package server import ( "context" + "crypto/subtle" + "log" "net/http" ) func (s *Server) sessionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - state := s.sessions.Snapshot() + state := s.sessions.Load(r) + updated, err := ensureCSRFToken(state) + if err != nil { + log.Printf("session: csrf token generation failed: %v", err) + http.Error(w, "session error", http.StatusInternalServerError) + return + } + state = updated + + if err := s.sessions.Save(w, state); err != nil { + log.Printf("session: save failed: %v", err) + } + ctx := withSession(r.Context(), state) next.ServeHTTP(w, r.WithContext(ctx)) }) } +func (s *Server) csrfMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + next.ServeHTTP(w, r) + return + } + + state := sessionFromContext(r.Context()) + if state.CSRFToken == "" { + http.Error(w, "missing csrf token", http.StatusForbidden) + return + } + + token := r.Header.Get("X-CSRF-Token") + if token == "" { + if err := r.ParseForm(); err == nil { + token = r.Form.Get("_csrf") + } + } + + if !validCSRFToken(token, state.CSRFToken) { + http.Error(w, "invalid csrf token", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} + type sessionContextKey struct{} func withSession(ctx context.Context, state SessionState) context.Context { @@ -28,3 +72,13 @@ func sessionFromContext(ctx context.Context) SessionState { } return SessionState{} } + +func validCSRFToken(provided, expected string) bool { + if provided == "" || expected == "" { + return false + } + if len(provided) != len(expected) { + return false + } + return subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) == 1 +} diff --git a/internal/server/server.go b/internal/server/server.go index 14eec47..9c60a53 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/rand" "fmt" "html/template" "net/http" @@ -23,7 +24,7 @@ const ( type Server struct { templates *template.Template authService *auth.Service - sessions *SessionManager + sessions *SessionStore } // New constructs a Server with parsed templates and default state. @@ -43,10 +44,20 @@ func New() (*Server, error) { return nil, fmt.Errorf("seed user: %w", err) } + secret := make([]byte, 32) + if _, err := rand.Read(secret); err != nil { + return nil, fmt.Errorf("session secret: %w", err) + } + + sessionStore, err := NewSessionStore(secret) + if err != nil { + return nil, fmt.Errorf("session store: %w", err) + } + return &Server{ templates: tmpl, authService: auth.NewService(store), - sessions: NewSessionManager(), + sessions: sessionStore, }, nil } @@ -59,6 +70,7 @@ func (s *Server) Router() http.Handler { middleware.Logger, middleware.Recoverer, s.sessionMiddleware, + s.csrfMiddleware, ) s.registerRoutes(r) return r diff --git a/internal/server/session.go b/internal/server/session.go index 8a7f569..76f30d3 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -1,50 +1,102 @@ package server -import "sync" +import ( + "crypto/rand" + "encoding/base64" + "errors" + "net/http" + "time" +) -// SessionState represents the snapshot of session metadata for a request. +const ( + sessionCookieName = "auth_session" + csrfSessionKey = "csrf_token" + sessionLifetime = 12 * time.Hour + sessionSecretMinLength = 32 + csrfTokenByteLength int = 32 +) + +// SessionStore persists session data using secure HTTP cookies. +type SessionStore struct { + secret []byte +} + +// NewSessionStore creates a cookie-backed session store. +func NewSessionStore(secret []byte) (*SessionStore, error) { + if len(secret) < sessionSecretMinLength { + return nil, errors.New("session secret must be at least 32 bytes") + } + // copy secret to avoid external mutation + buf := make([]byte, len(secret)) + copy(buf, secret) + return &SessionStore{secret: buf}, nil +} + +// SessionState holds per-request session data after loading. type SessionState struct { Authenticated bool Email string + CSRFToken string } -// SessionManager is a placeholder for future session persistence. -type SessionManager struct { - mu sync.RWMutex - authenticated bool - currentAccount string -} - -// NewSessionManager constructs an empty session manager. -func NewSessionManager() *SessionManager { - return &SessionManager{} -} - -// SetAuthenticated marks the provided account as the active authenticated user. -func (m *SessionManager) SetAuthenticated(email string) { - m.mu.Lock() - defer m.mu.Unlock() - - m.authenticated = true - m.currentAccount = email -} - -// Clear removes any active authentication data. -func (m *SessionManager) Clear() { - m.mu.Lock() - defer m.mu.Unlock() - - m.authenticated = false - m.currentAccount = "" -} - -// Snapshot captures the current session state for contextual use. -func (m *SessionManager) Snapshot() SessionState { - m.mu.RLock() - defer m.mu.RUnlock() - - return SessionState{ - Authenticated: m.authenticated, - Email: m.currentAccount, +// Load extracts session data from the request cookies. +func (s *SessionStore) Load(r *http.Request) SessionState { + c, err := r.Cookie(sessionCookieName) + if err != nil { + return SessionState{} } + + payload, err := decodeSession(c.Value, s.secret) + if err != nil { + return SessionState{} + } + + return payload +} + +// Save persists the session state onto the response cookies. +func (s *SessionStore) Save(w http.ResponseWriter, state SessionState) error { + serialized, err := encodeSession(state, s.secret) + if err != nil { + return err + } + + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: serialized, + Path: "/", + HttpOnly: true, + Secure: false, // TODO: in production, set to true + SameSite: http.SameSiteLaxMode, + Expires: time.Now().Add(sessionLifetime), + }) + + return nil +} + +// Clear removes the session cookie from the client. +func (s *SessionStore) Clear(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + MaxAge: -1, + HttpOnly: true, + Secure: false, + SameSite: http.SameSiteLaxMode, + }) +} + +// ensureCSRFToken returns a session state with a CSRF token present. +func ensureCSRFToken(state SessionState) (SessionState, error) { + if state.CSRFToken != "" { + return state, nil + } + token := make([]byte, csrfTokenByteLength) + if _, err := rand.Read(token); err != nil { + return state, err + } + state.CSRFToken = base64.RawURLEncoding.EncodeToString(token) + return state, nil } diff --git a/internal/server/session_encoding.go b/internal/server/session_encoding.go new file mode 100644 index 0000000..6753aee --- /dev/null +++ b/internal/server/session_encoding.go @@ -0,0 +1,53 @@ +package server + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" +) + +func encodeSession(state SessionState, secret []byte) (string, error) { + payload, err := json.Marshal(state) + if err != nil { + return "", err + } + + mac := hmac.New(sha256.New, secret) + mac.Write(payload) + sig := mac.Sum(nil) + + combined := append(payload, sig...) + return base64.RawURLEncoding.EncodeToString(combined), nil +} + +func decodeSession(raw string, secret []byte) (SessionState, error) { + var state SessionState + + decoded, err := base64.RawURLEncoding.DecodeString(raw) + if err != nil { + return state, err + } + + if len(decoded) <= sha256.Size { + return state, errors.New("session payload too small") + } + + payload := decoded[:len(decoded)-sha256.Size] + providedSig := decoded[len(decoded)-sha256.Size:] + + mac := hmac.New(sha256.New, secret) + mac.Write(payload) + expectedSig := mac.Sum(nil) + + if !hmac.Equal(providedSig, expectedSig) { + return state, errors.New("session signature mismatch") + } + + if err := json.Unmarshal(payload, &state); err != nil { + return state, err + } + + return state, nil +} diff --git a/internal/server/views.go b/internal/server/views.go index b4309de..072fa26 100644 --- a/internal/server/views.go +++ b/internal/server/views.go @@ -2,14 +2,15 @@ package server // PageData contains fields shared by the templates for now. type PageData struct { - Email string - Error string + Email string + Error string + CSRFToken string } -func newIndexData(email, errMsg string) PageData { - return PageData{Email: email, Error: errMsg} +func newIndexData(email, errMsg, token string) PageData { + return PageData{Email: email, Error: errMsg, CSRFToken: token} } -func newUnauthorizedData(errMsg string) PageData { - return PageData{Error: errMsg} +func newUnauthorizedData(errMsg, token string) PageData { + return PageData{Error: errMsg, CSRFToken: token} } diff --git a/web/templates/in.html b/web/templates/in.html index 7823fe3..3fce694 100644 --- a/web/templates/in.html +++ b/web/templates/in.html @@ -15,6 +15,7 @@

You are signed in as {{.Email}}.

This placeholder dashboard will evolve as we flesh out the auth flow.

+
diff --git a/web/templates/index.html b/web/templates/index.html index ab96f61..3456526 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -23,6 +23,7 @@ {{end}}
+