feat: secure sessions and csrf

This commit is contained in:
Ruidy 2025-09-20 01:08:03 +02:00
parent 346678027f
commit da1bf44d8f
No known key found for this signature in database
GPG key ID: 705C24D202990805
11 changed files with 241 additions and 60 deletions

View file

@ -20,7 +20,7 @@ Implement email/password authentication with secure password hashing, CSRF prote
## Coding Style & Naming Conventions ## Coding Style & Naming Conventions
Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Template IDs and Alpine component names should reflect their role (e.g., `login_form`). Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Promote named constants/variables over magic numbers or strings. Template IDs and Alpine component names should reflect their role (e.g., `login_form`).
## Testing Guidelines ## Testing Guidelines

View file

@ -10,6 +10,8 @@ import (
func (s *Server) loginHandler() http.HandlerFunc { func (s *Server) loginHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
state := sessionFromContext(r.Context())
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form submission", http.StatusBadRequest) http.Error(w, "invalid form submission", http.StatusBadRequest)
return return
@ -21,20 +23,25 @@ func (s *Server) loginHandler() http.HandlerFunc {
email, err := auth.NewUserEmail(emailInput) email, err := auth.NewUserEmail(emailInput)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "index.html", newIndexData("", "Email and password are required.")) s.render(w, "index.html", newIndexData("", "Email and password are required.", state.CSRFToken))
return return
} }
account, err := s.authService.Authenticate(r.Context(), email, password) account, err := s.authService.Authenticate(r.Context(), email, password)
switch { switch {
case err == nil: case err == nil:
s.sessions.SetAuthenticated(account.Email.String()) state.Authenticated = true
state.Email = account.Email.String()
if err := s.sessions.Save(w, state); err != nil {
log.Printf("session: save failed: %v", err)
}
http.Redirect(w, r, "/in", http.StatusSeeOther) http.Redirect(w, r, "/in", http.StatusSeeOther)
case errors.Is(err, auth.ErrInvalidInput): case errors.Is(err, auth.ErrInvalidInput):
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
s.render(w, "index.html", newIndexData(email.String(), "Email and password are required.")) s.render(w, "index.html", newIndexData(email.String(), "Email and password are required.", state.CSRFToken))
case errors.Is(err, auth.ErrInvalidCredentials): case errors.Is(err, auth.ErrInvalidCredentials):
s.renderLoginFailure(w, email) s.renderLoginFailure(w, email, state.CSRFToken)
default: default:
log.Printf("auth: authenticate failed: %v", err) log.Printf("auth: authenticate failed: %v", err)
http.Error(w, "unexpected error", http.StatusInternalServerError) http.Error(w, "unexpected error", http.StatusInternalServerError)
@ -44,12 +51,12 @@ func (s *Server) loginHandler() http.HandlerFunc {
func (s *Server) logoutHandler() http.HandlerFunc { func (s *Server) logoutHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
s.sessions.Clear() s.sessions.Clear(w)
http.Redirect(w, r, "/", http.StatusSeeOther) http.Redirect(w, r, "/", http.StatusSeeOther)
} }
} }
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail) { func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials.")) s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials.", token))
} }

View file

@ -8,10 +8,10 @@ func (s *Server) dashboardHandler() http.HandlerFunc {
if !state.Authenticated { if !state.Authenticated {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue.")) s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue.", state.CSRFToken))
return return
} }
s.render(w, "in.html", PageData{Email: state.Email}) s.render(w, "in.html", PageData{Email: state.Email, CSRFToken: state.CSRFToken})
} }
} }

View file

@ -5,6 +5,6 @@ import "net/http"
func (s *Server) indexHandler() http.HandlerFunc { func (s *Server) indexHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
state := sessionFromContext(r.Context()) state := sessionFromContext(r.Context())
s.render(w, "index.html", newIndexData(state.Email, "")) s.render(w, "index.html", newIndexData(state.Email, "", state.CSRFToken))
} }
} }

View file

@ -2,17 +2,61 @@ package server
import ( import (
"context" "context"
"crypto/subtle"
"log"
"net/http" "net/http"
) )
func (s *Server) sessionMiddleware(next http.Handler) http.Handler { func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state := s.sessions.Snapshot() state := s.sessions.Load(r)
updated, err := ensureCSRFToken(state)
if err != nil {
log.Printf("session: csrf token generation failed: %v", err)
http.Error(w, "session error", http.StatusInternalServerError)
return
}
state = updated
if err := s.sessions.Save(w, state); err != nil {
log.Printf("session: save failed: %v", err)
}
ctx := withSession(r.Context(), state) ctx := withSession(r.Context(), state)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
func (s *Server) csrfMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
next.ServeHTTP(w, r)
return
}
state := sessionFromContext(r.Context())
if state.CSRFToken == "" {
http.Error(w, "missing csrf token", http.StatusForbidden)
return
}
token := r.Header.Get("X-CSRF-Token")
if token == "" {
if err := r.ParseForm(); err == nil {
token = r.Form.Get("_csrf")
}
}
if !validCSRFToken(token, state.CSRFToken) {
http.Error(w, "invalid csrf token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
type sessionContextKey struct{} type sessionContextKey struct{}
func withSession(ctx context.Context, state SessionState) context.Context { func withSession(ctx context.Context, state SessionState) context.Context {
@ -28,3 +72,13 @@ func sessionFromContext(ctx context.Context) SessionState {
} }
return SessionState{} return SessionState{}
} }
func validCSRFToken(provided, expected string) bool {
if provided == "" || expected == "" {
return false
}
if len(provided) != len(expected) {
return false
}
return subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) == 1
}

View file

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"crypto/rand"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
@ -23,7 +24,7 @@ const (
type Server struct { type Server struct {
templates *template.Template templates *template.Template
authService *auth.Service authService *auth.Service
sessions *SessionManager sessions *SessionStore
} }
// New constructs a Server with parsed templates and default state. // New constructs a Server with parsed templates and default state.
@ -43,10 +44,20 @@ func New() (*Server, error) {
return nil, fmt.Errorf("seed user: %w", err) return nil, fmt.Errorf("seed user: %w", err)
} }
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, fmt.Errorf("session secret: %w", err)
}
sessionStore, err := NewSessionStore(secret)
if err != nil {
return nil, fmt.Errorf("session store: %w", err)
}
return &Server{ return &Server{
templates: tmpl, templates: tmpl,
authService: auth.NewService(store), authService: auth.NewService(store),
sessions: NewSessionManager(), sessions: sessionStore,
}, nil }, nil
} }
@ -59,6 +70,7 @@ func (s *Server) Router() http.Handler {
middleware.Logger, middleware.Logger,
middleware.Recoverer, middleware.Recoverer,
s.sessionMiddleware, s.sessionMiddleware,
s.csrfMiddleware,
) )
s.registerRoutes(r) s.registerRoutes(r)
return r return r

View file

@ -1,50 +1,102 @@
package server package server
import "sync" import (
"crypto/rand"
"encoding/base64"
"errors"
"net/http"
"time"
)
// SessionState represents the snapshot of session metadata for a request. const (
sessionCookieName = "auth_session"
csrfSessionKey = "csrf_token"
sessionLifetime = 12 * time.Hour
sessionSecretMinLength = 32
csrfTokenByteLength int = 32
)
// SessionStore persists session data using secure HTTP cookies.
type SessionStore struct {
secret []byte
}
// NewSessionStore creates a cookie-backed session store.
func NewSessionStore(secret []byte) (*SessionStore, error) {
if len(secret) < sessionSecretMinLength {
return nil, errors.New("session secret must be at least 32 bytes")
}
// copy secret to avoid external mutation
buf := make([]byte, len(secret))
copy(buf, secret)
return &SessionStore{secret: buf}, nil
}
// SessionState holds per-request session data after loading.
type SessionState struct { type SessionState struct {
Authenticated bool Authenticated bool
Email string Email string
CSRFToken string
} }
// SessionManager is a placeholder for future session persistence. // Load extracts session data from the request cookies.
type SessionManager struct { func (s *SessionStore) Load(r *http.Request) SessionState {
mu sync.RWMutex c, err := r.Cookie(sessionCookieName)
authenticated bool if err != nil {
currentAccount string return SessionState{}
}
// NewSessionManager constructs an empty session manager.
func NewSessionManager() *SessionManager {
return &SessionManager{}
}
// SetAuthenticated marks the provided account as the active authenticated user.
func (m *SessionManager) SetAuthenticated(email string) {
m.mu.Lock()
defer m.mu.Unlock()
m.authenticated = true
m.currentAccount = email
}
// Clear removes any active authentication data.
func (m *SessionManager) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.authenticated = false
m.currentAccount = ""
}
// Snapshot captures the current session state for contextual use.
func (m *SessionManager) Snapshot() SessionState {
m.mu.RLock()
defer m.mu.RUnlock()
return SessionState{
Authenticated: m.authenticated,
Email: m.currentAccount,
} }
payload, err := decodeSession(c.Value, s.secret)
if err != nil {
return SessionState{}
}
return payload
}
// Save persists the session state onto the response cookies.
func (s *SessionStore) Save(w http.ResponseWriter, state SessionState) error {
serialized, err := encodeSession(state, s.secret)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: serialized,
Path: "/",
HttpOnly: true,
Secure: false, // TODO: in production, set to true
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(sessionLifetime),
})
return nil
}
// Clear removes the session cookie from the client.
func (s *SessionStore) Clear(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
})
}
// ensureCSRFToken returns a session state with a CSRF token present.
func ensureCSRFToken(state SessionState) (SessionState, error) {
if state.CSRFToken != "" {
return state, nil
}
token := make([]byte, csrfTokenByteLength)
if _, err := rand.Read(token); err != nil {
return state, err
}
state.CSRFToken = base64.RawURLEncoding.EncodeToString(token)
return state, nil
} }

View file

@ -0,0 +1,53 @@
package server
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
)
func encodeSession(state SessionState, secret []byte) (string, error) {
payload, err := json.Marshal(state)
if err != nil {
return "", err
}
mac := hmac.New(sha256.New, secret)
mac.Write(payload)
sig := mac.Sum(nil)
combined := append(payload, sig...)
return base64.RawURLEncoding.EncodeToString(combined), nil
}
func decodeSession(raw string, secret []byte) (SessionState, error) {
var state SessionState
decoded, err := base64.RawURLEncoding.DecodeString(raw)
if err != nil {
return state, err
}
if len(decoded) <= sha256.Size {
return state, errors.New("session payload too small")
}
payload := decoded[:len(decoded)-sha256.Size]
providedSig := decoded[len(decoded)-sha256.Size:]
mac := hmac.New(sha256.New, secret)
mac.Write(payload)
expectedSig := mac.Sum(nil)
if !hmac.Equal(providedSig, expectedSig) {
return state, errors.New("session signature mismatch")
}
if err := json.Unmarshal(payload, &state); err != nil {
return state, err
}
return state, nil
}

View file

@ -2,14 +2,15 @@ package server
// PageData contains fields shared by the templates for now. // PageData contains fields shared by the templates for now.
type PageData struct { type PageData struct {
Email string Email string
Error string Error string
CSRFToken string
} }
func newIndexData(email, errMsg string) PageData { func newIndexData(email, errMsg, token string) PageData {
return PageData{Email: email, Error: errMsg} return PageData{Email: email, Error: errMsg, CSRFToken: token}
} }
func newUnauthorizedData(errMsg string) PageData { func newUnauthorizedData(errMsg, token string) PageData {
return PageData{Error: errMsg} return PageData{Error: errMsg, CSRFToken: token}
} }

View file

@ -15,6 +15,7 @@
<p>You are signed in as <strong>{{.Email}}</strong>.</p> <p>You are signed in as <strong>{{.Email}}</strong>.</p>
<p>This placeholder dashboard will evolve as we flesh out the auth flow.</p> <p>This placeholder dashboard will evolve as we flesh out the auth flow.</p>
<form method="post" action="/logout"> <form method="post" action="/logout">
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
<button type="submit" class="secondary">Sign out</button> <button type="submit" class="secondary">Sign out</button>
</form> </form>
</main> </main>

View file

@ -23,6 +23,7 @@
</article> </article>
{{end}} {{end}}
<form method="post" action="/login"> <form method="post" action="/login">
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
<label for="email">Email</label> <label for="email">Email</label>
<input <input
type="email" type="email"