feat: secure sessions and csrf

This commit is contained in:
Ruidy 2025-09-20 01:08:03 +02:00
parent 346678027f
commit da1bf44d8f
No known key found for this signature in database
GPG key ID: 705C24D202990805
11 changed files with 241 additions and 60 deletions

View file

@ -20,7 +20,7 @@ Implement email/password authentication with secure password hashing, CSRF prote
## Coding Style & Naming Conventions
Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Template IDs and Alpine component names should reflect their role (e.g., `login_form`).
Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Promote named constants/variables over magic numbers or strings. Template IDs and Alpine component names should reflect their role (e.g., `login_form`).
## Testing Guidelines

View file

@ -10,6 +10,8 @@ import (
func (s *Server) loginHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state := sessionFromContext(r.Context())
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form submission", http.StatusBadRequest)
return
@ -21,20 +23,25 @@ func (s *Server) loginHandler() http.HandlerFunc {
email, err := auth.NewUserEmail(emailInput)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
s.render(w, "index.html", newIndexData("", "Email and password are required."))
s.render(w, "index.html", newIndexData("", "Email and password are required.", state.CSRFToken))
return
}
account, err := s.authService.Authenticate(r.Context(), email, password)
switch {
case err == nil:
s.sessions.SetAuthenticated(account.Email.String())
state.Authenticated = true
state.Email = account.Email.String()
if err := s.sessions.Save(w, state); err != nil {
log.Printf("session: save failed: %v", err)
}
http.Redirect(w, r, "/in", http.StatusSeeOther)
case errors.Is(err, auth.ErrInvalidInput):
w.WriteHeader(http.StatusBadRequest)
s.render(w, "index.html", newIndexData(email.String(), "Email and password are required."))
s.render(w, "index.html", newIndexData(email.String(), "Email and password are required.", state.CSRFToken))
case errors.Is(err, auth.ErrInvalidCredentials):
s.renderLoginFailure(w, email)
s.renderLoginFailure(w, email, state.CSRFToken)
default:
log.Printf("auth: authenticate failed: %v", err)
http.Error(w, "unexpected error", http.StatusInternalServerError)
@ -44,12 +51,12 @@ func (s *Server) loginHandler() http.HandlerFunc {
func (s *Server) logoutHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.sessions.Clear()
s.sessions.Clear(w)
http.Redirect(w, r, "/", http.StatusSeeOther)
}
}
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail) {
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) {
w.WriteHeader(http.StatusUnauthorized)
s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials."))
s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials.", token))
}

View file

@ -8,10 +8,10 @@ func (s *Server) dashboardHandler() http.HandlerFunc {
if !state.Authenticated {
w.WriteHeader(http.StatusUnauthorized)
s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue."))
s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue.", state.CSRFToken))
return
}
s.render(w, "in.html", PageData{Email: state.Email})
s.render(w, "in.html", PageData{Email: state.Email, CSRFToken: state.CSRFToken})
}
}

View file

@ -5,6 +5,6 @@ import "net/http"
func (s *Server) indexHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state := sessionFromContext(r.Context())
s.render(w, "index.html", newIndexData(state.Email, ""))
s.render(w, "index.html", newIndexData(state.Email, "", state.CSRFToken))
}
}

View file

@ -2,17 +2,61 @@ package server
import (
"context"
"crypto/subtle"
"log"
"net/http"
)
func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state := s.sessions.Snapshot()
state := s.sessions.Load(r)
updated, err := ensureCSRFToken(state)
if err != nil {
log.Printf("session: csrf token generation failed: %v", err)
http.Error(w, "session error", http.StatusInternalServerError)
return
}
state = updated
if err := s.sessions.Save(w, state); err != nil {
log.Printf("session: save failed: %v", err)
}
ctx := withSession(r.Context(), state)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (s *Server) csrfMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
next.ServeHTTP(w, r)
return
}
state := sessionFromContext(r.Context())
if state.CSRFToken == "" {
http.Error(w, "missing csrf token", http.StatusForbidden)
return
}
token := r.Header.Get("X-CSRF-Token")
if token == "" {
if err := r.ParseForm(); err == nil {
token = r.Form.Get("_csrf")
}
}
if !validCSRFToken(token, state.CSRFToken) {
http.Error(w, "invalid csrf token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
type sessionContextKey struct{}
func withSession(ctx context.Context, state SessionState) context.Context {
@ -28,3 +72,13 @@ func sessionFromContext(ctx context.Context) SessionState {
}
return SessionState{}
}
func validCSRFToken(provided, expected string) bool {
if provided == "" || expected == "" {
return false
}
if len(provided) != len(expected) {
return false
}
return subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) == 1
}

View file

@ -2,6 +2,7 @@ package server
import (
"context"
"crypto/rand"
"fmt"
"html/template"
"net/http"
@ -23,7 +24,7 @@ const (
type Server struct {
templates *template.Template
authService *auth.Service
sessions *SessionManager
sessions *SessionStore
}
// New constructs a Server with parsed templates and default state.
@ -43,10 +44,20 @@ func New() (*Server, error) {
return nil, fmt.Errorf("seed user: %w", err)
}
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, fmt.Errorf("session secret: %w", err)
}
sessionStore, err := NewSessionStore(secret)
if err != nil {
return nil, fmt.Errorf("session store: %w", err)
}
return &Server{
templates: tmpl,
authService: auth.NewService(store),
sessions: NewSessionManager(),
sessions: sessionStore,
}, nil
}
@ -59,6 +70,7 @@ func (s *Server) Router() http.Handler {
middleware.Logger,
middleware.Recoverer,
s.sessionMiddleware,
s.csrfMiddleware,
)
s.registerRoutes(r)
return r

View file

@ -1,50 +1,102 @@
package server
import "sync"
import (
"crypto/rand"
"encoding/base64"
"errors"
"net/http"
"time"
)
// SessionState represents the snapshot of session metadata for a request.
const (
sessionCookieName = "auth_session"
csrfSessionKey = "csrf_token"
sessionLifetime = 12 * time.Hour
sessionSecretMinLength = 32
csrfTokenByteLength int = 32
)
// SessionStore persists session data using secure HTTP cookies.
type SessionStore struct {
secret []byte
}
// NewSessionStore creates a cookie-backed session store.
func NewSessionStore(secret []byte) (*SessionStore, error) {
if len(secret) < sessionSecretMinLength {
return nil, errors.New("session secret must be at least 32 bytes")
}
// copy secret to avoid external mutation
buf := make([]byte, len(secret))
copy(buf, secret)
return &SessionStore{secret: buf}, nil
}
// SessionState holds per-request session data after loading.
type SessionState struct {
Authenticated bool
Email string
CSRFToken string
}
// SessionManager is a placeholder for future session persistence.
type SessionManager struct {
mu sync.RWMutex
authenticated bool
currentAccount string
}
// NewSessionManager constructs an empty session manager.
func NewSessionManager() *SessionManager {
return &SessionManager{}
}
// SetAuthenticated marks the provided account as the active authenticated user.
func (m *SessionManager) SetAuthenticated(email string) {
m.mu.Lock()
defer m.mu.Unlock()
m.authenticated = true
m.currentAccount = email
}
// Clear removes any active authentication data.
func (m *SessionManager) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.authenticated = false
m.currentAccount = ""
}
// Snapshot captures the current session state for contextual use.
func (m *SessionManager) Snapshot() SessionState {
m.mu.RLock()
defer m.mu.RUnlock()
return SessionState{
Authenticated: m.authenticated,
Email: m.currentAccount,
// Load extracts session data from the request cookies.
func (s *SessionStore) Load(r *http.Request) SessionState {
c, err := r.Cookie(sessionCookieName)
if err != nil {
return SessionState{}
}
payload, err := decodeSession(c.Value, s.secret)
if err != nil {
return SessionState{}
}
return payload
}
// Save persists the session state onto the response cookies.
func (s *SessionStore) Save(w http.ResponseWriter, state SessionState) error {
serialized, err := encodeSession(state, s.secret)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: serialized,
Path: "/",
HttpOnly: true,
Secure: false, // TODO: in production, set to true
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(sessionLifetime),
})
return nil
}
// Clear removes the session cookie from the client.
func (s *SessionStore) Clear(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
})
}
// ensureCSRFToken returns a session state with a CSRF token present.
func ensureCSRFToken(state SessionState) (SessionState, error) {
if state.CSRFToken != "" {
return state, nil
}
token := make([]byte, csrfTokenByteLength)
if _, err := rand.Read(token); err != nil {
return state, err
}
state.CSRFToken = base64.RawURLEncoding.EncodeToString(token)
return state, nil
}

View file

@ -0,0 +1,53 @@
package server
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
)
func encodeSession(state SessionState, secret []byte) (string, error) {
payload, err := json.Marshal(state)
if err != nil {
return "", err
}
mac := hmac.New(sha256.New, secret)
mac.Write(payload)
sig := mac.Sum(nil)
combined := append(payload, sig...)
return base64.RawURLEncoding.EncodeToString(combined), nil
}
func decodeSession(raw string, secret []byte) (SessionState, error) {
var state SessionState
decoded, err := base64.RawURLEncoding.DecodeString(raw)
if err != nil {
return state, err
}
if len(decoded) <= sha256.Size {
return state, errors.New("session payload too small")
}
payload := decoded[:len(decoded)-sha256.Size]
providedSig := decoded[len(decoded)-sha256.Size:]
mac := hmac.New(sha256.New, secret)
mac.Write(payload)
expectedSig := mac.Sum(nil)
if !hmac.Equal(providedSig, expectedSig) {
return state, errors.New("session signature mismatch")
}
if err := json.Unmarshal(payload, &state); err != nil {
return state, err
}
return state, nil
}

View file

@ -4,12 +4,13 @@ package server
type PageData struct {
Email string
Error string
CSRFToken string
}
func newIndexData(email, errMsg string) PageData {
return PageData{Email: email, Error: errMsg}
func newIndexData(email, errMsg, token string) PageData {
return PageData{Email: email, Error: errMsg, CSRFToken: token}
}
func newUnauthorizedData(errMsg string) PageData {
return PageData{Error: errMsg}
func newUnauthorizedData(errMsg, token string) PageData {
return PageData{Error: errMsg, CSRFToken: token}
}

View file

@ -15,6 +15,7 @@
<p>You are signed in as <strong>{{.Email}}</strong>.</p>
<p>This placeholder dashboard will evolve as we flesh out the auth flow.</p>
<form method="post" action="/logout">
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
<button type="submit" class="secondary">Sign out</button>
</form>
</main>

View file

@ -23,6 +23,7 @@
</article>
{{end}}
<form method="post" action="/login">
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
<label for="email">Email</label>
<input
type="email"