Exemplo n.º 1
0
def _from_union(
    t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> ConstraintsProtocolT:
    _nullable: bool = isoptionaltype(t)
    nullable = nullable or _nullable
    _args = get_args(t)[:-1] if _nullable else get_args(t)
    if len(_args) == 1:
        return get_constraints(_args[0], nullable=nullable, name=name, cls=cls)
    c = MultiConstraints(
        (*(get_constraints(a, nullable=nullable, cls=cls) for a in _args),),
        name=name,
        tag=get_tag_for_types(_args),
    )
    return cast(ConstraintsProtocolT, c)
Exemplo n.º 2
0
 def __post_init__(self):
     self.has_default = self.parameter.default is not self.EMPTY
     self.args = util.get_args(self.resolved)
     self.resolved_origin = util.origin(self.resolved)
     self.generic = getattr(self.resolved, "__origin__",
                            self.resolved_origin)
     self.is_class_var = isclassvartype(self.un_resolved)
Exemplo n.º 3
0
 def _handle_union(
     self,
     anno: Annotation,
     ro: Optional[bool],
     wo: Optional[bool],
     name: Optional[str],
     parent: Optional[Type],
 ):
     fields: List[SchemaFieldT] = []
     args = get_args(anno.un_resolved)
     for t in args:
         if t.__class__ is ForwardRef or t is parent:
             n = name or get_name(t)
             fields.append(Ref(f"#/definitions/{n}"))
             continue
         fields.append(
             self.get_field(resolver.resolve(t, namespace=parent),
                            parent=parent))
     schema = self._check_optional(
         anno,
         MultiSchemaField(
             title=name and self.defname(anno.resolved, name=name),
             anyOf=(*fields, ),
         ),
         ro,
         wo,
         name,
     )
     self.__cache[anno] = schema
     return schema
Exemplo n.º 4
0
 def _evaluate_contraints(self):
     type = self.t
     if checks.isoptionaltype(type):
         args = util.get_args(type)[:-1]
         type = args[0] if len(args) == 1 else Union[args]
         self.nullable = True
         self.t = type
     c = self.factory(type, nullable=self.nullable, name=self.name)
     return c
Exemplo n.º 5
0
def get_constraints(
    t: Type[VT],
    *,
    nullable: bool = False,
    name: str = None,
    cls: Optional[Type] = ...,  # type: ignore
) -> ConstraintsProtocolT[VT]:
    while should_unwrap(t):
        nullable = nullable or isoptionaltype(t)
        t = get_args(t)[0]
    if t is cls or t in __stack:
        dc = DelayedConstraints(
            t, nullable=nullable, name=name, factory=get_constraints
        )
        return cast(ConstraintsProtocolT, dc)
    if isforwardref(t):
        if cls is ...:  # pragma: nocover
            raise TypeError(
                f"Cannot build constraints for {t} without an enclosing class."
            )
        fdc = ForwardDelayedConstraints(
            t,  # type: ignore
            cls.__module__,
            localns=getattr(cls, "__dict__", {}).copy(),
            nullable=nullable,
            name=name,
            factory=get_constraints,
        )
        return cast(ConstraintsProtocolT, fdc)
    if isconstrained(t):
        c: ConstraintsProtocolT = t.__constraints__  # type: ignore
        if (c.name, c.nullable) != (name, nullable):
            return dataclasses.replace(c, name=name, nullable=nullable)
        return c
    if isenumtype(t):
        ec = _from_enum_type(t, nullable=nullable, name=name)  # type: ignore
        return cast(ConstraintsProtocolT, ec)
    if isabstract(t):
        return cast(
            ConstraintsProtocolT, _from_strict_type(t, nullable=nullable, name=name)
        )
    if isnamedtuple(t) or istypeddict(t):
        handler = _from_class
    else:
        ot = origin(t)
        if ot in {type, abc.Callable}:
            handler = _from_strict_type  # type: ignore
            t = ot
        else:
            handler = _CONSTRAINT_BUILDER_HANDLERS.get_by_parent(ot, _from_class)  # type: ignore

    __stack.add(t)
    c = handler(t, nullable=nullable, name=name, cls=cls)
    __stack.clear()
    return c
Exemplo n.º 6
0
 def _evaluate_contraints(self):
     globalns = sys.modules[self.module].__dict__.copy()
     try:
         type = evaluate_forwardref(self.ref, globalns or {}, self.localns
                                    or {})
     except NameError as e:  # pragma: nocover
         warnings.warn(
             f"Counldn't resolve forward reference: {e}. "
             f"Make sure this type is available in {self.module}.")
         type = object  # make it a no-op
     if checks.isoptionaltype(type):
         args = util.get_args(type)[:-1]
         type = args[0] if len(args) == 1 else Union[args]
         self.nullable = True
     c = self.factory(type, nullable=self.nullable, name=self.name)
     return c
Exemplo n.º 7
0
def _from_array_type(
    t: Type[Array], *, nullable: bool = False, name: str = None, cls: Type = None
) -> ArrayConstraintsT:
    args = get_args(t)
    constr_class = cast(
        Type[ArrayConstraintsT], _ARRAY_CONSTRAINTS_BY_TYPE.get_by_parent(origin(t))
    )
    # If we don't have args, then return a naive constraint
    if not args:
        return constr_class(nullable=nullable, name=name)

    if constr_class is TupleConstraints and ... not in args:
        items = _resolve_args(*args, cls=cls, nullable=nullable, multi=False)
        return constr_class(nullable=nullable, values=items, name=name)  # type: ignore

    items = _resolve_args(*args, cls=cls, nullable=nullable, multi=True)
    return constr_class(nullable=nullable, values=items, name=name)  # type: ignore
Exemplo n.º 8
0
def _from_mapping_type(
    t: Type[Mapping], *, nullable: bool = False, name: str = None, cls: Type = None
) -> Union[MappingConstraints, DictConstraints]:
    if isbuiltintype(t):
        return DictConstraints(nullable=nullable, name=name)
    base = getattr(t, "__origin__", t)
    constr_class: Union[Type[MappingConstraints], Type[DictConstraints]]
    constr_class = MappingConstraints
    if base is dict:
        constr_class = DictConstraints
    args = get_args(t)
    if not args:
        return constr_class(nullable=nullable, name=name)
    key_arg, value_arg = args
    key_items, value_items = (
        _resolve_args(key_arg, cls=cls),
        _resolve_args(value_arg, cls=cls),
    )
    return constr_class(
        keys=key_items, values=value_items, nullable=nullable, name=name  # type: ignore
    )
Exemplo n.º 9
0
    def _build_dict_serializer(self, func: gen.Function,
                               annotation: "Annotation"):
        # Check for args
        kser_: SerializerT
        vser_: SerializerT
        kser_, vser_ = self.resolver.primitive, self.resolver.primitive
        args = util.get_args(annotation.resolved)
        if args:
            kt, vt = args
            ktr: "AnnotationT" = self.resolver.annotation(
                kt, flags=annotation.serde.flags)
            vtr: "AnnotationT" = self.resolver.annotation(
                vt, flags=annotation.serde.flags)
            kser_, vser_ = (self.factory(ktr), self.factory(vtr))
        kser_ = self._build_key_serializer(f"{func.name}_kser", kser_,
                                           annotation)
        # Get the names for our important variables

        serdict = make_kv_serdict(annotation, kser_, vser_)

        self._finalize_mapping_serializer(func, serdict, annotation)
Exemplo n.º 10
0
 def _build_dict_serializer(self, func: gen.Function,
                            annotation: Annotation):
     # Check for args
     kser_: SerializerT
     vser_: SerializerT
     kser_, vser_ = cast(SerializerT, self.resolver.primitive), cast(
         SerializerT, self.resolver.primitive)
     args = util.get_args(annotation.resolved)
     if args:
         kt, vt = args
         ktr: Annotation = self.resolver.annotation(
             kt, flags=annotation.serde.flags)
         vtr: Annotation = self.resolver.annotation(
             vt, flags=annotation.serde.flags)
         kser_, vser_ = (self.factory(ktr), self.factory(vtr))
     # Add sanity checks.
     self._check_add_null_check(func, annotation)
     self._add_type_check(func, annotation)
     ns: Dict[str, Any] = {
         "kser": kser_,
         "vser": vser_,
     }
     ksercall = "kser(k)"
     if annotation.serde.fields_out:
         ns["fields_out"] = annotation.serde.fields_out
         ksercall = "kser(fields_out.get(k, k))"
     if annotation.serde.flags.case:
         ns["case"] = annotation.serde.flags.case.transformer
         ksercall = f"case({ksercall})"
     gencall = f"({ksercall}, vser(v)) for k, v in o.items()"
     itercall = f"{ksercall}: vser(v) for k, v in o.items()"
     if annotation.serde.flags.omit:
         ns["omit"] = annotation.serde.flags.omit
         gencall += "if v not in omit"
         itercall += "if v not in omit"
     func.l(f"{gen.Keyword.RET} ({gencall}) if lazy else {{{itercall}}}",
            **ns)
Exemplo n.º 11
0
    def annotation(
        self,
        annotation: Type[ObjectT],
        name: str = None,
        parameter: Optional[inspect.Parameter] = None,
        is_optional: bool = None,
        is_strict: StrictModeT = None,
        flags: "SerdeFlags" = None,
        default: Any = EMPTY,
        namespace: Type = None,
    ) -> AnnotationT:
        """Get a :py:class:`Annotation` for this type.

        Unlike a :py:class:`ResolvedAnnotation`, this does not provide access to a
        serializer/deserializer/validator protocol.
        """
        flags = cast(
            "SerdeFlags",
            getattr(annotation, SERDE_FLAGS_ATTR, flags or SerdeFlags()))
        if parameter is None:
            parameter = inspect.Parameter(
                name or "_",
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
                annotation=annotation,
                default=default if checks.ishashable(default) else ...,
            )
        # Check for the super-type
        non_super = util.resolve_supertype(annotation)
        # Note, this may be a generic, like Union.
        orig = util.origin(annotation)
        use = non_super
        # Get the unfiltered args
        args = getattr(non_super, "__args__", None)
        # Set whether this is optional/strict
        is_optional = (is_optional or checks.isoptionaltype(non_super)
                       or parameter.default in self.OPTIONALS)
        is_strict = is_strict or checks.isstrict(non_super) or self.STRICT
        is_static = util.origin(use) not in self._DYNAMIC
        is_literal = checks.isliteral(use)
        # Determine whether we should use the first arg of the annotation
        while checks.should_unwrap(use) and args:
            is_optional = is_optional or checks.isoptionaltype(use)
            is_strict = is_strict or checks.isstrict(use)
            if is_optional and len(args) > 2:
                # We can't resolve this annotation.
                is_static = False
                use = Union[args[:-1]]
                break
            # Note that we don't re-assign `orig`.
            # This is intentional.
            # Special forms are needed for building the downstream validator.
            # Callers should be aware of this and perhaps use `util.origin` elsewhere.
            non_super = util.resolve_supertype(args[0])
            use = non_super
            args = util.get_args(use)
            is_static = util.origin(use) not in self._DYNAMIC
            is_literal = is_literal or checks.isliteral(use)

        # Only allow legal parameters at runtime, this has implementation implications.
        if is_literal:
            args = util.get_args(use)
            if any(not isinstance(a, self.LITERALS) for a in args):
                raise TypeError(
                    f"PEP 586: Unsupported parameters for 'Literal' type: {args}. "
                    "See https://www.python.org/dev/peps/pep-0586/"
                    "#legal-parameters-for-literal-at-type-check-time "
                    "for more information.")
        # The type definition doesn't exist yet.
        if use.__class__ is ForwardRef:
            module, localns = self.__module__, {}
            # Ideally we have a namespace from a parent class/function to the field
            if namespace:
                module = namespace.__module__
                localns = getattr(namespace, "__dict__", {})

            return ForwardDelayedAnnotation(
                ref=use,
                resolver=self,
                _name=name,
                parameter=parameter,
                is_optional=is_optional,
                is_strict=is_strict,
                flags=flags,
                default=default,
                module=module,
                localns=localns,
            )
        # The type definition is recursive or within a recursive loop.
        elif use is namespace or use in self.__stack:
            # If detected via stack, we can remove it now.
            # Otherwise we'll cause another recursive loop.
            if use in self.__stack:
                self.__stack.remove(use)
            return DelayedAnnotation(
                type=use,
                resolver=self,
                _name=name,
                parameter=parameter,
                is_optional=is_optional,
                is_strict=is_strict,
                flags=flags,
                default=default,
            )
        # Otherwise, add this type to the stack to prevent a recursive loop from elsewhere.
        if not checks.isstdlibtype(use):
            self.__stack.add(use)
        serde = (self._get_configuration(util.origin(use), flags)
                 if is_static and not is_literal else SerdeConfig(flags))

        anno = Annotation(
            resolved=use,
            origin=orig,
            un_resolved=annotation,
            parameter=parameter,
            optional=is_optional,
            strict=is_strict,
            static=is_static,
            serde=serde,
        )
        anno.translator = functools.partial(self.translator.factory,
                                            anno)  # type: ignore
        return anno
Exemplo n.º 12
0
 def values(self):
     return util.get_args(self.type)
Exemplo n.º 13
0
def test_get_args(annotation, args):
    assert get_args(annotation) == args
Exemplo n.º 14
0
def _from_class(
    t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> ConstraintsProtocolT[VT]:
    if not istypeddict(t) and not isnamedtuple(t) and isbuiltinsubtype(t):
        return cast(
            ConstraintsProtocolT, _from_strict_type(t, nullable=nullable, name=name)
        )
    try:
        params: Dict[str, inspect.Parameter] = {**cached_signature(t).parameters}
        hints = cached_type_hints(t)
        for x in hints.keys() & params.keys():
            p = params[x]
            params[x] = inspect.Parameter(
                p.name, p.kind, default=p.default, annotation=hints[x]
            )
        for x in hints.keys() - params.keys():
            hint = hints[x]
            if not isclassvartype(hint):
                continue
            # Hack in the classvars as "parameters" to allow for validation.
            default = getattr(t, x, empty)
            args = get_args(hint)
            if not args:
                hint = ClassVar[default.__class__]  # type: ignore
            params[x] = inspect.Parameter(
                x, inspect.Parameter.KEYWORD_ONLY, default=default, annotation=hint
            )
    except (ValueError, TypeError):
        return cast(
            ConstraintsProtocolT, _from_strict_type(t, nullable=nullable, name=name)
        )
    name = name or get_name(t)
    items: Optional[frozendict.FrozenDict[Hashable, ConstraintsT]] = (
        frozendict.FrozenDict(_resolve_params(t, **params)) or None
    )
    required = frozenset(
        (
            pname
            for pname, p in params.items()
            if (
                p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD} and p.default is p.empty
            )
        )
    )
    has_varargs = any(
        p.kind in {p.VAR_KEYWORD, p.VAR_POSITIONAL} for p in params.values()
    )
    kwargs = {
        "type": t,
        "nullable": nullable,
        "name": name,
        "required_keys": required,
        "items": items,
        "total": not has_varargs,
    }
    cls = ObjectConstraints
    if istypeddict(t):
        cls = TypedDictConstraints
        kwargs.update(type=dict, ttype=t, total=getattr(t, "__total__", bool(required)))
    c = cls(**kwargs)  # type: ignore
    return cast(ConstraintsProtocolT, c)