mirror of
https://github.com/rjNemo/auth
synced 2026-06-06 00:16:40 +00:00
feat: secure sessions and csrf
This commit is contained in:
parent
346678027f
commit
da1bf44d8f
11 changed files with 241 additions and 60 deletions
|
|
@ -20,7 +20,7 @@ Implement email/password authentication with secure password hashing, CSRF prote
|
|||
|
||||
## Coding Style & Naming Conventions
|
||||
|
||||
Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Template IDs and Alpine component names should reflect their role (e.g., `login_form`).
|
||||
Trust `gofmt`; avoid manual formatting. Use CamelCase for exported Go identifiers and snake_case for embedded assets. Keep handlers slim, factor shared logic into helpers, and add concise comments only when intent needs clarification. Promote named constants/variables over magic numbers or strings. Template IDs and Alpine component names should reflect their role (e.g., `login_form`).
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import (
|
|||
|
||||
func (s *Server) loginHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state := sessionFromContext(r.Context())
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form submission", http.StatusBadRequest)
|
||||
return
|
||||
|
|
@ -21,20 +23,25 @@ func (s *Server) loginHandler() http.HandlerFunc {
|
|||
email, err := auth.NewUserEmail(emailInput)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "index.html", newIndexData("", "Email and password are required."))
|
||||
s.render(w, "index.html", newIndexData("", "Email and password are required.", state.CSRFToken))
|
||||
return
|
||||
}
|
||||
|
||||
account, err := s.authService.Authenticate(r.Context(), email, password)
|
||||
switch {
|
||||
case err == nil:
|
||||
s.sessions.SetAuthenticated(account.Email.String())
|
||||
state.Authenticated = true
|
||||
state.Email = account.Email.String()
|
||||
if err := s.sessions.Save(w, state); err != nil {
|
||||
log.Printf("session: save failed: %v", err)
|
||||
}
|
||||
http.Redirect(w, r, "/in", http.StatusSeeOther)
|
||||
|
||||
case errors.Is(err, auth.ErrInvalidInput):
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
s.render(w, "index.html", newIndexData(email.String(), "Email and password are required."))
|
||||
s.render(w, "index.html", newIndexData(email.String(), "Email and password are required.", state.CSRFToken))
|
||||
case errors.Is(err, auth.ErrInvalidCredentials):
|
||||
s.renderLoginFailure(w, email)
|
||||
s.renderLoginFailure(w, email, state.CSRFToken)
|
||||
default:
|
||||
log.Printf("auth: authenticate failed: %v", err)
|
||||
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||||
|
|
@ -44,12 +51,12 @@ func (s *Server) loginHandler() http.HandlerFunc {
|
|||
|
||||
func (s *Server) logoutHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s.sessions.Clear()
|
||||
s.sessions.Clear(w)
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail) {
|
||||
func (s *Server) renderLoginFailure(w http.ResponseWriter, email auth.UserEmail, token string) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials."))
|
||||
s.render(w, "index.html", newIndexData(email.String(), "Invalid credentials.", token))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ func (s *Server) dashboardHandler() http.HandlerFunc {
|
|||
|
||||
if !state.Authenticated {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue."))
|
||||
s.render(w, "unauthorized.html", newUnauthorizedData("Sign in to continue.", state.CSRFToken))
|
||||
return
|
||||
}
|
||||
|
||||
s.render(w, "in.html", PageData{Email: state.Email})
|
||||
s.render(w, "in.html", PageData{Email: state.Email, CSRFToken: state.CSRFToken})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,6 @@ import "net/http"
|
|||
func (s *Server) indexHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state := sessionFromContext(r.Context())
|
||||
s.render(w, "index.html", newIndexData(state.Email, ""))
|
||||
s.render(w, "index.html", newIndexData(state.Email, "", state.CSRFToken))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,17 +2,61 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
state := s.sessions.Snapshot()
|
||||
state := s.sessions.Load(r)
|
||||
updated, err := ensureCSRFToken(state)
|
||||
if err != nil {
|
||||
log.Printf("session: csrf token generation failed: %v", err)
|
||||
http.Error(w, "session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
state = updated
|
||||
|
||||
if err := s.sessions.Save(w, state); err != nil {
|
||||
log.Printf("session: save failed: %v", err)
|
||||
}
|
||||
|
||||
ctx := withSession(r.Context(), state)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) csrfMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
state := sessionFromContext(r.Context())
|
||||
if state.CSRFToken == "" {
|
||||
http.Error(w, "missing csrf token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
token := r.Header.Get("X-CSRF-Token")
|
||||
if token == "" {
|
||||
if err := r.ParseForm(); err == nil {
|
||||
token = r.Form.Get("_csrf")
|
||||
}
|
||||
}
|
||||
|
||||
if !validCSRFToken(token, state.CSRFToken) {
|
||||
http.Error(w, "invalid csrf token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
type sessionContextKey struct{}
|
||||
|
||||
func withSession(ctx context.Context, state SessionState) context.Context {
|
||||
|
|
@ -28,3 +72,13 @@ func sessionFromContext(ctx context.Context) SessionState {
|
|||
}
|
||||
return SessionState{}
|
||||
}
|
||||
|
||||
func validCSRFToken(provided, expected string) bool {
|
||||
if provided == "" || expected == "" {
|
||||
return false
|
||||
}
|
||||
if len(provided) != len(expected) {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) == 1
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
|
@ -23,7 +24,7 @@ const (
|
|||
type Server struct {
|
||||
templates *template.Template
|
||||
authService *auth.Service
|
||||
sessions *SessionManager
|
||||
sessions *SessionStore
|
||||
}
|
||||
|
||||
// New constructs a Server with parsed templates and default state.
|
||||
|
|
@ -43,10 +44,20 @@ func New() (*Server, error) {
|
|||
return nil, fmt.Errorf("seed user: %w", err)
|
||||
}
|
||||
|
||||
secret := make([]byte, 32)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, fmt.Errorf("session secret: %w", err)
|
||||
}
|
||||
|
||||
sessionStore, err := NewSessionStore(secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session store: %w", err)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
templates: tmpl,
|
||||
authService: auth.NewService(store),
|
||||
sessions: NewSessionManager(),
|
||||
sessions: sessionStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -59,6 +70,7 @@ func (s *Server) Router() http.Handler {
|
|||
middleware.Logger,
|
||||
middleware.Recoverer,
|
||||
s.sessionMiddleware,
|
||||
s.csrfMiddleware,
|
||||
)
|
||||
s.registerRoutes(r)
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -1,50 +1,102 @@
|
|||
package server
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionState represents the snapshot of session metadata for a request.
|
||||
const (
|
||||
sessionCookieName = "auth_session"
|
||||
csrfSessionKey = "csrf_token"
|
||||
sessionLifetime = 12 * time.Hour
|
||||
sessionSecretMinLength = 32
|
||||
csrfTokenByteLength int = 32
|
||||
)
|
||||
|
||||
// SessionStore persists session data using secure HTTP cookies.
|
||||
type SessionStore struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
// NewSessionStore creates a cookie-backed session store.
|
||||
func NewSessionStore(secret []byte) (*SessionStore, error) {
|
||||
if len(secret) < sessionSecretMinLength {
|
||||
return nil, errors.New("session secret must be at least 32 bytes")
|
||||
}
|
||||
// copy secret to avoid external mutation
|
||||
buf := make([]byte, len(secret))
|
||||
copy(buf, secret)
|
||||
return &SessionStore{secret: buf}, nil
|
||||
}
|
||||
|
||||
// SessionState holds per-request session data after loading.
|
||||
type SessionState struct {
|
||||
Authenticated bool
|
||||
Email string
|
||||
CSRFToken string
|
||||
}
|
||||
|
||||
// SessionManager is a placeholder for future session persistence.
|
||||
type SessionManager struct {
|
||||
mu sync.RWMutex
|
||||
authenticated bool
|
||||
currentAccount string
|
||||
}
|
||||
|
||||
// NewSessionManager constructs an empty session manager.
|
||||
func NewSessionManager() *SessionManager {
|
||||
return &SessionManager{}
|
||||
}
|
||||
|
||||
// SetAuthenticated marks the provided account as the active authenticated user.
|
||||
func (m *SessionManager) SetAuthenticated(email string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.authenticated = true
|
||||
m.currentAccount = email
|
||||
}
|
||||
|
||||
// Clear removes any active authentication data.
|
||||
func (m *SessionManager) Clear() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.authenticated = false
|
||||
m.currentAccount = ""
|
||||
}
|
||||
|
||||
// Snapshot captures the current session state for contextual use.
|
||||
func (m *SessionManager) Snapshot() SessionState {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return SessionState{
|
||||
Authenticated: m.authenticated,
|
||||
Email: m.currentAccount,
|
||||
// Load extracts session data from the request cookies.
|
||||
func (s *SessionStore) Load(r *http.Request) SessionState {
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
return SessionState{}
|
||||
}
|
||||
|
||||
payload, err := decodeSession(c.Value, s.secret)
|
||||
if err != nil {
|
||||
return SessionState{}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// Save persists the session state onto the response cookies.
|
||||
func (s *SessionStore) Save(w http.ResponseWriter, state SessionState) error {
|
||||
serialized, err := encodeSession(state, s.secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: serialized,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: false, // TODO: in production, set to true
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Now().Add(sessionLifetime),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes the session cookie from the client.
|
||||
func (s *SessionStore) Clear(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
// ensureCSRFToken returns a session state with a CSRF token present.
|
||||
func ensureCSRFToken(state SessionState) (SessionState, error) {
|
||||
if state.CSRFToken != "" {
|
||||
return state, nil
|
||||
}
|
||||
token := make([]byte, csrfTokenByteLength)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
return state, err
|
||||
}
|
||||
state.CSRFToken = base64.RawURLEncoding.EncodeToString(token)
|
||||
return state, nil
|
||||
}
|
||||
|
|
|
|||
53
internal/server/session_encoding.go
Normal file
53
internal/server/session_encoding.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
func encodeSession(state SessionState, secret []byte) (string, error) {
|
||||
payload, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, secret)
|
||||
mac.Write(payload)
|
||||
sig := mac.Sum(nil)
|
||||
|
||||
combined := append(payload, sig...)
|
||||
return base64.RawURLEncoding.EncodeToString(combined), nil
|
||||
}
|
||||
|
||||
func decodeSession(raw string, secret []byte) (SessionState, error) {
|
||||
var state SessionState
|
||||
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return state, err
|
||||
}
|
||||
|
||||
if len(decoded) <= sha256.Size {
|
||||
return state, errors.New("session payload too small")
|
||||
}
|
||||
|
||||
payload := decoded[:len(decoded)-sha256.Size]
|
||||
providedSig := decoded[len(decoded)-sha256.Size:]
|
||||
|
||||
mac := hmac.New(sha256.New, secret)
|
||||
mac.Write(payload)
|
||||
expectedSig := mac.Sum(nil)
|
||||
|
||||
if !hmac.Equal(providedSig, expectedSig) {
|
||||
return state, errors.New("session signature mismatch")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(payload, &state); err != nil {
|
||||
return state, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
|
@ -4,12 +4,13 @@ package server
|
|||
type PageData struct {
|
||||
Email string
|
||||
Error string
|
||||
CSRFToken string
|
||||
}
|
||||
|
||||
func newIndexData(email, errMsg string) PageData {
|
||||
return PageData{Email: email, Error: errMsg}
|
||||
func newIndexData(email, errMsg, token string) PageData {
|
||||
return PageData{Email: email, Error: errMsg, CSRFToken: token}
|
||||
}
|
||||
|
||||
func newUnauthorizedData(errMsg string) PageData {
|
||||
return PageData{Error: errMsg}
|
||||
func newUnauthorizedData(errMsg, token string) PageData {
|
||||
return PageData{Error: errMsg, CSRFToken: token}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
<p>You are signed in as <strong>{{.Email}}</strong>.</p>
|
||||
<p>This placeholder dashboard will evolve as we flesh out the auth flow.</p>
|
||||
<form method="post" action="/logout">
|
||||
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
|
||||
<button type="submit" class="secondary">Sign out</button>
|
||||
</form>
|
||||
</main>
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
</article>
|
||||
{{end}}
|
||||
<form method="post" action="/login">
|
||||
<input type="hidden" name="_csrf" value="{{.CSRFToken}}" />
|
||||
<label for="email">Email</label>
|
||||
<input
|
||||
type="email"
|
||||
|
|
|
|||
Loading…
Reference in a new issue