diff --git a/README.md b/README.md index 1bf54d8..bff99ac 100644 --- a/README.md +++ b/README.md @@ -71,8 +71,14 @@ use a cloud alternative such as Railway, fly.io, _etc._ DB_USER=your_db_username DB_PASSWORD=your_db_password DB_NAME=rentease + # Stripe configuration (optional until you enable automatic sync) + APP_STRIPE_SECRET_KEY=sk_test_your_key + APP_STRIPE_WEBHOOK_SECRET=whsec_your_webhook_secret + APP_STRIPE_CONNECT_ACCOUNT=acct_your_connect_account # optional ``` + Leave the Stripe variables blank to continue using manual cash entry only. When set, Rentease will pull payments from Stripe, process webhooks sent to `/webhooks/stripe`, and expose a manual sync endpoint at `POST /api/stripe/sync` (protected by the existing API key middleware). + 5. Start the application ```sh diff --git a/cmd/cron/main.go b/cmd/cron/main.go index aae9908..17d0d75 100644 --- a/cmd/cron/main.go +++ b/cmd/cron/main.go @@ -7,6 +7,8 @@ import ( "syscall" "github.com/rjNemo/rentease/pkg/cron" + + internalcron "github.com/rjNemo/rentease/internal/cron" ) func main() { @@ -15,8 +17,13 @@ func main() { scheduler.AddJob(cron.Job{ Name: "Monthly Booking Report", Schedule: "minute", - // Schedule: "monthly", - // Action: cron.JobMonthlyBookingReport, + Action: internalcron.JobMonthlyBookingReport, + }) + + scheduler.AddJob(cron.Job{ + Name: "Stripe Payment Sync", + Schedule: "daily", + Action: internalcron.JobStripePaymentSync, }) go scheduler.Start() diff --git a/go.mod b/go.mod index 9903582..663e5c1 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( require ( github.com/google/go-cmp v0.7.0 // indirect + github.com/stripe/stripe-go/v79 v79.12.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index 1326c7f..e008838 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stripe/stripe-go/v79 v79.12.0 h1:HQs/kxNEB3gYA7FnkSFkp0kSOeez0fsmCWev6SxftYs= +github.com/stripe/stripe-go/v79 v79.12.0/go.mod h1:cuH6X0zC8peY6f1AubHwgJ/fJSn2dh5pfiCr6CjyKVU= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -79,17 +81,23 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= +golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/config/config.go b/internal/config/config.go index 5dde253..1c6edae 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,6 +36,12 @@ type Config struct { SecretKey string `env:"SECRET_KEY, required"` // SessionSecret is the secret key used for session signing SessionSecret string `env:"SESSION_SECRET, required"` + // StripeSecretKey is the API key used to authenticate with Stripe + StripeSecretKey string `env:"STRIPE_SECRET_KEY"` + // StripeWebhookSecret is the signing secret for validating Stripe webhooks + StripeWebhookSecret string `env:"STRIPE_WEBHOOK_SECRET"` + // StripeConnectAccount is the connected account ID when using Stripe Connect (optional) + StripeConnectAccount string `env:"STRIPE_CONNECT_ACCOUNT"` } // New creates a [Config] struct. It first parses the environment variables. You can use a .env file. diff --git a/internal/cron/job_report.go b/internal/cron/job_report.go index 96aa146..71ac2c7 100644 --- a/internal/cron/job_report.go +++ b/internal/cron/job_report.go @@ -29,7 +29,7 @@ func JobMonthlyBookingReport() error { } store := booking.NewPgStore(db) - service, err := bookingService.NewService(nil, store, nil, ps) + service, err := bookingService.NewService(nil, store, nil, ps, nil) if err != nil { return fmt.Errorf("error creating booking service: %w", err) } diff --git a/internal/cron/job_stripe_sync.go b/internal/cron/job_stripe_sync.go new file mode 100644 index 0000000..0c75ddf --- /dev/null +++ b/internal/cron/job_stripe_sync.go @@ -0,0 +1,68 @@ +package cron + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/rjNemo/rentease/internal/config" + "github.com/rjNemo/rentease/internal/driver/database" + stripeclient "github.com/rjNemo/rentease/internal/driver/stripe" + "github.com/rjNemo/rentease/internal/repository/booking" + bookingservice "github.com/rjNemo/rentease/internal/service/booking" +) + +// JobStripePaymentSync synchronises Stripe payments for the last 24 hours. It is +// safe to run multiple times thanks to the repository upsert semantics. +func JobStripePaymentSync() error { + ctx := context.Background() + + cfg, err := config.New(ctx) + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + if cfg.StripeSecretKey == "" { + slog.Default().Warn("stripe secret key missing; skipping stripe sync job") + return nil + } + + db, err := database.New(cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("error connecting to database: %w", err) + } + + // Auto-migrate payment schema if necessary. + if err := database.Migrate(db, &bookingservice.Booking{}, &bookingservice.Item{}, &bookingservice.Payment{}); err != nil { + return fmt.Errorf("error migrating database: %w", err) + } + + store := booking.NewPgStore(db) + + opts := []stripeclient.Option{} + if cfg.StripeConnectAccount != "" { + opts = append(opts, stripeclient.WithAccount(cfg.StripeConnectAccount)) + } + + client, err := stripeclient.New(cfg.StripeSecretKey, opts...) + if err != nil { + return fmt.Errorf("error creating stripe client: %w", err) + } + + logger := slog.Default() + service, err := bookingservice.NewService(logger, store, nil, nil, client) + if err != nil { + return fmt.Errorf("error creating booking service: %w", err) + } + + to := time.Now().UTC() + from := to.Add(-24 * time.Hour) + + if err := service.SyncStripePayments(ctx, from, to); err != nil { + return fmt.Errorf("error syncing stripe payments: %w", err) + } + + slog.Default().Info("stripe payment sync job completed", slog.Time("from", from), slog.Time("to", to)) + return nil +} diff --git a/internal/driver/stripe/client.go b/internal/driver/stripe/client.go new file mode 100644 index 0000000..6b768ed --- /dev/null +++ b/internal/driver/stripe/client.go @@ -0,0 +1,166 @@ +package stripe + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + stripe "github.com/stripe/stripe-go/v79" + stripeclient "github.com/stripe/stripe-go/v79/client" +) + +// Option configures a Client instance. +type Option func(*Client) + +// WithAccount sets the Stripe connected account identifier used for requests. +func WithAccount(account string) Option { + return func(c *Client) { + c.account = strings.TrimSpace(account) + } +} + +// Client wraps Stripe's SDK to expose the subset of functionality needed by the +// application while keeping the rest of the codebase decoupled from the SDK. +type Client struct { + api *stripeclient.API + account string +} + +// New constructs a Client using the provided secret key. The key must not be empty. +func New(secretKey string, opts ...Option) (*Client, error) { + trimmed := strings.TrimSpace(secretKey) + if trimmed == "" { + return nil, errors.New("stripe secret key is required") + } + + api := stripeclient.New(trimmed, nil) + client := &Client{api: api} + for _, opt := range opts { + opt(client) + } + return client, nil +} + +// Payment represents the subset of payment intent data consumed by the booking service. +type Payment struct { + ID string + Amount float64 + Currency string + Status string + PaymentMethod string + BookingID *uint + Created time.Time +} + +// NormalizePaymentIntent converts a Stripe payment intent into the simplified Payment structure used +// by the application. Fields that are absent default to their zero values. +func NormalizePaymentIntent(pi *stripe.PaymentIntent) Payment { + if pi == nil { + return Payment{} + } + + amount := float64(pi.AmountReceived) / 100.0 + if amount == 0 { + amount = float64(pi.Amount) / 100.0 + } + + return Payment{ + ID: pi.ID, + Amount: amount, + Currency: strings.ToUpper(string(pi.Currency)), + Status: string(pi.Status), + PaymentMethod: deriveMethod(pi), + BookingID: extractBookingID(pi.Metadata), + Created: time.Unix(int64(pi.Created), 0), + } +} + +// ListPaymentsParams defines the time boundaries used when fetching Stripe payments. +type ListPaymentsParams struct { + From time.Time + To time.Time +} + +// ListPayments fetches payment intents created within the provided time range. The +// results are normalised into Payment structs suitable for downstream processing. +func (c *Client) ListPayments(ctx context.Context, params ListPaymentsParams) ([]Payment, error) { + listParams := &stripe.PaymentIntentListParams{} + listParams.Context = ctx + listParams.AddExpand("data.latest_charge") + listParams.AddExpand("data.payment_method") + + if !params.From.IsZero() { + listParams.Filters.AddFilter("created", "gte", strconv.FormatInt(params.From.Unix(), 10)) + } + if !params.To.IsZero() { + listParams.Filters.AddFilter("created", "lte", strconv.FormatInt(params.To.Unix(), 10)) + } + + if c.account != "" { + listParams.SetStripeAccount(c.account) + } + + iter := c.api.PaymentIntents.List(listParams) + payments := make([]Payment, 0) + + for iter.Next() { + pi := iter.PaymentIntent() + if pi == nil { + continue + } + + payments = append(payments, NormalizePaymentIntent(pi)) + } + + if err := iter.Err(); err != nil { + return nil, fmt.Errorf("stripe payment intents iteration failed: %w", err) + } + + return payments, nil +} + +func deriveMethod(pi *stripe.PaymentIntent) string { + if pi == nil { + return "" + } + + if pi.LatestCharge != nil && pi.LatestCharge.PaymentMethodDetails != nil { + typ := pi.LatestCharge.PaymentMethodDetails.Type + if typ != "" { + return string(typ) + } + } + + if pi.PaymentMethod != nil && pi.PaymentMethod.Type != "" { + return string(pi.PaymentMethod.Type) + } + + if len(pi.PaymentMethodTypes) > 0 { + return pi.PaymentMethodTypes[0] + } + + return "" +} + +func extractBookingID(metadata map[string]string) *uint { + if len(metadata) == 0 { + return nil + } + + keys := []string{"booking_id", "bookingId", "bookingID"} + for _, key := range keys { + if raw, ok := metadata[key]; ok { + if raw == "" { + continue + } + if id, err := strconv.ParseUint(raw, 10, 32); err == nil { + value := uint(id) + return &value + } + } + } + return nil +} diff --git a/internal/repository/booking/pg_store.go b/internal/repository/booking/pg_store.go index 22818d8..c40f10b 100644 --- a/internal/repository/booking/pg_store.go +++ b/internal/repository/booking/pg_store.go @@ -1,6 +1,7 @@ package booking import ( + "errors" "fmt" "time" @@ -172,3 +173,49 @@ func (ps *PgStore) UpdatePayment(id int, amount float64, paymentMethod string) ( Error return p, err } + +func (ps *PgStore) UpsertStripePayment(p *booking.Payment) (*booking.Payment, error) { + if p.StripePaymentID == nil || *p.StripePaymentID == "" { + return nil, fmt.Errorf("stripe payment id is required") + } + + existing := new(booking.Payment) + stripeID := *p.StripePaymentID + if err := ps.db.Where("stripe_payment_id = ?", stripeID).First(existing).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + if err := ps.db.Create(p).Error; err != nil { + return nil, fmt.Errorf("failed to create stripe payment: %w", err) + } + return p, nil + } + return nil, fmt.Errorf("failed to lookup stripe payment: %w", err) + } + + updates := map[string]any{ + "amount": p.Amount, + "payment_method": p.PaymentMethod, + "stripe_status": p.StripeStatus, + "booking_id": p.BookingID, + } + + if err := ps.db.Model(existing). + Clauses(clause.Returning{}). + Updates(updates). + Error; err != nil { + return nil, fmt.Errorf("failed to update stripe payment: %w", err) + } + + return existing, nil +} + +func (ps *PgStore) FindStripePayment(stripePaymentID string) (*booking.Payment, error) { + if stripePaymentID == "" { + return nil, fmt.Errorf("stripe payment id is required") + } + + p := new(booking.Payment) + if err := ps.db.Where("stripe_payment_id = ?", stripePaymentID).First(p).Error; err != nil { + return nil, err + } + return p, nil +} diff --git a/internal/server/handle_bookings.go b/internal/server/handle_bookings.go index a9fa5d5..0594d98 100644 --- a/internal/server/handle_bookings.go +++ b/internal/server/handle_bookings.go @@ -55,6 +55,20 @@ func handleBookingListPage(bs *booking.Service, hc *config.Host) echo.HandlerFun } } +func paymentViewModelFromBookingPayment(p booking.Payment) *view.PaymentViewModel { + stripeStatus := "" + if p.StripeStatus != nil { + stripeStatus = *p.StripeStatus + } + + return &view.PaymentViewModel{ + Amount: strconv.FormatFloat(p.Amount, 'f', 2, 64), + PaymentMethod: string(p.PaymentMethod), + PaymentUrl: fmt.Sprintf("%s/%d", constant.RoutePayment, p.ID), + StripeStatus: stripeStatus, + } +} + func handleBookingList(bs *booking.Service) echo.HandlerFunc { return func(c echo.Context) error { search := c.FormValue("search") @@ -156,11 +170,7 @@ func handleBookingPage(bs *booking.Service, hc *config.Host) echo.HandlerFunc { } }), Payments: u.Map(b.Payments, func(p booking.Payment) view.PaymentViewModel { - return view.PaymentViewModel{ - Amount: strconv.FormatFloat(p.Amount, 'f', 2, 64), - PaymentMethod: string(p.PaymentMethod), - PaymentUrl: fmt.Sprintf("%s/%d", constant.RoutePayment, p.ID), - } + return *paymentViewModelFromBookingPayment(p) }), }, Total: strconv.FormatFloat(u.Reduce(b.Items, func(i booking.Item, sum float64) float64 { @@ -380,11 +390,7 @@ func handlePaymentUpdate(bs *booking.Service) echo.HandlerFunc { p := bs.UpdatePayment(up.Id, up.Amount, up.PaymentMethod) - return renderTempl(c, http.StatusOK, view.PaymentLine(&view.PaymentViewModel{ - Amount: strconv.FormatFloat(p.Amount, 'f', 2, 64), - PaymentMethod: string(p.PaymentMethod), - PaymentUrl: fmt.Sprintf("%s/%d", constant.RoutePayment, p.ID), - })) + return renderTempl(c, http.StatusOK, view.PaymentLine(paymentViewModelFromBookingPayment(*p))) } } diff --git a/internal/server/handle_payments.go b/internal/server/handle_payments.go index cfaad51..9fb19c4 100644 --- a/internal/server/handle_payments.go +++ b/internal/server/handle_payments.go @@ -1,14 +1,12 @@ package server import ( - "fmt" "net/http" "strconv" "github.com/labstack/echo/v4" u "github.com/rjNemo/underscore" - "github.com/rjNemo/rentease/internal/constant" "github.com/rjNemo/rentease/internal/service/booking" "github.com/rjNemo/rentease/internal/view" ) @@ -42,11 +40,7 @@ func handleCreatePayment(bs *booking.Service) echo.HandlerFunc { return renderTempl(c, http.StatusOK, view.PaymentList( u.Map(nb.Payments, func(p booking.Payment) *view.PaymentViewModel { - return &view.PaymentViewModel{ - Amount: strconv.FormatFloat(p.Amount, 'f', 2, 64), - PaymentMethod: string(p.PaymentMethod), - PaymentUrl: fmt.Sprintf("%s/%d", constant.RoutePayment, p.ID), - } + return paymentViewModelFromBookingPayment(p) }), )) } @@ -61,11 +55,7 @@ func handlePaymentForm(bs *booking.Service) echo.HandlerFunc { } p := bs.OnePayment(id) - form := view.PaymentForm(&view.PaymentViewModel{ - Amount: strconv.FormatFloat(p.Amount, 'f', 2, 64), - PaymentMethod: string(p.PaymentMethod), - PaymentUrl: fmt.Sprintf("%s/%d", constant.RoutePayment, p.ID), - }) + form := view.PaymentForm(paymentViewModelFromBookingPayment(*p)) return renderTempl(c, http.StatusOK, form) } } diff --git a/internal/server/handle_stripe_sync.go b/internal/server/handle_stripe_sync.go new file mode 100644 index 0000000..833dcda --- /dev/null +++ b/internal/server/handle_stripe_sync.go @@ -0,0 +1,63 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "time" + + "github.com/labstack/echo/v4" + + "github.com/rjNemo/rentease/internal/service/booking" +) + +type stripeSyncRequest struct { + From string `json:"from"` + To string `json:"to"` +} + +type stripeSyncer interface { + SyncStripePayments(ctx context.Context, from, to time.Time) error +} + +func handleStripeSync(bs stripeSyncer) echo.HandlerFunc { + return func(c echo.Context) error { + req := new(stripeSyncRequest) + if err := json.NewDecoder(c.Request().Body).Decode(req); err != nil { + if !errors.Is(err, io.EOF) { + return echo.NewHTTPError(http.StatusBadRequest, "invalid request payload") + } + } + + now := time.Now().UTC() + from := now.Add(-24 * time.Hour) + to := now + + if req.From != "" { + parsed, err := time.Parse(time.RFC3339, req.From) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid 'from' timestamp, expected RFC3339") + } + from = parsed + } + + if req.To != "" { + parsed, err := time.Parse(time.RFC3339, req.To) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid 'to' timestamp, expected RFC3339") + } + to = parsed + } + + if err := bs.SyncStripePayments(c.Request().Context(), from, to); err != nil { + if errors.Is(err, booking.ErrStripeClientNotConfigured) { + return echo.NewHTTPError(http.StatusServiceUnavailable, "stripe client not configured") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) + } +} diff --git a/internal/server/handle_stripe_sync_test.go b/internal/server/handle_stripe_sync_test.go new file mode 100644 index 0000000..c93f116 --- /dev/null +++ b/internal/server/handle_stripe_sync_test.go @@ -0,0 +1,75 @@ +package server + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v4" +) + +type stubStripeSyncer struct { + from time.Time + to time.Time + err error +} + +func (s *stubStripeSyncer) SyncStripePayments(ctx context.Context, from, to time.Time) error { + s.from = from + s.to = to + return s.err +} + +func TestHandleStripeSyncSuccess(t *testing.T) { + syncer := &stubStripeSyncer{} + handler := handleStripeSync(syncer) + + now := time.Now().UTC().Truncate(time.Second) + from := now.Add(-2 * time.Hour).Format(time.RFC3339) + to := now.Format(time.RFC3339) + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/stripe/sync", strings.NewReader(`{"from":"`+from+`","to":"`+to+`"}`)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if err := handler(c); err != nil { + t.Fatalf("handler returned error: %v", err) + } + + if rec.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", rec.Code) + } + + if syncer.from.IsZero() || syncer.to.IsZero() { + t.Fatal("expected syncer to receive time bounds") + } +} + +func TestHandleStripeSyncInvalidTimestamp(t *testing.T) { + syncer := &stubStripeSyncer{} + handler := handleStripeSync(syncer) + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/stripe/sync", strings.NewReader(`{"from":"not-a-date"}`)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + if err == nil { + t.Fatal("expected error for invalid timestamp") + } + + httpErr, ok := err.(*echo.HTTPError) + if !ok { + t.Fatalf("expected HTTPError, got %T", err) + } + if httpErr.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", httpErr.Code) + } +} diff --git a/internal/server/handle_stripe_webhook.go b/internal/server/handle_stripe_webhook.go new file mode 100644 index 0000000..6f6d160 --- /dev/null +++ b/internal/server/handle_stripe_webhook.go @@ -0,0 +1,65 @@ +package server + +import ( + "context" + "encoding/json" + "io" + "net/http" + + "github.com/labstack/echo/v4" + stripe "github.com/stripe/stripe-go/v79" + "github.com/stripe/stripe-go/v79/webhook" +) + +type stripeEventService interface { + HandlePaymentIntentSucceeded(ctx context.Context, pi *stripe.PaymentIntent) error + HandleChargeRefunded(ctx context.Context, ch *stripe.Charge) error +} + +func handleStripeWebhook(bs stripeEventService, secret string) echo.HandlerFunc { + return func(c echo.Context) error { + if secret == "" { + return echo.NewHTTPError(http.StatusServiceUnavailable, "stripe webhook secret not configured") + } + + payload, err := io.ReadAll(c.Request().Body) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "unable to read request body") + } + + sig := c.Request().Header.Get("Stripe-Signature") + if sig == "" { + return echo.NewHTTPError(http.StatusBadRequest, "missing Stripe-Signature header") + } + + event, err := webhook.ConstructEvent(payload, sig, secret) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid webhook signature") + } + + switch event.Type { + case stripe.EventTypePaymentIntentSucceeded: + var pi stripe.PaymentIntent + if err := json.Unmarshal(event.Data.Raw, &pi); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid payment intent payload") + } + + if err := bs.HandlePaymentIntentSucceeded(c.Request().Context(), &pi); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + case stripe.EventTypeChargeRefunded: + var ch stripe.Charge + if err := json.Unmarshal(event.Data.Raw, &ch); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid charge payload") + } + + if err := bs.HandleChargeRefunded(c.Request().Context(), &ch); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + default: + // Acknowledge events we don't actively process. + } + + return c.NoContent(http.StatusOK) + } +} diff --git a/internal/server/handle_stripe_webhook_test.go b/internal/server/handle_stripe_webhook_test.go new file mode 100644 index 0000000..28153aa --- /dev/null +++ b/internal/server/handle_stripe_webhook_test.go @@ -0,0 +1,150 @@ +package server + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v4" + stripe "github.com/stripe/stripe-go/v79" + "github.com/stripe/stripe-go/v79/webhook" +) + +type stubStripeEventService struct { + intentCalled bool + chargeCalled bool + err error +} + +func (s *stubStripeEventService) HandlePaymentIntentSucceeded(ctx context.Context, pi *stripe.PaymentIntent) error { + s.intentCalled = true + return s.err +} + +func (s *stubStripeEventService) HandleChargeRefunded(ctx context.Context, ch *stripe.Charge) error { + s.chargeCalled = true + return s.err +} + +func TestHandleStripeWebhookPaymentIntent(t *testing.T) { + secret := "whsec_test" + payload := map[string]any{ + "id": "evt_test", + "type": "payment_intent.succeeded", + "api_version": stripe.APIVersion, + "data": map[string]any{ + "object": map[string]any{ + "id": "pi_123", + "amount": 10000, + "amount_received": 10000, + "currency": "eur", + "status": "succeeded", + "metadata": map[string]string{"booking_id": "42"}, + "payment_method_types": []string{"card"}, + }, + }, + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("failed to marshal payload: %v", err) + } + + ts := time.Now() + sig := webhook.ComputeSignature(ts, payloadBytes, secret) + sigHeader := fmt.Sprintf("t=%d,v1=%s", ts.Unix(), hex.EncodeToString(sig)) + if _, err := webhook.ConstructEvent(payloadBytes, sigHeader, secret); err != nil { + t.Fatalf("signature validation failed in setup: %v", err) + } + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/webhooks/stripe", strings.NewReader(string(payloadBytes))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set("Stripe-Signature", sigHeader) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + service := &stubStripeEventService{} + handler := handleStripeWebhook(service, secret) + + if err := handler(c); err != nil { + t.Fatalf("handler returned error: %v", err) + } + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + + if !service.intentCalled { + t.Fatalf("expected payment intent handler to be called") + } +} + +func TestHandleStripeWebhookChargeRefunded(t *testing.T) { + secret := "whsec_test" + payload := map[string]any{ + "id": "evt_charge", + "type": "charge.refunded", + "api_version": stripe.APIVersion, + "data": map[string]any{ + "object": map[string]any{ + "id": "ch_123", + "amount": 5000, + "amount_refunded": 5000, + "payment_intent": map[string]any{ + "id": "pi_123", + }, + }, + }, + } + + payloadBytes, _ := json.Marshal(payload) + ts := time.Now() + sig := webhook.ComputeSignature(ts, payloadBytes, secret) + sigHeader := fmt.Sprintf("t=%d,v1=%s", ts.Unix(), hex.EncodeToString(sig)) + if _, err := webhook.ConstructEvent(payloadBytes, sigHeader, secret); err != nil { + t.Fatalf("signature validation failed in setup: %v", err) + } + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/webhooks/stripe", strings.NewReader(string(payloadBytes))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set("Stripe-Signature", sigHeader) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + service := &stubStripeEventService{} + handler := handleStripeWebhook(service, secret) + + if err := handler(c); err != nil { + t.Fatalf("handler returned error: %v", err) + } + + if !service.chargeCalled { + t.Fatalf("expected charge handler to be called") + } +} + +func TestHandleStripeWebhookInvalidSignature(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/webhooks/stripe", strings.NewReader("{}")) + req.Header.Set("Stripe-Signature", "invalid") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := handleStripeWebhook(&stubStripeEventService{}, "secret") + err := handler(c) + if err == nil { + t.Fatal("expected error for invalid signature") + } + + if httpErr, ok := err.(*echo.HTTPError); !ok || httpErr.Code != http.StatusBadRequest { + t.Fatalf("expected 400 HTTP error, got %v", err) + } +} diff --git a/internal/server/option.go b/internal/server/option.go index a983bbf..050c34f 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -6,11 +6,12 @@ import ( ) type options struct { - port *int - fs *embed.FS - debug *bool - secretKey *string - origins []string + port *int + fs *embed.FS + debug *bool + secretKey *string + origins []string + stripeWebhookSecret *string } type Option func(*options) error @@ -52,3 +53,13 @@ func WithOrigins(origins []string) Option { return nil } } + +func WithStripeWebhookSecret(secret string) Option { + return func(o *options) error { + if secret == "" { + return nil + } + o.stripeWebhookSecret = &secret + return nil + } +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 8bb055d..b107b4e 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -9,6 +9,7 @@ func (s Server) MountHandlers() { s.Router.GET("/healthz", handleHealthCheck()) s.Router.GET("/", handleLoginPage()) s.Router.POST("/", handleLogin(s.as)) + s.Router.POST("/webhooks/stripe", handleStripeWebhook(s.bs, s.stripeWebhookSecret)) api := s.Router.Group("/api") api.Use(middleware.KeyAuthWithConfig(middleware.KeyAuthConfig{ @@ -20,6 +21,7 @@ func (s Server) MountHandlers() { api.POST("/sync", handleSync(s.bs)) api.GET("/bookings", handleBookingList(s.bs)) api.POST("/bookings", handleCreateBooking(s.bs)) + api.POST("/stripe/sync", handleStripeSync(s.bs)) private := s.Router.Group("") private.Use(MakeAuthMiddleware(s.as)) diff --git a/internal/server/server.go b/internal/server/server.go index 2369dc6..7683592 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -23,11 +23,12 @@ import ( ) type Server struct { - Router *echo.Echo - bs *booking.Service - as *auth.Service - hc *config.Host - addr string + Router *echo.Echo + bs *booking.Service + as *auth.Service + hc *config.Host + addr string + stripeWebhookSecret string } func New(bs *booking.Service, as *auth.Service, hc *config.Host, opts ...Option) (*Server, error) { @@ -40,11 +41,16 @@ func New(bs *booking.Service, as *auth.Service, hc *config.Host, opts ...Option) } s := &Server{ - Router: NewRouter(*option.fs, *option.debug, *option.secretKey, option.origins), - bs: bs, - as: as, - hc: hc, - addr: fmt.Sprintf("0.0.0.0:%d", *option.port), + Router: NewRouter(*option.fs, *option.debug, *option.secretKey, option.origins), + bs: bs, + as: as, + hc: hc, + addr: fmt.Sprintf("0.0.0.0:%d", *option.port), + stripeWebhookSecret: "", + } + + if option.stripeWebhookSecret != nil { + s.stripeWebhookSecret = *option.stripeWebhookSecret } s.MountHandlers() diff --git a/internal/service/booking/models.go b/internal/service/booking/models.go index bf9b87e..b212448 100644 --- a/internal/service/booking/models.go +++ b/internal/service/booking/models.go @@ -126,8 +126,10 @@ func (i Item) ToFrench() string { type Payment struct { gorm.Model - BookingID uint `gorm:"not null;index"` - Booking Booking `gorm:"foreignKey:BookingID;constraint:OnDelete:CASCADE"` - Amount float64 - PaymentMethod config.PaymentMethod + BookingID uint `gorm:"not null;index"` + Booking Booking `gorm:"foreignKey:BookingID;constraint:OnDelete:CASCADE"` + Amount float64 + PaymentMethod config.PaymentMethod + StripePaymentID *string `gorm:"size:255;uniqueIndex"` + StripeStatus *string `gorm:"size:32"` } diff --git a/internal/service/booking/payment.go b/internal/service/booking/payment.go index 8cc006b..2529e04 100644 --- a/internal/service/booking/payment.go +++ b/internal/service/booking/payment.go @@ -34,3 +34,12 @@ func (bs Service) UpdatePayment(id int, amount float64, paymentMethod string) *P } return p } + +func (bs Service) UpsertStripePayment(p *Payment) (*Payment, error) { + sp, err := bs.store.UpsertStripePayment(p) + if err != nil { + log.Println(err) + return nil, err + } + return sp, nil +} diff --git a/internal/service/booking/service.go b/internal/service/booking/service.go index d0074d1..ccc7a4e 100644 --- a/internal/service/booking/service.go +++ b/internal/service/booking/service.go @@ -1,10 +1,12 @@ package booking import ( + "context" "log/slog" "time" "github.com/rjNemo/rentease/internal/config" + stripeclient "github.com/rjNemo/rentease/internal/driver/stripe" ) type Store interface { @@ -27,6 +29,12 @@ type Store interface { CreatePayment(p *Payment) (*Payment, error) GetPayment(id int) (*Payment, error) UpdatePayment(id int, amount float64, paymentMethod string) (*Payment, error) + UpsertStripePayment(p *Payment) (*Payment, error) + FindStripePayment(stripePaymentID string) (*Payment, error) +} + +type StripeClient interface { + ListPayments(ctx context.Context, params stripeclient.ListPaymentsParams) ([]stripeclient.Payment, error) } type PdfClient interface { @@ -47,14 +55,21 @@ type Service struct { parser parserClient pdf PdfClient logger *slog.Logger + stripe StripeClient } -func NewService(logger *slog.Logger, store Store, parser parserClient, pdf PdfClient) (*Service, error) { +func NewService(logger *slog.Logger, store Store, parser parserClient, pdf PdfClient, stripe StripeClient) (*Service, error) { + svcLogger := logger + if svcLogger == nil { + svcLogger = slog.Default() + } + return &Service{ - logger: logger.With(slog.String("component", "booking_service")), + logger: svcLogger.With(slog.String("component", "booking_service")), store: store, parser: parser, pdf: pdf, + stripe: stripe, }, nil } diff --git a/internal/service/booking/stripe_sync.go b/internal/service/booking/stripe_sync.go new file mode 100644 index 0000000..9b5a6e7 --- /dev/null +++ b/internal/service/booking/stripe_sync.go @@ -0,0 +1,70 @@ +package booking + +import ( + "context" + "errors" + "log/slog" + "strings" + "time" + + "github.com/rjNemo/rentease/internal/config" + stripeclient "github.com/rjNemo/rentease/internal/driver/stripe" +) + +// ErrStripeClientNotConfigured indicates the service was asked to run a Stripe operation without a configured client. +var ErrStripeClientNotConfigured = errors.New("stripe client not configured") + +// SyncStripePayments pulls Stripe payments within the provided time window and +// upserts them into the local datastore. Payments lacking booking metadata are +// skipped to avoid incorrect associations. +func (bs Service) SyncStripePayments(ctx context.Context, from, to time.Time) error { + if bs.stripe == nil { + return ErrStripeClientNotConfigured + } + + payments, err := bs.stripe.ListPayments(ctx, stripeclient.ListPaymentsParams{From: from, To: to}) + if err != nil { + return err + } + + var multi error + for _, payment := range payments { + if payment.BookingID == nil { + bs.logger.Warn("stripe payment missing booking metadata", slog.String("payment_id", payment.ID)) + continue + } + + bookingID := uint(*payment.BookingID) + stripeID := payment.ID + status := strings.ToLower(payment.Status) + + _, err = bs.store.UpsertStripePayment(&Payment{ + BookingID: bookingID, + Amount: payment.Amount, + PaymentMethod: mapStripeMethod(payment.PaymentMethod), + StripePaymentID: &stripeID, + StripeStatus: &status, + }) + if err != nil { + multi = errors.Join(multi, err) + bs.logger.Error("failed to upsert stripe payment", slog.String("payment_id", payment.ID), slog.Any("error", err)) + } + } + + return multi +} + +func mapStripeMethod(method string) config.PaymentMethod { + switch strings.ToLower(method) { + case "card", "link", "apple_pay", "google_pay", "cashapp": + return config.PaymentMethod("Card") + case "ach_credit_transfer", "ach_debit", "us_bank_account", "sepa_debit", "bank_transfer", "blik", "bancontact": + return config.PaymentMethod("Transfer") + case "cash": + return config.PaymentMethod("Cash") + case "check": + return config.PaymentMethod("Cheque") + default: + return config.PaymentMethod("Card") + } +} diff --git a/internal/service/booking/stripe_sync_test.go b/internal/service/booking/stripe_sync_test.go new file mode 100644 index 0000000..80f56d3 --- /dev/null +++ b/internal/service/booking/stripe_sync_test.go @@ -0,0 +1,165 @@ +package booking + +import ( + "context" + "errors" + "log/slog" + "testing" + "time" + + "gorm.io/gorm" + + "github.com/rjNemo/rentease/internal/config" + stripeclient "github.com/rjNemo/rentease/internal/driver/stripe" +) + +type fakeStripeClient struct { + payments []stripeclient.Payment + err error +} + +func (f *fakeStripeClient) ListPayments(ctx context.Context, params stripeclient.ListPaymentsParams) ([]stripeclient.Payment, error) { + return f.payments, f.err +} + +type mockStore struct { + upserts []*Payment + err error + byStripeID map[string]*Payment +} + +func (m *mockStore) record(p *Payment) (*Payment, error) { + cp := *p + m.upserts = append(m.upserts, &cp) + if cp.StripePaymentID != nil { + if m.byStripeID == nil { + m.byStripeID = make(map[string]*Payment) + } + clone := cp + m.byStripeID[*cp.StripePaymentID] = &clone + } + if m.err != nil { + return nil, m.err + } + return &cp, nil +} + +func (m *mockStore) All() []*Line { return nil } +func (m *mockStore) Search(string) []*Line { return nil } +func (m *mockStore) List(time.Time, time.Time) ([]*Line, error) { return nil, nil } +func (m *mockStore) CardTotal(time.Time, time.Time) (float64, error) { return 0, nil } +func (m *mockStore) Get(int) *Booking { return nil } +func (m *mockStore) Create(*Booking) error { return nil } +func (m *mockStore) Update(*Booking) error { return nil } +func (m *mockStore) Cancel(int) error { return nil } +func (m *mockStore) CreateItem(*Item) error { return nil } +func (m *mockStore) PayItem(int) (*Item, error) { return nil, nil } +func (m *mockStore) GetItem(int) (*Item, error) { return nil, nil } +func (m *mockStore) UpdateItem(int, string, string, string, int, float64) (*Item, error) { + return nil, nil +} +func (m *mockStore) CreatePayment(*Payment) (*Payment, error) { return nil, nil } +func (m *mockStore) GetPayment(int) (*Payment, error) { return nil, nil } +func (m *mockStore) UpdatePayment(int, float64, string) (*Payment, error) { return nil, nil } +func (m *mockStore) UpsertStripePayment(p *Payment) (*Payment, error) { return m.record(p) } +func (m *mockStore) FindStripePayment(id string) (*Payment, error) { + if m.byStripeID == nil { + return nil, gorm.ErrRecordNotFound + } + if p, ok := m.byStripeID[id]; ok { + clone := *p + return &clone, nil + } + return nil, gorm.ErrRecordNotFound +} + +func TestSyncStripePayments(t *testing.T) { + bookingID := uint(42) + stripePayments := []stripeclient.Payment{ + { + ID: "pi_123", + Amount: 120.50, + PaymentMethod: "card", + Status: "succeeded", + BookingID: &bookingID, + }, + } + + store := &mockStore{} + stripe := &fakeStripeClient{payments: stripePayments} + logger := slog.New(slog.DiscardHandler) + + svc, err := NewService(logger, store, nil, nil, stripe) + if err != nil { + t.Fatalf("NewService returned error: %v", err) + } + + if err := svc.SyncStripePayments(context.Background(), time.Now().Add(-time.Hour), time.Now()); err != nil { + t.Fatalf("SyncStripePayments returned error: %v", err) + } + + if len(store.upserts) != 1 { + t.Fatalf("expected 1 upsert, got %d", len(store.upserts)) + } + + upsert := store.upserts[0] + if upsert.Amount != 120.50 { + t.Errorf("unexpected amount: %v", upsert.Amount) + } + if upsert.PaymentMethod != config.PaymentMethod("Card") { + t.Errorf("unexpected payment method: %v", upsert.PaymentMethod) + } + if upsert.StripePaymentID == nil || *upsert.StripePaymentID != "pi_123" { + t.Errorf("stripe payment id not set correctly: %v", upsert.StripePaymentID) + } +} + +func TestSyncStripePaymentsSkipsMissingBooking(t *testing.T) { + stripePayments := []stripeclient.Payment{ + {ID: "pi_123", Amount: 10}, + } + + store := &mockStore{} + stripe := &fakeStripeClient{payments: stripePayments} + logger := slog.New(slog.DiscardHandler) + + svc, err := NewService(logger, store, nil, nil, stripe) + if err != nil { + t.Fatalf("NewService returned error: %v", err) + } + + if err := svc.SyncStripePayments(context.Background(), time.Now().Add(-time.Hour), time.Now()); err != nil { + t.Fatalf("SyncStripePayments returned error: %v", err) + } + + if len(store.upserts) != 0 { + t.Fatalf("expected 0 upserts, got %d", len(store.upserts)) + } +} + +func TestSyncStripePaymentsReturnsAggregatedError(t *testing.T) { + bookingID := uint(7) + stripePayments := []stripeclient.Payment{ + { + ID: "pi_err", + Amount: 50, + PaymentMethod: "card", + Status: "succeeded", + BookingID: &bookingID, + }, + } + + store := &mockStore{err: errors.New("db failure")} + stripe := &fakeStripeClient{payments: stripePayments} + logger := slog.New(slog.DiscardHandler) + + svc, err := NewService(logger, store, nil, nil, stripe) + if err != nil { + t.Fatalf("NewService returned error: %v", err) + } + + err = svc.SyncStripePayments(context.Background(), time.Now().Add(-time.Hour), time.Now()) + if err == nil { + t.Fatalf("expected error, got nil") + } +} diff --git a/internal/service/booking/stripe_webhook.go b/internal/service/booking/stripe_webhook.go new file mode 100644 index 0000000..e3d9dc4 --- /dev/null +++ b/internal/service/booking/stripe_webhook.go @@ -0,0 +1,93 @@ +package booking + +import ( + "context" + "errors" + "log/slog" + "math" + "strings" + + stripe "github.com/stripe/stripe-go/v79" + "gorm.io/gorm" + + stripeclient "github.com/rjNemo/rentease/internal/driver/stripe" +) + +// HandlePaymentIntentSucceeded persists successful Stripe payment intents received via webhook. +func (bs Service) HandlePaymentIntentSucceeded(ctx context.Context, pi *stripe.PaymentIntent) error { + if pi == nil { + return errors.New("payment intent payload is missing") + } + + normalized := stripeclient.NormalizePaymentIntent(pi) + if normalized.ID == "" { + return errors.New("payment intent missing id") + } + + if normalized.BookingID == nil { + bs.logger.Warn("stripe webhook payment missing booking metadata", slog.String("payment_intent", normalized.ID)) + return nil + } + + bookingID := uint(*normalized.BookingID) + stripeID := normalized.ID + status := strings.ToLower(normalized.Status) + + _, err := bs.store.UpsertStripePayment(&Payment{ + BookingID: bookingID, + Amount: normalized.Amount, + PaymentMethod: mapStripeMethod(normalized.PaymentMethod), + StripePaymentID: &stripeID, + StripeStatus: &status, + }) + if err != nil { + return err + } + + bs.logger.Info("stripe payment intent processed", slog.String("payment_intent", normalized.ID), slog.Int("booking_id", int(bookingID))) + return nil +} + +// HandleChargeRefunded updates an existing Stripe payment when a charge is refunded. +func (bs Service) HandleChargeRefunded(ctx context.Context, ch *stripe.Charge) error { + if ch == nil { + return errors.New("charge payload is missing") + } + + if ch.PaymentIntent == nil || ch.PaymentIntent.ID == "" { + bs.logger.Warn("stripe refund missing payment intent", slog.String("charge", ch.ID)) + return nil + } + + existing, err := bs.store.FindStripePayment(ch.PaymentIntent.ID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + bs.logger.Warn("stripe refund received for unknown payment", slog.String("payment_intent", ch.PaymentIntent.ID)) + return nil + } + return err + } + + amount := existing.Amount + if ch.AmountRefunded > 0 { + net := float64(ch.Amount-ch.AmountRefunded) / 100.0 + amount = math.Max(net, 0) + } + + status := "refunded" + stripeID := ch.PaymentIntent.ID + + _, err = bs.store.UpsertStripePayment(&Payment{ + BookingID: existing.BookingID, + Amount: amount, + PaymentMethod: existing.PaymentMethod, + StripePaymentID: &stripeID, + StripeStatus: &status, + }) + if err != nil { + return err + } + + bs.logger.Info("stripe charge refunded processed", slog.String("charge", ch.ID), slog.String("payment_intent", ch.PaymentIntent.ID)) + return nil +} diff --git a/internal/service/booking/stripe_webhook_test.go b/internal/service/booking/stripe_webhook_test.go new file mode 100644 index 0000000..3b75a00 --- /dev/null +++ b/internal/service/booking/stripe_webhook_test.go @@ -0,0 +1,85 @@ +package booking + +import ( + "context" + "errors" + "log/slog" + "testing" + + stripe "github.com/stripe/stripe-go/v79" + + "github.com/rjNemo/rentease/internal/config" +) + +func TestHandleChargeRefundedUpdatesAmount(t *testing.T) { + store := &mockStore{} + stripeID := "pi_123" + status := "succeeded" + _, _ = store.UpsertStripePayment(&Payment{ + BookingID: 42, + Amount: 100, + PaymentMethod: config.PaymentMethod("Card"), + StripePaymentID: &stripeID, + StripeStatus: &status, + }) + + svc, err := NewService(slog.New(slog.DiscardHandler), store, nil, nil, nil) + if err != nil { + t.Fatalf("NewService returned error: %v", err) + } + + charge := &stripe.Charge{ + ID: "ch_123", + Amount: 10000, + AmountRefunded: 2500, + PaymentIntent: &stripe.PaymentIntent{ID: stripeID}, + } + + if err := svc.HandleChargeRefunded(context.Background(), charge); err != nil { + t.Fatalf("HandleChargeRefunded returned error: %v", err) + } + + updated, err := store.FindStripePayment(stripeID) + if err != nil { + t.Fatalf("expected payment to be present: %v", err) + } + + if updated.Amount != 75 { + t.Fatalf("expected amount 75, got %v", updated.Amount) + } + if updated.StripeStatus == nil || *updated.StripeStatus != "refunded" { + t.Fatalf("expected status refunded, got %v", updated.StripeStatus) + } +} + +func TestHandleChargeRefundedUnknownPayment(t *testing.T) { + store := &mockStore{} + svc, _ := NewService(slog.New(slog.DiscardHandler), store, nil, nil, nil) + + charge := &stripe.Charge{PaymentIntent: &stripe.PaymentIntent{ID: "pi_missing"}} + + if err := svc.HandleChargeRefunded(context.Background(), charge); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestHandleChargeRefundedStoreError(t *testing.T) { + store := &mockStore{} + stripeID := "pi_321" + status := "succeeded" + _, _ = store.UpsertStripePayment(&Payment{ + BookingID: 1, + Amount: 10, + PaymentMethod: config.PaymentMethod("Card"), + StripePaymentID: &stripeID, + StripeStatus: &status, + }) + store.err = errors.New("db error") + + svc, _ := NewService(slog.New(slog.DiscardHandler), store, nil, nil, nil) + charge := &stripe.Charge{PaymentIntent: &stripe.PaymentIntent{ID: stripeID}} + + if err := svc.HandleChargeRefunded(context.Background(), charge); err == nil { + t.Fatalf("expected error when store fails") + } +} diff --git a/internal/view/booking_viewmodel.go b/internal/view/booking_viewmodel.go index 1c1bf82..b9dadf0 100644 --- a/internal/view/booking_viewmodel.go +++ b/internal/view/booking_viewmodel.go @@ -32,4 +32,5 @@ type PaymentViewModel struct { Amount string PaymentMethod string PaymentUrl string + StripeStatus string } diff --git a/internal/view/payment.templ b/internal/view/payment.templ index 593e89a..0843c20 100644 --- a/internal/view/payment.templ +++ b/internal/view/payment.templ @@ -5,7 +5,12 @@ templ PaymentLine(payment *PaymentViewModel) {