예제 #1
0
    def constr_wrapper(cls_: Type[ObjectT]) -> Type[ObjectT]:
        nonlocal constraints
        nonlocal values
        cdict = dict(cls_.__dict__)
        cdict.pop("__dict__", None)
        cdict.pop("__weakref__", None)
        constr_cls = _get_constraint_cls(cls_)
        if not constr_cls:
            raise TypeError(f"can't constrain type {cls_.__name__!r}")

        args: Tuple[Type, ...] = ()
        if values and constr_cls.type in {list, dict, set, tuple, frozenset}:  # type: ignore
            args = _handle_constraint_values(constraints, values, args)
        if constr_cls.type == dict:  # type: ignore
            args = _handle_constraint_keys(constraints, args)
        if args:
            cdict["__args__"] = args

        constraints_inst = constr_cls(**constraints)
        bases = inspect.getmro(cls_)

        def new(_new):
            @functools.wraps(_new)
            def __constrained_new(*args, **kwargs):
                result = _new(*args, **kwargs)
                return constraints_inst.validate(result)

            return __constrained_new

        def init(_init):
            @functools.wraps(_init)
            def __constrained_init(self, *args, **kwargs):
                _init(self, *args, **kwargs)
                constraints_inst.validate(self)

            return __constrained_init

        name = cls_.__name__
        if isbuiltintype(cls_):
            name = f"Constrained{cls_.__name__.capitalize()}"

        cdict.update(
            __constraints__=constraints_inst,
            __parent__=constraints_inst.type,
            __module__=cls_.__module__,
            **(
                {"__new__": new(cls_.__new__)}
                if constraints_inst.type in {str, bytes, int, float}
                else {"__init__": init(cls_.__init__)}
            ),
        )
        cls: Type[ObjectT] = cast(Type[ObjectT], type(name, bases, cdict))

        return cls
예제 #2
0
파일: typed.py 프로젝트: amilstead/typical
    def get_origin(cls, annotation: Any) -> Any:
        """Get origins for subclasses of typing._SpecialForm, recursive"""
        # Resolve custom NewTypes, recursively.
        actual = checks.resolve_supertype(annotation)
        # Extract the origin of the annotation, recursively.
        actual = getattr(actual, "__origin__", actual)
        if checks.isoptionaltype(annotation) or checks.isclassvartype(
                annotation):
            args = cls.get_args(annotation)
            return cls.get_origin(args[0]) if args else actual

        # provide defaults for generics
        if not checks.isbuiltintype(actual):
            actual = cls._check_generics(actual)

        return actual
예제 #3
0
def _from_mapping_type(
    t: Type[Mapping], *, nullable: bool = False, name: str = None, cls: Type = None
) -> Union[MappingConstraints, DictConstraints]:
    if isbuiltintype(t):
        return DictConstraints(nullable=nullable, name=name)
    base = getattr(t, "__origin__", t)
    constr_class: Union[Type[MappingConstraints], Type[DictConstraints]]
    constr_class = MappingConstraints
    if base is dict:
        constr_class = DictConstraints
    args = get_args(t)
    if not args:
        return constr_class(nullable=nullable, name=name)
    key_arg, value_arg = args
    key_items, value_items = (
        _resolve_args(key_arg, cls=cls),
        _resolve_args(value_arg, cls=cls),
    )
    return constr_class(
        keys=key_items, values=value_items, nullable=nullable, name=name  # type: ignore
    )
예제 #4
0
def origin(annotation: Any) -> Any:
    """Get the highest-order 'origin'-type for subclasses of typing._SpecialForm.

    For the purposes of this library, if we can resolve to a builtin type, we will.

    Examples
    --------
    >>> import typic
    >>> from typing import Dict, Mapping, NewType, Optional
    >>> typic.origin(Dict)
    <class 'dict'>
    >>> typic.origin(Mapping)
    <class 'dict'>
    >>> Registry = NewType('Registry', Dict)
    >>> typic.origin(Registry)
    <class 'dict'>
    >>> class Foo: ...
    ...
    >>> typic.origin(Foo)
    <class 'typic.util.Foo'>
    """
    # Resolve custom NewTypes.
    actual = resolve_supertype(annotation)

    # Unwrap optional/classvar
    if checks.isclassvartype(actual):
        args = get_args(actual)
        actual = args[0] if args else actual

    actual = get_origin(actual) or actual

    # provide defaults for generics
    if not checks.isbuiltintype(actual):
        actual = _check_generics(actual)

    if inspect.isroutine(actual):
        actual = collections.abc.Callable

    return actual
예제 #5
0
파일: des.py 프로젝트: seandstewart/typical
 def _set_checks(self, func: gen.Block, anno_name: str,
                 annotation: Annotation):
     _ctx = {}
     # run a safe eval if input is text and anno isn't
     if inspect.isclass(annotation.resolved_origin) and (issubclass(
             annotation.resolved_origin,
         (str, bytes)) or checks.isdecimaltype(annotation.resolved_origin)):
         self._add_vtype(func)
     else:
         self._add_eval(func)
     # Equality checks for defaults and optionals
     custom_equality = hasattr(annotation.resolved_origin, "equals")
     if custom_equality and (annotation.optional or annotation.has_default):
         func.l(f"custom_equality = hasattr({self.VNAME}, 'equals')")
     null = ""
     if annotation.optional:
         null = f"{self.VNAME} in {self.resolver.OPTIONALS}"
         if custom_equality:
             null = (
                 f"(any({self.VNAME}.equals(o) for o in {self.resolver.OPTIONALS}) "
                 "if custom_equality "
                 f"else {null})")
     eq = ""
     if (annotation.has_default and annotation.parameter.default
             not in self.resolver.OPTIONALS):
         eq = f"{self.VNAME} == __default"
         if custom_equality:
             if hasattr(annotation.parameter.default, "equals"):
                 eq = f"__default.equals({self.VNAME})"
             eq = f"{self.VNAME}.equals(__default) if custom_equality else {eq}"
         _ctx["__default"] = annotation.parameter.default
     if eq or null:
         # Add a type-check for anything that isn't a builtin.
         if eq and not checks.isbuiltintype(annotation.resolved_origin):
             eq = f"{self.VTYPE} is {anno_name} and {eq}"
         check = " or ".join(c for c in (null, eq) if c)
         with func.b(f"if {check}:", **_ctx) as b:  # type: ignore
             b.l(f"return {self.VNAME}")
예제 #6
0
def test_isbuiltintype(obj: typing.Any):
    assert isbuiltintype(obj)
예제 #7
0
파일: des.py 프로젝트: seandstewart/typical
class DesFactory:
    """A callable class for ``des``erialzing values.

    Checks for:

            - builtin types
            - :py:mod:`typing` type annotations
            - :py:class:`datetime.date`
            - :py:class:`datetime.datetime`
            - :py:class:`typing.TypedDict`
            - :py:class:`typing.NamedTuple`
            - :py:func:`collections.namedtuple`
            - User-defined classes (limited)

    Examples
    --------
    >>> import typic
    >>> typic.transmute(bytes, "foo")
    b'foo'
    >>> typic.transmute(dict, "{'foo': 'bar'}")
    {'foo': 'bar'}
    """

    STRICT = STRICT_MODE
    DEFAULT_BYTE_ENCODING = "utf-8"
    UNRESOLVABLE = frozenset((
        Any,
        Match,
        re.Match,  # type: ignore
        type(None),
        _empty,
        Callable,
        abc.Callable,
    ))
    VNAME = "val"
    VTYPE = "vtype"
    __DES_CACHE: Dict[str, DeserializerT] = {}
    __USER_DESS: DeserializerRegistryT = deque()

    def __init__(self, resolver: Resolver):
        self.resolver = resolver

    def register(self, deserializer: DeserializerT, check: DeserializerT):
        """Register a user-defined coercer.

        In the rare case where typic can't figure out how to coerce your annotation
        correctly, a custom coercer may be registered alongside a check function which
        returns a simple boolean indicating whether this is the correct coercer for an
        annotation.
        """
        self.__USER_DESS.appendleft((check, deserializer))

    def _set_checks(self, func: gen.Block, anno_name: str,
                    annotation: Annotation):
        _ctx = {}
        # run a safe eval if input is text and anno isn't
        if inspect.isclass(annotation.resolved_origin) and (issubclass(
                annotation.resolved_origin,
            (str, bytes)) or checks.isdecimaltype(annotation.resolved_origin)):
            self._add_vtype(func)
        else:
            self._add_eval(func)
        # Equality checks for defaults and optionals
        custom_equality = hasattr(annotation.resolved_origin, "equals")
        if custom_equality and (annotation.optional or annotation.has_default):
            func.l(f"custom_equality = hasattr({self.VNAME}, 'equals')")
        null = ""
        if annotation.optional:
            null = f"{self.VNAME} in {self.resolver.OPTIONALS}"
            if custom_equality:
                null = (
                    f"(any({self.VNAME}.equals(o) for o in {self.resolver.OPTIONALS}) "
                    "if custom_equality "
                    f"else {null})")
        eq = ""
        if (annotation.has_default and annotation.parameter.default
                not in self.resolver.OPTIONALS):
            eq = f"{self.VNAME} == __default"
            if custom_equality:
                if hasattr(annotation.parameter.default, "equals"):
                    eq = f"__default.equals({self.VNAME})"
                eq = f"{self.VNAME}.equals(__default) if custom_equality else {eq}"
            _ctx["__default"] = annotation.parameter.default
        if eq or null:
            # Add a type-check for anything that isn't a builtin.
            if eq and not checks.isbuiltintype(annotation.resolved_origin):
                eq = f"{self.VTYPE} is {anno_name} and {eq}"
            check = " or ".join(c for c in (null, eq) if c)
            with func.b(f"if {check}:", **_ctx) as b:  # type: ignore
                b.l(f"return {self.VNAME}")

    @staticmethod
    def _get_name(annotation: Annotation) -> str:
        return get_defname("deserializer", annotation)

    def _build_date_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        origin = annotation.resolved_origin
        # From an int
        with func.b(f"if isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}.fromtimestamp({self.VNAME})")
        # From a string
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME})"
            b.l(line, dateparse=dateparse)
        if issubclass(origin, datetime.datetime):
            with func.b(f"if isinstance({self.VNAME}, datetime):",
                        datetime=datetime.datetime) as b:
                # Use pendulum's helper if possible.
                if origin is DateTime:
                    b.l(f"{self.VNAME} = instance({self.VNAME})",
                        instance=instance)
                else:
                    b.l(
                        f"{self.VNAME} = "
                        f"{anno_name}("
                        f"{self.VNAME}.year, "
                        f"{self.VNAME}.month, "
                        f"{self.VNAME}.day, "
                        f"{self.VNAME}.hour, "
                        f"{self.VNAME}.minute, "
                        f"{self.VNAME}.second, "
                        f"{self.VNAME}.microsecond, "
                        f"{self.VNAME}.tzinfo"
                        f")", )
            with func.b(f"elif isinstance({self.VNAME}, date):",
                        date=datetime.date) as b:
                b.l(
                    f"{self.VNAME} = "
                    f"{anno_name}("
                    f"{self.VNAME}.year, "
                    f"{self.VNAME}.month, "
                    f"{self.VNAME}.day"
                    f")", )
        elif issubclass(origin, datetime.date):
            with func.b(f"if isinstance({self.VNAME}, datetime):",
                        datetime=datetime.datetime) as b:
                b.l(f"{self.VNAME} = {self.VNAME}.date()")
        with func.b(f"elif isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}.fromtimestamp({self.VNAME})")
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME}, exact=True)"
            b.l(line, dateparse=dateparse)

    def _build_time_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        # From an int
        with func.b(f"if isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}(int({self.VNAME}))")
        # From a string
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME}, exact=True)"
            b.l(line, dateparse=dateparse)
        # From a datetime
        with func.b(f"if isinstance({self.VNAME}, datetime):",
                    datetime=datetime.datetime) as b:
            b.l(f"{self.VNAME} = {self.VNAME}.time()")
        # From a date
        with func.b(f"if isinstance({self.VNAME}, date):",
                    date=datetime.date) as b:
            b.l(f"{self.VNAME} = {anno_name}(0)")

    def _build_timedelta_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        # From an int
        with func.b(f"if isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}(int({self.VNAME}))")
        # From a string
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME}, exact=True)"
            b.l(line, dateparse=dateparse)

    def _build_uuid_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        self._add_type_check(func, anno_name)
        with func.b(f"if issubclass({self.VTYPE}, UUID):",
                    UUID=uuid.UUID) as b:
            b.l(f"{self.VNAME} = {anno_name}(int={self.VNAME}.int)")

        with func.b(f"elif isinstance({self.VNAME}, str):") as b:
            b.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

        with func.b(f"elif isinstance({self.VNAME}, bytes):") as b:
            b.l(f"{self.VNAME} = {anno_name}(bytes={self.VNAME})")

        with func.b(f"elif isinstance({self.VNAME}, int):") as b:
            b.l(f"{self.VNAME} = {anno_name}(int={self.VNAME})")

        with func.b(f"elif isinstance({self.VNAME}, tuple):") as b:
            b.l(f"{self.VNAME} = {anno_name}(fields={self.VNAME})")

    def _add_eval(self, func: gen.Block):
        func.l(
            f"_, {self.VNAME} = __eval({self.VNAME}) "
            f"if isinstance({self.VNAME}, (str, bytes)) "
            f"else (False, {self.VNAME})",
            __eval=safe_eval,
        )
        self._add_vtype(func)

    def _add_type_check(self, func: gen.Block, anno_name: str):
        with func.b(f"if {self.VTYPE} is {anno_name}:") as b:
            b.l(f"{gen.Keyword.RET} {self.VNAME}")

    def _add_vtype(self, func: gen.Block):
        func.l(f"{self.VTYPE} = {self.VNAME}.__class__")

    def _get_default_factory(self, annotation: "AnnotationT"):
        factory: Union[Type, Callable[..., Any], None] = None
        args: Tuple = annotation.args if isinstance(annotation,
                                                    Annotation) else tuple()
        if args:
            factory_anno = self.resolver.annotation(args[-1])
            if isinstance(factory_anno, ForwardDelayedAnnotation):
                return factory
            elif isinstance(factory_anno, DelayedAnnotation):
                use = factory_anno.type
                raw = use
            else:
                use = factory_anno.resolved_origin
                raw = factory_anno.un_resolved
            factory = use
            if issubclass(use, defaultdict):
                factory_nested = self._get_default_factory(factory_anno)

                def factory():
                    return defaultdict(factory_nested)

                factory.__qualname__ = f"factory({repr(raw)})"  # type: ignore

            if not checks.isbuiltinsubtype(use):  # type: ignore

                params: Mapping[str, inspect.Parameter] = cached_signature(
                    use).parameters
                if not any(p.default is p.empty for p in params.values()):

                    def factory(*, __origin=use):
                        return __origin()

                    factory.__qualname__ = f"factory({repr(raw)})"  # type: ignore

        return factory

    def _build_text_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        origin = annotation.resolved_origin
        # Encode for bytes
        if issubclass(origin, bytes):
            with func.b(f"if isinstance({self.VNAME}, str):") as b:
                b.l(f"{self.VNAME} = {anno_name}("
                    f"{self.VNAME}, encoding={DEFAULT_ENCODING!r})")
        # Decode for str
        elif issubclass(origin, str):
            with func.b(
                    f"if isinstance({self.VNAME}, (bytes, bytearray)):") as b:
                b.l(f"{self.VNAME} = {self.VNAME}.decode({DEFAULT_ENCODING!r})"
                    )
        func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_builtin_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        origin = annotation.resolved_origin
        if issubclass(origin, (str, bytes)):
            self._build_text_des(context)
        elif checks.ismappingtype(origin):
            self._build_mapping_des(context)
        elif checks.iscollectiontype(origin):
            self._build_collection_des(context)
        # bool, int, float...
        else:
            func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_pattern_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        func.l(
            f"{self.VNAME} = {self.VNAME} "
            f"if issubclass({self.VTYPE}, {anno_name}) "
            f"else __re_compile({self.VNAME})",
            __re_compile=re.compile,
        )

    def _build_fromdict_des(self, func: gen.Block, anno_name: str):
        self._add_type_check(func, anno_name)
        func.l(f"{self.VNAME} = {anno_name}.from_dict({self.VNAME})")

    def _build_typeddict_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        total = getattr(context.annotation.resolved_origin, "__total__", True)
        with func.b(f"if issubclass({self.VTYPE}, Mapping):",
                    Mapping=abc.Mapping) as b:
            fields_deser = {
                x: self.resolver._resolve_from_annotation(
                    y, namespace=namespace).transmute
                for x, y in annotation.serde.fields.items()
            }
            x = "fields_in[x]"
            y = (f"fields_deser[x]({self.VNAME}[x])"
                 if fields_deser else f"{self.VNAME}[x]")
            line = f"{{{x}: {y} for x in fields_in.keys()"
            tail = "}" if total else f"& {self.VNAME}.keys()}}"
            b.l(f"{self.VNAME} = {anno_name}(**{line}{tail})",
                fields_deser=fields_deser)
        with func.b("else:") as b:
            b.l(
                f"{self.VNAME} = translate({self.VNAME}, {anno_name})",
                translate=self.resolver.translate,
            )

    def _build_typedtuple_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        with func.b(f"if issubclass({self.VTYPE}, Mapping):",
                    Mapping=abc.Mapping) as b:
            if annotation.serde.fields:
                ctx = dataclasses.replace(context, func=b)
                self._build_typeddict_des(ctx)
            else:
                b.l(f"{self.VNAME} = {anno_name}(**{self.VNAME})", )
        with func.b(
                f"elif isinstance({self.VNAME}, (list, set, frozenset, tuple)):"
        ) as b:
            if annotation.serde.fields:
                b.l(
                    f"{self.VNAME} = __bind({anno_name}, *{self.VNAME}).eval()",
                    __bind=self.resolver.bind,
                )
            else:
                b.l(f"{self.VNAME} = {anno_name}(*{self.VNAME})", )
        with func.b("else:") as b:
            b.l(
                f"{self.VNAME} = translate({self.VNAME}, {anno_name})",
                translate=self.resolver.translate,
            )

    def _build_mapping_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        key_des, item_des = None, None
        args = annotation.args
        if args:
            args = cast(Tuple[Type, Type], args)
            key_type, item_type = args
            key_des = self.resolver.resolve(key_type,
                                            flags=annotation.serde.flags,
                                            namespace=namespace)
            item_des = self.resolver.resolve(item_type,
                                             flags=annotation.serde.flags,
                                             namespace=namespace)
        if issubclass(annotation.resolved_origin, defaultdict):
            factory = self._get_default_factory(annotation)
            func.namespace[anno_name] = functools.partial(defaultdict, factory)
        kd_name = f"{anno_name}_key_des"
        it_name = f"{anno_name}_item_des"
        iterate = f"iterate({self.VNAME})"
        iterate_values = f"iterate({self.VNAME}, values=True)"
        line = f"{anno_name}({iterate})"
        line_values = f"{anno_name}({iterate_values})"
        if args or annotation.serde.fields_in:
            x, y = "x", "y"
            # If there are args & field mapping, get the correct field name
            # AND serialize the key.
            if args and annotation.serde.fields_in:
                x = f"{kd_name}(fields_in.get(x, x))"
            # If there is only a field mapping, get the correct name for the field.
            elif annotation.serde.fields_in:
                x = "fields_in.get(x, x)"
            # If there are only serializers, get the serialized value
            elif args:
                x = f"{kd_name}(x)"
                y = f"{it_name}(y)"
            line = f"{anno_name}({{{x}: {y} for x, y in {iterate}}})"
            line_values = f"{anno_name}({{{x}: {y} for x, y in {iterate_values}}})"
        # If we don't have nested annotations, we can short-circuit on valid inputs
        else:
            self._add_type_check(func, anno_name)
        # Write the lines.
        with func.b(f"if ismappingtype({self.VTYPE}):") as b:
            b.l(f"{self.VNAME} = {line}")
        with func.b("else:") as ob:
            with ob.b("try:") as b:
                b.l(f"{self.VNAME} = {line_values}")
            with ob.b("except (TypeError, ValueError):") as b:
                b.l(f"{self.VNAME} = {line}")
        func.namespace.update({
            kd_name: key_des,
            it_name: item_des,
            "Mapping": abc.Mapping,
            "iterate": self.resolver.iterate,
            "ismappingtype": checks.ismappingtype,
        })

    def _build_tuple_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        if annotation.args and annotation.args[-1] is not ...:
            item_des = {
                ix: self.resolver.resolve(t,
                                          flags=annotation.serde.flags,
                                          namespace=namespace)
                for ix, t in enumerate(annotation.args)
            }
            item_des_name = "item_des"
            iterate = f"iterate({self.VNAME}, values=True)"
            line = (
                f"{anno_name}"
                f"({item_des_name}[ix](v) for ix, v in enumerate({iterate})"
                f"if ix in {item_des_name})")
            func.l(
                f"{self.VNAME} = {line}",
                level=None,
                **{
                    item_des_name: item_des,
                    "iterate": self.resolver.iterate,
                },
            )
        else:
            self._build_collection_des(context)

    def _build_collection_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        item_des = None
        it_name = f"{anno_name}_item_des"
        iterate = f"iterate({self.VNAME}, values=True)"
        line = f"{self.VNAME} = {anno_name}({iterate})"
        if annotation.args:
            item_type = annotation.args[0]
            item_des = self.resolver.resolve(item_type,
                                             flags=annotation.serde.flags,
                                             namespace=namespace)
            line = (f"{self.VNAME} = "
                    f"{anno_name}({it_name}(x) for x in parent({iterate}))")
        else:
            self._add_type_check(func, anno_name)
        func.l(
            line,
            level=None,
            **{
                it_name: item_des,
                "Collection": abc.Collection,
                "iterate": self.resolver.iterate,
            },
        )

    def _build_decimal_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        self._add_type_check(func, anno_name)
        func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_path_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        self._add_type_check(func, anno_name)
        func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_user_type_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        serde = annotation.serde
        resolved = annotation.resolved
        self._add_type_check(func, anno_name)
        # Main branch - we have a mapping for a user-defined class.
        # This is where the serde configuration comes in.
        # WINDY PATH AHEAD
        func.l("# Happy path - deserialize a mapping into the object.")
        with func.b(f"if ismappingtype({self.VTYPE}):",
                    ismappingtype=checks.ismappingtype) as b:
            # Universal line - transform input to known keys/values.
            # Specific values may change.
            def mainline(k, v):
                return f"{{{k}: {v} for x in fields_in.keys() & {self.VNAME}.keys()}}"

            # The "happy path" - e.g., no guesswork needed.
            def happypath(k, v, **ns):
                b.l(f"{self.VNAME} = {anno_name}(**{mainline(k, v)})", **ns)

            # Default X - translate given `x` to known input `x`
            x = "fields_in[x]"
            # No field name translation needs to happen.
            if {*serde.fields_in.keys()} == {*serde.fields_in.values()}:
                x = "x"

            # Default Y - get the given `y` with the given `x`
            y = f"{self.VNAME}[x]"
            # Get the intersection of known input fields and annotations.
            matched = {*serde.fields_in.values()} & serde.fields.keys()
            # Happy path! This is a `@typic.al` wrapped class.
            if self.resolver.known(resolved) or self.resolver.delayed(
                    resolved):
                happypath(x, y)
            # Secondary happy path! We know how to deserialize already.
            else:
                fields_in = serde.fields_in
                fnamespace = namespace or resolved
                if serde.fields and len(matched) == len(serde.fields_in):
                    desers = {
                        f: self.resolver._resolve_from_annotation(
                            serde.fields[f], namespace=fnamespace).transmute
                        for f in matched
                    }
                else:
                    protocols = self.resolver.protocols(
                        annotation.resolved_origin)
                    fields_in = {x: x for x in protocols}
                    desers = {f: p.transmute for f, p in protocols.items()}
                y = f"desers[{x}]({self.VNAME}[x])"
                happypath(x, y, desers=desers, fields_in=fields_in)

        # Secondary branch - we have some other input for a user-defined class
        func.l("# Unknown path, just try casting it directly.")
        with func.b(
                f"elif isbuiltinsubtype({self.VTYPE}) and not isnamedtuple({self.VTYPE}):",
                isbuiltinsubtype=checks.isbuiltinsubtype,
                isnamedtuple=checks.isnamedtuple,
        ) as b:
            b.l(f"{self.VNAME} = {anno_name}({self.VNAME})")
        # Final branch - user-defined class for another user-defined class
        func.l("# Two user-defined types, "
               "try to translate the input into the desired output.")
        with func.b("else:") as b:
            b.l(
                f"{self.VNAME} = translate({self.VNAME}, {anno_name})",
                translate=self.resolver.translate,
            )

    def _build_literal_des(self,
                           annotation: Annotation,
                           func_name: str,
                           namespace: Type = None):
        args = annotation.args
        types: Set[Type] = {a.__class__ for a in args}
        t = types.pop() if len(types) == 1 else Union[tuple(types)]
        t_anno = cast(
            Annotation,
            self.resolver.annotation(
                t,  # type: ignore
                name=annotation.parameter.name,
                # parameter=annotation.parameter,
                is_optional=annotation.optional,
                is_strict=annotation.strict,
                flags=annotation.serde.flags,
                default=annotation.parameter.default,
            ),
        )
        return self._build_des(t_anno, func_name, namespace)

    def _build_union_des(self, context: BuildContext):
        func, annotation, namespace = (
            context.func,
            context.annotation,
            context.namespace,
        )
        # Get all types which we may coerce to.
        args = (*(a for a in annotation.args
                  if a not in {None, Ellipsis, type(None)}), )
        if not args:
            return
        # Add a type-check, but exclude str|bytes, since those are too permissive.
        types = {a for a in args if a not in {str, bytes}}
        if types:
            with func.b(f"if {self.VTYPE} in types:", types=types) as b:
                b.l(f"return {self.VNAME}")
        # Get all custom types, which may have discriminators
        targets = (*(a for a in args if not checks.isstdlibtype(a)), )
        # We can only build a tagged union deserializer if all args are valid
        if args != targets:
            return self._build_generic_union_des(context)

        # Try to collect the field which will be the discriminator.
        # First, get a mapping of Type -> Proto & Type -> Fields
        tagged = get_tag_for_types(targets)
        # Just bail out if we can't find a key.
        if not tagged:
            return self._build_generic_union_des(context)
        # If we got a key, re-map the protocols to the value for each type.
        deserializers = {
            value: self.resolver.resolve(t, namespace=namespace).transmute
            for value, t in tagged.types_by_values
        }
        # Finally, build the deserializer
        func.namespace.update(
            tag=tagged.tag,
            desers=deserializers,
            empty=_empty,
        )
        with func.b(f"if issubclass({self.VTYPE}, Mapping):",
                    Mapping=abc.Mapping) as b:
            b.l(f"tag_value = {self.VNAME}.get(tag, empty)")
        with func.b("else:") as b:
            b.l(f"tag_value = getattr({self.VNAME}, tag, empty)")
        with func.b("if tag_value in desers:") as b:
            b.l(f"{self.VNAME} = desers[tag_value]({self.VNAME})")
        with func.b("else:") as b:
            b.l("raise ValueError("
                'f"Value is missing field {tag!r} with one of '
                '{(*desers,)}: {val!r}"'
                ")")

    def _build_generic_union_des(self, context: BuildContext):
        annotation, namespace, func = (
            context.annotation,
            context.namespace,
            context.func,
        )
        annos = {
            get_name(a): self.resolver.resolve(a, namespace=namespace)
            for a in annotation.args
            if a not in {None, Ellipsis, type(None)}
        }
        if annos:
            desers = {f"{n}_des": p.transmute for n, p in annos.items()}
            types = {n: p.annotation.resolved_origin for n, p in annos.items()}
            ctx: Mapping[str, Union[Type, DeserializerT]] = {**types, **desers}
            for name in annos:
                # Can't do subclass checks with these...
                if name in {"Literal", "Final"}:
                    continue
                with func.b(f"if issubclass({name}, {self.VTYPE}):") as b:
                    b.l(f"return {name}_des({self.VNAME})")
            for name in desers:
                with func.b("try:") as b:
                    b.l(f"return {name}({self.VNAME})")
                with func.b("except (TypeError, ValueError, KeyError):") as b:
                    b.l("pass")
            func.namespace.update(ctx)
            func.l(
                "raise ValueError("
                f'f"Value could not be deserialized into one of {(*annos,)}: {{val!r}}"'
                ")", )
            return False

    def _build_des(  # noqa: C901
        self,
        annotation: Annotation[Type[ObjectT]],
        func_name: str,
        namespace: Type = None,
    ) -> DeserializerT[ObjectT]:
        args = annotation.args
        # Get the "origin" of the annotation.
        # For natives and their typing.* equivs, this will be a builtin type.
        # For SpecialForms (Union, mainly) this will be the un-subscripted type.
        # For custom types or classes, this will be the same as the annotation.
        origin = annotation.resolved_origin
        anno_name = get_unique_name(origin)
        ns = {
            anno_name: origin,
            "parent": getattr(origin, "__parent__", origin),
            "issubclass": cached_issubclass,
            **annotation.serde.asdict(),
        }
        if checks.isliteral(origin):
            return self._build_literal_des(annotation, func_name, namespace)
        with gen.Block(ns) as main:
            with main.f(func_name, main.param(f"{self.VNAME}")) as func:
                needs_return = None
                context = BuildContext(annotation, ns, anno_name, func,
                                       namespace)
                if origin not in self.UNRESOLVABLE:
                    # Set our top-level sanity checks.
                    self._set_checks(func, anno_name, annotation)
                    # Move through our queue.
                    for check, handler in self._HANDLERS.items():
                        # If this is a valid type for this handler,
                        #   write the deserializer.
                        if check(origin, args):
                            needs_return = handler(self, context)
                            break
                # If the deserializer doesn't contain a return statement, add one.
                if needs_return is not False:
                    func.l(f"{gen.Keyword.RET} {self.VNAME}")
        deserializer = main.compile(ns=ns, name=func_name)
        return deserializer

    # Order is IMPORTANT! This is a FIFO queue.
    _HANDLERS: Mapping[HandlerCheckT, BuildHandlerT] = {
        # Special handler for Unions...
        lambda origin, args: checks.isuniontype(origin):
        _build_union_des,
        # Non-intersecting types (order doesn't matter here.
        lambda origin, args: checks.isdatetype(origin):
        _build_date_des,
        lambda origin, args: checks.istimetype(origin):
        _build_time_des,
        lambda origin, args: checks.istimedeltatype(origin):
        _build_timedelta_des,
        lambda origin, args: checks.isuuidtype(origin):
        _build_uuid_des,
        lambda origin, args: origin in {Pattern, re.Pattern}:
        _build_pattern_des,
        lambda origin, args: issubclass(origin, pathlib.Path):
        _build_path_des,
        lambda origin, args: checks.isdecimaltype(origin):
        _build_decimal_des,
        # MUST come before subtype check.
        lambda origin, args: (not args and checks.isbuiltintype(origin)):
        _build_builtin_des,
        # Psuedo-structured containers, should check before generics.
        lambda origin, args: checks.istypeddict(origin):
        _build_typeddict_des,
        lambda origin, args: checks.istypedtuple(origin):
        _build_typedtuple_des,
        lambda origin, args: checks.isnamedtuple(origin):
        _build_typedtuple_des,
        lambda origin, args: (not args and checks.isbuiltinsubtype(origin)):
        _build_builtin_des,
        # A mapping is a collection so must come before that check.
        lambda origin, args: checks.ismappingtype(origin):
        _build_mapping_des,
        # A tuple is a collection so must come before that check.
        lambda origin, args: checks.istupletype(origin):
        _build_tuple_des,
        # Generic collection handler
        lambda origin, args: checks.iscollectiontype(origin):
        _build_collection_des,
        # Catch-all for custom user types (user-defined classes).
        lambda origin, args: True:
        _build_user_type_des,
    }

    def factory(
        self,
        annotation: Annotation[Type[ObjectT]],
        namespace: Type = None,
    ) -> DeserializerT[ObjectT]:
        annotation.serde = annotation.serde or SerdeConfig()
        key = self._get_name(annotation)
        if key in self.__DES_CACHE:
            return self.__DES_CACHE[key]
        deserializer: Optional[DeserializerT] = None
        for check, des in self.__USER_DESS:
            if check(annotation.resolved):
                deserializer = des
                break
        if not deserializer:
            deserializer = self._build_des(annotation, key, namespace)
        self.__DES_CACHE[key] = deserializer
        return deserializer
예제 #8
0
 def _build_des(  # noqa: C901
         self,
         annotation: "Annotation",
         func_name: str,
         namespace: Type = None) -> Callable:
     args = annotation.args
     # Get the "origin" of the annotation.
     # For natives and their typing.* equivs, this will be a builtin type.
     # For SpecialForms (Union, mainly) this will be the un-subscripted type.
     # For custom types or classes, this will be the same as the annotation.
     origin = annotation.resolved_origin
     anno_name = get_unique_name(origin)
     ns = {
         anno_name: origin,
         "parent": getattr(origin, "__parent__", origin),
         "issubclass": cached_issubclass,
         **annotation.serde.asdict(),
     }
     if checks.isliteral(origin):
         return self._build_literal_des(annotation, func_name, namespace)
     with gen.Block(ns) as main:
         with main.f(func_name, main.param(f"{self.VNAME}")) as func:
             if origin not in self.UNRESOLVABLE:
                 self._set_checks(func, anno_name, annotation)
                 if origin is Union:
                     self._build_union_des(func, annotation, namespace)
                 elif checks.isdatetype(origin):
                     self._build_date_des(func, anno_name, annotation)
                 elif checks.istimetype(origin):
                     self._build_time_des(func, anno_name, annotation)
                 elif checks.istimedeltatype(origin):
                     self._build_timedelta_des(func, anno_name, annotation)
                 elif checks.isuuidtype(origin):
                     self._build_uuid_des(func, anno_name, annotation)
                 elif origin in {Pattern, re.Pattern}:  # type: ignore
                     self._build_pattern_des(func, anno_name)
                 elif issubclass(origin, pathlib.Path):
                     self._build_path_des(func, anno_name)
                 elif not args and checks.isbuiltintype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.isfromdictclass(origin):
                     self._build_fromdict_des(func, anno_name)
                 elif checks.isenumtype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.istypeddict(origin):
                     self._build_typeddict_des(
                         func,
                         anno_name,
                         annotation,
                         total=origin.__total__,  # type: ignore
                         namespace=namespace,
                     )
                 elif checks.istypedtuple(origin) or checks.isnamedtuple(
                         origin):
                     self._build_typedtuple_des(func,
                                                anno_name,
                                                annotation,
                                                namespace=namespace)
                 elif not args and checks.isbuiltinsubtype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.ismappingtype(origin):
                     self._build_mapping_des(func,
                                             anno_name,
                                             annotation,
                                             namespace=namespace)
                 elif checks.istupletype(origin):
                     self._build_tuple_des(func,
                                           anno_name,
                                           annotation,
                                           namespace=namespace)
                 elif checks.iscollectiontype(origin):
                     self._build_collection_des(func,
                                                anno_name,
                                                annotation,
                                                namespace=namespace)
                 else:
                     self._build_generic_des(func,
                                             anno_name,
                                             annotation,
                                             namespace=namespace)
             func.l(f"{gen.Keyword.RET} {self.VNAME}")
     deserializer = main.compile(ns=ns, name=func_name)
     return deserializer