Exemplo n.º 1
0
 def _add_type_check(self, func: gen.Function, annotation: Annotation):
     resolved_name = util.get_name(annotation.resolved)
     func.l(f"{self._FNAME} = name or {resolved_name!r}")
     line = "if not tcheck(o.__class__, t):"
     check: Callable[..., bool] = util.cached_issubclass
     t = annotation.generic
     if checks.isbuiltinsubtype(annotation.generic):
         line = "if not tcheck(o, t):"
         check = isinstance  # type: ignore
     if checks.istypeddict(t):
         t = dict
     with func.b(line, tcheck=check) as b:
         msg = (
             f"{{{self._FNAME}}}: type {{inst_tname!r}} "
             f"is not a subtype of type "
             f"{util.get_qualname(annotation.generic)!r}. "
             f"Perhaps this annotation should be "
             f"Union[{{inst_tname}}, {util.get_qualname(annotation.generic)}]?"
         )
         b.l("inst_tname = qualname(o.__class__)")
         b.l(
             f"raise err(f{msg!r})",
             err=SerializationValueError,
             qualname=util.get_qualname,
             t=t,
         )
Exemplo n.º 2
0
def signature(obj: Union[Callable, Type]) -> inspect.Signature:
    """Get the signature of a type or callable.

    Also supports TypedDict subclasses
    """
    return (typed_dict_signature(obj)
            if checks.istypeddict(obj)  # type: ignore
            else inspect.signature(obj))
Exemplo n.º 3
0
def safe_get_params(obj: Type) -> Mapping[str, inspect.Parameter]:
    params: Mapping[str, inspect.Parameter]
    try:
        if checks.issubclass(obj, Mapping) and not checks.istypeddict(obj):
            return {}
        params = cached_signature(obj).parameters
    except (ValueError, TypeError):  # pragma: nocover
        params = {}
    return params
Exemplo n.º 4
0
def get_constraints(
    t: Type[VT],
    *,
    nullable: bool = False,
    name: str = None,
    cls: Optional[Type] = ...,  # type: ignore
) -> ConstraintsProtocolT[VT]:
    while should_unwrap(t):
        nullable = nullable or isoptionaltype(t)
        t = get_args(t)[0]
    if t is cls or t in __stack:
        dc = DelayedConstraints(
            t, nullable=nullable, name=name, factory=get_constraints
        )
        return cast(ConstraintsProtocolT, dc)
    if isforwardref(t):
        if cls is ...:  # pragma: nocover
            raise TypeError(
                f"Cannot build constraints for {t} without an enclosing class."
            )
        fdc = ForwardDelayedConstraints(
            t,  # type: ignore
            cls.__module__,
            localns=getattr(cls, "__dict__", {}).copy(),
            nullable=nullable,
            name=name,
            factory=get_constraints,
        )
        return cast(ConstraintsProtocolT, fdc)
    if isconstrained(t):
        c: ConstraintsProtocolT = t.__constraints__  # type: ignore
        if (c.name, c.nullable) != (name, nullable):
            return dataclasses.replace(c, name=name, nullable=nullable)
        return c
    if isenumtype(t):
        ec = _from_enum_type(t, nullable=nullable, name=name)  # type: ignore
        return cast(ConstraintsProtocolT, ec)
    if isabstract(t):
        return cast(
            ConstraintsProtocolT, _from_strict_type(t, nullable=nullable, name=name)
        )
    if isnamedtuple(t) or istypeddict(t):
        handler = _from_class
    else:
        ot = origin(t)
        if ot in {type, abc.Callable}:
            handler = _from_strict_type  # type: ignore
            t = ot
        else:
            handler = _CONSTRAINT_BUILDER_HANDLERS.get_by_parent(ot, _from_class)  # type: ignore

    __stack.add(t)
    c = handler(t, nullable=nullable, name=name, cls=cls)
    __stack.clear()
    return c
Exemplo n.º 5
0
 def _build_field(
     self,
     use: Type,
     protocol: SerdeProtocol,
     parent: Optional[Type],
     enum_: Optional[Tuple[Any, ...]],
     default: Optional[Any],
     ro: Optional[bool],
     wo: Optional[bool],
     name: Optional[str],
 ) -> SchemaFieldT:
     # If we've got a base object, use it
     base: Optional[SchemaFieldT]
     if use is object:
         base = UndeclaredSchemaField()
     elif istypeddict(use) or isnamedtuple(use):
         base = None
     else:
         base = cast(SchemaFieldT, SCHEMA_FIELD_FORMATS.get_by_parent(use))
     if base:
         config: MutableMapping = (protocol.constraints.for_schema()
                                   if protocol.constraints else {})
         config.update(enum=enum_,
                       default=default,
                       readOnly=ro,
                       writeOnly=wo)
         # `use` should always be a dict if the annotation is a Mapping,
         # thanks to `origin()` & `resolve()`.
         if isinstance(base, ObjectSchemaField):
             config = self._handle_mapping(protocol,
                                           parent=parent,
                                           name=name,
                                           **config)
         elif isinstance(base, ArraySchemaField):
             config = self._handle_array(protocol, parent=parent, **config)
         schema = dataclasses.replace(base, **config)
     else:
         try:
             schema = self.build_schema(use)
         except (ValueError, TypeError) as e:
             warnings.warn(f"Couldn't build schema for {use}: {e}")
             schema = UndeclaredSchemaField(
                 enum=enum_,
                 title=self.defname(use, name=name),
                 default=default,
                 readOnly=ro,
                 writeOnly=wo,
             )
     schema = self._check_optional(protocol.annotation, schema, ro, wo,
                                   name)
     return schema
Exemplo n.º 6
0
    def _compile_serializer(
            self, annotation: Annotation[Type[_T]]) -> SerializerT[_T]:
        # Check for an optional and extract the type if possible.
        func_name = self._get_name(annotation)
        # We've been here before...
        if func_name in self._serializer_cache:
            return self._serializer_cache[func_name]

        serializer: SerializerT
        origin = annotation.resolved_origin
        # Lazy shortcut for messy paths (Union, Any, ...)
        if (origin in self._DYNAMIC or not annotation.static
                or checks.isuniontype(origin)):
            serializer = cast(SerializerT, self.resolver.primitive)
        # Routines (functions or methods) can't be serialized...
        elif issubclass(
                origin,
                abc.Callable) or inspect.isroutine(origin):  # type: ignore
            name = util.get_qualname(origin)
            with gen.Block() as main:
                with self._define(main, func_name) as func:
                    func.l(
                        f'raise TypeError("Routines are not serializable. ({name!r}).")'
                    )

            serializer = main.compile(name=func_name)
            self._serializer_cache[func_name] = serializer
        # Enums are special
        elif checks.isenumtype(annotation.resolved):
            serializer = self._compile_enum_serializer(annotation)
        # Primitives don't require further processing.
        # Just check for nullable and the correct type.
        elif origin in self._PRIMITIVES:
            ns: dict = {}
            with gen.Block(ns) as main:
                with self._define(main, func_name) as func:
                    self._check_add_null_check(func, annotation)
                    self._add_type_check(func, annotation)
                    line = "o"
                    if annotation.origin in (type(o)
                                             for o in self.resolver.OPTIONALS):
                        line = "None"
                    func.l(f"{gen.Keyword.RET} {line}")

            serializer = main.compile(name=func_name, ns=ns)
            self._serializer_cache[func_name] = serializer

        # Defined cases are pre-compiled, but we have to check for optionals.
        elif origin in self._DEFINED:
            serializer = self._compile_defined_serializer(
                annotation, self._DEFINED[origin])
        elif issubclass(origin, (*self._DEFINED, )):
            serializer = self._compile_defined_subclass_serializer(
                origin, annotation)
        elif issubclass(origin, self._PRIMITIVES):
            serializer = self._compile_primitive_subclass_serializer(
                origin, annotation)
        else:
            # Build the function namespace
            anno_name = f"{func_name}_anno"
            ns = {anno_name: origin, **annotation.serde.asdict()}
            with gen.Block(ns) as main:
                with self._define(main, func_name) as func:
                    # Mapping types need special nested processing as well
                    istypeddict = checks.istypeddict(origin)
                    istypedtuple = checks.istypedtuple(origin)
                    istypicklass = checks.istypicklass(origin)
                    if not istypeddict and issubclass(origin, self._DICTITER):
                        self._build_dict_serializer(func, annotation)
                    # Array types need nested processing.
                    elif (not istypedtuple and not istypeddict
                          and not istypicklass
                          and issubclass(origin, self._LISTITER)):
                        self._build_list_serializer(func, annotation)
                    # Build a serializer for a structured class.
                    else:
                        self._build_class_serializer(func, annotation)
            serializer = main.compile(name=func_name, ns=ns)
            self._serializer_cache[func_name] = serializer
        return serializer
Exemplo n.º 7
0
def test_transmute_simple(annotation, value, expected):
    transmuted = transmute(annotation, value)
    t = dict if istypeddict(annotation) else annotation
    assert isinstance(transmuted, t)
    assert transmuted == expected
Exemplo n.º 8
0
class DesFactory:
    """A callable class for ``des``erialzing values.

    Checks for:

            - builtin types
            - :py:mod:`typing` type annotations
            - :py:class:`datetime.date`
            - :py:class:`datetime.datetime`
            - :py:class:`typing.TypedDict`
            - :py:class:`typing.NamedTuple`
            - :py:func:`collections.namedtuple`
            - User-defined classes (limited)

    Examples
    --------
    >>> import typic
    >>> typic.transmute(bytes, "foo")
    b'foo'
    >>> typic.transmute(dict, "{'foo': 'bar'}")
    {'foo': 'bar'}
    """

    STRICT = STRICT_MODE
    DEFAULT_BYTE_ENCODING = "utf-8"
    UNRESOLVABLE = frozenset((
        Any,
        Match,
        re.Match,  # type: ignore
        type(None),
        _empty,
        Callable,
        abc.Callable,
    ))
    VNAME = "val"
    VTYPE = "vtype"
    __DES_CACHE: Dict[str, DeserializerT] = {}
    __USER_DESS: DeserializerRegistryT = deque()

    def __init__(self, resolver: Resolver):
        self.resolver = resolver

    def register(self, deserializer: DeserializerT, check: DeserializerT):
        """Register a user-defined coercer.

        In the rare case where typic can't figure out how to coerce your annotation
        correctly, a custom coercer may be registered alongside a check function which
        returns a simple boolean indicating whether this is the correct coercer for an
        annotation.
        """
        self.__USER_DESS.appendleft((check, deserializer))

    def _set_checks(self, func: gen.Block, anno_name: str,
                    annotation: Annotation):
        _ctx = {}
        # run a safe eval if input is text and anno isn't
        if inspect.isclass(annotation.resolved_origin) and (issubclass(
                annotation.resolved_origin,
            (str, bytes)) or checks.isdecimaltype(annotation.resolved_origin)):
            self._add_vtype(func)
        else:
            self._add_eval(func)
        # Equality checks for defaults and optionals
        custom_equality = hasattr(annotation.resolved_origin, "equals")
        if custom_equality and (annotation.optional or annotation.has_default):
            func.l(f"custom_equality = hasattr({self.VNAME}, 'equals')")
        null = ""
        if annotation.optional:
            null = f"{self.VNAME} in {self.resolver.OPTIONALS}"
            if custom_equality:
                null = (
                    f"(any({self.VNAME}.equals(o) for o in {self.resolver.OPTIONALS}) "
                    "if custom_equality "
                    f"else {null})")
        eq = ""
        if (annotation.has_default and annotation.parameter.default
                not in self.resolver.OPTIONALS):
            eq = f"{self.VNAME} == __default"
            if custom_equality:
                if hasattr(annotation.parameter.default, "equals"):
                    eq = f"__default.equals({self.VNAME})"
                eq = f"{self.VNAME}.equals(__default) if custom_equality else {eq}"
            _ctx["__default"] = annotation.parameter.default
        if eq or null:
            # Add a type-check for anything that isn't a builtin.
            if eq and not checks.isbuiltintype(annotation.resolved_origin):
                eq = f"{self.VTYPE} is {anno_name} and {eq}"
            check = " or ".join(c for c in (null, eq) if c)
            with func.b(f"if {check}:", **_ctx) as b:  # type: ignore
                b.l(f"return {self.VNAME}")

    @staticmethod
    def _get_name(annotation: Annotation) -> str:
        return get_defname("deserializer", annotation)

    def _build_date_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        origin = annotation.resolved_origin
        # From an int
        with func.b(f"if isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}.fromtimestamp({self.VNAME})")
        # From a string
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME})"
            b.l(line, dateparse=dateparse)
        if issubclass(origin, datetime.datetime):
            with func.b(f"if isinstance({self.VNAME}, datetime):",
                        datetime=datetime.datetime) as b:
                # Use pendulum's helper if possible.
                if origin is DateTime:
                    b.l(f"{self.VNAME} = instance({self.VNAME})",
                        instance=instance)
                else:
                    b.l(
                        f"{self.VNAME} = "
                        f"{anno_name}("
                        f"{self.VNAME}.year, "
                        f"{self.VNAME}.month, "
                        f"{self.VNAME}.day, "
                        f"{self.VNAME}.hour, "
                        f"{self.VNAME}.minute, "
                        f"{self.VNAME}.second, "
                        f"{self.VNAME}.microsecond, "
                        f"{self.VNAME}.tzinfo"
                        f")", )
            with func.b(f"elif isinstance({self.VNAME}, date):",
                        date=datetime.date) as b:
                b.l(
                    f"{self.VNAME} = "
                    f"{anno_name}("
                    f"{self.VNAME}.year, "
                    f"{self.VNAME}.month, "
                    f"{self.VNAME}.day"
                    f")", )
        elif issubclass(origin, datetime.date):
            with func.b(f"if isinstance({self.VNAME}, datetime):",
                        datetime=datetime.datetime) as b:
                b.l(f"{self.VNAME} = {self.VNAME}.date()")
        with func.b(f"elif isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}.fromtimestamp({self.VNAME})")
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME}, exact=True)"
            b.l(line, dateparse=dateparse)

    def _build_time_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        # From an int
        with func.b(f"if isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}(int({self.VNAME}))")
        # From a string
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME}, exact=True)"
            b.l(line, dateparse=dateparse)
        # From a datetime
        with func.b(f"if isinstance({self.VNAME}, datetime):",
                    datetime=datetime.datetime) as b:
            b.l(f"{self.VNAME} = {self.VNAME}.time()")
        # From a date
        with func.b(f"if isinstance({self.VNAME}, date):",
                    date=datetime.date) as b:
            b.l(f"{self.VNAME} = {anno_name}(0)")

    def _build_timedelta_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        # From an int
        with func.b(f"if isinstance({self.VNAME}, (int, float)):") as b:
            b.l(f"{self.VNAME} = {anno_name}(int({self.VNAME}))")
        # From a string
        with func.b(f"elif isinstance({self.VNAME}, (str, bytes)):") as b:
            line = f"{self.VNAME} = dateparse({self.VNAME}, exact=True)"
            b.l(line, dateparse=dateparse)

    def _build_uuid_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        self._add_type_check(func, anno_name)
        with func.b(f"if issubclass({self.VTYPE}, UUID):",
                    UUID=uuid.UUID) as b:
            b.l(f"{self.VNAME} = {anno_name}(int={self.VNAME}.int)")

        with func.b(f"elif isinstance({self.VNAME}, str):") as b:
            b.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

        with func.b(f"elif isinstance({self.VNAME}, bytes):") as b:
            b.l(f"{self.VNAME} = {anno_name}(bytes={self.VNAME})")

        with func.b(f"elif isinstance({self.VNAME}, int):") as b:
            b.l(f"{self.VNAME} = {anno_name}(int={self.VNAME})")

        with func.b(f"elif isinstance({self.VNAME}, tuple):") as b:
            b.l(f"{self.VNAME} = {anno_name}(fields={self.VNAME})")

    def _add_eval(self, func: gen.Block):
        func.l(
            f"_, {self.VNAME} = __eval({self.VNAME}) "
            f"if isinstance({self.VNAME}, (str, bytes)) "
            f"else (False, {self.VNAME})",
            __eval=safe_eval,
        )
        self._add_vtype(func)

    def _add_type_check(self, func: gen.Block, anno_name: str):
        with func.b(f"if {self.VTYPE} is {anno_name}:") as b:
            b.l(f"{gen.Keyword.RET} {self.VNAME}")

    def _add_vtype(self, func: gen.Block):
        func.l(f"{self.VTYPE} = {self.VNAME}.__class__")

    def _get_default_factory(self, annotation: "AnnotationT"):
        factory: Union[Type, Callable[..., Any], None] = None
        args: Tuple = annotation.args if isinstance(annotation,
                                                    Annotation) else tuple()
        if args:
            factory_anno = self.resolver.annotation(args[-1])
            if isinstance(factory_anno, ForwardDelayedAnnotation):
                return factory
            elif isinstance(factory_anno, DelayedAnnotation):
                use = factory_anno.type
                raw = use
            else:
                use = factory_anno.resolved_origin
                raw = factory_anno.un_resolved
            factory = use
            if issubclass(use, defaultdict):
                factory_nested = self._get_default_factory(factory_anno)

                def factory():
                    return defaultdict(factory_nested)

                factory.__qualname__ = f"factory({repr(raw)})"  # type: ignore

            if not checks.isbuiltinsubtype(use):  # type: ignore

                params: Mapping[str, inspect.Parameter] = cached_signature(
                    use).parameters
                if not any(p.default is p.empty for p in params.values()):

                    def factory(*, __origin=use):
                        return __origin()

                    factory.__qualname__ = f"factory({repr(raw)})"  # type: ignore

        return factory

    def _build_text_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        origin = annotation.resolved_origin
        # Encode for bytes
        if issubclass(origin, bytes):
            with func.b(f"if isinstance({self.VNAME}, str):") as b:
                b.l(f"{self.VNAME} = {anno_name}("
                    f"{self.VNAME}, encoding={DEFAULT_ENCODING!r})")
        # Decode for str
        elif issubclass(origin, str):
            with func.b(
                    f"if isinstance({self.VNAME}, (bytes, bytearray)):") as b:
                b.l(f"{self.VNAME} = {self.VNAME}.decode({DEFAULT_ENCODING!r})"
                    )
        func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_builtin_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        origin = annotation.resolved_origin
        if issubclass(origin, (str, bytes)):
            self._build_text_des(context)
        elif checks.ismappingtype(origin):
            self._build_mapping_des(context)
        elif checks.iscollectiontype(origin):
            self._build_collection_des(context)
        # bool, int, float...
        else:
            func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_pattern_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        func.l(
            f"{self.VNAME} = {self.VNAME} "
            f"if issubclass({self.VTYPE}, {anno_name}) "
            f"else __re_compile({self.VNAME})",
            __re_compile=re.compile,
        )

    def _build_fromdict_des(self, func: gen.Block, anno_name: str):
        self._add_type_check(func, anno_name)
        func.l(f"{self.VNAME} = {anno_name}.from_dict({self.VNAME})")

    def _build_typeddict_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        total = getattr(context.annotation.resolved_origin, "__total__", True)
        with func.b(f"if issubclass({self.VTYPE}, Mapping):",
                    Mapping=abc.Mapping) as b:
            fields_deser = {
                x: self.resolver._resolve_from_annotation(
                    y, namespace=namespace).transmute
                for x, y in annotation.serde.fields.items()
            }
            x = "fields_in[x]"
            y = (f"fields_deser[x]({self.VNAME}[x])"
                 if fields_deser else f"{self.VNAME}[x]")
            line = f"{{{x}: {y} for x in fields_in.keys()"
            tail = "}" if total else f"& {self.VNAME}.keys()}}"
            b.l(f"{self.VNAME} = {anno_name}(**{line}{tail})",
                fields_deser=fields_deser)
        with func.b("else:") as b:
            b.l(
                f"{self.VNAME} = translate({self.VNAME}, {anno_name})",
                translate=self.resolver.translate,
            )

    def _build_typedtuple_des(self, context: BuildContext):
        func, annotation, anno_name = (
            context.func,
            context.annotation,
            context.anno_name,
        )
        with func.b(f"if issubclass({self.VTYPE}, Mapping):",
                    Mapping=abc.Mapping) as b:
            if annotation.serde.fields:
                ctx = dataclasses.replace(context, func=b)
                self._build_typeddict_des(ctx)
            else:
                b.l(f"{self.VNAME} = {anno_name}(**{self.VNAME})", )
        with func.b(
                f"elif isinstance({self.VNAME}, (list, set, frozenset, tuple)):"
        ) as b:
            if annotation.serde.fields:
                b.l(
                    f"{self.VNAME} = __bind({anno_name}, *{self.VNAME}).eval()",
                    __bind=self.resolver.bind,
                )
            else:
                b.l(f"{self.VNAME} = {anno_name}(*{self.VNAME})", )
        with func.b("else:") as b:
            b.l(
                f"{self.VNAME} = translate({self.VNAME}, {anno_name})",
                translate=self.resolver.translate,
            )

    def _build_mapping_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        key_des, item_des = None, None
        args = annotation.args
        if args:
            args = cast(Tuple[Type, Type], args)
            key_type, item_type = args
            key_des = self.resolver.resolve(key_type,
                                            flags=annotation.serde.flags,
                                            namespace=namespace)
            item_des = self.resolver.resolve(item_type,
                                             flags=annotation.serde.flags,
                                             namespace=namespace)
        if issubclass(annotation.resolved_origin, defaultdict):
            factory = self._get_default_factory(annotation)
            func.namespace[anno_name] = functools.partial(defaultdict, factory)
        kd_name = f"{anno_name}_key_des"
        it_name = f"{anno_name}_item_des"
        iterate = f"iterate({self.VNAME})"
        iterate_values = f"iterate({self.VNAME}, values=True)"
        line = f"{anno_name}({iterate})"
        line_values = f"{anno_name}({iterate_values})"
        if args or annotation.serde.fields_in:
            x, y = "x", "y"
            # If there are args & field mapping, get the correct field name
            # AND serialize the key.
            if args and annotation.serde.fields_in:
                x = f"{kd_name}(fields_in.get(x, x))"
            # If there is only a field mapping, get the correct name for the field.
            elif annotation.serde.fields_in:
                x = "fields_in.get(x, x)"
            # If there are only serializers, get the serialized value
            elif args:
                x = f"{kd_name}(x)"
                y = f"{it_name}(y)"
            line = f"{anno_name}({{{x}: {y} for x, y in {iterate}}})"
            line_values = f"{anno_name}({{{x}: {y} for x, y in {iterate_values}}})"
        # If we don't have nested annotations, we can short-circuit on valid inputs
        else:
            self._add_type_check(func, anno_name)
        # Write the lines.
        with func.b(f"if ismappingtype({self.VTYPE}):") as b:
            b.l(f"{self.VNAME} = {line}")
        with func.b("else:") as ob:
            with ob.b("try:") as b:
                b.l(f"{self.VNAME} = {line_values}")
            with ob.b("except (TypeError, ValueError):") as b:
                b.l(f"{self.VNAME} = {line}")
        func.namespace.update({
            kd_name: key_des,
            it_name: item_des,
            "Mapping": abc.Mapping,
            "iterate": self.resolver.iterate,
            "ismappingtype": checks.ismappingtype,
        })

    def _build_tuple_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        if annotation.args and annotation.args[-1] is not ...:
            item_des = {
                ix: self.resolver.resolve(t,
                                          flags=annotation.serde.flags,
                                          namespace=namespace)
                for ix, t in enumerate(annotation.args)
            }
            item_des_name = "item_des"
            iterate = f"iterate({self.VNAME}, values=True)"
            line = (
                f"{anno_name}"
                f"({item_des_name}[ix](v) for ix, v in enumerate({iterate})"
                f"if ix in {item_des_name})")
            func.l(
                f"{self.VNAME} = {line}",
                level=None,
                **{
                    item_des_name: item_des,
                    "iterate": self.resolver.iterate,
                },
            )
        else:
            self._build_collection_des(context)

    def _build_collection_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        item_des = None
        it_name = f"{anno_name}_item_des"
        iterate = f"iterate({self.VNAME}, values=True)"
        line = f"{self.VNAME} = {anno_name}({iterate})"
        if annotation.args:
            item_type = annotation.args[0]
            item_des = self.resolver.resolve(item_type,
                                             flags=annotation.serde.flags,
                                             namespace=namespace)
            line = (f"{self.VNAME} = "
                    f"{anno_name}({it_name}(x) for x in parent({iterate}))")
        else:
            self._add_type_check(func, anno_name)
        func.l(
            line,
            level=None,
            **{
                it_name: item_des,
                "Collection": abc.Collection,
                "iterate": self.resolver.iterate,
            },
        )

    def _build_decimal_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        self._add_type_check(func, anno_name)
        func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_path_des(self, context: BuildContext):
        func, anno_name = context.func, context.anno_name
        self._add_type_check(func, anno_name)
        func.l(f"{self.VNAME} = {anno_name}({self.VNAME})")

    def _build_user_type_des(self, context: BuildContext):
        func, annotation, namespace, anno_name = (
            context.func,
            context.annotation,
            context.namespace,
            context.anno_name,
        )
        serde = annotation.serde
        resolved = annotation.resolved
        self._add_type_check(func, anno_name)
        # Main branch - we have a mapping for a user-defined class.
        # This is where the serde configuration comes in.
        # WINDY PATH AHEAD
        func.l("# Happy path - deserialize a mapping into the object.")
        with func.b(f"if ismappingtype({self.VTYPE}):",
                    ismappingtype=checks.ismappingtype) as b:
            # Universal line - transform input to known keys/values.
            # Specific values may change.
            def mainline(k, v):
                return f"{{{k}: {v} for x in fields_in.keys() & {self.VNAME}.keys()}}"

            # The "happy path" - e.g., no guesswork needed.
            def happypath(k, v, **ns):
                b.l(f"{self.VNAME} = {anno_name}(**{mainline(k, v)})", **ns)

            # Default X - translate given `x` to known input `x`
            x = "fields_in[x]"
            # No field name translation needs to happen.
            if {*serde.fields_in.keys()} == {*serde.fields_in.values()}:
                x = "x"

            # Default Y - get the given `y` with the given `x`
            y = f"{self.VNAME}[x]"
            # Get the intersection of known input fields and annotations.
            matched = {*serde.fields_in.values()} & serde.fields.keys()
            # Happy path! This is a `@typic.al` wrapped class.
            if self.resolver.known(resolved) or self.resolver.delayed(
                    resolved):
                happypath(x, y)
            # Secondary happy path! We know how to deserialize already.
            else:
                fields_in = serde.fields_in
                fnamespace = namespace or resolved
                if serde.fields and len(matched) == len(serde.fields_in):
                    desers = {
                        f: self.resolver._resolve_from_annotation(
                            serde.fields[f], namespace=fnamespace).transmute
                        for f in matched
                    }
                else:
                    protocols = self.resolver.protocols(
                        annotation.resolved_origin)
                    fields_in = {x: x for x in protocols}
                    desers = {f: p.transmute for f, p in protocols.items()}
                y = f"desers[{x}]({self.VNAME}[x])"
                happypath(x, y, desers=desers, fields_in=fields_in)

        # Secondary branch - we have some other input for a user-defined class
        func.l("# Unknown path, just try casting it directly.")
        with func.b(
                f"elif isbuiltinsubtype({self.VTYPE}) and not isnamedtuple({self.VTYPE}):",
                isbuiltinsubtype=checks.isbuiltinsubtype,
                isnamedtuple=checks.isnamedtuple,
        ) as b:
            b.l(f"{self.VNAME} = {anno_name}({self.VNAME})")
        # Final branch - user-defined class for another user-defined class
        func.l("# Two user-defined types, "
               "try to translate the input into the desired output.")
        with func.b("else:") as b:
            b.l(
                f"{self.VNAME} = translate({self.VNAME}, {anno_name})",
                translate=self.resolver.translate,
            )

    def _build_literal_des(self,
                           annotation: Annotation,
                           func_name: str,
                           namespace: Type = None):
        args = annotation.args
        types: Set[Type] = {a.__class__ for a in args}
        t = types.pop() if len(types) == 1 else Union[tuple(types)]
        t_anno = cast(
            Annotation,
            self.resolver.annotation(
                t,  # type: ignore
                name=annotation.parameter.name,
                # parameter=annotation.parameter,
                is_optional=annotation.optional,
                is_strict=annotation.strict,
                flags=annotation.serde.flags,
                default=annotation.parameter.default,
            ),
        )
        return self._build_des(t_anno, func_name, namespace)

    def _build_union_des(self, context: BuildContext):
        func, annotation, namespace = (
            context.func,
            context.annotation,
            context.namespace,
        )
        # Get all types which we may coerce to.
        args = (*(a for a in annotation.args
                  if a not in {None, Ellipsis, type(None)}), )
        if not args:
            return
        # Add a type-check, but exclude str|bytes, since those are too permissive.
        types = {a for a in args if a not in {str, bytes}}
        if types:
            with func.b(f"if {self.VTYPE} in types:", types=types) as b:
                b.l(f"return {self.VNAME}")
        # Get all custom types, which may have discriminators
        targets = (*(a for a in args if not checks.isstdlibtype(a)), )
        # We can only build a tagged union deserializer if all args are valid
        if args != targets:
            return self._build_generic_union_des(context)

        # Try to collect the field which will be the discriminator.
        # First, get a mapping of Type -> Proto & Type -> Fields
        tagged = get_tag_for_types(targets)
        # Just bail out if we can't find a key.
        if not tagged:
            return self._build_generic_union_des(context)
        # If we got a key, re-map the protocols to the value for each type.
        deserializers = {
            value: self.resolver.resolve(t, namespace=namespace).transmute
            for value, t in tagged.types_by_values
        }
        # Finally, build the deserializer
        func.namespace.update(
            tag=tagged.tag,
            desers=deserializers,
            empty=_empty,
        )
        with func.b(f"if issubclass({self.VTYPE}, Mapping):",
                    Mapping=abc.Mapping) as b:
            b.l(f"tag_value = {self.VNAME}.get(tag, empty)")
        with func.b("else:") as b:
            b.l(f"tag_value = getattr({self.VNAME}, tag, empty)")
        with func.b("if tag_value in desers:") as b:
            b.l(f"{self.VNAME} = desers[tag_value]({self.VNAME})")
        with func.b("else:") as b:
            b.l("raise ValueError("
                'f"Value is missing field {tag!r} with one of '
                '{(*desers,)}: {val!r}"'
                ")")

    def _build_generic_union_des(self, context: BuildContext):
        annotation, namespace, func = (
            context.annotation,
            context.namespace,
            context.func,
        )
        annos = {
            get_name(a): self.resolver.resolve(a, namespace=namespace)
            for a in annotation.args
            if a not in {None, Ellipsis, type(None)}
        }
        if annos:
            desers = {f"{n}_des": p.transmute for n, p in annos.items()}
            types = {n: p.annotation.resolved_origin for n, p in annos.items()}
            ctx: Mapping[str, Union[Type, DeserializerT]] = {**types, **desers}
            for name in annos:
                # Can't do subclass checks with these...
                if name in {"Literal", "Final"}:
                    continue
                with func.b(f"if issubclass({name}, {self.VTYPE}):") as b:
                    b.l(f"return {name}_des({self.VNAME})")
            for name in desers:
                with func.b("try:") as b:
                    b.l(f"return {name}({self.VNAME})")
                with func.b("except (TypeError, ValueError, KeyError):") as b:
                    b.l("pass")
            func.namespace.update(ctx)
            func.l(
                "raise ValueError("
                f'f"Value could not be deserialized into one of {(*annos,)}: {{val!r}}"'
                ")", )
            return False

    def _build_des(  # noqa: C901
        self,
        annotation: Annotation[Type[ObjectT]],
        func_name: str,
        namespace: Type = None,
    ) -> DeserializerT[ObjectT]:
        args = annotation.args
        # Get the "origin" of the annotation.
        # For natives and their typing.* equivs, this will be a builtin type.
        # For SpecialForms (Union, mainly) this will be the un-subscripted type.
        # For custom types or classes, this will be the same as the annotation.
        origin = annotation.resolved_origin
        anno_name = get_unique_name(origin)
        ns = {
            anno_name: origin,
            "parent": getattr(origin, "__parent__", origin),
            "issubclass": cached_issubclass,
            **annotation.serde.asdict(),
        }
        if checks.isliteral(origin):
            return self._build_literal_des(annotation, func_name, namespace)
        with gen.Block(ns) as main:
            with main.f(func_name, main.param(f"{self.VNAME}")) as func:
                needs_return = None
                context = BuildContext(annotation, ns, anno_name, func,
                                       namespace)
                if origin not in self.UNRESOLVABLE:
                    # Set our top-level sanity checks.
                    self._set_checks(func, anno_name, annotation)
                    # Move through our queue.
                    for check, handler in self._HANDLERS.items():
                        # If this is a valid type for this handler,
                        #   write the deserializer.
                        if check(origin, args):
                            needs_return = handler(self, context)
                            break
                # If the deserializer doesn't contain a return statement, add one.
                if needs_return is not False:
                    func.l(f"{gen.Keyword.RET} {self.VNAME}")
        deserializer = main.compile(ns=ns, name=func_name)
        return deserializer

    # Order is IMPORTANT! This is a FIFO queue.
    _HANDLERS: Mapping[HandlerCheckT, BuildHandlerT] = {
        # Special handler for Unions...
        lambda origin, args: checks.isuniontype(origin):
        _build_union_des,
        # Non-intersecting types (order doesn't matter here.
        lambda origin, args: checks.isdatetype(origin):
        _build_date_des,
        lambda origin, args: checks.istimetype(origin):
        _build_time_des,
        lambda origin, args: checks.istimedeltatype(origin):
        _build_timedelta_des,
        lambda origin, args: checks.isuuidtype(origin):
        _build_uuid_des,
        lambda origin, args: origin in {Pattern, re.Pattern}:
        _build_pattern_des,
        lambda origin, args: issubclass(origin, pathlib.Path):
        _build_path_des,
        lambda origin, args: checks.isdecimaltype(origin):
        _build_decimal_des,
        # MUST come before subtype check.
        lambda origin, args: (not args and checks.isbuiltintype(origin)):
        _build_builtin_des,
        # Psuedo-structured containers, should check before generics.
        lambda origin, args: checks.istypeddict(origin):
        _build_typeddict_des,
        lambda origin, args: checks.istypedtuple(origin):
        _build_typedtuple_des,
        lambda origin, args: checks.isnamedtuple(origin):
        _build_typedtuple_des,
        lambda origin, args: (not args and checks.isbuiltinsubtype(origin)):
        _build_builtin_des,
        # A mapping is a collection so must come before that check.
        lambda origin, args: checks.ismappingtype(origin):
        _build_mapping_des,
        # A tuple is a collection so must come before that check.
        lambda origin, args: checks.istupletype(origin):
        _build_tuple_des,
        # Generic collection handler
        lambda origin, args: checks.iscollectiontype(origin):
        _build_collection_des,
        # Catch-all for custom user types (user-defined classes).
        lambda origin, args: True:
        _build_user_type_des,
    }

    def factory(
        self,
        annotation: Annotation[Type[ObjectT]],
        namespace: Type = None,
    ) -> DeserializerT[ObjectT]:
        annotation.serde = annotation.serde or SerdeConfig()
        key = self._get_name(annotation)
        if key in self.__DES_CACHE:
            return self.__DES_CACHE[key]
        deserializer: Optional[DeserializerT] = None
        for check, des in self.__USER_DESS:
            if check(annotation.resolved):
                deserializer = des
                break
        if not deserializer:
            deserializer = self._build_des(annotation, key, namespace)
        self.__DES_CACHE[key] = deserializer
        return deserializer
Exemplo n.º 9
0
def _from_class(
    t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> ConstraintsProtocolT[VT]:
    if not istypeddict(t) and not isnamedtuple(t) and isbuiltinsubtype(t):
        return cast(
            ConstraintsProtocolT, _from_strict_type(t, nullable=nullable, name=name)
        )
    try:
        params: Dict[str, inspect.Parameter] = {**cached_signature(t).parameters}
        hints = cached_type_hints(t)
        for x in hints.keys() & params.keys():
            p = params[x]
            params[x] = inspect.Parameter(
                p.name, p.kind, default=p.default, annotation=hints[x]
            )
        for x in hints.keys() - params.keys():
            hint = hints[x]
            if not isclassvartype(hint):
                continue
            # Hack in the classvars as "parameters" to allow for validation.
            default = getattr(t, x, empty)
            args = get_args(hint)
            if not args:
                hint = ClassVar[default.__class__]  # type: ignore
            params[x] = inspect.Parameter(
                x, inspect.Parameter.KEYWORD_ONLY, default=default, annotation=hint
            )
    except (ValueError, TypeError):
        return cast(
            ConstraintsProtocolT, _from_strict_type(t, nullable=nullable, name=name)
        )
    name = name or get_name(t)
    items: Optional[frozendict.FrozenDict[Hashable, ConstraintsT]] = (
        frozendict.FrozenDict(_resolve_params(t, **params)) or None
    )
    required = frozenset(
        (
            pname
            for pname, p in params.items()
            if (
                p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD} and p.default is p.empty
            )
        )
    )
    has_varargs = any(
        p.kind in {p.VAR_KEYWORD, p.VAR_POSITIONAL} for p in params.values()
    )
    kwargs = {
        "type": t,
        "nullable": nullable,
        "name": name,
        "required_keys": required,
        "items": items,
        "total": not has_varargs,
    }
    cls = ObjectConstraints
    if istypeddict(t):
        cls = TypedDictConstraints
        kwargs.update(type=dict, ttype=t, total=getattr(t, "__total__", bool(required)))
    c = cls(**kwargs)  # type: ignore
    return cast(ConstraintsProtocolT, c)
Exemplo n.º 10
0
 def _build_des(  # noqa: C901
         self,
         annotation: "Annotation",
         func_name: str,
         namespace: Type = None) -> Callable:
     args = annotation.args
     # Get the "origin" of the annotation.
     # For natives and their typing.* equivs, this will be a builtin type.
     # For SpecialForms (Union, mainly) this will be the un-subscripted type.
     # For custom types or classes, this will be the same as the annotation.
     origin = annotation.resolved_origin
     anno_name = get_unique_name(origin)
     ns = {
         anno_name: origin,
         "parent": getattr(origin, "__parent__", origin),
         "issubclass": cached_issubclass,
         **annotation.serde.asdict(),
     }
     if checks.isliteral(origin):
         return self._build_literal_des(annotation, func_name, namespace)
     with gen.Block(ns) as main:
         with main.f(func_name, main.param(f"{self.VNAME}")) as func:
             if origin not in self.UNRESOLVABLE:
                 self._set_checks(func, anno_name, annotation)
                 if origin is Union:
                     self._build_union_des(func, annotation, namespace)
                 elif checks.isdatetype(origin):
                     self._build_date_des(func, anno_name, annotation)
                 elif checks.istimetype(origin):
                     self._build_time_des(func, anno_name, annotation)
                 elif checks.istimedeltatype(origin):
                     self._build_timedelta_des(func, anno_name, annotation)
                 elif checks.isuuidtype(origin):
                     self._build_uuid_des(func, anno_name, annotation)
                 elif origin in {Pattern, re.Pattern}:  # type: ignore
                     self._build_pattern_des(func, anno_name)
                 elif issubclass(origin, pathlib.Path):
                     self._build_path_des(func, anno_name)
                 elif not args and checks.isbuiltintype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.isfromdictclass(origin):
                     self._build_fromdict_des(func, anno_name)
                 elif checks.isenumtype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.istypeddict(origin):
                     self._build_typeddict_des(
                         func,
                         anno_name,
                         annotation,
                         total=origin.__total__,  # type: ignore
                         namespace=namespace,
                     )
                 elif checks.istypedtuple(origin) or checks.isnamedtuple(
                         origin):
                     self._build_typedtuple_des(func,
                                                anno_name,
                                                annotation,
                                                namespace=namespace)
                 elif not args and checks.isbuiltinsubtype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.ismappingtype(origin):
                     self._build_mapping_des(func,
                                             anno_name,
                                             annotation,
                                             namespace=namespace)
                 elif checks.istupletype(origin):
                     self._build_tuple_des(func,
                                           anno_name,
                                           annotation,
                                           namespace=namespace)
                 elif checks.iscollectiontype(origin):
                     self._build_collection_des(func,
                                                anno_name,
                                                annotation,
                                                namespace=namespace)
                 else:
                     self._build_generic_des(func,
                                             anno_name,
                                             annotation,
                                             namespace=namespace)
             func.l(f"{gen.Keyword.RET} {self.VNAME}")
     deserializer = main.compile(ns=ns, name=func_name)
     return deserializer
Exemplo n.º 11
0
    def _compile_serializer(self, annotation: "Annotation") -> SerializerT:
        # Check for an optional and extract the type if possible.
        func_name = self._get_name(annotation)
        # We've been here before...
        if func_name in self._serializer_cache:
            return self._serializer_cache[func_name]

        serializer: SerializerT
        origin = annotation.resolved_origin
        # Lazy shortcut for messy paths (Union, Any, ...)
        if origin in self._DYNAMIC or not annotation.static:
            serializer = self.resolver.primitive
        # Enums are special
        elif checks.isenumtype(annotation.resolved):
            serializer = self._compile_enum_serializer(annotation)
        # Primitives don't require further processing.
        # Just check for nullable and the correct type.
        elif origin in self._PRIMITIVES:
            ns: dict = {}
            with gen.Block(ns) as main:
                with main.f(
                        func_name,
                        main.param("o"),
                        main.param("lazy", default=False),
                        main.param("name", default=None),
                ) as func:
                    self._check_add_null_check(func, annotation)
                    self._add_type_check(func, annotation)
                    line = "o"
                    if annotation.origin in (type(o)
                                             for o in self.resolver.OPTIONALS):
                        line = "None"
                    func.l(f"{gen.Keyword.RET} {line}")

            serializer = main.compile(name=func_name, ns=ns)
            self._serializer_cache[func_name] = serializer

        # Defined cases are pre-compiled, but we have to check for optionals.
        elif origin in self._DEFINED:
            serializer = self._compile_defined_serializer(
                annotation, self._DEFINED[origin])
        elif issubclass(origin, (*self._DEFINED, )):
            serializer = self._compile_defined_subclass_serializer(
                origin, annotation)
        elif issubclass(origin, self._PRIMITIVES):
            serializer = self._compile_primitive_subclass_serializer(
                origin, annotation)
        else:
            # Build the function namespace
            anno_name = f"{func_name}_anno"
            ns = {anno_name: origin, **annotation.serde.asdict()}
            with gen.Block(ns) as main:
                with main.f(
                        func_name,
                        main.param("o"),
                        main.param("lazy", default=False),
                        main.param("name", default=None),
                ) as func:
                    # Mapping types need special nested processing as well
                    if not checks.istypeddict(origin) and issubclass(
                            origin, self._DICTITER):
                        self._build_dict_serializer(func, annotation)
                    # Array types need nested processing.
                    elif not checks.istypedtuple(origin) and issubclass(
                            origin, self._LISTITER):
                        self._build_list_serializer(func, annotation)
                    # Build a serializer for a structured class.
                    else:
                        self._build_class_serializer(func, annotation)
            serializer = main.compile(name=func_name, ns=ns)
            self._serializer_cache[func_name] = serializer
        return serializer