示例#1
0
    def create_schema(
        self,
        model: Type[Model],
        *,
        name: str = "",
        depth: int = 0,
        fields: Optional[List[str]] = None,
        exclude: Optional[List[str]] = None,
    ) -> Type[Schema]:
        name = name or model.__name__

        if fields and exclude:
            raise ConfigError(
                "Only one of 'include' or 'exclude' should be set.")

        key = self.get_key(model, name, depth, fields, exclude)
        if key in self.schemas:
            return self.schemas[key]

        definitions = {}
        for fld in self._selected_model_fields(model, fields, exclude):
            python_type, field_info = get_schema_field(fld, depth=depth)
            definitions[fld.name] = (python_type, field_info)

        schema = cast(
            Type[Schema],
            create_pydantic_model(name, __base__=Schema,
                                  **definitions),  # type: ignore
        )
        self.schemas[key] = schema
        return schema
示例#2
0
文件: utils.py 项目: uvby/bitcart
def get_pagination_model(display_model):
    return create_pydantic_model(
        f"PaginationResponse_{display_model.__name__}",
        count=(int, ...),
        next=(Optional[str], None),
        previous=(Optional[str], None),
        result=(List[display_model], ...),
        __base__=BaseModel,
    )
示例#3
0
 def get_response_models(self) -> Dict[str, Type]:
     display_model = self.pydantic_model if not self.display_model else self.display_model
     pagination_response = create_pydantic_model(
         f"PaginationResponse_{display_model.__name__}",
         count=(int, ...),
         next=(Optional[str], None),
         previous=(Optional[str], None),
         result=(List[display_model], ...),
         __base__=BaseModel,
     )
     return {
         "get": pagination_response,
         "get_count": int,
         "get_one": display_model if self.get_one_model else None,
         "post": display_model,
         "put": display_model,
         "patch": display_model,
         "delete": display_model,
     }
示例#4
0
    def create_schema(
        self,
        model: Type[Model],
        *,
        name: str = "",
        depth: int = 0,
        fields: Optional[List[str]] = None,
        exclude: Optional[List[str]] = None,
        custom_fields: Optional[List[Tuple[str, Any, Any]]] = None,
        base_class: Type[Schema] = Schema,
    ) -> Type[Schema]:
        name = name or model.__name__

        if fields and exclude:
            raise ConfigError("Only one of 'fields' or 'exclude' should be set.")

        key = self.get_key(model, name, depth, fields, exclude, custom_fields)
        if key in self.schemas:
            return self.schemas[key]

        definitions = {}
        for fld in self._selected_model_fields(model, fields, exclude):
            python_type, field_info = get_schema_field(fld, depth=depth)
            definitions[fld.name] = (python_type, field_info)

        if custom_fields:
            for fld_name, python_type, field_info in custom_fields:
                definitions[fld_name] = (python_type, field_info)

        if name in self.schema_names:
            name = self._get_unique_name(name)

        schema: Type[Schema] = create_pydantic_model(
            name,
            __config__=None,
            __base__=base_class,
            __module__=base_class.__module__,
            __validators__={},
            **definitions,
        )  # type: ignore
        self.schemas[key] = schema
        self.schema_names.add(name)
        return schema
示例#5
0
文件: utils.py 项目: uvby/bitcart
def model_view(
    router: APIRouter,
    path: str,
    orm_model,
    pydantic_model,
    get_data_source,
    create_model=None,
    display_model=None,
    allowed_methods: List[str] = ["GET_COUNT", "GET_ONE"] + HTTP_METHODS,
    custom_methods: Dict[str, Callable] = {},
    background_tasks_mapping: Dict[str, Callable] = {},
    request_handlers: Dict[str, Callable] = {},
    auth=True,
    get_one_auth=True,
    post_auth=True,
    get_one_model=True,
    scopes=None,
):
    from . import schemes

    if scopes is None:
        scopes = {i: [] for i in ENDPOINTS}
    crud_models.append((path, orm_model, get_data_source))

    display_model = pydantic_model if not display_model else display_model
    if isinstance(scopes, list):
        scopes_list = scopes.copy()
        scopes = {i: scopes_list for i in ENDPOINTS}
    scopes = defaultdict(list, **scopes)

    PaginationResponse = create_pydantic_model(
        f"PaginationResponse_{display_model.__name__}",
        count=(int, ...),
        next=(Optional[str], None),
        previous=(Optional[str], None),
        result=(List[display_model], ...),
        __base__=BaseModel,
    )

    if not create_model:
        create_model = pydantic_model  # pragma: no cover
    response_models: Dict[str, Type] = {
        "get": PaginationResponse,
        "get_count": int,
        "get_one": display_model if get_one_model else None,
        "post": display_model,
        "put": display_model,
        "patch": display_model,
        "delete": display_model,
    }

    item_path = path_join(path, "{model_id}")
    count_path = path_join(path, "count")
    paths: Dict[str, str] = {
        "get": path,
        "get_count": count_path,
        "get_one": item_path,
        "post": path,
        "put": item_path,
        "patch": item_path,
        "delete": item_path,
    }

    auth_dependency = AuthDependency(auth)

    async def _get_one(model_id: int,
                       user: schemes.User,
                       internal: bool = False):
        if orm_model != models.User:
            query = orm_model.query.select_from(get_data_source())
            if user:
                query = query.where(models.User.id == user.id)
        else:
            query = orm_model.query
        item = await query.where(orm_model.id == model_id).gino.first()
        if custom_methods.get("get_one"):
            item = await custom_methods["get_one"](model_id, user, item,
                                                   internal)
        if not item:
            raise HTTPException(
                status_code=404,
                detail=f"Object with id {model_id} does not exist!")
        return item

    async def get(
        pagination: pagination.Pagination = Depends(),
        user: Union[None, schemes.User] = Security(auth_dependency,
                                                   scopes=scopes["get_all"]),
    ):
        if custom_methods.get("get"):
            return await custom_methods["get"](pagination, user,
                                               get_data_source())
        else:
            return await pagination.paginate(orm_model, get_data_source(),
                                             user.id)

    async def get_count(user: Union[None, schemes.User] = Security(
        auth_dependency, scopes=scopes["get_count"])):
        return (await ((orm_model.query.select_from(get_data_source()).where(
            models.User.id == user.id) if orm_model != models.User else
                        orm_model.query).with_only_columns([
                            db.db.func.count(distinct(orm_model.id))
                        ]).order_by(None).gino.scalar()) or 0)

    async def get_one(model_id: int, request: Request):
        try:
            user = await auth_dependency(request,
                                         SecurityScopes(scopes["get_one"]))
        except HTTPException:
            if get_one_auth:
                raise
            user = None
        return await _get_one(model_id, user)

    async def post(
        model: create_model,  # type: ignore,
        request: Request,
    ):
        try:
            user = await auth_dependency(request,
                                         SecurityScopes(scopes["post"]))
        except HTTPException:
            if post_auth:
                raise
            user = None
        try:
            if custom_methods.get("post"):
                obj = await custom_methods["post"](model, user)
            else:
                obj = await orm_model.create(**model.dict())  # type: ignore
        except (
                asyncpg.exceptions.UniqueViolationError,
                asyncpg.exceptions.NotNullViolationError,
                asyncpg.exceptions.ForeignKeyViolationError,
        ) as e:
            raise HTTPException(422, e.message)
        if background_tasks_mapping.get("post"):
            background_tasks_mapping["post"].send(obj.id)
        return obj

    async def put(
        model_id: int,
        model: pydantic_model,
        user: Union[None, schemes.User] = Security(auth_dependency,
                                                   scopes=scopes["put"]),
    ):  # type: ignore
        item = await _get_one(model_id, user, True)
        try:
            if custom_methods.get("put"):
                await custom_methods["put"](item, model,
                                            user)  # pragma: no cover
            else:
                await item.update(**model.dict()).apply()  # type: ignore
        except (
                asyncpg.exceptions.UniqueViolationError,
                asyncpg.exceptions.NotNullViolationError,
                asyncpg.exceptions.ForeignKeyViolationError,
        ) as e:
            raise HTTPException(422, e.message)
        return item

    async def patch(
        model_id: int,
        model: pydantic_model,
        user: Union[None, schemes.User] = Security(auth_dependency,
                                                   scopes=scopes["patch"]),
    ):  # type: ignore
        item = await _get_one(model_id, user, True)
        try:
            if custom_methods.get("patch"):
                await custom_methods["patch"](item, model,
                                              user)  # pragma: no cover
            else:
                await item.update(
                    **model.dict(exclude_unset=True)  # type: ignore
                ).apply()
        except (  # pragma: no cover
                asyncpg.exceptions.UniqueViolationError,
                asyncpg.exceptions.NotNullViolationError,
                asyncpg.exceptions.ForeignKeyViolationError,
        ) as e:
            raise HTTPException(422, e.message)  # pragma: no cover
        return item

    async def delete(
        model_id: int,
        user: Union[None, schemes.User] = Security(auth_dependency,
                                                   scopes=scopes["delete"]),
    ):
        item = await _get_one(model_id, user, True)
        if custom_methods.get("delete"):
            await custom_methods["delete"](item, user)
        else:
            await item.delete()
        return item

    for method in allowed_methods:
        method_name = method.lower()
        router.add_api_route(
            paths.get(method_name),  # type: ignore
            request_handlers.get(method_name) or locals()[method_name],
            methods=[method_name if method in HTTP_METHODS else "get"],
            response_model=response_models.get(method_name),
        )