Exemplo n.º 1
0
def sa_model_fields(
    Model: type,
    *,
    types: AttributeType = AttributeType.COLUMN,
    make_optional: FilterFunctionT,
    only_readable: bool = False,
    only_writable: bool = False,
    exclude: FilterT = (),
    can_omit_nullable: bool = True,
    naming: ModelNameMakerFunction,
) -> Dict[str, Tuple[type, Field]]:
    """ Take an SqlAlchemy model and generate pydantic Field()s from it

    It will use sa_model_info() to extract attribute information from the SqlAlchemy model.
    Only fields selected by `types` & `exclude` will be considered.
    If SqlAlchemy model contains type annotations, they will override column types.

    Args:
        Model: the model to generate fields from
        types: attribute types to include. See AttributeType
        make_optional: a function(name)->bool that selects fields to make Optional[]
        only_readable: only include fields that are readable
        only_writable: only include fields that are writable
        exclude: the list of fields to ignore, or a filter(name) to exclude fields dynamically.
            See also: sa2schema.filters for useful presets
        can_omit_nullable: `False` to make nullable fields and fields with defaults required.
        naming: optionally, a callable(Model) naming pattern generator. This is required for resolving relationship targets.
            If relationships aren't used, provide some exception thrower.
    Returns:
        a dict: attribute names => (type, Field)
    """
    # Model annotations will override any Column types
    model_annotations = getattr(Model, '__annotations__', {})
    model_annotations = resolve_annotations(model_annotations,
                                            Model.__module__)

    # Walk attributes
    attributes = [
        (name, info, make_optional(name)) for name, info in sa_model_info(
            Model, types=types, exclude=exclude).items()
        if (not only_readable or info.readable) and (
            not only_writable or info.writable) and
        # Hardcoded for now.
        (not name.startswith('_')
         )  # exclude private properties. Consistent with Pydantic behavior.
    ]

    # Generate Field()s
    return {
        name: (
            # Field type
            pydantic_field_type(name, info, model_annotations, made_optional,
                                naming),
            # Field() object
            make_field(info,
                       made_optional,
                       can_omit_nullable=can_omit_nullable),
        )
        for name, info, made_optional in attributes
    }
Exemplo n.º 2
0
 def __new__(mcs, name, bases, namespaces, **kwargs):
     """
     Iterate through fields and wrap then with typing.Optional type.
     """
     annotations = resolve_annotations(namespaces.get("__annotations__", {}), namespaces.get("__module__", None))
     for base in bases:
         annotations = {**annotations, **getattr(base, "__annotations__", {})}
     for field in annotations:
         if not field.startswith("__"):
             annotations[field] = Optional[annotations[field]]
     namespaces["__annotations__"] = annotations
     return super().__new__(mcs, name, bases, namespaces, **kwargs)
Exemplo n.º 3
0
 def __new__(cls, name, bases, namespace, **kwargs):  # noqa C901
     annotations = resolve_annotations(namespace.get("__annotations__", {}),
                                       namespace.get("__module__", None))
     for field, annotation in annotations.items():
         if get_origin(annotation) is not Annotated:
             continue
         namespace[field] = FieldAnnotation.create_field_info(
             annotation=annotation, value=namespace.get(field, Undefined))
         # Pydantic doesn't yet support Annotated annotations [1], so we'll unwrap the root type.
         # This prevents later inspection of the annotations with `get_type_hints`, so we'll
         # preferably avoid.
         #
         # 1: https://github.com/samuelcolvin/pydantic/pull/2147
         if not _pydantic_is_annotated_aware:
             namespace["__annotations__"][field] = get_args(annotation)[0]
     return super().__new__(cls, name, bases, namespace, **kwargs)
Exemplo n.º 4
0
    def __validate_cls_namespace__(name: str,
                                   namespace: Dict) -> None:  # noqa C901
        """Validate the class name space in place"""
        annotations = resolve_annotations(namespace.get("__annotations__", {}),
                                          namespace.get("__module__"))
        config = validate_config(namespace.get("Config", BaseODMConfig), name)
        odm_fields: Dict[str, ODMBaseField] = {}
        references: List[str] = []
        bson_serialized_fields: Set[str] = set()
        mutable_fields: Set[str] = set()

        # Make sure all fields are defined with type annotation
        for field_name, value in namespace.items():
            if (should_touch_field(value=value) and not is_dunder(field_name)
                    and field_name not in annotations):
                raise TypeError(
                    f"field {field_name} is defined without type annotation")

        # Validate fields types and substitute bson fields
        for (field_name, field_type) in annotations.items():
            if not is_dunder(field_name) and should_touch_field(
                    type_=field_type):
                substituted_type = validate_type(field_type)
                # Handle BSON serialized fields after substitution to allow some
                # builtin substitution
                bson_serialization_method = getattr(substituted_type,
                                                    "__bson__", None)
                if bson_serialization_method is not None:
                    bson_serialized_fields.add(field_name)
                annotations[field_name] = substituted_type

        # Validate fields
        for (field_name, field_type) in annotations.items():
            value = namespace.get(field_name, Undefined)

            if is_dunder(field_name) or not should_touch_field(
                    value, field_type):
                continue  # pragma: no cover
                # https://github.com/nedbat/coveragepy/issues/198

            if isinstance(value, PDFieldInfo):
                raise TypeError(
                    "please use odmantic.Field instead of pydantic.Field")

            if is_type_mutable(field_type):
                mutable_fields.add(field_name)

            if lenient_issubclass(field_type, EmbeddedModel):
                if isinstance(value, ODMFieldInfo):
                    namespace[field_name] = value.pydantic_field_info
                    key_name = (value.key_name
                                if value.key_name is not None else field_name)
                    primary_field = value.primary_field
                else:
                    key_name = field_name
                    primary_field = False

                odm_fields[field_name] = ODMEmbedded(
                    primary_field=primary_field,
                    model=field_type,
                    key_name=key_name,
                    model_config=config,
                )
            elif lenient_issubclass(field_type, Model):
                if not isinstance(value, ODMReferenceInfo):
                    raise TypeError(
                        f"cannot define a reference {field_name} (in {name}) without"
                        " a Reference assigned to it")
                key_name = value.key_name if value.key_name is not None else field_name
                raise_on_invalid_key_name(key_name)
                odm_fields[field_name] = ODMReference(model=field_type,
                                                      key_name=key_name,
                                                      model_config=config)
                references.append(field_name)
                del namespace[
                    field_name]  # Remove default ODMReferenceInfo value
            else:
                if isinstance(value, ODMFieldInfo):
                    key_name = (value.key_name
                                if value.key_name is not None else field_name)
                    raise_on_invalid_key_name(key_name)
                    odm_fields[field_name] = ODMField(
                        primary_field=value.primary_field,
                        key_name=key_name,
                        model_config=config,
                    )
                    namespace[field_name] = value.pydantic_field_info

                elif value is Undefined:
                    odm_fields[field_name] = ODMField(primary_field=False,
                                                      key_name=field_name,
                                                      model_config=config)

                else:
                    try:
                        parse_obj_as(field_type, value)
                    except ValidationError:
                        raise TypeError(
                            f"Unhandled field definition {name}: {repr(field_type)}"
                            f" = {repr(value)}")
                    odm_fields[field_name] = ODMField(primary_field=False,
                                                      key_name=field_name,
                                                      model_config=config)

        duplicate_key = find_duplicate_key(odm_fields.values())
        if duplicate_key is not None:
            raise TypeError(f"Duplicated key_name: {duplicate_key} in {name}")
        # NOTE: Duplicate key detection make sur that at most one primary key is
        # defined
        namespace["__annotations__"] = annotations
        namespace["__odm_fields__"] = odm_fields
        namespace["__references__"] = tuple(references)
        namespace["__bson_serialized_fields__"] = frozenset(
            bson_serialized_fields)
        namespace["__mutable_fields__"] = frozenset(mutable_fields)
        namespace["Config"] = config
Exemplo n.º 5
0
def test_resolve_annotations_no_module():
    # TODO: is there a better test for this, can this case really happen?
    fr = ForwardRef('Foo')
    assert resolve_annotations({'Foo': ForwardRef('Foo')}, None) == {'Foo': fr}
Exemplo n.º 6
0
        def __new__(mcs, name, bases, namespace, **kwargs):
            from pydantic.fields import Undefined
            from pydantic.class_validators import extract_validators, inherit_validators
            from pydantic.types import PyObject
            from pydantic.typing import is_classvar, resolve_annotations
            from pydantic.utils import lenient_issubclass, validate_field_name
            from pydantic.main import inherit_config, prepare_config, UNTOUCHED_TYPES

            fields: Dict[str, ModelField] = {}
            config = BaseConfig
            validators: Dict[str, List[Validator]] = {}

            for base in reversed(bases):
                if issubclass(base, AbstractCheckedSession) and base != AbstractCheckedSession:
                    config = inherit_config(base.__config__, config)
                    fields.update(deepcopy(base.__fields__))
                    validators = inherit_validators(base.__validators__, validators)

            config = inherit_config(namespace.get('Config'), config)
            validators = inherit_validators(extract_validators(namespace), validators)

            # update fields inherited from base classes
            for field in fields.values():
                field.set_config(config)
                extra_validators = validators.get(field.name, [])
                if extra_validators:
                    field.class_validators.update(extra_validators)
                    # re-run prepare to add extra validators
                    field.populate_validators()

            prepare_config(config, name)

            # extract and build fields
            class_vars = set()
            if (namespace.get('__module__'), namespace.get('__qualname__')) != \
                    ('larray.core.checked', 'CheckedSession'):
                untouched_types = UNTOUCHED_TYPES + config.keep_untouched

                # annotation only fields need to come first in fields
                annotations = resolve_annotations(namespace.get('__annotations__', {}),
                                                  namespace.get('__module__', None))
                for ann_name, ann_type in annotations.items():
                    if is_classvar(ann_type):
                        class_vars.add(ann_name)
                    elif not ann_name.startswith('_'):
                        validate_field_name(bases, ann_name)
                        value = namespace.get(ann_name, Undefined)
                        if (isinstance(value, untouched_types) and ann_type != PyObject
                                and not lenient_issubclass(getattr(ann_type, '__origin__', None), Type)):
                            continue
                        fields[ann_name] = ModelField.infer(name=ann_name, value=value, annotation=ann_type,
                                                            class_validators=validators.get(ann_name, []),
                                                            config=config)

                for var_name, value in namespace.items():
                    # 'var_name not in annotations' because namespace.items() contains annotated fields
                    # with default values
                    # 'var_name not in class_vars' to avoid to update a field if it was redeclared (by mistake)
                    if (var_name not in annotations and not var_name.startswith('_')
                            and not isinstance(value, untouched_types) and var_name not in class_vars):
                        validate_field_name(bases, var_name)
                        # the method ModelField.infer() fails to infer the type of Group objects
                        # (which are interpreted as ndarray objects)
                        annotation = type(value) if isinstance(value, Group) else annotations.get(var_name)
                        inferred = ModelField.infer(name=var_name, value=value, annotation=annotation,
                                                    class_validators=validators.get(var_name, []), config=config)
                        if var_name in fields and inferred.type_ != fields[var_name].type_:
                            raise TypeError(f'The type of {name}.{var_name} differs from the new default value; '
                                            f'if you wish to change the type of this field, please use a type '
                                            f'annotation')
                        fields[var_name] = inferred

            new_namespace = {
                '__config__': config,
                '__fields__': fields,
                '__field_defaults__': {n: f.default for n, f in fields.items() if not f.required},
                '__validators__': validators,
                **{n: v for n, v in namespace.items() if n not in fields},
            }
            return super().__new__(mcs, name, bases, new_namespace, **kwargs)