From a20953cfb455c40b60b6a3e77e467265c8e59d03 Mon Sep 17 00:00:00 2001 From: Ruidy Date: Sat, 20 Sep 2025 17:43:20 +0200 Subject: [PATCH] test: cover auth service and handlers --- internal/server/server_test.go | 202 ++++++++++++++++++++++++++ internal/service/auth/service_test.go | 168 +++++++++++++++++++++ 2 files changed, 370 insertions(+) create mode 100644 internal/server/server_test.go create mode 100644 internal/service/auth/service_test.go diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..194ea52 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,202 @@ +package server + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/rjnemo/auth/internal/config" + "github.com/rjnemo/auth/internal/driver/logging" +) + +func newTestServer(t *testing.T) *Server { + t.Helper() + + cfg := config.Config{ + ListenAddr: ":0", + LogMode: logging.ModeText, + Environment: "test", + SessionSecret: bytes.Repeat([]byte("s"), 32), + } + + logger := logging.New(io.Discard, logging.ModeText, nil) + + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("new server: %v", err) + } + return srv +} + +func attachSession(req *http.Request, state SessionState) *http.Request { + return req.WithContext(withSession(req.Context(), state)) +} + +func TestLoginPageHandler(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req = attachSession(req, SessionState{CSRFToken: "token"}) + rr := httptest.NewRecorder() + + srv.loginPageHandler()(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "Welcome back to Nucleus") { + t.Fatalf("expected login copy in response, got %q", body) + } +} + +func TestLoginHandlerSuccess(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + + form := url.Values{} + form.Set("email", "user@example.com") + form.Set("password", "Password123") + form.Set("_csrf", "csrf-token") + + req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = attachSession(req, SessionState{CSRFToken: "csrf-token"}) + + rr := httptest.NewRecorder() + srv.loginHandler()(rr, req) + + res := rr.Result() + if res.StatusCode != http.StatusSeeOther { + t.Fatalf("expected 303, got %d", res.StatusCode) + } + if loc := res.Header.Get("Location"); loc != "/dashboard" { + t.Fatalf("expected redirect to /dashboard, got %q", loc) + } + foundSession := false + for _, c := range res.Cookies() { + if c.Name == sessionCookieName && c.Value != "" { + foundSession = true + } + } + if !foundSession { + t.Fatal("expected session cookie to be set") + } +} + +func TestLoginHandlerInvalidCredentials(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + + form := url.Values{} + form.Set("email", "user@example.com") + form.Set("password", "Password999") + form.Set("_csrf", "csrf-token") + + req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = attachSession(req, SessionState{CSRFToken: "csrf-token"}) + + rr := httptest.NewRecorder() + srv.loginHandler()(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "Unable to sign in") { + t.Fatalf("expected failure message, got %q", body) + } +} + +func TestSignupHandlerSuccess(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + + form := url.Values{} + form.Set("email", "new-user@example.com") + form.Set("password", "Password123") + form.Set("_csrf", "csrf-token") + + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = attachSession(req, SessionState{CSRFToken: "csrf-token"}) + + rr := httptest.NewRecorder() + srv.signupHandler()(rr, req) + + res := rr.Result() + if res.StatusCode != http.StatusSeeOther { + t.Fatalf("expected 303, got %d", res.StatusCode) + } + if loc := res.Header.Get("Location"); loc != "/dashboard" { + t.Fatalf("expected redirect to /dashboard, got %q", loc) + } +} + +func TestSignupHandlerDuplicate(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + + form := url.Values{} + form.Set("email", "user@example.com") + form.Set("password", "Password123") + form.Set("_csrf", "csrf-token") + + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = attachSession(req, SessionState{CSRFToken: "csrf-token"}) + + rr := httptest.NewRecorder() + srv.signupHandler()(rr, req) + + if rr.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "account with that email") { + t.Fatalf("expected duplicate email message, got %q", body) + } +} + +func TestDashboardPageHandler(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + + t.Run("unauthenticated", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) + req = attachSession(req, SessionState{CSRFToken: "csrf"}) + rr := httptest.NewRecorder() + srv.dashboardPageHandler()(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "Access denied") { + t.Fatalf("expected unauthorized template, got %q", body) + } + }) + + t.Run("authenticated", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) + req = attachSession(req, SessionState{Authenticated: true, Email: "user@example.com", CSRFToken: "csrf"}) + rr := httptest.NewRecorder() + srv.dashboardPageHandler()(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + body := rr.Body.String() + if !strings.Contains(body, "Member since") { + t.Fatalf("expected membership text, got %q", body) + } + }) +} diff --git a/internal/service/auth/service_test.go b/internal/service/auth/service_test.go new file mode 100644 index 0000000..a9eb52f --- /dev/null +++ b/internal/service/auth/service_test.go @@ -0,0 +1,168 @@ +package auth + +import ( + "context" + "errors" + "testing" +) + +func TestServiceAuthenticate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := NewMemoryStore() + service := NewService(store) + + email := MustUserEmail("user@example.com") + salt, hash, err := HashPassword("Password123") + if err != nil { + t.Fatalf("hash password: %v", err) + } + if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash}); err != nil { + t.Fatalf("seed user: %v", err) + } + + tests := map[string]struct { + email UserEmail + password string + wantErr error + }{ + "invalid input": {email: email, password: "", wantErr: ErrInvalidInput}, + "weak password": {email: email, password: "short1", wantErr: ErrWeakPassword}, + "unknown account": {email: MustUserEmail("missing@example.com"), password: "Password123", wantErr: ErrInvalidCredentials}, + "wrong password": {email: email, password: "Password999", wantErr: ErrInvalidCredentials}, + "success": {email: email, password: "Password123", wantErr: nil}, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + account, err := service.Authenticate(ctx, tc.email, tc.password) + if tc.wantErr != nil { + if !errors.Is(err, tc.wantErr) { + t.Fatalf("expected %v, got %v", tc.wantErr, err) + } + if account != nil { + t.Fatalf("expected no account, got %#v", account) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if account == nil { + t.Fatalf("expected account") + } + if account.Email != email { + t.Fatalf("expected email %q, got %q", email, account.Email) + } + }) + } +} + +func TestServiceLookupByEmail(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := NewMemoryStore() + service := NewService(store) + + email := MustUserEmail("lookup@example.com") + if err := store.Create(ctx, User{Email: email}); err != nil { + t.Fatalf("seed user: %v", err) + } + + cases := map[string]struct { + email UserEmail + wantErr error + }{ + "zero": {email: UserEmail(""), wantErr: ErrInvalidInput}, + "missing": {email: MustUserEmail("none@example.com"), wantErr: ErrUserNotFound}, + "found": {email: email, wantErr: nil}, + } + + for name, tc := range cases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + user, err := service.LookupByEmail(ctx, tc.email) + if tc.wantErr != nil { + if !errors.Is(err, tc.wantErr) { + t.Fatalf("expected %v, got %v", tc.wantErr, err) + } + if user != nil { + t.Fatalf("expected nil user, got %#v", user) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user == nil || user.Email != email { + t.Fatalf("expected user with email %q", email) + } + }) + } +} + +func TestServiceRegister(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := NewMemoryStore() + service := NewService(store) + + email := MustUserEmail("taken@example.com") + if err := store.Create(ctx, User{Email: email}); err != nil { + t.Fatalf("seed user: %v", err) + } + + tests := map[string]struct { + email UserEmail + password string + wantErr error + }{ + "invalid input": {email: UserEmail(""), password: "", wantErr: ErrInvalidInput}, + "weak password": {email: MustUserEmail("weak@example.com"), password: "weak", wantErr: ErrWeakPassword}, + "duplicate": {email: email, password: "Password123", wantErr: ErrEmailExists}, + "success": {email: MustUserEmail("new@example.com"), password: "Password123", wantErr: nil}, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + user, err := service.Register(ctx, tc.email, tc.password) + if tc.wantErr != nil { + if !errors.Is(err, tc.wantErr) { + t.Fatalf("expected %v, got %v", tc.wantErr, err) + } + if user != nil { + t.Fatalf("expected nil user, got %#v", user) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.Email != tc.email { + t.Fatalf("expected email %q, got %q", tc.email, user.Email) + } + if user.CreatedAt.IsZero() { + t.Fatal("expected CreatedAt to be set") + } + // Ensure the user is persisted with hashed credentials. + persisted, err := store.FindByEmail(ctx, tc.email) + if err != nil { + t.Fatalf("expected persisted user: %v", err) + } + if persisted.PasswordSalt == "" || persisted.PasswordHash == "" { + t.Fatal("expected password salt/hash to be stored") + } + }) + } +}