diff --git a/internal/web/handlers.go b/internal/web/handlers.go index 04d2507..fc7d1a1 100644 --- a/internal/web/handlers.go +++ b/internal/web/handlers.go @@ -9,38 +9,40 @@ import ( "github.com/rjNemo/payit/internal/payments" ) -func (h *Handler) createCheckoutSession(w http.ResponseWriter, r *http.Request) { - var req payments.CheckoutSessionRequest +func (h *Handler) createCheckoutSession() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req payments.CheckoutSessionRequest - if r.Body != nil { - defer func(body io.ReadCloser) { - _ = body.Close() - }(r.Body) - dec := json.NewDecoder(r.Body) - dec.DisallowUnknownFields() + if r.Body != nil { + defer func(body io.ReadCloser) { + _ = body.Close() + }(r.Body) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() - if err := dec.Decode(&req); err != nil { - if errors.Is(err, io.EOF) { - // Empty body is acceptable; default quantity applies. - } else { - http.Error(w, "invalid request payload", http.StatusBadRequest) + if err := dec.Decode(&req); err != nil { + if errors.Is(err, io.EOF) { + // Empty body is acceptable; default quantity applies. + } else { + http.Error(w, "invalid request payload", http.StatusBadRequest) + return + } + } else if dec.More() { + http.Error(w, "unexpected data in request body", http.StatusBadRequest) return } - } else if dec.More() { - http.Error(w, "unexpected data in request body", http.StatusBadRequest) + } + + session, err := h.checkout.CreateSession(r.Context(), req) + if err != nil { + http.Error(w, "checkout session failed", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(session); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) return } } - - session, err := h.checkout.CreateSession(r.Context(), req) - if err != nil { - http.Error(w, "checkout session failed", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(session); err != nil { - http.Error(w, "failed to encode response", http.StatusInternalServerError) - return - } } diff --git a/internal/web/handlers_test.go b/internal/web/handlers_test.go index 37df129..46153d2 100644 --- a/internal/web/handlers_test.go +++ b/internal/web/handlers_test.go @@ -37,7 +37,7 @@ func TestCreateCheckoutSessionSuccess(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/checkout", bytes.NewReader(body)) rec := httptest.NewRecorder() - handler.createCheckoutSession(rec, req) + handler.createCheckoutSession()(rec, req) if rec.Code != http.StatusOK { t.Fatalf("expected status 200, got %d", rec.Code) @@ -70,7 +70,7 @@ func TestCreateCheckoutSessionDefaultsQuantity(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/checkout", http.NoBody) rec := httptest.NewRecorder() - handler.createCheckoutSession(rec, req) + handler.createCheckoutSession()(rec, req) if rec.Code != http.StatusOK { t.Fatalf("expected status 200, got %d", rec.Code) @@ -86,7 +86,7 @@ func TestCreateCheckoutSessionRejectsInvalidJSON(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/checkout", bytes.NewBufferString("{")) rec := httptest.NewRecorder() - handler.createCheckoutSession(rec, req) + handler.createCheckoutSession()(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("expected status 400, got %d", rec.Code) @@ -101,7 +101,7 @@ func TestCreateCheckoutSessionStripeFailure(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/checkout", http.NoBody) rec := httptest.NewRecorder() - handler.createCheckoutSession(rec, req) + handler.createCheckoutSession()(rec, req) if rec.Code != http.StatusInternalServerError { t.Fatalf("expected status 500, got %d", rec.Code) @@ -111,7 +111,7 @@ func TestCreateCheckoutSessionStripeFailure(t *testing.T) { func TestCreateCheckoutSessionMethodNotAllowed(t *testing.T) { handler := &Handler{checkout: &fakeCheckoutService{}} mux := http.NewServeMux() - mux.HandleFunc("POST /api/checkout", handler.createCheckoutSession) + mux.HandleFunc("POST /api/checkout", handler.createCheckoutSession()) req := httptest.NewRequest(http.MethodGet, "/api/checkout", http.NoBody) rec := httptest.NewRecorder() diff --git a/internal/web/middleware.go b/internal/web/middleware.go new file mode 100644 index 0000000..c9039be --- /dev/null +++ b/internal/web/middleware.go @@ -0,0 +1,26 @@ +package web + +import ( + "log" + "net/http" + "time" +) + +type WrappedWriter struct { + http.ResponseWriter + StatusCode int +} + +func (w *WrappedWriter) WriteHeader(statusCode int) { + w.StatusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func LoggerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + wrapped := &WrappedWriter{ResponseWriter: w, StatusCode: http.StatusOK} + next.ServeHTTP(wrapped, r) + log.Printf("%s %s %d %v", r.Method, r.URL.Path, wrapped.StatusCode, time.Since(start)) + }) +} diff --git a/internal/web/page.go b/internal/web/page.go index 9438b95..8a7b790 100644 --- a/internal/web/page.go +++ b/internal/web/page.go @@ -13,18 +13,20 @@ type checkoutPageData struct { Currency string } -func (h *Handler) renderCheckoutPage(w http.ResponseWriter, r *http.Request) { - price := float64(h.cfg.Product.PriceCents) / 100 - data := checkoutPageData{ - ProductName: h.cfg.Product.Name, - ProductDescription: h.cfg.Product.Description, - PriceDisplay: fmt.Sprintf("$%.2f", price), - Currency: strings.ToUpper(h.cfg.Product.Currency), - } +func (h *Handler) renderCheckoutPage() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + price := float64(h.cfg.Product.PriceCents) / 100 + data := checkoutPageData{ + ProductName: h.cfg.Product.Name, + ProductDescription: h.cfg.Product.Description, + PriceDisplay: fmt.Sprintf("$%.2f", price), + Currency: strings.ToUpper(h.cfg.Product.Currency), + } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - if err := h.page.ExecuteTemplate(w, "index.html", data); err != nil { - http.Error(w, "failed to render page", http.StatusInternalServerError) - return + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := h.page.ExecuteTemplate(w, "index.html", data); err != nil { + http.Error(w, "failed to render page", http.StatusInternalServerError) + return + } } } diff --git a/internal/web/routes.go b/internal/web/routes.go index 2ea2ee6..751f2fc 100644 --- a/internal/web/routes.go +++ b/internal/web/routes.go @@ -5,7 +5,7 @@ import ( ) func (h *Handler) registerRoutes(mux *http.ServeMux) { - mux.HandleFunc("POST /api/checkout", h.createCheckoutSession) - mux.Handle("GET /", http.HandlerFunc(h.renderCheckoutPage)) + mux.Handle("POST /api/checkout", h.createCheckoutSession()) + mux.Handle("GET /", h.renderCheckoutPage()) mux.Handle("GET /static/", http.StripPrefix("/static/", http.FileServer(http.FS(h.fs)))) } diff --git a/internal/web/server.go b/internal/web/server.go index f122f98..957eb0c 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -41,5 +41,5 @@ func NewServer(cfg config.Config) http.Handler { mux := http.NewServeMux() h.registerRoutes(mux) - return mux + return LoggerMiddleware(mux) }