mirror of
https://github.com/rjNemo/fastapi
synced 2026-06-11 13:06:43 +00:00
🐛 Check already cloned fields in create_cloned_field to support recursive models (#1164)
* FIX: #894 Include recursion check for create_cloned_field. Added test for recursive model. * ♻️ Refactor and format create_cloned_field() Co-authored-by: Lukas Voegtle <lukas.voegtle@sick.de> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
0d165d1efa
commit
0f152b4e97
2 changed files with 99 additions and 7 deletions
|
|
@ -131,17 +131,26 @@ def create_response_field(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_cloned_field(field: ModelField) -> ModelField:
|
def create_cloned_field(
|
||||||
|
field: ModelField, *, cloned_types: Dict[Type[BaseModel], Type[BaseModel]] = None,
|
||||||
|
) -> ModelField:
|
||||||
|
# _cloned_types has already cloned types, to support recursive models
|
||||||
|
if cloned_types is None:
|
||||||
|
cloned_types = dict()
|
||||||
original_type = field.type_
|
original_type = field.type_
|
||||||
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
||||||
original_type = original_type.__pydantic_model__ # type: ignore
|
original_type = original_type.__pydantic_model__ # type: ignore
|
||||||
use_type = original_type
|
use_type = original_type
|
||||||
if lenient_issubclass(original_type, BaseModel):
|
if lenient_issubclass(original_type, BaseModel):
|
||||||
original_type = cast(Type[BaseModel], original_type)
|
original_type = cast(Type[BaseModel], original_type)
|
||||||
use_type = create_model(original_type.__name__, __base__=original_type)
|
use_type = cloned_types.get(original_type)
|
||||||
for f in original_type.__fields__.values():
|
if use_type is None:
|
||||||
use_type.__fields__[f.name] = create_cloned_field(f)
|
use_type = create_model(original_type.__name__, __base__=original_type)
|
||||||
|
cloned_types[original_type] = use_type
|
||||||
|
for f in original_type.__fields__.values():
|
||||||
|
use_type.__fields__[f.name] = create_cloned_field(
|
||||||
|
f, cloned_types=cloned_types
|
||||||
|
)
|
||||||
new_field = create_response_field(name=field.name, type_=use_type)
|
new_field = create_response_field(name=field.name, type_=use_type)
|
||||||
new_field.has_alias = field.has_alias
|
new_field.has_alias = field.has_alias
|
||||||
new_field.alias = field.alias
|
new_field.alias = field.alias
|
||||||
|
|
@ -157,10 +166,13 @@ def create_cloned_field(field: ModelField) -> ModelField:
|
||||||
new_field.validate_always = field.validate_always
|
new_field.validate_always = field.validate_always
|
||||||
if field.sub_fields:
|
if field.sub_fields:
|
||||||
new_field.sub_fields = [
|
new_field.sub_fields = [
|
||||||
create_cloned_field(sub_field) for sub_field in field.sub_fields
|
create_cloned_field(sub_field, cloned_types=cloned_types)
|
||||||
|
for sub_field in field.sub_fields
|
||||||
]
|
]
|
||||||
if field.key_field:
|
if field.key_field:
|
||||||
new_field.key_field = create_cloned_field(field.key_field)
|
new_field.key_field = create_cloned_field(
|
||||||
|
field.key_field, cloned_types=cloned_types
|
||||||
|
)
|
||||||
new_field.validators = field.validators
|
new_field.validators = field.validators
|
||||||
if PYDANTIC_1:
|
if PYDANTIC_1:
|
||||||
new_field.pre_validators = field.pre_validators
|
new_field.pre_validators = field.pre_validators
|
||||||
|
|
|
||||||
80
tests/test_validate_response_recursive.py
Normal file
80
tests/test_validate_response_recursive.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class RecursiveItem(BaseModel):
|
||||||
|
sub_items: List["RecursiveItem"] = []
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
RecursiveItem.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
class RecursiveSubitemInSubmodel(BaseModel):
|
||||||
|
sub_items2: List["RecursiveItemViaSubmodel"] = []
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class RecursiveItemViaSubmodel(BaseModel):
|
||||||
|
sub_items1: List[RecursiveSubitemInSubmodel] = []
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
RecursiveSubitemInSubmodel.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/recursive", response_model=RecursiveItem)
|
||||||
|
def get_recursive():
|
||||||
|
return {"name": "item", "sub_items": [{"name": "subitem", "sub_items": []}]}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/recursive-submodel", response_model=RecursiveItemViaSubmodel)
|
||||||
|
def get_recursive_submodel():
|
||||||
|
return {
|
||||||
|
"name": "item",
|
||||||
|
"sub_items1": [
|
||||||
|
{
|
||||||
|
"name": "subitem",
|
||||||
|
"sub_items2": [
|
||||||
|
{
|
||||||
|
"name": "subsubitem",
|
||||||
|
"sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive():
|
||||||
|
response = client.get("/items/recursive")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
"sub_items": [{"name": "subitem", "sub_items": []}],
|
||||||
|
"name": "item",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.get("/items/recursive-submodel")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
"name": "item",
|
||||||
|
"sub_items1": [
|
||||||
|
{
|
||||||
|
"name": "subitem",
|
||||||
|
"sub_items2": [
|
||||||
|
{
|
||||||
|
"name": "subsubitem",
|
||||||
|
"sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue