mirror of
https://github.com/rjNemo/auth
synced 2026-06-06 00:16:40 +00:00
test: cover auth service and handlers
This commit is contained in:
parent
c02501329a
commit
a20953cfb4
2 changed files with 370 additions and 0 deletions
202
internal/server/server_test.go
Normal file
202
internal/server/server_test.go
Normal file
|
|
@ -0,0 +1,202 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/rjnemo/auth/internal/config"
|
||||||
|
"github.com/rjnemo/auth/internal/driver/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestServer(t *testing.T) *Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cfg := config.Config{
|
||||||
|
ListenAddr: ":0",
|
||||||
|
LogMode: logging.ModeText,
|
||||||
|
Environment: "test",
|
||||||
|
SessionSecret: bytes.Repeat([]byte("s"), 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := logging.New(io.Discard, logging.ModeText, nil)
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new server: %v", err)
|
||||||
|
}
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func attachSession(req *http.Request, state SessionState) *http.Request {
|
||||||
|
return req.WithContext(withSession(req.Context(), state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginPageHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := newTestServer(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req = attachSession(req, SessionState{CSRFToken: "token"})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
srv.loginPageHandler()(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); !strings.Contains(body, "Welcome back to Nucleus") {
|
||||||
|
t.Fatalf("expected login copy in response, got %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginHandlerSuccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := newTestServer(t)
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("email", "user@example.com")
|
||||||
|
form.Set("password", "Password123")
|
||||||
|
form.Set("_csrf", "csrf-token")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req = attachSession(req, SessionState{CSRFToken: "csrf-token"})
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
srv.loginHandler()(rr, req)
|
||||||
|
|
||||||
|
res := rr.Result()
|
||||||
|
if res.StatusCode != http.StatusSeeOther {
|
||||||
|
t.Fatalf("expected 303, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if loc := res.Header.Get("Location"); loc != "/dashboard" {
|
||||||
|
t.Fatalf("expected redirect to /dashboard, got %q", loc)
|
||||||
|
}
|
||||||
|
foundSession := false
|
||||||
|
for _, c := range res.Cookies() {
|
||||||
|
if c.Name == sessionCookieName && c.Value != "" {
|
||||||
|
foundSession = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundSession {
|
||||||
|
t.Fatal("expected session cookie to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginHandlerInvalidCredentials(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := newTestServer(t)
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("email", "user@example.com")
|
||||||
|
form.Set("password", "Password999")
|
||||||
|
form.Set("_csrf", "csrf-token")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req = attachSession(req, SessionState{CSRFToken: "csrf-token"})
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
srv.loginHandler()(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected 401, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); !strings.Contains(body, "Unable to sign in") {
|
||||||
|
t.Fatalf("expected failure message, got %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignupHandlerSuccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := newTestServer(t)
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("email", "new-user@example.com")
|
||||||
|
form.Set("password", "Password123")
|
||||||
|
form.Set("_csrf", "csrf-token")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req = attachSession(req, SessionState{CSRFToken: "csrf-token"})
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
srv.signupHandler()(rr, req)
|
||||||
|
|
||||||
|
res := rr.Result()
|
||||||
|
if res.StatusCode != http.StatusSeeOther {
|
||||||
|
t.Fatalf("expected 303, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if loc := res.Header.Get("Location"); loc != "/dashboard" {
|
||||||
|
t.Fatalf("expected redirect to /dashboard, got %q", loc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignupHandlerDuplicate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := newTestServer(t)
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("email", "user@example.com")
|
||||||
|
form.Set("password", "Password123")
|
||||||
|
form.Set("_csrf", "csrf-token")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req = attachSession(req, SessionState{CSRFToken: "csrf-token"})
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
srv.signupHandler()(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusConflict {
|
||||||
|
t.Fatalf("expected 409, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); !strings.Contains(body, "account with that email") {
|
||||||
|
t.Fatalf("expected duplicate email message, got %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardPageHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := newTestServer(t)
|
||||||
|
|
||||||
|
t.Run("unauthenticated", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||||
|
req = attachSession(req, SessionState{CSRFToken: "csrf"})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
srv.dashboardPageHandler()(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected 401, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); !strings.Contains(body, "Access denied") {
|
||||||
|
t.Fatalf("expected unauthorized template, got %q", body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("authenticated", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||||
|
req = attachSession(req, SessionState{Authenticated: true, Email: "user@example.com", CSRFToken: "csrf"})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
srv.dashboardPageHandler()(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
body := rr.Body.String()
|
||||||
|
if !strings.Contains(body, "Member since") {
|
||||||
|
t.Fatalf("expected membership text, got %q", body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
168
internal/service/auth/service_test.go
Normal file
168
internal/service/auth/service_test.go
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServiceAuthenticate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
store := NewMemoryStore()
|
||||||
|
service := NewService(store)
|
||||||
|
|
||||||
|
email := MustUserEmail("user@example.com")
|
||||||
|
salt, hash, err := HashPassword("Password123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("hash password: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.Create(ctx, User{Email: email, PasswordSalt: salt, PasswordHash: hash}); err != nil {
|
||||||
|
t.Fatalf("seed user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]struct {
|
||||||
|
email UserEmail
|
||||||
|
password string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
"invalid input": {email: email, password: "", wantErr: ErrInvalidInput},
|
||||||
|
"weak password": {email: email, password: "short1", wantErr: ErrWeakPassword},
|
||||||
|
"unknown account": {email: MustUserEmail("missing@example.com"), password: "Password123", wantErr: ErrInvalidCredentials},
|
||||||
|
"wrong password": {email: email, password: "Password999", wantErr: ErrInvalidCredentials},
|
||||||
|
"success": {email: email, password: "Password123", wantErr: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
account, err := service.Authenticate(ctx, tc.email, tc.password)
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
if !errors.Is(err, tc.wantErr) {
|
||||||
|
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
if account != nil {
|
||||||
|
t.Fatalf("expected no account, got %#v", account)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if account == nil {
|
||||||
|
t.Fatalf("expected account")
|
||||||
|
}
|
||||||
|
if account.Email != email {
|
||||||
|
t.Fatalf("expected email %q, got %q", email, account.Email)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServiceLookupByEmail(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
store := NewMemoryStore()
|
||||||
|
service := NewService(store)
|
||||||
|
|
||||||
|
email := MustUserEmail("lookup@example.com")
|
||||||
|
if err := store.Create(ctx, User{Email: email}); err != nil {
|
||||||
|
t.Fatalf("seed user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := map[string]struct {
|
||||||
|
email UserEmail
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
"zero": {email: UserEmail(""), wantErr: ErrInvalidInput},
|
||||||
|
"missing": {email: MustUserEmail("none@example.com"), wantErr: ErrUserNotFound},
|
||||||
|
"found": {email: email, wantErr: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
user, err := service.LookupByEmail(ctx, tc.email)
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
if !errors.Is(err, tc.wantErr) {
|
||||||
|
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
if user != nil {
|
||||||
|
t.Fatalf("expected nil user, got %#v", user)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if user == nil || user.Email != email {
|
||||||
|
t.Fatalf("expected user with email %q", email)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServiceRegister(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
store := NewMemoryStore()
|
||||||
|
service := NewService(store)
|
||||||
|
|
||||||
|
email := MustUserEmail("taken@example.com")
|
||||||
|
if err := store.Create(ctx, User{Email: email}); err != nil {
|
||||||
|
t.Fatalf("seed user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]struct {
|
||||||
|
email UserEmail
|
||||||
|
password string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
"invalid input": {email: UserEmail(""), password: "", wantErr: ErrInvalidInput},
|
||||||
|
"weak password": {email: MustUserEmail("weak@example.com"), password: "weak", wantErr: ErrWeakPassword},
|
||||||
|
"duplicate": {email: email, password: "Password123", wantErr: ErrEmailExists},
|
||||||
|
"success": {email: MustUserEmail("new@example.com"), password: "Password123", wantErr: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
user, err := service.Register(ctx, tc.email, tc.password)
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
if !errors.Is(err, tc.wantErr) {
|
||||||
|
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
if user != nil {
|
||||||
|
t.Fatalf("expected nil user, got %#v", user)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if user.Email != tc.email {
|
||||||
|
t.Fatalf("expected email %q, got %q", tc.email, user.Email)
|
||||||
|
}
|
||||||
|
if user.CreatedAt.IsZero() {
|
||||||
|
t.Fatal("expected CreatedAt to be set")
|
||||||
|
}
|
||||||
|
// Ensure the user is persisted with hashed credentials.
|
||||||
|
persisted, err := store.FindByEmail(ctx, tc.email)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected persisted user: %v", err)
|
||||||
|
}
|
||||||
|
if persisted.PasswordSalt == "" || persisted.PasswordHash == "" {
|
||||||
|
t.Fatal("expected password salt/hash to be stored")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue