mirror of
https://github.com/rjNemo/fastapi
synced 2026-06-12 05:26:45 +00:00
This commit is contained in:
parent
219d299426
commit
b087246f26
12 changed files with 305 additions and 17 deletions
0
docs/src/websockets/__init__.py
Normal file
0
docs/src/websockets/__init__.py
Normal file
|
|
@ -44,10 +44,9 @@ async def get():
|
||||||
return HTMLResponse(html)
|
return HTMLResponse(html)
|
||||||
|
|
||||||
|
|
||||||
@app.websocket_route("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive_text()
|
data = await websocket.receive_text()
|
||||||
await websocket.send_text(f"Message text was: {data}")
|
await websocket.send_text(f"Message text was: {data}")
|
||||||
await websocket.close()
|
|
||||||
|
|
|
||||||
78
docs/src/websockets/tutorial002.py
Normal file
78
docs/src/websockets/tutorial002.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
from fastapi import Cookie, Depends, FastAPI, Header
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
from starlette.status import WS_1008_POLICY_VIOLATION
|
||||||
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
html = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Chat</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>WebSocket Chat</h1>
|
||||||
|
<form action="" onsubmit="sendMessage(event)">
|
||||||
|
<label>Item ID: <input type="text" id="itemId" autocomplete="off" value="foo"/></label>
|
||||||
|
<button onclick="connect(event)">Connect</button>
|
||||||
|
<br>
|
||||||
|
<label>Message: <input type="text" id="messageText" autocomplete="off"/></label>
|
||||||
|
<button>Send</button>
|
||||||
|
</form>
|
||||||
|
<ul id='messages'>
|
||||||
|
</ul>
|
||||||
|
<script>
|
||||||
|
var ws = null;
|
||||||
|
function connect(event) {
|
||||||
|
var input = document.getElementById("itemId")
|
||||||
|
ws = new WebSocket("ws://localhost:8000/items/" + input.value + "/ws");
|
||||||
|
ws.onmessage = function(event) {
|
||||||
|
var messages = document.getElementById('messages')
|
||||||
|
var message = document.createElement('li')
|
||||||
|
var content = document.createTextNode(event.data)
|
||||||
|
message.appendChild(content)
|
||||||
|
messages.appendChild(message)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
function sendMessage(event) {
|
||||||
|
var input = document.getElementById("messageText")
|
||||||
|
ws.send(input.value)
|
||||||
|
input.value = ''
|
||||||
|
event.preventDefault()
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def get():
|
||||||
|
return HTMLResponse(html)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cookie_or_client(
|
||||||
|
websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None)
|
||||||
|
):
|
||||||
|
if session is None and x_client is None:
|
||||||
|
await websocket.close(code=WS_1008_POLICY_VIOLATION)
|
||||||
|
return session or x_client
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/items/{item_id}/ws")
|
||||||
|
async def websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
item_id: int,
|
||||||
|
q: str = None,
|
||||||
|
cookie_or_client: str = Depends(get_cookie_or_client),
|
||||||
|
):
|
||||||
|
await websocket.accept()
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
await websocket.send_text(
|
||||||
|
f"Session Cookie or X-Client Header value is: {cookie_or_client}"
|
||||||
|
)
|
||||||
|
if q is not None:
|
||||||
|
await websocket.send_text(f"Query parameter q is: {q}")
|
||||||
|
await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")
|
||||||
|
|
@ -27,9 +27,9 @@ But it's the simplest way to focus on the server-side of WebSockets and have a w
|
||||||
{!./src/websockets/tutorial001.py!}
|
{!./src/websockets/tutorial001.py!}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Create a `websocket_route`
|
## Create a `websocket`
|
||||||
|
|
||||||
In your **FastAPI** application, create a `websocket_route`:
|
In your **FastAPI** application, create a `websocket`:
|
||||||
|
|
||||||
```Python hl_lines="3 47 48"
|
```Python hl_lines="3 47 48"
|
||||||
{!./src/websockets/tutorial001.py!}
|
{!./src/websockets/tutorial001.py!}
|
||||||
|
|
@ -38,15 +38,6 @@ In your **FastAPI** application, create a `websocket_route`:
|
||||||
!!! tip
|
!!! tip
|
||||||
In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
|
In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
|
||||||
|
|
||||||
That is not required, but it's recommended as it will provide you completion and checks inside the function.
|
|
||||||
|
|
||||||
|
|
||||||
!!! info
|
|
||||||
This `websocket_route` we are using comes directly from <a href="https://www.starlette.io/applications/" target="_blank">Starlette</a>.
|
|
||||||
|
|
||||||
That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc).
|
|
||||||
|
|
||||||
|
|
||||||
## Await for messages and send messages
|
## Await for messages and send messages
|
||||||
|
|
||||||
In your WebSocket route you can `await` for messages and send messages.
|
In your WebSocket route you can `await` for messages and send messages.
|
||||||
|
|
@ -57,6 +48,32 @@ In your WebSocket route you can `await` for messages and send messages.
|
||||||
|
|
||||||
You can receive and send binary, text, and JSON data.
|
You can receive and send binary, text, and JSON data.
|
||||||
|
|
||||||
|
## Using `Depends` and others
|
||||||
|
|
||||||
|
In WebSocket endpoints you can import from `fastapi` and use:
|
||||||
|
|
||||||
|
* `Depends`
|
||||||
|
* `Security`
|
||||||
|
* `Cookie`
|
||||||
|
* `Header`
|
||||||
|
* `Path`
|
||||||
|
* `Query`
|
||||||
|
|
||||||
|
They work the same way as for other FastAPI endpoints/*path operations*:
|
||||||
|
|
||||||
|
```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78"
|
||||||
|
{!./src/websockets/tutorial002.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly.
|
||||||
|
|
||||||
|
You can use a closing code from the <a href="https://tools.ietf.org/html/rfc6455#section-7.4.1" target="_blank">valid codes defined in the specification</a>.
|
||||||
|
|
||||||
|
In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the <a href="https://github.com/encode/starlette/pull/527" target="_blank">PR #527</a> in Starlette.
|
||||||
|
|
||||||
|
## More info
|
||||||
|
|
||||||
To learn more about the options, check Starlette's documentation for:
|
To learn more about the options, check Starlette's documentation for:
|
||||||
|
|
||||||
* <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>.
|
* <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>.
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,18 @@ class FastAPI(Starlette):
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def add_api_websocket_route(
|
||||||
|
self, path: str, endpoint: Callable, name: str = None
|
||||||
|
) -> None:
|
||||||
|
self.router.add_api_websocket_route(path, endpoint, name=name)
|
||||||
|
|
||||||
|
def websocket(self, path: str, name: str = None) -> Callable:
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
self.add_api_websocket_route(path, func, name=name)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
def include_router(
|
def include_router(
|
||||||
self,
|
self,
|
||||||
router: routing.APIRouter,
|
router: routing.APIRouter,
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class Dependant:
|
||||||
name: str = None,
|
name: str = None,
|
||||||
call: Callable = None,
|
call: Callable = None,
|
||||||
request_param_name: str = None,
|
request_param_name: str = None,
|
||||||
|
websocket_param_name: str = None,
|
||||||
background_tasks_param_name: str = None,
|
background_tasks_param_name: str = None,
|
||||||
security_scopes_param_name: str = None,
|
security_scopes_param_name: str = None,
|
||||||
security_scopes: List[str] = None,
|
security_scopes: List[str] = None,
|
||||||
|
|
@ -38,6 +39,7 @@ class Dependant:
|
||||||
self.dependencies = dependencies or []
|
self.dependencies = dependencies or []
|
||||||
self.security_requirements = security_schemes or []
|
self.security_requirements = security_schemes or []
|
||||||
self.request_param_name = request_param_name
|
self.request_param_name = request_param_name
|
||||||
|
self.websocket_param_name = websocket_param_name
|
||||||
self.background_tasks_param_name = background_tasks_param_name
|
self.background_tasks_param_name = background_tasks_param_name
|
||||||
self.security_scopes = security_scopes
|
self.security_scopes = security_scopes
|
||||||
self.security_scopes_param_name = security_scopes_param_name
|
self.security_scopes_param_name = security_scopes_param_name
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from starlette.background import BackgroundTasks
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
param_supported_types = (
|
param_supported_types = (
|
||||||
str,
|
str,
|
||||||
|
|
@ -184,6 +185,8 @@ def get_dependant(
|
||||||
)
|
)
|
||||||
elif lenient_issubclass(param.annotation, Request):
|
elif lenient_issubclass(param.annotation, Request):
|
||||||
dependant.request_param_name = param_name
|
dependant.request_param_name = param_name
|
||||||
|
elif lenient_issubclass(param.annotation, WebSocket):
|
||||||
|
dependant.websocket_param_name = param_name
|
||||||
elif lenient_issubclass(param.annotation, BackgroundTasks):
|
elif lenient_issubclass(param.annotation, BackgroundTasks):
|
||||||
dependant.background_tasks_param_name = param_name
|
dependant.background_tasks_param_name = param_name
|
||||||
elif lenient_issubclass(param.annotation, SecurityScopes):
|
elif lenient_issubclass(param.annotation, SecurityScopes):
|
||||||
|
|
@ -279,7 +282,7 @@ def is_coroutine_callable(call: Callable) -> bool:
|
||||||
|
|
||||||
async def solve_dependencies(
|
async def solve_dependencies(
|
||||||
*,
|
*,
|
||||||
request: Request,
|
request: Union[Request, WebSocket],
|
||||||
dependant: Dependant,
|
dependant: Dependant,
|
||||||
body: Dict[str, Any] = None,
|
body: Dict[str, Any] = None,
|
||||||
background_tasks: BackgroundTasks = None,
|
background_tasks: BackgroundTasks = None,
|
||||||
|
|
@ -326,8 +329,10 @@ async def solve_dependencies(
|
||||||
)
|
)
|
||||||
values.update(body_values)
|
values.update(body_values)
|
||||||
errors.extend(body_errors)
|
errors.extend(body_errors)
|
||||||
if dependant.request_param_name:
|
if dependant.request_param_name and isinstance(request, Request):
|
||||||
values[dependant.request_param_name] = request
|
values[dependant.request_param_name] = request
|
||||||
|
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
||||||
|
values[dependant.websocket_param_name] = request
|
||||||
if dependant.background_tasks_param_name:
|
if dependant.background_tasks_param_name:
|
||||||
if background_tasks is None:
|
if background_tasks is None:
|
||||||
background_tasks = BackgroundTasks()
|
background_tasks = BackgroundTasks()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from fastapi import params
|
from fastapi import params
|
||||||
|
|
@ -21,8 +22,14 @@ from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
from starlette.routing import compile_path, get_name, request_response
|
from starlette.routing import (
|
||||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
compile_path,
|
||||||
|
get_name,
|
||||||
|
request_response,
|
||||||
|
websocket_session,
|
||||||
|
)
|
||||||
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
|
||||||
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
|
||||||
def serialize_response(*, field: Field = None, response: Response) -> Any:
|
def serialize_response(*, field: Field = None, response: Response) -> Any:
|
||||||
|
|
@ -97,6 +104,35 @@ def get_app(
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def get_websocket_app(dependant: Dependant) -> Callable:
|
||||||
|
async def app(websocket: WebSocket) -> None:
|
||||||
|
values, errors, _ = await solve_dependencies(
|
||||||
|
request=websocket, dependant=dependant
|
||||||
|
)
|
||||||
|
if errors:
|
||||||
|
await websocket.close(code=WS_1008_POLICY_VIOLATION)
|
||||||
|
errors_out = ValidationError(errors)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
|
||||||
|
)
|
||||||
|
assert dependant.call is not None, "dependant.call must me a function"
|
||||||
|
await dependant.call(**values)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class APIWebSocketRoute(routing.WebSocketRoute):
|
||||||
|
def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
|
||||||
|
self.path = path
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.name = get_name(endpoint) if name is None else name
|
||||||
|
self.dependant = get_dependant(path=path, call=self.endpoint)
|
||||||
|
self.app = websocket_session(get_websocket_app(dependant=self.dependant))
|
||||||
|
regex = "^" + path + "$"
|
||||||
|
regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
|
||||||
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||||
|
|
||||||
|
|
||||||
class APIRoute(routing.Route):
|
class APIRoute(routing.Route):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -281,6 +317,19 @@ class APIRouter(routing.Router):
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def add_api_websocket_route(
|
||||||
|
self, path: str, endpoint: Callable, name: str = None
|
||||||
|
) -> None:
|
||||||
|
route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
|
||||||
|
self.routes.append(route)
|
||||||
|
|
||||||
|
def websocket(self, path: str, name: str = None) -> Callable:
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
self.add_api_websocket_route(path, func, name=name)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
def include_router(
|
def include_router(
|
||||||
self,
|
self,
|
||||||
router: "APIRouter",
|
router: "APIRouter",
|
||||||
|
|
@ -326,6 +375,10 @@ class APIRouter(routing.Router):
|
||||||
include_in_schema=route.include_in_schema,
|
include_in_schema=route.include_in_schema,
|
||||||
name=route.name,
|
name=route.name,
|
||||||
)
|
)
|
||||||
|
elif isinstance(route, APIWebSocketRoute):
|
||||||
|
self.add_api_websocket_route(
|
||||||
|
prefix + route.path, route.endpoint, name=route.name
|
||||||
|
)
|
||||||
elif isinstance(route, routing.WebSocketRoute):
|
elif isinstance(route, routing.WebSocketRoute):
|
||||||
self.add_websocket_route(
|
self.add_websocket_route(
|
||||||
prefix + route.path, route.endpoint, name=route.name
|
prefix + route.path, route.endpoint, name=route.name
|
||||||
|
|
|
||||||
0
tests/test_tutorial/test_websockets/__init__.py
Normal file
0
tests/test_tutorial/test_websockets/__init__.py
Normal file
25
tests/test_tutorial/test_websockets/test_tutorial001.py
Normal file
25
tests/test_tutorial/test_websockets/test_tutorial001.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
from websockets.tutorial001 import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_main():
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"<!DOCTYPE html>" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket():
|
||||||
|
with pytest.raises(WebSocketDisconnect):
|
||||||
|
with client.websocket_connect("/ws") as websocket:
|
||||||
|
message = "Message one"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}"
|
||||||
|
message = "Message two"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}"
|
||||||
83
tests/test_tutorial/test_websockets/test_tutorial002.py
Normal file
83
tests/test_tutorial/test_websockets/test_tutorial002.py
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
from websockets.tutorial002 import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_main():
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"<!DOCTYPE html>" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_with_cookie():
|
||||||
|
with pytest.raises(WebSocketDisconnect):
|
||||||
|
with client.websocket_connect(
|
||||||
|
"/items/1/ws", cookies={"session": "fakesession"}
|
||||||
|
) as websocket:
|
||||||
|
message = "Message one"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Session Cookie or X-Client Header value is: fakesession"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}, for item ID: 1"
|
||||||
|
message = "Message two"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Session Cookie or X-Client Header value is: fakesession"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}, for item ID: 1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_with_header():
|
||||||
|
with pytest.raises(WebSocketDisconnect):
|
||||||
|
with client.websocket_connect(
|
||||||
|
"/items/2/ws", headers={"X-Client": "xmen"}
|
||||||
|
) as websocket:
|
||||||
|
message = "Message one"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}, for item ID: 2"
|
||||||
|
message = "Message two"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}, for item ID: 2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_with_header_and_query():
|
||||||
|
with pytest.raises(WebSocketDisconnect):
|
||||||
|
with client.websocket_connect(
|
||||||
|
"/items/2/ws?q=baz", headers={"X-Client": "xmen"}
|
||||||
|
) as websocket:
|
||||||
|
message = "Message one"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Query parameter q is: baz"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}, for item ID: 2"
|
||||||
|
message = "Message two"
|
||||||
|
websocket.send_text(message)
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Query parameter q is: baz"
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == f"Message text was: {message}, for item ID: 2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_no_credentials():
|
||||||
|
with pytest.raises(WebSocketDisconnect):
|
||||||
|
client.websocket_connect("/items/2/ws")
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_invalid_data():
|
||||||
|
with pytest.raises(WebSocketDisconnect):
|
||||||
|
client.websocket_connect("/items/foo/ws", headers={"X-Client": "xmen"})
|
||||||
|
|
@ -28,6 +28,13 @@ async def routerprefixindex(websocket: WebSocket):
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/router2")
|
||||||
|
async def routerindex(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_text("Hello, router!")
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
app.include_router(prefix_router, prefix="/prefix")
|
app.include_router(prefix_router, prefix="/prefix")
|
||||||
|
|
||||||
|
|
@ -51,3 +58,10 @@ def test_prefix_router():
|
||||||
with client.websocket_connect("/prefix/") as websocket:
|
with client.websocket_connect("/prefix/") as websocket:
|
||||||
data = websocket.receive_text()
|
data = websocket.receive_text()
|
||||||
assert data == "Hello, router with prefix!"
|
assert data == "Hello, router with prefix!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_router2():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client.websocket_connect("/router2") as websocket:
|
||||||
|
data = websocket.receive_text()
|
||||||
|
assert data == "Hello, router!"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue