Beispiel #1
0
 def inject_model(self):
     return pydantic.create_model("JSONBlob",
                                  **{self.attribute: (dict, ...)})
Beispiel #2
0
from pydantic import ValidationError, create_model
from pydantic.error_wrappers import ErrorList
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.requests import Request
from starlette.websockets import WebSocket


class HTTPException(StarletteHTTPException):
    def __init__(
        self, status_code: int, detail: Any = None, headers: dict = None
    ) -> None:
        super().__init__(status_code=status_code, detail=detail)
        self.headers = headers


RequestErrorModel = create_model("Request")
WebSocketErrorModel = create_model("WebSocket")


class RequestValidationError(ValidationError):
    def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
        self.body = body
        if PYDANTIC_1:
            super().__init__(errors, RequestErrorModel)
        else:
            super().__init__(errors, Request)  # type: ignore  # pragma: nocover


class WebSocketRequestValidationError(ValidationError):
    def __init__(self, errors: Sequence[ErrorList]) -> None:
        if PYDANTIC_1:
    validate_assignment = True


@dataclass(config=Config)
class AddProject:
    name: str
    slug: Optional[str]
    description: Optional[str]


p = AddProject(name='x', slug='y', description='z')


class TypeAliasAsAttribute(BaseModel):
    __type_alias_attribute__ = Union[str, bytes]


class NestedModel(BaseModel):
    class Model(BaseModel):
        id: str

    model: Model


_ = NestedModel.Model

DynamicModel = create_model('DynamicModel', __base__=Model)

dynamic_model = DynamicModel(x=1, y='y')
dynamic_model.x = 2
        log.debug(f"Found queue with ID {queue_id}")

    queue = _get_queue_session(q_name, q_path)
    queue.put(payload)
    log.info(f"Added payload to queue with queue ID {queue_id}")
    log.debug(f"Payload values: {payload}")
    return queue.size


@router.post(
    "/{queue_id}/data",
    response_model=create_model(
        "GenericDataProxyIngestResponse",
        status=(
            str,
            ...,
        ),
        sh256_hash=(str, ...),
        position_in_queue=(int, ...),
    ),
    description=
    ("Saves the request's body and returns the hash of the body for future content verification. "
     "Also returns the position in the queue for synchronous calls (-1 for async calls)."
     ),
)
async def send_data(
        background_tasks: BackgroundTasks,
        body: Any = Body(...),
        queue_id: str = Path(
            ...,
            title="Queue ID",
def test_field_wrong_tuple():
    with pytest.raises(errors.ConfigError):
        create_model('FooModel', foo=(1, 2, 3))
def create_cloned_field(field: ModelField) -> ModelField:
    original_type = field.type_
    if is_dataclass(original_type) and hasattr(original_type,
                                               "__pydantic_model__"):
        original_type = original_type.__pydantic_model__  # type: ignore
    use_type = original_type
    if lenient_issubclass(original_type, BaseModel):
        original_type = cast(Type[BaseModel], original_type)
        use_type = create_model(original_type.__name__,
                                __config__=original_type.__config__)
        for f in original_type.__fields__.values():
            use_type.__fields__[f.name] = f
        use_type.__validators__ = original_type.__validators__
    if PYDANTIC_1:
        new_field = ModelField(
            name=field.name,
            type_=use_type,
            class_validators={},
            default=None,
            required=False,
            model_config=BaseConfig,
            field_info=FieldInfo(None),
        )
    else:  # pragma: nocover
        new_field = ModelField(  # type: ignore
            name=field.name,
            type_=use_type,
            class_validators={},
            default=None,
            required=False,
            model_config=BaseConfig,
            schema=FieldInfo(None),
        )
    new_field.has_alias = field.has_alias
    new_field.alias = field.alias
    new_field.class_validators = field.class_validators
    new_field.default = field.default
    new_field.required = field.required
    new_field.model_config = field.model_config
    if PYDANTIC_1:
        new_field.field_info = field.field_info
    else:  # pragma: nocover
        new_field.schema = field.schema  # type: ignore
    new_field.allow_none = field.allow_none
    new_field.validate_always = field.validate_always
    if field.sub_fields:
        new_field.sub_fields = [
            create_cloned_field(sub_field) for sub_field in field.sub_fields
        ]
    if field.key_field:
        new_field.key_field = create_cloned_field(field.key_field)
    new_field.validators = field.validators
    if PYDANTIC_1:
        new_field.pre_validators = field.pre_validators
        new_field.post_validators = field.post_validators
    else:  # pragma: nocover
        new_field.whole_pre_validators = field.whole_pre_validators  # type: ignore
        new_field.whole_post_validators = field.whole_post_validators  # type: ignore
    new_field.parse_json = field.parse_json
    new_field.shape = field.shape
    try:
        new_field.populate_validators()
    except AttributeError:  # pragma: nocover
        # TODO: remove when removing support for Pydantic < 1.0.0
        new_field._populate_validators()  # type: ignore
    return new_field
from pydantic import BaseModel, create_model


class FooModel(BaseModel):
    foo: str
    bar: int = 123


BarModel = create_model('BarModel', apple='russet', banana='yellow', __base__=FooModel)
print(BarModel)
#> <class 'pydantic.main.BarModel'>
print(', '.join(BarModel.__fields__.keys()))
#> foo, bar, apple, banana
Beispiel #8
0
    def register_resource(self, resource: Resource) -> None:
        print(f"Registering Resource: {resource.name}")
        # process resource_settings
        # add name to api
        router = APIRouter()

        # check models for data relations
        resource.schema = self._embed_data_relation(resource.schema)
        resource.schema.__config__.extra = (  # type: ignore
            "forbid"  # type: ignore # TODO: this should be on the InSchema
        )
        resource.response_model = self._embed_data_relation(
            resource.response_model, response=True)

        response_model = self._prepare_response_model(resource.response_model,
                                                      resource.name)

        Response = create_model(
            f"ResponseSchema_{resource.name}",
            data=(List[response_model], Field(...,
                                              alias="_data")),  # type: ignore
            __base__=BaseResponseSchema,
        )

        PostResponse = create_model(
            f"PostResponseSchema_{resource.name}",
            data=(List[response_model], Field(...,
                                              alias="_data")),  # type: ignore
            __base__=BaseSchema,  # No meta or links
        )

        for method in resource.resource_methods:
            if method == "POST":
                router.add_api_route(
                    f"/{resource.name}",
                    endpoint=collections_endpoint_factory(resource, method),
                    response_model=PostResponse,
                    response_model_exclude_unset=True,
                    methods=[method],
                    status_code=201,
                )
            elif method == "DELETE":
                router.add_api_route(
                    f"/{resource.name}",
                    endpoint=collections_endpoint_factory(resource, method),
                    methods=[method],
                    status_code=204,
                )
            else:
                router.add_api_route(
                    f"/{resource.name}",
                    endpoint=collections_endpoint_factory(resource, method),
                    response_model=Response,
                    response_model_exclude_unset=True,
                    methods=[method],
                )

        ItemResponse = create_model(
            f"ItemResponseSchema_{resource.name}",
            data=(List[response_model], Field(...,
                                              alias="_data")),  # type: ignore
            __base__=ItemBaseResponseSchema,
        )

        for method in resource.item_methods:
            if method in ["PUT", "DELETE", "PATCH"]:
                router.add_api_route(
                    f"/{resource.name}/{{{str(resource.item_name) + '_id'}}}",
                    endpoint=item_endpoint_factory(resource, method),
                    methods=[method],
                    status_code=204,
                )
            else:
                router.add_api_route(
                    f"/{resource.name}/{{{str(resource.item_name) + '_id'}}}",
                    endpoint=item_endpoint_factory(resource, method),
                    response_model=ItemResponse,
                    response_model_exclude_unset=True,
                    methods=[method],
                )

        for sub_resource in resource.sub_resources:
            Response = create_model(
                f"ResponseSchema_{resource.name}_sub_resource_{sub_resource.name}",
                data=(List[sub_resource.resource.response_model],
                      Field(..., alias="_data")),  # type: ignore
                __base__=BaseResponseSchema,
            )

            print(f"Registering Sub Resource: {sub_resource.name}")
            router.add_api_route(
                f"/{resource.name}/{{{str(resource.item_name) + '_id'}}}/{sub_resource.name}",
                endpoint=subresource_endpoint_factory(resource, "GET",
                                                      sub_resource),
                response_model=Response,
                response_model_exclude_unset=True,
                methods=["GET"],
            )

        # TODO: api versioning
        self.include_router(
            router,
            tags=[str(resource.name)],
        )
Beispiel #9
0
def _parse_and_bound_params(handler: HTTPHandler) -> HTTPHandler:
    sig = signature(handler)

    __parameters__ = {}
    __exclusive_models__ = {}
    path: Dict[str, Any] = {}
    query: Dict[str, Any] = {}
    header: Dict[str, Any] = {}
    cookie: Dict[str, Any] = {}
    body: Dict[str, Any] = {}

    for name, param in sig.parameters.items():
        default = param.default
        annotation = param.annotation

        if isinstance(default, QueryInfo):
            _type_ = query
        elif isinstance(default, HeaderInfo):
            _type_ = header
        elif isinstance(default, CookieInfo):
            _type_ = cookie
        elif isinstance(default, BodyInfo):
            _type_ = body
        elif isinstance(default, PathInfo):
            _type_ = path
        elif isinstance(default, ExclusiveInfo):
            if isclass(annotation) and issubclass(annotation, BaseModel):
                model = annotation
            else:
                model = create_model(
                    "temporary_exclusive_model",
                    __config__=create_model_config(default.title,
                                                   default.description),
                    __root__=(annotation, ...),
                )
            __parameters__[default.name] = model
            __exclusive_models__[model] = name
            continue
        else:
            continue

        if annotation != param.empty:
            _type_[name] = (annotation, default)
        else:
            _type_[name] = default

    __locals__ = locals()
    for key in filter(lambda key: bool(__locals__[key]),
                      ("path", "query", "header", "cookie", "body")):
        if key in __parameters__:
            raise RuntimeError(
                f'Exclusive("{key}") and {key.capitalize()} cannot be used at the same time'
            )
        __parameters__[key] = create_model("temporary_model",
                                           **locals()[key])  # type: ignore

    if "body" in __parameters__:
        setattr(handler, "__request_body__", __parameters__.pop("body"))

    if __parameters__:
        setattr(handler, "__parameters__", __parameters__)

    if __exclusive_models__:
        setattr(handler, "__exclusive_models__", __exclusive_models__)

    return handler
Beispiel #10
0
 def generate_metric_model(cls, metric_list: Iterable[str]) -> CatalogModel:
     metrics_obj_model = create_model("MetricObjModel", **{f: (str, None) for f in metric_list}, __base__=CatalogModel)
     return create_model("MetricsModel", metric=(metrics_obj_model, None), __base__=cls)
Beispiel #11
0
 def add_fields(cls, **fields):
     return create_model('NewModel', __base__=cls, **fields)
Beispiel #12
0
def get_models_from_rpc_methods(methods: Dict[str, Callable]) -> TypeModelSet:
    clean_models: List[Type[BaseModel]] = [
        create_model(
            "Health",
            is_sick=(bool, False),
            checks=(Dict[str, str], Field(..., example={"srv": "ok"})),
            version=(str, Field(..., example="1.0.0")),
            start_time=(datetime, ...),
            up_time=(str, Field(..., example="0:00:12.850850")),
        )
    ]

    for method, func in methods.items():
        sig = inspect.signature(func)

        method_name = getattr(func, "__rpc_name__", method)
        request_params_model = getattr(func, "__rpc_request_model__", None)
        response_result_model = getattr(func, "__rpc_response_model__", None)

        camel_method_name = snake_to_camel(method_name)

        request_model_name = f"{camel_method_name}Request"
        request_params_model_name = f"{camel_method_name}RequestParams"
        response_model_name = f"{camel_method_name}Response"
        response_result_model_name = f"{camel_method_name}ResponseResult"

        RequestParamsModel = request_params_model or create_model(
            request_params_model_name,
            **get_field_definitions(sig.parameters),
        )

        if getattr(RequestParamsModel, "__name__", "") == request_model_name:
            fix_model_name(RequestParamsModel, request_params_model_name)

        RequestModel: Type[BaseModel] = create_model(
            request_model_name,
            method=(str, Field(..., example=method_name)),
            params=(RequestParamsModel, ...),
        )

        response: Dict[str, Any] = dict(
            code=(int, Field(..., example=0)),
            message=(str, Field(..., example="OK")),
        )

        ResponseResultModel = response_result_model or sig.return_annotation
        if ResponseResultModel is not None:
            if (getattr(ResponseResultModel, "__name__",
                        "") == response_model_name):
                fix_model_name(ResponseResultModel, response_result_model_name)

            response["result"] = (ResponseResultModel, None)

        ResponseModel: Type[BaseModel] = create_model(response_model_name,
                                                      **response)

        clean_models.extend([RequestParamsModel, RequestModel, ResponseModel])

    flat_models = get_flat_models_from_models(clean_models)

    return flat_models
Beispiel #13
0
 def inject_update_model(self):
     return pydantic.create_model(
         f"{self.schema.__name__}_MutableJSONBlob",
         **{self.attribute: (mutable(self.schema), ...)})
Beispiel #14
0
 def inject_model(self):
     return pydantic.create_model(f"{self.schema.__name__}_JSONBlob",
                                  **{self.attribute: (self.schema, ...)})
Beispiel #15
0
def create_fast_models(odoo_model_name, abstract=False):
    '''
    Creates a database model and pydantic model dynamically by reading field details from odoo table ir_model_fields
    :param abstract:
    :param aims_model_name: aims model name like 'srcm.abhyasi'
    :return: dynamially created model
    '''
    fast_model_name = get_fast_model_name(odoo_model_name) + ('Abstract' if
                                                              abstract else '')
    # read the model fields from odoo fields metadata
    model_fields = crud.read_aims_model_fields(odoo_model_name)

    mail_message_fields_to_ignore = [
        'message_follower_ids', 'message_ids', 'message_main_attachment_id',
        'website_message_ids'
    ]

    # initialize fastapi model fields
    db_fields = {'__tablename__': get_odoo_table_name(odoo_model_name)}
    pydantic_fields = {}
    readonly_fields = []

    for model_field in model_fields:
        # ignore odoo computed fields and mail_message fields
        if not model_field.store or model_field.name in mail_message_fields_to_ignore:
            continue

        if model_field.readonly:
            readonly_fields.append(model_field.name)

        # TODO check if we can use same for text and char
        if model_field.ttype in ['char', 'text']:
            db_fields[model_field.name] = Column(
                String, nullable=(not model_field.required))
            if model_field.required:
                pydantic_fields[model_field.name] = str
            else:
                pydantic_fields[model_field.name] = (Optional[str], None)

        if model_field.ttype == 'integer':
            if model_field.name == 'id':
                db_fields[model_field.name] = Column(Integer, primary_key=True)
            else:
                db_fields[model_field.name] = Column(
                    Integer, nullable=(not model_field.required))
            pydantic_fields[model_field.name] = (int, 0)

        if model_field.ttype == 'boolean':
            db_fields[model_field.name] = Column(
                Boolean, nullable=(not model_field.required))
            if model_field.required:
                pydantic_fields[model_field.name] = bool
            else:
                pydantic_fields[model_field.name] = (Optional[bool], None)

        if model_field.ttype == 'datetime':
            db_fields[model_field.name] = Column(
                DateTime, nullable=(not model_field.required))
            if model_field.required:
                pydantic_fields[model_field.name] = datetime
            else:
                pydantic_fields[model_field.name] = (Optional[datetime], None)

        if model_field.ttype == 'date':
            db_fields[model_field.name] = Column(
                Date, nullable=(not model_field.required))
            if model_field.required:
                pydantic_fields[model_field.name] = date
            else:
                pydantic_fields[model_field.name] = (Optional[date], None)

        if model_field.ttype == 'many2one':
            if model_field.name in [
                    'create_uid', 'write_uid', 'address_view_id'
            ]:
                continue

            rel_model_name = get_fast_model_name(model_field.relation)
            if rel_model_name not in pydantic_models.keys():
                logger.warning(
                    f'skipping {odoo_model_name}.{model_field.name} as {rel_model_name} is not configured'
                )
                continue

            rel_table_name = get_odoo_table_name(model_field.relation)
            rel_table_col_name = f'{rel_table_name}.id'

            if model_field.name.endswith('_id'):
                id_field_name = model_field.name
                obj_field_name = model_field.name.replace('_id', '')
            else:
                id_field_name = model_field.name + '_id'
                obj_field_name = model_field.name

            db_fields[id_field_name] = Column(
                Integer,
                ForeignKey(rel_table_col_name),
                nullable=(not model_field.required))
            db_fields[obj_field_name] = relationship(
                rel_model_name + 'Db', foreign_keys=[db_fields[id_field_name]])

            if model_field.required:
                pydantic_fields[id_field_name] = (int, 0)
                pydantic_fields[obj_field_name] = (
                    pydantic_models[rel_model_name], None)
            else:
                pydantic_fields[id_field_name] = (Optional[int], 0)
                pydantic_fields[obj_field_name] = (
                    Optional[pydantic_models[rel_model_name]], None)

        # if model_field.ttype == 'many2many':
        #     to_model = get_django_model_name(model_field.relation)

        #     # Create a through model which contains the mapping
        #     rel_model = create_aims_manytomany_rel_model(model_field)

        #     # related name for reverse relationships. Needed if the models have more than one relation
        #     related_name = AIMS_RELATED_NAMES.get((model_field.model, model_field.name))

        #     # add the field
        #     fields[model_field.name] = models.ManyToManyField(to=to_model, through=rel_model, related_name=related_name,
        #                                                       null=(not model_field.required),
        #                                                       through_fields=(model_field.column1, model_field.column2))

    # fields['readonly_fields'] = readonly_fields

    db_model = type(fast_model_name + 'Db', (Base, ), db_fields)

    pydantic_model = pydantic.create_model(fast_model_name + 'Pyd',
                                           **pydantic_fields)
    pydantic_model.Config.orm_mode = True

    # add to cache
    db_models[fast_model_name] = db_model
    pydantic_models[fast_model_name] = pydantic_model
    return db_model, pydantic_model
Beispiel #16
0
from pydantic import BaseModel, ValidationError, create_model
from pydantic.error_wrappers import ErrorList
from starlette.exceptions import HTTPException as StarletteHTTPException


class HTTPException(StarletteHTTPException):
    def __init__(
        self,
        status_code: int,
        detail: Any = None,
        headers: Optional[Dict[str, Any]] = None,
    ) -> None:
        super().__init__(status_code=status_code, detail=detail, headers=headers)


RequestErrorModel: Type[BaseModel] = create_model("Request")
WebSocketErrorModel: Type[BaseModel] = create_model("WebSocket")


class FastAPIError(RuntimeError):
    """
    A generic, FastAPI-specific error.
    """


class RequestValidationError(ValidationError):
    def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
        self.body = body
        super().__init__(errors, RequestErrorModel)

Beispiel #17
0
@pytest.mark.parametrize('input,output', [
    (UUID('ebcdab58-6eb8-46fb-a190-d07a33e9eac8'),
     '"ebcdab58-6eb8-46fb-a190-d07a33e9eac8"'),
    (datetime.datetime(2032, 1, 1, 1, 1), '"2032-01-01T01:01:00"'),
    (datetime.datetime(2032, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
     '"2032-01-01T01:01:00+00:00"'),
    (datetime.datetime(2032, 1, 1), '"2032-01-01T00:00:00"'),
    (datetime.time(12, 34, 56), '"12:34:56"'),
    (datetime.timedelta(days=12, seconds=34,
                        microseconds=56), '1036834.000056'),
    ({1, 2, 3}, '[1, 2, 3]'),
    (frozenset([1, 2, 3]), '[1, 2, 3]'),
    ((v for v in range(4)), '[0, 1, 2, 3]'),
    (b'this is bytes', '"this is bytes"'),
    (Decimal('12.34'), '12.34'),
    (create_model('BarModel', a='b', c='d')(), '{"a": "b", "c": "d"}'),
    (MyEnum.foo, '"bar"'),
])
def test_encoding(input, output):
    assert output == json.dumps(input, default=pydantic_encoder)


def test_model_encoding():
    class ModelA(BaseModel):
        x: int
        y: str

    class Model(BaseModel):
        a: float
        b: bytes
        c: Decimal
Beispiel #18
0
async def open_fsp_sidepane(
    linked: LinkedSplits,
    conf: dict[str, dict[str, str]],

) -> FieldsForm:

    schema = {}

    assert len(conf) == 1  # for now

    # add (single) selection widget
    for name, config in conf.items():
        schema[name] = {
                'label': '**fsp**:',
                'type': 'select',
                'default_value': [name],
            }

        # add parameters for selection "options"
        params = config.get('params', {})
        for name, config in params.items():

            default = config['default_value']
            kwargs = config.get('widget_kwargs', {})

            # add to ORM schema
            schema.update({
                name: {
                    'label': f'**{name}**:',
                    'type': 'edit',
                    'default_value': default,
                    'kwargs': kwargs,
                },
            })

    sidepane: FieldsForm = mk_form(
        parent=linked.godwidget,
        fields_schema=schema,
    )

    # https://pydantic-docs.helpmanual.io/usage/models/#dynamic-model-creation
    FspConfig = create_model(
        'FspConfig',
        name=name,
        **params,
    )
    sidepane.model = FspConfig()

    # just a logger for now until we get fsp configs up and running.
    async def settings_change(
        key: str,
        value: str

    ) -> bool:
        print(f'{key}: {value}')
        return True

    # TODO:
    async with (
        open_form_input_handling(
            sidepane,
            focus_next=linked.godwidget,
            on_value_change=settings_change,
        )
    ):
        yield sidepane
Beispiel #19
0
def get_model(name, *args, **kwargs):
    model = model_map.get(name)
    if model is None:
        model = create_model(name, *args, **kwargs)
        model_map[name] = model
    return model
Beispiel #20
0
    _ignore_ = "MotorName _current_axis"
    MotorName = vars()
    for _current_axis in types.Axis:
        MotorName[_current_axis.name] = _current_axis.name.lower()


class EngagedMotor(BaseModel):
    """Engaged motor"""
    enabled: bool = Field(..., description="Is engine enabled")


# Dynamically create the Engaged motors. It has one EngagedMotor per MotorName
EngagedMotors = create_model(
    "EngagedMotors",
    __config__=None,
    __base__=None,
    __module__=None,
    __validators__=None,
    **{motor.value: (EngagedMotor, ...)
       for motor in MotorName})
EngagedMotors.__doc__ = "Which motors are engaged"


class Axes(BaseModel):
    """A list of motor axes to disengage"""
    axes: typing.List[MotorName]

    @validator('axes', pre=True)
    def lower_case_motor_name(cls, v):
        return [m.lower() for m in v]
Beispiel #21
0
from pydantic import create_model, ValidationError, validator


def username_alphanumeric(cls, v):
    assert v.isalnum(), "must be alphanumeric"
    return v


validators = {
    "username_validator": validator("username")(username_alphanumeric)
}

UserModel = create_model(
    "UserModel",
    username=(str, ...),
    __validators__=validators,
)

user = UserModel(username="******")
print(user)

try:
    UserModel(username="******")
except ValidationError as e:
    print(e)
Beispiel #22
0
    def create_stripped_model_type(
        cls,
        stripped_fields: Optional[List[str]] = None,
        stripped_fields_aliases: Optional[List[str]] = None
    ) -> Type['OscalBaseModel']:
        """Create a pydantic model, which is derived from the current model, but missing certain fields.

        OSCAL mandates a 'strict' schema (e.g. unless otherwise stated no additional fields), and certain fields
        are mandatory. Given this the corresponding dataclasses are also strict. Workflows with trestle require missing
        mandatory fields. This allows creation of derivative models missing certain fields.

        Args:
            stripped_fields: The fields to be removed from the current data class.
            stripped_fields_aliases: The fields to be removed from the current data class provided by alias.

        Returns:
            Pydantic data class thta can be used to instanciate a model.

        Raises:
            TrestleError: If user provided both stripped_fields and stripped_field_aliases or neither.
            TrestleError: If incorrect aliases or field names are provided.
        """
        if stripped_fields is not None and stripped_fields_aliases is not None:
            raise err.TrestleError(
                'Either "stripped_fields" or "stripped_fields_aliases" need to be passed, not both.'
            )
        if stripped_fields is None and stripped_fields_aliases is None:
            raise err.TrestleError(
                'Exactly one of "stripped_fields" or "stripped_fields_aliases" must be provided'
            )

        # create alias to field_name mapping
        excluded_fields = []
        if stripped_fields is not None:
            excluded_fields = stripped_fields
        elif stripped_fields_aliases is not None:
            alias_to_field = cls.alias_to_field_map()
            try:
                excluded_fields = [
                    alias_to_field[key].name for key in stripped_fields_aliases
                ]
            except KeyError as e:
                raise err.TrestleError(
                    f'Field {str(e)} does not exist in the model')

        current_fields = cls.__fields__
        new_fields_for_model = {}
        # Build field list
        for current_mfield in current_fields.values():
            if current_mfield.name in excluded_fields:
                continue
            # Validate name in the field
            # Cehcke behaviour with an alias
            if current_mfield.required:
                new_fields_for_model[current_mfield.name] = (
                    current_mfield.outer_type_,
                    Field(...,
                          title=current_mfield.name,
                          alias=current_mfield.alias))
            else:
                new_fields_for_model[current_mfield.name] = (
                    Optional[current_mfield.outer_type_],
                    Field(None,
                          title=current_mfield.name,
                          alias=current_mfield.alias))
        new_model = create_model(cls.__name__,
                                 __base__=OscalBaseModel,
                                 **new_fields_for_model)  # type: ignore
        # TODO: This typing cast should NOT be necessary. Potentially fixable with a fix to pydantic. Issue #175
        new_model = cast(Type[OscalBaseModel], new_model)

        return new_model
def test_invalid_name():
    with pytest.warns(RuntimeWarning):
        model = create_model('FooModel', _foo=(str, ...))
    assert len(model.__fields__) == 0
Beispiel #24
0
def protobuf_to_pydantic_model(
    protobuf_model: Union[Descriptor,
                          GeneratedProtocolMessageType]) -> BaseModel:
    """
    Converts Protobuf messages to Pydantic model for jsonschema creation/validattion

    ..note:: Model gets assigned in the global Namespace :data:PROTO_TO_PYDANTIC_MODELS

    :param protobuf_model: message from jina.proto file
    :type protobuf_model: Union[Descriptor, GeneratedProtocolMessageType]
    :return: Pydantic model
    """

    all_fields = {}
    camel_case_fields = {}  # {"random_string": {"alias": "randomString"}}
    oneof_fields = defaultdict(list)
    oneof_field_validators = {}

    if isinstance(protobuf_model, Descriptor):
        model_name = protobuf_model.name
        protobuf_fields = protobuf_model.fields
    elif isinstance(protobuf_model, GeneratedProtocolMessageType):
        model_name = protobuf_model.DESCRIPTOR.name
        protobuf_fields = protobuf_model.DESCRIPTOR.fields

    if model_name in vars(PROTO_TO_PYDANTIC_MODELS):
        return PROTO_TO_PYDANTIC_MODELS.__getattribute__(model_name)

    for f in protobuf_fields:
        field_name = f.name
        camel_case_fields[field_name] = {'alias': f.camelcase_name}

        field_type = PROTOBUF_TO_PYTHON_TYPE[f.type]
        default_value = f.default_value
        default_factory = None

        if f.containing_oneof:
            # Proto Field type: oneof
            # NOTE: oneof fields are handled as a post-processing step
            oneof_fields[f.containing_oneof.name].append(field_name)

        if field_type is Enum:
            # Proto Field Type: enum
            enum_dict = {}
            for enum_field in f.enum_type.values:
                enum_dict[enum_field.name] = enum_field.number
            field_type = Enum(f.enum_type.name, enum_dict)

        if f.message_type:
            if f.message_type.name == 'Struct':
                # Proto Field Type: google.protobuf.Struct
                field_type = Dict
                default_factory = dict
            elif f.message_type.name == 'Timestamp':
                # Proto Field Type: google.protobuf.Timestamp
                field_type = datetime
                default_factory = datetime.now
            else:
                # Proto field type: Proto message defined in jina.proto
                if f.message_type.name == model_name:
                    # Self-referencing models
                    field_type = model_name
                else:
                    # This field_type itself a Pydantic model
                    field_type = protobuf_to_pydantic_model(f.message_type)
                    PROTO_TO_PYDANTIC_MODELS.model_name = field_type

        if f.label == FieldDescriptor.LABEL_REPEATED:
            field_type = List[field_type]

        all_fields[field_name] = (
            field_type,
            Field(default_factory=default_factory)
            if default_factory else Field(default=default_value),
        )

        # some fixes on Doc.scores and Doc.evaluations
        if field_name in ('scores', 'evaluations'):
            all_fields[field_name] = (
                Dict[str, PROTO_TO_PYDANTIC_MODELS.NamedScoreProto],
                Field(default={}),
            )

    # Post-processing (Handle oneof fields)
    for oneof_k, oneof_v_list in oneof_fields.items():
        oneof_field_validators[
            f'oneof_validator_{oneof_k}'] = _get_oneof_validator(
                oneof_fields=oneof_v_list, oneof_key=oneof_k)
        oneof_field_validators[f'oneof_setter_{oneof_k}'] = _get_oneof_setter(
            oneof_fields=oneof_v_list, oneof_key=oneof_k)

    if model_name == 'DocumentProto':
        oneof_field_validators['tags_validator'] = _get_tags_updater()

    CustomConfig.fields = camel_case_fields
    model = create_model(
        model_name,
        **all_fields,
        __config__=CustomConfig,
        __validators__=oneof_field_validators,
    )
    model.update_forward_refs()
    PROTO_TO_PYDANTIC_MODELS.__setattr__(model_name, model)
    return model
def test_config_and_base():
    with pytest.raises(errors.ConfigError):
        create_model('FooModel',
                     __config__=BaseModel.Config,
                     __base__=BaseModel)
Beispiel #26
0
    def __init__(
        self,
        store: Store,
        model: Type[BaseModel],
        post_query_operators: List[QueryOperator],
        get_query_operators: List[QueryOperator],
        tags: Optional[List[str]] = None,
        include_in_schema: Optional[bool] = True,
        duplicate_fields_check: Optional[List[str]] = None,
        enable_default_search: Optional[bool] = True,
        state_enum: Optional[Enum] = None,
        default_state: Optional[Any] = None,
        calculate_submission_id: Optional[bool] = False,
        get_sub_path: Optional[str] = "/",
        post_sub_path: Optional[str] = "/",
    ):
        """
        Args:
            store: The Maggma Store to get data from
            model: The pydantic model this resource represents
            tags: List of tags for the Endpoint
            post_query_operators: Operators for the query language for post data
            get_query_operators: Operators for the query language for get data
            include_in_schema: Whether to include the submission resource in the documented schema
            duplicate_fields_check: Fields in model used to check for duplicates for POST data
            enable_default_search: Enable default endpoint search behavior.
            state_enum: State Enum defining possible data states
            default_state: Default state value in provided state Enum
            calculate_submission_id: Whether to calculate and use a submission ID as primary data key.
                If False, the store key is used instead.
            get_sub_path: GET sub-URL path for the resource.
            post_sub_path: POST sub-URL path for the resource.
        """

        if isinstance(state_enum, Enum) and default_state not in [
                entry.value for entry in state_enum  # type: ignore
        ]:
            raise RuntimeError(
                "If data is stateful a state enum and valid default value must be provided"
            )

        self.state_enum = state_enum
        self.default_state = default_state
        self.store = store
        self.tags = tags or []
        self.post_query_operators = post_query_operators
        self.get_query_operators = (
            [op
             for op in get_query_operators if op is not None]  # type: ignore
            + [SubmissionQuery(state_enum)]
            if state_enum is not None else get_query_operators)
        self.include_in_schema = include_in_schema
        self.duplicate_fields_check = duplicate_fields_check
        self.enable_default_search = enable_default_search
        self.calculate_submission_id = calculate_submission_id
        self.get_sub_path = get_sub_path
        self.post_sub_path = post_sub_path

        new_fields = {}  # type: dict
        if self.calculate_submission_id:
            new_fields["submission_id"] = (
                str,
                Field(..., description="Unique submission ID"),
            )

        if state_enum is not None:
            new_fields["state"] = (
                List[state_enum],  # type: ignore
                Field(..., description="List of data status descriptions"),
            )

            new_fields["updated"] = (
                List[datetime],
                Field(..., description="List of status update datetimes"),
            )

        if new_fields:
            model = create_model(model.__name__, __base__=model, **new_fields)

        self.response_model = Response[model]  # type: ignore

        super().__init__(model)
Beispiel #27
0
def validate_type(fq_type_name: "string",
                  value: "any",
                  validation_parameters: "dict" = None) -> "bool":
    """
        Check whether `value` satisfies the constraints of type `fq_type_name`. When the given type (fq_type_name)
        requires validation_parameters, they can be provided using the optional `validation_parameters` argument.

        The following types require validation_parameters:

            * pydantic.condecimal:
                gt: Decimal = None
                ge: Decimal = None
                lt: Decimal = None
                le: Decimal = None
                max_digits: int = None
                decimal_places: int = None
                multiple_of: Decimal = None
            * pydantic.confloat and pydantic.conint:
                gt: float = None
                ge: float = None
                lt: float = None
                le: float = None
                multiple_of: float = None,
            * pydantic.constr:
                min_length: int = None
                max_length: int = None
                curtail_length: int = None (Only verify the regex on the first curtail_length characters)
                regex: str = None          (The regex is verified via Pattern.match())
            * pydantic.stricturl:
                min_length: int = 1
                max_length: int = 2 ** 16
                tld_required: bool = True
                allowed_schemes: Optional[Set[str]] = None

        Example usage:

            * Define a vlan_id type which represent a valid vlan ID (0-4,095):

              typedef vlan_id as number matching std::validate_type("pydantic.conint", self, {"ge": 0, "le": 4095})


    """
    if not (fq_type_name.startswith("pydantic.")
            or fq_type_name.startswith("datetime.")
            or fq_type_name.startswith("ipaddress.")
            or fq_type_name.startswith("uuid.")):
        return False
    module_name, type_name = fq_type_name.split(".", 1)
    module = importlib.import_module(module_name)
    t = getattr(module, type_name)
    # Construct pydantic model
    if validation_parameters is not None:
        model = pydantic.create_model(fq_type_name,
                                      value=(t(**validation_parameters), ...))
    else:
        model = pydantic.create_model(fq_type_name, value=(t, ...))
    # Do validation
    try:
        model(value=value)
    except pydantic.ValidationError:
        return False

    return True
Beispiel #28
0
def _get_method(func: Callable, model_name_map: ModelDict,
                schemas: Dict[str, Dict]) -> Method:
    sig = inspect.signature(func)
    docstr = inspect.getdoc(func)
    kwargs: Dict[str, Any] = {
        'paramStructure': ParamStructure.BY_NAME,
    }

    method_name: str = getattr(func, "__rpc_name__")
    request_params_model: Optional[Type[BaseModel]] = getattr(
        func, "__rpc_request_model__", None)
    response_result_model: Optional[Type[BaseModel]] = getattr(
        func, "__rpc_response_model__", None)
    deprecated: bool = getattr(func, "__rpc_deprecated__", False)
    summary: str = getattr(func, "__rpc_summary__", "")
    description: str = getattr(func, "__rpc_description__", "")

    kwargs['name'] = method_name
    if summary:
        kwargs['summary'] = summary
    if description:
        kwargs['description'] = description
    if deprecated:
        kwargs['deprecated'] = True

    params_docs: Dict[str, str] = {}
    result_doc: Optional[str] = None

    if docstr:
        doc = docstring_parser.parse(docstr)
        if 'summary' not in kwargs and doc.short_description:
            kwargs['summary'] = doc.short_description
        if 'description' not in kwargs and doc.long_description:
            kwargs['description'] = doc.long_description
        if doc.returns:
            # import docstring_parser.common
            # docstring_parser.common.DocstringReturns().description
            result_doc = doc.returns.description

        for p in doc.params:
            params_docs[p.arg_name] = p.description

    camel_method_name = _snake_to_camel(method_name)

    request_model_name = f"{camel_method_name}Request"
    request_params_model_name = f"{camel_method_name}RequestParams"
    response_model_name = f"{camel_method_name}Response"
    response_result_model_name = f"{camel_method_name}ResponseResult"

    defs = _get_field_definitions(sig.parameters)

    RequestParamsModel: Type[BaseModel] = request_params_model or create_model(
        request_params_model_name, **{a: b
                                      for a, b in defs})

    if getattr(RequestParamsModel, "__name__", "") == request_model_name:
        _fix_model_name(RequestParamsModel, request_params_model_name)

    params_def = _get_model_definition(RequestParamsModel, model_name_map,
                                       schemas)

    kwargs['params'] = []

    for name, typ in defs:
        schema = params_def['properties'][name]
        required = 'required' in params_def and name in params_def['required']
        params_kwargs = dict(
            name=name,
            # summary
            # description
            required=required,
            schema=Schema(**schema),
            # deprecated: bool = False
        )
        if name in params_docs:
            params_kwargs['summary'] = params_docs[name]
        kwargs['params'].append(ContentDescriptor(**params_kwargs))

    ResponseResultModel = response_result_model or (
        Any if sig.return_annotation is sig.empty else sig.return_annotation)
    response = {}
    if ResponseResultModel is not None:
        if getattr(ResponseResultModel, "__name__", "") == response_model_name:
            _fix_model_name(ResponseResultModel, response_result_model_name)

        response["result"] = (ResponseResultModel, None)

        ResponseModel: Type[BaseModel] = create_model(
            response_model_name,
            **response  # type: ignore
        )

        params_def = _get_model_definition(ResponseModel, model_name_map,
                                           schemas)

        result_schema = params_def['properties']['result']
    else:
        result_schema = {}

    result_kwargs = dict(
        name='result',
        # summary
        # description
        required=True,
        schema=Schema(**result_schema),
        # deprecated: bool = False)
    )
    if result_doc:
        result_kwargs['summary'] = result_doc

    kwargs['result'] = ContentDescriptor(**result_kwargs)

    examples = _get_examples(func)
    if examples:
        kwargs['examples'] = examples

    errors = _get_errors(func)
    if errors:
        kwargs['errors'] = errors

    return Method(**kwargs)
Beispiel #29
0
    def get_schema(self, allow_nan=False):
        if not self.initialized:
            parser = argparse.ArgumentParser(
                formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        opt = argparse.Namespace()
        self._json_parse_known_args(parser, opt, {})

        named_parsers = {"": parser}
        for model_name in models.get_models_names():
            if self.isTrain and model_name in ["test"]:
                continue
            setter = models.get_option_setter(model_name)
            model_parser = argparse.ArgumentParser()
            setter(model_parser)
            self._json_parse_known_args(model_parser, opt, {})
            named_parsers[model_name] = model_parser

        self.opt = opt
        self.parser = parser
        json_vals = self.to_json()

        for k in json_vals:
            if json_vals[k] is None:
                json_vals[k] = "None"
            if not allow_nan:
                if json_vals[k] == float("inf"):
                    json_vals[k] = 1e100
                if json_vals[k] == float("-inf"):
                    json_vals[k] = -1e100
                if type(json_vals[k]) == float and math.isnan(json_vals[k]):
                    json_vals[k] = 0

        from pydantic import create_model
        schema = create_model(type(self).__name__, **json_vals).schema()

        option_tags = defaultdict(list)

        for parser_name in named_parsers:
            current_parser = named_parsers[parser_name]

            for action_group in current_parser._action_groups:
                for action in action_group._group_actions:
                    if isinstance(action, _HelpAction):
                        continue

                    if len(parser_name) > 0:
                        option_tags[action.dest].append(parser_name)

                    if action.dest in schema["properties"]:
                        field = schema["properties"][action.dest]
                        description = action.help if action.help is not None else ""
                        for c in "#*<>":
                            description = description.replace(c, "\\" + c)
                        field["description"] = description
                        if "title" in field:
                            del field["title"]

        for tagged in option_tags:
            tags = " | ".join(option_tags[tagged])
            schema["properties"][tagged][
                "description"] = "[" + tags + "]\n\n" + schema["properties"][
                    tagged]["description"]

        return schema
Beispiel #30
0
def notes_list_model():
    model = create_model('', notes=(List[dict], ...))
    return model