Include route in scope to allow middleware and other tools to extract its information (#4603)

This commit is contained in:
Sebastián Ramírez 2022-02-21 16:51:26 +01:00 committed by GitHub
parent 1ce16c2f40
commit f5d7df3c6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 2 deletions

View file

@ -13,6 +13,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple,
Type, Type,
Union, Union,
) )
@ -44,7 +45,7 @@ from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.routing import BaseRoute from starlette.routing import BaseRoute, Match
from starlette.routing import Mount as Mount # noqa from starlette.routing import Mount as Mount # noqa
from starlette.routing import ( from starlette.routing import (
compile_path, compile_path,
@ -53,7 +54,7 @@ from starlette.routing import (
websocket_session, websocket_session,
) )
from starlette.status import WS_1008_POLICY_VIOLATION from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp from starlette.types import ASGIApp, Scope
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
@ -296,6 +297,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
) )
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRoute(routing.Route): class APIRoute(routing.Route):
def __init__( def __init__(
@ -432,6 +439,12 @@ class APIRoute(routing.Route):
dependency_overrides_provider=self.dependency_overrides_provider, dependency_overrides_provider=self.dependency_overrides_provider,
) )
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRouter(routing.Router): class APIRouter(routing.Router):
def __init__( def __init__(

50
tests/test_route_scope.py Normal file
View file

@ -0,0 +1,50 @@
import pytest
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.routing import APIRoute, APIWebSocketRoute
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/users/{user_id}")
async def get_user(user_id: str, request: Request):
route: APIRoute = request.scope["route"]
return {"user_id": user_id, "path": route.path}
@app.websocket("/items/{item_id}")
async def websocket_item(item_id: str, websocket: WebSocket):
route: APIWebSocketRoute = websocket.scope["route"]
await websocket.accept()
await websocket.send_json({"item_id": item_id, "path": route.path})
client = TestClient(app)
def test_get():
response = client.get("/users/rick")
assert response.status_code == 200, response.text
assert response.json() == {"user_id": "rick", "path": "/users/{user_id}"}
def test_invalid_method_doesnt_match():
response = client.post("/users/rick")
assert response.status_code == 405, response.text
def test_invalid_path_doesnt_match():
response = client.post("/usersx/rick")
assert response.status_code == 404, response.text
def test_websocket():
with client.websocket_connect("/items/portal-gun") as websocket:
data = websocket.receive_json()
assert data == {"item_id": "portal-gun", "path": "/items/{item_id}"}
def test_websocket_invalid_path_doesnt_match():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/itemsx/portal-gun") as websocket:
websocket.receive_json()