mirror of
https://github.com/rjNemo/auth
synced 2026-06-06 08:26:39 +00:00
feat: secure sessions and csrf
This commit is contained in:
parent
346678027f
commit
da1bf44d8f
11 changed files with 241 additions and 60 deletions
|
|
@ -20,7 +20,7 @@ Implement email/password authentication with secure password hashing, CSRF prote
|
||||||
|
|
||||||
## Coding Style & Naming Conventions
|
## Coding Style & Naming Conventions
|
||||||
|
|
||||||
Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Template IDs and Alpine component names should reflect their role (e.g., `login_form`).
|
Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Promote named constants/variables over magic numbers or strings. Template IDs and Alpine component names should reflect their role (e.g., `login_form`).
|
||||||
|
|
||||||
## Testing Guidelines
|
## Testing Guidelines
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ import (
|
||||||
|
|
||||||
func (s *Server) loginHandler() http.HandlerFunc {
|
func (s *Server) loginHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state := sessionFromContext(r.Context())
|
||||||
|
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
http.Error(w, "invalid form submission", http.StatusBadRequest)
|
http.Error(w, "invalid form submission", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
|
@ -21,20 +23,25 @@ func (s *Server) loginHandler() http.HandlerFunc {
|
||||||
email, err := auth.NewUserEmail(emailInput)
|
email, err := auth.NewUserEmail(emailInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
s.render(w, "index.html", newIndexData("", "Email and password are required."))
|
s.render(w, "index.html", newIndexData("", "Email and password are required.", state.CSRFToken))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.authService.Authenticate(r.Context(), email, password)
|
account, err := s.authService.Authenticate(r.Context(), email, password)
|
||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
s.sessions.SetAuthenticated(account.Email.String())
|
state.Authenticated = true
|
||||||
|
state.Email = account.Email.String()
|
||||||
|
if err := s.sessions.Save(w, state); err != nil {
|
||||||
|
log.Printf("session: save failed: %v", err)
|
||||||
|
}
|
||||||
http.Redirect(w, r, "/in", http.StatusSeeOther)
|
http.Redirect(w, r, "/in", http.StatusSeeOther)
|
||||||
|
|
||||||
case errors.Is(err, auth.ErrInvalidInput):
|
case errors.Is(err, auth.ErrInvalidInput):
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
s.render(w, "index.html", newIndexData(email.String(), "Email and password are required."))
|
s.render(w, "index.html", newIndexData(email.String(), "Email and password are required.", state.CSRFToken))
|
||||||
case errors.Is(err, auth.ErrInvalidCredentials):
|
case errors.Is(err, auth.ErrInvalidCredentials):
|
||||||
s.renderLoginFailure(w, email)
|
s.renderLoginFailure(w, email, state.CSRFToken)
|
||||||
default:
|
default:
|
||||||
log.Printf("auth: authenticate failed: %v", err)
|
log.Printf("auth: authenticate failed: %v", err)
|
||||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||||
|
|
@ -44,12 +51,12 @@ func (s *Server) loginHandler() http.HandlerFunc {
|
||||||
|
|
||||||
func (s *Server) logoutHandler() http.HandlerFunc {
|
func (s *Server) logoutHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
s.sessions.Clear()
|
s.sessions.Clear(w)
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail) {
|
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials."))
|
s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials.", token))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ func (s *Server) dashboardHandler() http.HandlerFunc {
|
||||||
|
|
||||||
if !state.Authenticated {
|
if !state.Authenticated {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue."))
|
s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue.", state.CSRFToken))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.render(w, "in.html", PageData{Email: state.Email})
|
s.render(w, "in.html", PageData{Email: state.Email, CSRFToken: state.CSRFToken})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,6 @@ import "net/http"
|
||||||
func (s *Server) indexHandler() http.HandlerFunc {
|
func (s *Server) indexHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
state := sessionFromContext(r.Context())
|
state := sessionFromContext(r.Context())
|
||||||
s.render(w, "index.html", newIndexData(state.Email, ""))
|
s.render(w, "index.html", newIndexData(state.Email, "", state.CSRFToken))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,61 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
|
func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
state := s.sessions.Snapshot()
|
state := s.sessions.Load(r)
|
||||||
|
updated, err := ensureCSRFToken(state)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("session: csrf token generation failed: %v", err)
|
||||||
|
http.Error(w, "session error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state = updated
|
||||||
|
|
||||||
|
if err := s.sessions.Save(w, state); err != nil {
|
||||||
|
log.Printf("session: save failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ctx := withSession(r.Context(), state)
|
ctx := withSession(r.Context(), state)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) csrfMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := sessionFromContext(r.Context())
|
||||||
|
if state.CSRFToken == "" {
|
||||||
|
http.Error(w, "missing csrf token", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := r.Header.Get("X-CSRF-Token")
|
||||||
|
if token == "" {
|
||||||
|
if err := r.ParseForm(); err == nil {
|
||||||
|
token = r.Form.Get("_csrf")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validCSRFToken(token, state.CSRFToken) {
|
||||||
|
http.Error(w, "invalid csrf token", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type sessionContextKey struct{}
|
type sessionContextKey struct{}
|
||||||
|
|
||||||
func withSession(ctx context.Context, state SessionState) context.Context {
|
func withSession(ctx context.Context, state SessionState) context.Context {
|
||||||
|
|
@ -28,3 +72,13 @@ func sessionFromContext(ctx context.Context) SessionState {
|
||||||
}
|
}
|
||||||
return SessionState{}
|
return SessionState{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validCSRFToken(provided, expected string) bool {
|
||||||
|
if provided == "" || expected == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(provided) != len(expected) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) == 1
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -23,7 +24,7 @@ const (
|
||||||
type Server struct {
|
type Server struct {
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
authService *auth.Service
|
authService *auth.Service
|
||||||
sessions *SessionManager
|
sessions *SessionStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// New constructs a Server with parsed templates and default state.
|
// New constructs a Server with parsed templates and default state.
|
||||||
|
|
@ -43,10 +44,20 @@ func New() (*Server, error) {
|
||||||
return nil, fmt.Errorf("seed user: %w", err)
|
return nil, fmt.Errorf("seed user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
secret := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(secret); err != nil {
|
||||||
|
return nil, fmt.Errorf("session secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionStore, err := NewSessionStore(secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("session store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
templates: tmpl,
|
templates: tmpl,
|
||||||
authService: auth.NewService(store),
|
authService: auth.NewService(store),
|
||||||
sessions: NewSessionManager(),
|
sessions: sessionStore,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -59,6 +70,7 @@ func (s *Server) Router() http.Handler {
|
||||||
middleware.Logger,
|
middleware.Logger,
|
||||||
middleware.Recoverer,
|
middleware.Recoverer,
|
||||||
s.sessionMiddleware,
|
s.sessionMiddleware,
|
||||||
|
s.csrfMiddleware,
|
||||||
)
|
)
|
||||||
s.registerRoutes(r)
|
s.registerRoutes(r)
|
||||||
return r
|
return r
|
||||||
|
|
|
||||||
|
|
@ -1,50 +1,102 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// SessionState represents the snapshot of session metadata for a request.
|
const (
|
||||||
|
sessionCookieName = "auth_session"
|
||||||
|
csrfSessionKey = "csrf_token"
|
||||||
|
sessionLifetime = 12 * time.Hour
|
||||||
|
sessionSecretMinLength = 32
|
||||||
|
csrfTokenByteLength int = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionStore persists session data using secure HTTP cookies.
|
||||||
|
type SessionStore struct {
|
||||||
|
secret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionStore creates a cookie-backed session store.
|
||||||
|
func NewSessionStore(secret []byte) (*SessionStore, error) {
|
||||||
|
if len(secret) < sessionSecretMinLength {
|
||||||
|
return nil, errors.New("session secret must be at least 32 bytes")
|
||||||
|
}
|
||||||
|
// copy secret to avoid external mutation
|
||||||
|
buf := make([]byte, len(secret))
|
||||||
|
copy(buf, secret)
|
||||||
|
return &SessionStore{secret: buf}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionState holds per-request session data after loading.
|
||||||
type SessionState struct {
|
type SessionState struct {
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
Email string
|
Email string
|
||||||
|
CSRFToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionManager is a placeholder for future session persistence.
|
// Load extracts session data from the request cookies.
|
||||||
type SessionManager struct {
|
func (s *SessionStore) Load(r *http.Request) SessionState {
|
||||||
mu sync.RWMutex
|
c, err := r.Cookie(sessionCookieName)
|
||||||
authenticated bool
|
if err != nil {
|
||||||
currentAccount string
|
return SessionState{}
|
||||||
}
|
|
||||||
|
|
||||||
// NewSessionManager constructs an empty session manager.
|
|
||||||
func NewSessionManager() *SessionManager {
|
|
||||||
return &SessionManager{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAuthenticated marks the provided account as the active authenticated user.
|
|
||||||
func (m *SessionManager) SetAuthenticated(email string) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
m.authenticated = true
|
|
||||||
m.currentAccount = email
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear removes any active authentication data.
|
|
||||||
func (m *SessionManager) Clear() {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
m.authenticated = false
|
|
||||||
m.currentAccount = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Snapshot captures the current session state for contextual use.
|
|
||||||
func (m *SessionManager) Snapshot() SessionState {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
|
|
||||||
return SessionState{
|
|
||||||
Authenticated: m.authenticated,
|
|
||||||
Email: m.currentAccount,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
payload, err := decodeSession(c.Value, s.secret)
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save persists the session state onto the response cookies.
|
||||||
|
func (s *SessionStore) Save(w http.ResponseWriter, state SessionState) error {
|
||||||
|
serialized, err := encodeSession(state, s.secret)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: sessionCookieName,
|
||||||
|
Value: serialized,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: false, // TODO: in production, set to true
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Expires: time.Now().Add(sessionLifetime),
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes the session cookie from the client.
|
||||||
|
func (s *SessionStore) Clear(w http.ResponseWriter) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: sessionCookieName,
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
Expires: time.Unix(0, 0),
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: false,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureCSRFToken returns a session state with a CSRF token present.
|
||||||
|
func ensureCSRFToken(state SessionState) (SessionState, error) {
|
||||||
|
if state.CSRFToken != "" {
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
token := make([]byte, csrfTokenByteLength)
|
||||||
|
if _, err := rand.Read(token); err != nil {
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
state.CSRFToken = base64.RawURLEncoding.EncodeToString(token)
|
||||||
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
53
internal/server/session_encoding.go
Normal file
53
internal/server/session_encoding.go
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func encodeSession(state SessionState, secret []byte) (string, error) {
|
||||||
|
payload, err := json.Marshal(state)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
mac := hmac.New(sha256.New, secret)
|
||||||
|
mac.Write(payload)
|
||||||
|
sig := mac.Sum(nil)
|
||||||
|
|
||||||
|
combined := append(payload, sig...)
|
||||||
|
return base64.RawURLEncoding.EncodeToString(combined), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSession(raw string, secret []byte) (SessionState, error) {
|
||||||
|
var state SessionState
|
||||||
|
|
||||||
|
decoded, err := base64.RawURLEncoding.DecodeString(raw)
|
||||||
|
if err != nil {
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decoded) <= sha256.Size {
|
||||||
|
return state, errors.New("session payload too small")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := decoded[:len(decoded)-sha256.Size]
|
||||||
|
providedSig := decoded[len(decoded)-sha256.Size:]
|
||||||
|
|
||||||
|
mac := hmac.New(sha256.New, secret)
|
||||||
|
mac.Write(payload)
|
||||||
|
expectedSig := mac.Sum(nil)
|
||||||
|
|
||||||
|
if !hmac.Equal(providedSig, expectedSig) {
|
||||||
|
return state, errors.New("session signature mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(payload, &state); err != nil {
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
@ -2,14 +2,15 @@ package server
|
||||||
|
|
||||||
// PageData contains fields shared by the templates for now.
|
// PageData contains fields shared by the templates for now.
|
||||||
type PageData struct {
|
type PageData struct {
|
||||||
Email string
|
Email string
|
||||||
Error string
|
Error string
|
||||||
|
CSRFToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIndexData(email, errMsg string) PageData {
|
func newIndexData(email, errMsg, token string) PageData {
|
||||||
return PageData{Email: email, Error: errMsg}
|
return PageData{Email: email, Error: errMsg, CSRFToken: token}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUnauthorizedData(errMsg string) PageData {
|
func newUnauthorizedData(errMsg, token string) PageData {
|
||||||
return PageData{Error: errMsg}
|
return PageData{Error: errMsg, CSRFToken: token}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
<p>You are signed in as <strong>{{.Email}}</strong>.</p>
|
<p>You are signed in as <strong>{{.Email}}</strong>.</p>
|
||||||
<p>This placeholder dashboard will evolve as we flesh out the auth flow.</p>
|
<p>This placeholder dashboard will evolve as we flesh out the auth flow.</p>
|
||||||
<form method="post" action="/logout">
|
<form method="post" action="/logout">
|
||||||
|
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
|
||||||
<button type="submit" class="secondary">Sign out</button>
|
<button type="submit" class="secondary">Sign out</button>
|
||||||
</form>
|
</form>
|
||||||
</main>
|
</main>
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
</article>
|
</article>
|
||||||
{{end}}
|
{{end}}
|
||||||
<form method="post" action="/login">
|
<form method="post" action="/login">
|
||||||
|
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
|
||||||
<label for="email">Email</label>
|
<label for="email">Email</label>
|
||||||
<input
|
<input
|
||||||
type="email"
|
type="email"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue