mirror of
https://github.com/rjNemo/fastapi
synced 2026-06-11 21:16:45 +00:00
🐛 Fix removing body from status codes that do not support it (#5145)
This commit is contained in:
parent
a0fd613527
commit
c43120258f
5 changed files with 38 additions and 24 deletions
|
|
@ -1,3 +1,2 @@
|
||||||
METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
|
METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
|
||||||
STATUS_CODES_WITH_NO_BODY = {100, 101, 102, 103, 204, 304}
|
|
||||||
REF_PREFIX = "#/components/schemas/"
|
REF_PREFIX = "#/components/schemas/"
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,7 @@ from fastapi.datastructures import DefaultPlaceholder
|
||||||
from fastapi.dependencies.models import Dependant
|
from fastapi.dependencies.models import Dependant
|
||||||
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
|
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.openapi.constants import (
|
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
|
||||||
METHODS_WITH_BODY,
|
|
||||||
REF_PREFIX,
|
|
||||||
STATUS_CODES_WITH_NO_BODY,
|
|
||||||
)
|
|
||||||
from fastapi.openapi.models import OpenAPI
|
from fastapi.openapi.models import OpenAPI
|
||||||
from fastapi.params import Body, Param
|
from fastapi.params import Body, Param
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
|
@ -21,6 +17,7 @@ from fastapi.utils import (
|
||||||
deep_dict_update,
|
deep_dict_update,
|
||||||
generate_operation_id_for_path,
|
generate_operation_id_for_path,
|
||||||
get_model_definitions,
|
get_model_definitions,
|
||||||
|
is_body_allowed_for_status_code,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import ModelField, Undefined
|
from pydantic.fields import ModelField, Undefined
|
||||||
|
|
@ -265,9 +262,8 @@ def get_openapi_path(
|
||||||
operation.setdefault("responses", {}).setdefault(status_code, {})[
|
operation.setdefault("responses", {}).setdefault(status_code, {})[
|
||||||
"description"
|
"description"
|
||||||
] = route.response_description
|
] = route.response_description
|
||||||
if (
|
if route_response_media_type and is_body_allowed_for_status_code(
|
||||||
route_response_media_type
|
route.status_code
|
||||||
and route.status_code not in STATUS_CODES_WITH_NO_BODY
|
|
||||||
):
|
):
|
||||||
response_schema = {"type": "string"}
|
response_schema = {"type": "string"}
|
||||||
if lenient_issubclass(current_response_class, JSONResponse):
|
if lenient_issubclass(current_response_class, JSONResponse):
|
||||||
|
|
|
||||||
|
|
@ -29,13 +29,13 @@ from fastapi.dependencies.utils import (
|
||||||
)
|
)
|
||||||
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
|
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
|
||||||
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
|
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
|
||||||
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
|
|
||||||
from fastapi.types import DecoratedCallable
|
from fastapi.types import DecoratedCallable
|
||||||
from fastapi.utils import (
|
from fastapi.utils import (
|
||||||
create_cloned_field,
|
create_cloned_field,
|
||||||
create_response_field,
|
create_response_field,
|
||||||
generate_unique_id,
|
generate_unique_id,
|
||||||
get_value_or_default,
|
get_value_or_default,
|
||||||
|
is_body_allowed_for_status_code,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
||||||
|
|
@ -232,7 +232,17 @@ def get_request_handler(
|
||||||
if raw_response.background is None:
|
if raw_response.background is None:
|
||||||
raw_response.background = background_tasks
|
raw_response.background = background_tasks
|
||||||
return raw_response
|
return raw_response
|
||||||
response_data = await serialize_response(
|
response_args: Dict[str, Any] = {"background": background_tasks}
|
||||||
|
# If status_code was set, use it, otherwise use the default from the
|
||||||
|
# response class, in the case of redirect it's 307
|
||||||
|
current_status_code = (
|
||||||
|
status_code if status_code else sub_response.status_code
|
||||||
|
)
|
||||||
|
if current_status_code is not None:
|
||||||
|
response_args["status_code"] = current_status_code
|
||||||
|
if sub_response.status_code:
|
||||||
|
response_args["status_code"] = sub_response.status_code
|
||||||
|
content = await serialize_response(
|
||||||
field=response_field,
|
field=response_field,
|
||||||
response_content=raw_response,
|
response_content=raw_response,
|
||||||
include=response_model_include,
|
include=response_model_include,
|
||||||
|
|
@ -243,15 +253,10 @@ def get_request_handler(
|
||||||
exclude_none=response_model_exclude_none,
|
exclude_none=response_model_exclude_none,
|
||||||
is_coroutine=is_coroutine,
|
is_coroutine=is_coroutine,
|
||||||
)
|
)
|
||||||
response_args: Dict[str, Any] = {"background": background_tasks}
|
response = actual_response_class(content, **response_args)
|
||||||
# If status_code was set, use it, otherwise use the default from the
|
if not is_body_allowed_for_status_code(status_code):
|
||||||
# response class, in the case of redirect it's 307
|
response.body = b""
|
||||||
if status_code is not None:
|
|
||||||
response_args["status_code"] = status_code
|
|
||||||
response = actual_response_class(response_data, **response_args)
|
|
||||||
response.headers.raw.extend(sub_response.headers.raw)
|
response.headers.raw.extend(sub_response.headers.raw)
|
||||||
if sub_response.status_code:
|
|
||||||
response.status_code = sub_response.status_code
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
@ -377,8 +382,8 @@ class APIRoute(routing.Route):
|
||||||
status_code = int(status_code)
|
status_code = int(status_code)
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
if self.response_model:
|
if self.response_model:
|
||||||
assert (
|
assert is_body_allowed_for_status_code(
|
||||||
status_code not in STATUS_CODES_WITH_NO_BODY
|
status_code
|
||||||
), f"Status code {status_code} must not have a response body"
|
), f"Status code {status_code} must not have a response body"
|
||||||
response_name = "Response_" + self.unique_id
|
response_name = "Response_" + self.unique_id
|
||||||
self.response_field = create_response_field(
|
self.response_field = create_response_field(
|
||||||
|
|
@ -410,8 +415,8 @@ class APIRoute(routing.Route):
|
||||||
assert isinstance(response, dict), "An additional response must be a dict"
|
assert isinstance(response, dict), "An additional response must be a dict"
|
||||||
model = response.get("model")
|
model = response.get("model")
|
||||||
if model:
|
if model:
|
||||||
assert (
|
assert is_body_allowed_for_status_code(
|
||||||
additional_status_code not in STATUS_CODES_WITH_NO_BODY
|
additional_status_code
|
||||||
), f"Status code {additional_status_code} must not have a response body"
|
), f"Status code {additional_status_code} must not have a response body"
|
||||||
response_name = f"Response_{additional_status_code}_{self.unique_id}"
|
response_name = f"Response_{additional_status_code}_{self.unique_id}"
|
||||||
response_field = create_response_field(name=response_name, type_=model)
|
response_field = create_response_field(name=response_name, type_=model)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,13 @@ if TYPE_CHECKING: # pragma: nocover
|
||||||
from .routing import APIRoute
|
from .routing import APIRoute
|
||||||
|
|
||||||
|
|
||||||
|
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||||
|
if status_code is None:
|
||||||
|
return True
|
||||||
|
current_status_code = int(status_code)
|
||||||
|
return not (current_status_code < 200 or current_status_code in {204, 304})
|
||||||
|
|
||||||
|
|
||||||
def get_model_definitions(
|
def get_model_definitions(
|
||||||
*,
|
*,
|
||||||
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class JsonApiError(BaseModel):
|
||||||
responses={500: {"description": "Error", "model": JsonApiError}},
|
responses={500: {"description": "Error", "model": JsonApiError}},
|
||||||
)
|
)
|
||||||
async def a():
|
async def a():
|
||||||
pass # pragma: no cover
|
pass
|
||||||
|
|
||||||
|
|
||||||
@app.get("/b", responses={204: {"description": "No Content"}})
|
@app.get("/b", responses={204: {"description": "No Content"}})
|
||||||
|
|
@ -106,3 +106,10 @@ def test_openapi_schema():
|
||||||
response = client.get("/openapi.json")
|
response = client.get("/openapi.json")
|
||||||
assert response.status_code == 200, response.text
|
assert response.status_code == 200, response.text
|
||||||
assert response.json() == openapi_schema
|
assert response.json() == openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_response():
|
||||||
|
response = client.get("/a")
|
||||||
|
assert response.status_code == 204, response.text
|
||||||
|
assert "content-length" not in response.headers
|
||||||
|
assert response.content == b""
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue