Exemplo n.º 1
0
    def _compile_translator(
        self, source: Type, target: Type, exclude: Tuple[str, ...] = ()
    ) -> TranslatorT:
        if isliteral(target):
            raise TranslatorTypeError(
                f"Cannot translate to literal type: {target!r}. "
            ) from None
        if isliteral(source):
            raise TranslatorTypeError(
                f"Cannot translate from literal type: {source!r}. "
            ) from None
        # Get the target fields for translation.
        target_fields = self.get_fields(target)
        if target_fields is None:
            if isiterabletype(target):
                return self._compile_iterable_translator(source, target)
            raise TranslatorTypeError(
                f"Cannot translate to type {target!r}. "
                f"Unable to determine target fields."
            ) from None

        # Ensure that the target fields are a subset of the source fields.
        # We treat the target fields as the parameters for the target,
        # so this must be true.
        fields = self.get_fields(source, as_source=True, exclude=exclude) or {}
        fields_to_pass = {x: fields[x] for x in fields.keys() & target_fields.keys()}
        required = self.required_fields(target_fields)
        if not required.issubset(fields_to_pass.keys()):
            diff = (*(required - fields.keys()),)
            raise TranslatorValueError(
                f"{source!r} can't be translated to {target!r}. "
                f"Source is missing required fields: {diff}."
            ) from None
        protocols = self.resolver.protocols(target)

        # Build the translator.
        anno_name = get_unique_name(source)
        target_name = get_unique_name(target)
        func_name = self._get_name(source, target)
        oname = "o"
        ctx: Dict[str, Any] = {target_name: target, anno_name: source}
        with Block(ctx) as main:
            with main.f(func_name, Block.p(oname)) as func:
                args = ", ".join(
                    self._iter_field_assigns(fields_to_pass, oname, protocols, ctx)
                )
                func.l(f"{Keyword.RET} {target_name}({args})")
        trans = main.compile(name=func_name, ns=ctx)
        return trans
Exemplo n.º 2
0
def get_tag_for_types(types: Tuple[Type, ...]) -> Optional[TaggedUnion]:
    if any(
        t in {None, ...} or not inspect.isclass(t) or checks.isstdlibtype(t)
        for t in types
    ):
        return None
    if len(types) > 1:
        root = types[0]
        root_hints = cached_type_hints(root)
        intersection = {*root_hints}
        fields_by_type = {root: root_hints}
        t: Type
        for t in types[1:]:
            hints = cached_type_hints(t)
            intersection &= hints.keys()
            fields_by_type[t] = hints
        tag = None
        literal = False
        # If we have an intersection, check if it's constant value we can use
        # TODO: This won't support Generics in this state.
        #  We don't support generics yet (#119), but when we do,
        #  we need to add a branch for tagged unions from generics.
        while intersection and tag is None:
            f = intersection.pop()
            v = getattr(root, f, empty)
            if v is not empty and not isinstance(v, MemberDescriptorType):
                tag = f
                continue
            rhint = root_hints[f]
            if checks.isliteral(rhint):
                tag, literal = f, True
        if tag:
            if literal:
                tbv = (
                    *((a, t) for t in types for a in get_args(fields_by_type[t][tag])),
                )
            else:
                tbv = (*((getattr(t, tag), t) for t in types),)
            return TaggedUnion(
                tag=tag, types=types, isliteral=literal, types_by_values=tbv
            )
    return None
Exemplo n.º 3
0
 def _build_des(  # noqa: C901
     self,
     annotation: Annotation[Type[ObjectT]],
     func_name: str,
     namespace: Type = None,
 ) -> DeserializerT[ObjectT]:
     args = annotation.args
     # Get the "origin" of the annotation.
     # For natives and their typing.* equivs, this will be a builtin type.
     # For SpecialForms (Union, mainly) this will be the un-subscripted type.
     # For custom types or classes, this will be the same as the annotation.
     origin = annotation.resolved_origin
     anno_name = get_unique_name(origin)
     ns = {
         anno_name: origin,
         "parent": getattr(origin, "__parent__", origin),
         "issubclass": cached_issubclass,
         **annotation.serde.asdict(),
     }
     if checks.isliteral(origin):
         return self._build_literal_des(annotation, func_name, namespace)
     with gen.Block(ns) as main:
         with main.f(func_name, main.param(f"{self.VNAME}")) as func:
             needs_return = None
             context = BuildContext(annotation, ns, anno_name, func,
                                    namespace)
             if origin not in self.UNRESOLVABLE:
                 # Set our top-level sanity checks.
                 self._set_checks(func, anno_name, annotation)
                 # Move through our queue.
                 for check, handler in self._HANDLERS.items():
                     # If this is a valid type for this handler,
                     #   write the deserializer.
                     if check(origin, args):
                         needs_return = handler(self, context)
                         break
             # If the deserializer doesn't contain a return statement, add one.
             if needs_return is not False:
                 func.l(f"{gen.Keyword.RET} {self.VNAME}")
     deserializer = main.compile(ns=ns, name=func_name)
     return deserializer
Exemplo n.º 4
0
    def annotation(
        self,
        annotation: Type[ObjectT],
        name: str = None,
        parameter: Optional[inspect.Parameter] = None,
        is_optional: bool = None,
        is_strict: StrictModeT = None,
        flags: "SerdeFlags" = None,
        default: Any = EMPTY,
        namespace: Type = None,
    ) -> AnnotationT:
        """Get a :py:class:`Annotation` for this type.

        Unlike a :py:class:`ResolvedAnnotation`, this does not provide access to a
        serializer/deserializer/validator protocol.
        """
        flags = cast(
            "SerdeFlags",
            getattr(annotation, SERDE_FLAGS_ATTR, flags or SerdeFlags()))
        if parameter is None:
            parameter = inspect.Parameter(
                name or "_",
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
                annotation=annotation,
                default=default if checks.ishashable(default) else ...,
            )
        # Check for the super-type
        non_super = util.resolve_supertype(annotation)
        # Note, this may be a generic, like Union.
        orig = util.origin(annotation)
        use = non_super
        # Get the unfiltered args
        args = getattr(non_super, "__args__", None)
        # Set whether this is optional/strict
        is_optional = (is_optional or checks.isoptionaltype(non_super)
                       or parameter.default in self.OPTIONALS)
        is_strict = is_strict or checks.isstrict(non_super) or self.STRICT
        is_static = util.origin(use) not in self._DYNAMIC
        is_literal = checks.isliteral(use)
        # Determine whether we should use the first arg of the annotation
        while checks.should_unwrap(use) and args:
            is_optional = is_optional or checks.isoptionaltype(use)
            is_strict = is_strict or checks.isstrict(use)
            if is_optional and len(args) > 2:
                # We can't resolve this annotation.
                is_static = False
                use = Union[args[:-1]]
                break
            # Note that we don't re-assign `orig`.
            # This is intentional.
            # Special forms are needed for building the downstream validator.
            # Callers should be aware of this and perhaps use `util.origin` elsewhere.
            non_super = util.resolve_supertype(args[0])
            use = non_super
            args = util.get_args(use)
            is_static = util.origin(use) not in self._DYNAMIC
            is_literal = is_literal or checks.isliteral(use)

        # Only allow legal parameters at runtime, this has implementation implications.
        if is_literal:
            args = util.get_args(use)
            if any(not isinstance(a, self.LITERALS) for a in args):
                raise TypeError(
                    f"PEP 586: Unsupported parameters for 'Literal' type: {args}. "
                    "See https://www.python.org/dev/peps/pep-0586/"
                    "#legal-parameters-for-literal-at-type-check-time "
                    "for more information.")
        # The type definition doesn't exist yet.
        if use.__class__ is ForwardRef:
            module, localns = self.__module__, {}
            # Ideally we have a namespace from a parent class/function to the field
            if namespace:
                module = namespace.__module__
                localns = getattr(namespace, "__dict__", {})

            return ForwardDelayedAnnotation(
                ref=use,
                resolver=self,
                _name=name,
                parameter=parameter,
                is_optional=is_optional,
                is_strict=is_strict,
                flags=flags,
                default=default,
                module=module,
                localns=localns,
            )
        # The type definition is recursive or within a recursive loop.
        elif use is namespace or use in self.__stack:
            # If detected via stack, we can remove it now.
            # Otherwise we'll cause another recursive loop.
            if use in self.__stack:
                self.__stack.remove(use)
            return DelayedAnnotation(
                type=use,
                resolver=self,
                _name=name,
                parameter=parameter,
                is_optional=is_optional,
                is_strict=is_strict,
                flags=flags,
                default=default,
            )
        # Otherwise, add this type to the stack to prevent a recursive loop from elsewhere.
        if not checks.isstdlibtype(use):
            self.__stack.add(use)
        serde = (self._get_configuration(util.origin(use), flags)
                 if is_static and not is_literal else SerdeConfig(flags))

        anno = Annotation(
            resolved=use,
            origin=orig,
            un_resolved=annotation,
            parameter=parameter,
            optional=is_optional,
            strict=is_strict,
            static=is_static,
            serde=serde,
        )
        anno.translator = functools.partial(self.translator.factory,
                                            anno)  # type: ignore
        return anno
Exemplo n.º 5
0
    def get_field(
        self,
        protocol: SerdeProtocol,
        *,
        ro: bool = None,
        wo: bool = None,
        name: str = None,
        parent: Type = None,
    ) -> "SchemaFieldT":
        """Get a field definition for a JSON Schema."""
        if protocol.annotation in self.__stack:
            name = self.defname(protocol.annotation.resolved_origin, name)
            return self._check_optional(protocol.annotation,
                                        Ref(f"#/definitions/{name}"), ro, wo,
                                        name)
        anno = protocol.annotation
        if anno in self.__cache:
            return self.__cache[anno]
        # Get the default value
        # `None` gets filtered out down the line. this is okay.
        # If a field isn't required an empty default is functionally the same
        # as a default to None for the JSON schema.
        default = anno.parameter.default if anno.has_default else None
        # `use` is the based annotation we will use for building the schema
        use = getattr(anno.origin, "__parent__", anno.origin)
        # This is a flat optional, handle it separately from the Union block.
        use = anno.resolved if isuniontype(use) and not anno.args else use
        # If there's not a static annotation, short-circuit the rest of the checks.
        schema: SchemaFieldT
        if use in {Any, anno.EMPTY}:
            schema = self._check_optional(anno, UndeclaredSchemaField(), ro,
                                          wo, name)
            self.__cache[anno] = schema
            return schema

        # Unions are `anyOf`, get a new field for each arg and return.
        # {'type': ['string', 'integer']} ==
        #   {'anyOf': [{'type': 'string'}, {'type': 'integer'}]}
        # We don't care about syntactic sugar if it's functionally the same.
        if isuniontype(use):
            return self._handle_union(anno=anno,
                                      ro=ro,
                                      wo=wo,
                                      name=name,
                                      parent=parent)

        self.__stack.add(anno)

        # Check if this should be ro/wo
        if use in {ReadOnly, WriteOnly, Final}:
            ro = (use in {ReadOnly, Final}) or None
            wo = (use is WriteOnly) or None
            use = origin(anno.resolved)
            use = getattr(use, "__parent__", use)

        # Check for an enumeration
        enum_ = None
        # Functionally, literals are enumerations.
        if isliteral(use):
            enum_ = (*(a for a in anno.args if a is not None), )
            ts = {a.__class__ for a in enum_}
            use = Literal
            if len(ts) == 1:
                use = ts.pop()

        elif issubclass(use, enum.Enum):
            use = cast(Type[enum.Enum], use)
            enum_ = tuple(x.value for x in use)
            use = getattr(use._member_type_, "__parent__",
                          use._member_type_)  # type: ignore

        # If this is ro with a default, we can consider this a const
        # Which is an enum with a single value -
        # we don't currently honor `{'const': <val>}` since it's just syntactic sugar.
        if ro and default:
            enum_ = (default.value
                     if isinstance(default, enum.Enum) else default, )

        schema = self._build_field(
            use=use,
            protocol=protocol,
            parent=parent,
            enum_=enum_,
            default=default,
            ro=ro,
            wo=wo,
            name=name,
        )
        self.__cache[anno] = schema
        self.__stack.clear()
        return schema
Exemplo n.º 6
0
 def _build_des(  # noqa: C901
         self,
         annotation: "Annotation",
         func_name: str,
         namespace: Type = None) -> Callable:
     args = annotation.args
     # Get the "origin" of the annotation.
     # For natives and their typing.* equivs, this will be a builtin type.
     # For SpecialForms (Union, mainly) this will be the un-subscripted type.
     # For custom types or classes, this will be the same as the annotation.
     origin = annotation.resolved_origin
     anno_name = get_unique_name(origin)
     ns = {
         anno_name: origin,
         "parent": getattr(origin, "__parent__", origin),
         "issubclass": cached_issubclass,
         **annotation.serde.asdict(),
     }
     if checks.isliteral(origin):
         return self._build_literal_des(annotation, func_name, namespace)
     with gen.Block(ns) as main:
         with main.f(func_name, main.param(f"{self.VNAME}")) as func:
             if origin not in self.UNRESOLVABLE:
                 self._set_checks(func, anno_name, annotation)
                 if origin is Union:
                     self._build_union_des(func, annotation, namespace)
                 elif checks.isdatetype(origin):
                     self._build_date_des(func, anno_name, annotation)
                 elif checks.istimetype(origin):
                     self._build_time_des(func, anno_name, annotation)
                 elif checks.istimedeltatype(origin):
                     self._build_timedelta_des(func, anno_name, annotation)
                 elif checks.isuuidtype(origin):
                     self._build_uuid_des(func, anno_name, annotation)
                 elif origin in {Pattern, re.Pattern}:  # type: ignore
                     self._build_pattern_des(func, anno_name)
                 elif issubclass(origin, pathlib.Path):
                     self._build_path_des(func, anno_name)
                 elif not args and checks.isbuiltintype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.isfromdictclass(origin):
                     self._build_fromdict_des(func, anno_name)
                 elif checks.isenumtype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.istypeddict(origin):
                     self._build_typeddict_des(
                         func,
                         anno_name,
                         annotation,
                         total=origin.__total__,  # type: ignore
                         namespace=namespace,
                     )
                 elif checks.istypedtuple(origin) or checks.isnamedtuple(
                         origin):
                     self._build_typedtuple_des(func,
                                                anno_name,
                                                annotation,
                                                namespace=namespace)
                 elif not args and checks.isbuiltinsubtype(origin):
                     self._build_builtin_des(func, anno_name, annotation)
                 elif checks.ismappingtype(origin):
                     self._build_mapping_des(func,
                                             anno_name,
                                             annotation,
                                             namespace=namespace)
                 elif checks.istupletype(origin):
                     self._build_tuple_des(func,
                                           anno_name,
                                           annotation,
                                           namespace=namespace)
                 elif checks.iscollectiontype(origin):
                     self._build_collection_des(func,
                                                anno_name,
                                                annotation,
                                                namespace=namespace)
                 else:
                     self._build_generic_des(func,
                                             anno_name,
                                             annotation,
                                             namespace=namespace)
             func.l(f"{gen.Keyword.RET} {self.VNAME}")
     deserializer = main.compile(ns=ns, name=func_name)
     return deserializer