Пример #1
0
def converter_types(
    converter: Converter,
    source: Optional[AnyType] = None,
    target: Optional[AnyType] = None,
    namespace: Dict[str, Any] = None,
) -> Tuple[AnyType, AnyType]:
    try:
        # in pre 3.9, Generic __new__ perturb signature of types
        if (isinstance(converter, type)
                and converter.__new__ is Generic.__new__ is not object.__new__
                and converter.__init__ is not object.__init__  # type: ignore
            ):
            parameters = list(
                signature(converter.__init__).parameters.values())[
                    1:]  # type: ignore
        else:
            parameters = list(signature(converter).parameters.values())
    except ValueError:  # builtin types
        if target is None and isclass(converter):
            target = cast(Type[Any], converter)
        if source is None:
            raise TypeError("Converter source is unknown") from None
    else:
        if not parameters:
            raise TypeError("converter must have at least one parameter")
        first_param, *other_params = parameters
        for p in other_params:
            if p.default is Parameter.empty and p.kind not in (
                    Parameter.VAR_POSITIONAL,
                    Parameter.VAR_KEYWORD,
            ):
                raise TypeError(
                    "converter must have at most one parameter without default"
                )
        if source is not None and target is not None:
            return source, target
        types = get_type_hints(converter, None, namespace, include_extras=True)
        if not types and isclass(converter):
            types = get_type_hints(
                converter.__new__, None, namespace,
                include_extras=True) or get_type_hints(
                    converter.__init__,
                    None,
                    namespace,
                    include_extras=True  # type: ignore
                )
        if source is None:
            try:
                source = types.pop(first_param.name)
            except KeyError:
                raise TypeError("converter source is unknown") from None
        if target is None:
            try:
                target = types.pop("return")
            except KeyError:
                if isclass(converter):
                    target = cast(Type, converter)
                else:
                    raise TypeError("converter target is unknown") from None
    return source, target
Пример #2
0
def validator(arg=None, *, field=None, discard=None, owner=None):
    if callable(arg):
        validator_ = Validator(arg, field, discard)
        if is_method(arg):
            cls = method_class(arg)
            if cls is None:
                if owner is not None:
                    raise TypeError(
                        "Validator owner cannot be set for class validator")
                return validator_
            elif owner is None:
                owner = cls
        if owner is None:
            try:
                first_param = next(iter(signature(arg).parameters))
                owner = get_origin_or_type2(get_type_hints(arg)[first_param])
            except Exception:
                raise ValueError("Validator first parameter must be typed")
        validator_._register(owner)
        return arg
    else:
        field = field or arg
        if field is not None:
            check_field_or_name(field)
        if discard is not None:
            if not isinstance(discard, Collection) or isinstance(discard, str):
                discard = [discard]
            for discarded in discard:
                check_field_or_name(discarded)
        return lambda func: validator(
            func, field=field, discard=discard, owner=owner)
Пример #3
0
    def decorator(method: MethodOrProperty):
        if owner is None and is_method(
                method) and method_class(method) is None:

            class Descriptor(MethodWrapper[MethodOrProperty]):
                def __set_name__(self, owner, name):
                    super().__set_name__(owner, name)
                    register(method_wrapper(method), owner, name)

            return Descriptor(method)
        else:
            owner2 = owner
            if is_method(method):
                if owner2 is None:
                    owner2 = method_class(method)
                method = method_wrapper(method)
            if owner2 is None:
                try:
                    hints = get_type_hints(method)
                    owner2 = get_origin_or_type2(hints[next(iter(hints))])
                except (KeyError, StopIteration):
                    raise TypeError(
                        "First parameter of method must be typed") from None
            assert not isinstance(method, property)
            register(cast(Callable, method), owner2, method.__name__)
            return method
Пример #4
0
def to_raw_deserializer(func: Callable) -> Converter:
    types = get_type_hints(func, include_extras=True)
    if "return" not in types:
        raise TypeError("Return must be annotated")
    sig = signature(func)
    fields: List[MakeDataclassField] = []
    kwargs_param = None
    for name, param in sig.parameters.items():
        if param.kind == Parameter.POSITIONAL_ONLY:  # pragma: no cover
            raise TypeError("Forbidden positional-only parameter")
        if param.kind == Parameter.VAR_POSITIONAL:
            raise TypeError("Forbidden variadic positional parameter")
        if param.kind == Parameter.VAR_KEYWORD:
            from apischema import properties

            field_ = field(default_factory=dict, metadata=properties)
            type_ = Mapping[str, types.get(name, Any)]  # type: ignore
            fields.append((name, type_, field_))  # type: ignore
            kwargs_param = name
            continue
        default = param.default if param.default is not Parameter.empty else MISSING
        try:
            fields.append((name, types[name], field(default=default)))
        except KeyError:
            raise TypeError("All parameters must be annotated")

    def converter(obj):
        kwargs = {f: getattr(obj, f) for f, _, _ in fields}
        if kwargs_param in kwargs:
            kwargs.update(kwargs.pop(kwargs_param))
        return func(**kwargs)

    cls = with_fields_set(make_dataclass(to_camel_case(func.__name__), fields))
    converter.__annotations__ = {"obj": cls, "return": types["return"]}
    return converter
Пример #5
0
def parameters_as_fields(
    func: Callable,
    parameters_metadata: Mapping[str,
                                 Mapping] = None) -> Sequence[ObjectField]:
    parameters_metadata = parameters_metadata or {}
    types = get_type_hints(func, include_extras=True)
    fields = []
    for param_name, param in inspect.signature(func).parameters.items():
        if param.kind is inspect.Parameter.POSITIONAL_ONLY:
            raise TypeError("Positional only parameters are not supported")
        param_type = types.get(param_name, Any)
        if param.kind in {
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
                inspect.Parameter.KEYWORD_ONLY,
        }:
            field = ObjectField(
                param_name,
                param_type,
                param.default is inspect.Parameter.empty,
                parameters_metadata.get(param_name, empty_dict),
                default=param.default,
            )
            fields.append(field)
        elif param.kind == inspect.Parameter.VAR_KEYWORD:
            field = ObjectField(
                param_name,
                Mapping[str, param_type],  # type: ignore
                False,
                properties | parameters_metadata.get(param_name, empty_dict),
                default_factory=dict,
            )
            fields.append(field)
    return fields
Пример #6
0
def is_async(func: Callable, types: Mapping[str, AnyType] = None) -> bool:
    wrapped_func = func
    while hasattr(wrapped_func, "__wrapped__"):
        wrapped_func = wrapped_func.__wrapped__  # type: ignore
    if inspect.iscoroutinefunction(wrapped_func):
        return True
    if types is None:
        try:
            types = get_type_hints(func)
        except Exception:
            types = {}
    return get_origin_or_type2(types.get("return")) == awaitable_origin
Пример #7
0
 def decorator(func: Func) -> Func:
     parameters = resolver_parameters(func, skip_first=False)
     types = get_type_hints(func)
     if get_origin_or_class(types[parameters[0].name]) == cls:
         parameters = parameters[1:]
         wrapper: Callable = func
     else:
         wrapper = lambda __, *args, **kwargs: func(*args, **kwargs
                                                    )  # noqa: E731
     _resolvers[cls][name
                     or func.__name__] = Resolver(func, wrapper, parameters,
                                                  conversions, schema)
     return func
Пример #8
0
def object_deserialization(
    func: Callable[..., T],
    *input_class_modifiers: Callable[[type], Any],
    parameters_metadata: Mapping[str, Mapping] = None,
) -> Any:
    fields = parameters_as_fields(func, parameters_metadata)
    types = get_type_hints(func, include_extras=True)
    if "return" not in types:
        raise TypeError("Object deserialization must be typed")
    return_type = types["return"]
    bases = ()
    if getattr(return_type, "__parameters__", ()):
        bases = (Generic[return_type.__parameters__], )  # type: ignore
    elif func.__name__ != "<lambda>":
        input_class_modifiers = (
            type_name(to_pascal_case(func.__name__)),
            *input_class_modifiers,
        )

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    input_cls = new_class(
        to_pascal_case(func.__name__),
        bases,
        exec_body=lambda ns: ns.update({"__init__": __init__}),
    )
    for modifier in input_class_modifiers:
        modifier(input_cls)
    set_object_fields(input_cls, fields)
    if any(f.additional_properties for f in fields):
        kwargs_param = next(f.name for f in fields if f.additional_properties)

        def wrapper(input):
            kwargs = input.kwargs.copy()
            kwargs.update(kwargs.pop(kwargs_param))
            return func(**kwargs)

    else:

        def wrapper(input):
            return func(**input.kwargs)

    wrapper.__annotations__["input"] = input_cls
    wrapper.__annotations__["return"] = return_type
    return wrapper
Пример #9
0
 def types(self, owner: AnyType = None) -> Mapping[str, AnyType]:
     types = get_type_hints(self.func, include_extras=True)
     if "return" not in types:
         if isclass(self.func):
             types["return"] = self.func
         else:
             raise TypeError("Function must be typed")
     types["return"] = self.return_type(types["return"])
     if get_args2(owner):
         first_param = next(iter(signature(self.func).parameters))
         substitution, _ = subtyping_substitution(
             types.get(first_param, get_origin_or_type2(owner)), owner)
         types = {
             name: substitute_type_vars(tp, substitution)
             for name, tp in types.items()
         }
     return types
Пример #10
0
def check_converter(
    converter: Converter,
    param: Optional[AnyType],
    ret: Optional[AnyType],
    namespace: Dict[str, Any] = None,
) -> Tuple[AnyType, AnyType]:
    try:
        parameters = iter(signature(converter).parameters.values())
    except ValueError:  # builtin types
        if ret is None and isclass(converter):
            ret = cast(Type[Any], converter)
        if param is None:
            raise TypeError("converter parameter must be typed")
    else:
        try:
            first = next(parameters)
        except StopIteration:
            raise TypeError("converter must have at least one parameter")
        types = get_type_hints(converter, None, namespace, include_extras=True)
        for p in parameters:
            if p.default is Parameter.empty and p.kind not in (
                    Parameter.VAR_POSITIONAL,
                    Parameter.VAR_KEYWORD,
            ):
                raise TypeError("converter must have at most one parameter "
                                "without default")
        if param is None:
            try:
                param = types.pop(first.name)
            except KeyError:
                raise TypeError("converter parameter must be typed")
        if ret is None:
            try:
                ret = types.pop("return")
            except KeyError:
                if isclass(converter):
                    ret = cast(Type, converter)
                else:
                    raise TypeError("converter return must be typed")
    return param, ret
Пример #11
0
 def __init_subclass__(cls, **kwargs):
     tags = set(getattr(cls, TAGS_ATTR, ()))
     types = get_type_hints(cls, include_extras=True)
     for tag, tp in types.items():
         if get_origin2(tp) == Tagged:
             tagged = cls.__dict__.get(tag, Tagged())
             setattr(cls, tag,
                     field(default=Undefined, metadata=tagged.metadata))
             cls.__annotations__[tag] = Union[get_args2(types[tag])[0],
                                              UndefinedType]
             tags.add(tag)
         elif tag not in tags:
             if get_origin2(tp) != ClassVar:
                 cls.__annotations__[tag] = ClassVar[tp]
             else:
                 raise TypeError(
                     "Only Tagged or ClassVar fields are allowed in TaggedUnion"
                 )
     setattr(cls, TAGS_ATTR, tags)
     schema(min_props=1, max_props=1)(dataclass(init=False,
                                                repr=False)(cls))
     for tag in tags:
         setattr(cls, tag, Tag(tag, cls))
Пример #12
0
def _fields_and_init(
    cls: type, fields_and_methods: Union[Iterable[Any],
                                         Callable[[], Iterable[Any]]]
) -> Tuple[Sequence[ObjectField], Callable[[Any, Any], None]]:
    fields = object_fields(cls)
    output_fields: Dict[str, ObjectField] = OrderedDict()
    methods = []
    if callable(fields_and_methods):
        fields_and_methods = fields_and_methods()
    for elt in fields_and_methods:
        if elt is ...:
            output_fields.update(fields)
            continue
        if isinstance(elt, tuple):
            elt, metadata = elt
        else:
            metadata = empty_dict
        if not isinstance(metadata, Mapping):
            raise TypeError(f"Invalid metadata {metadata}")
        if isinstance(elt, Field):
            elt = elt.name
        if isinstance(elt, str) and elt in fields:
            elt = fields[elt]
        if is_method(elt):
            elt = method_wrapper(elt)
        if isinstance(elt, ObjectField):
            if metadata:
                output_fields[elt.name] = replace(elt,
                                                  metadata={
                                                      **elt.metadata,
                                                      **metadata
                                                  },
                                                  default=MISSING_DEFAULT)
            else:
                output_fields[elt.name] = elt
            continue
        elif callable(elt):
            types = get_type_hints(elt)
            first_param = next(iter(inspect.signature(elt).parameters))
            substitution, _ = subtyping_substitution(
                types.get(first_param, with_parameters(cls)), cls)
            ret = substitute_type_vars(types.get("return", Any), substitution)
            output_fields[elt.__name__] = ObjectField(elt.__name__,
                                                      ret,
                                                      metadata=metadata)
            methods.append((elt, output_fields[elt.__name__]))
        else:
            raise TypeError(
                f"Invalid serialization member {elt} for class {cls}")

    serialized_methods = [m for m, f in methods if output_fields[f.name] is f]
    serialized_fields = list(output_fields.keys() -
                             {m.__name__
                              for m in serialized_methods})

    def __init__(self, obj):
        for field in serialized_fields:
            setattr(self, field, getattr(obj, field))
        for method in serialized_methods:
            setattr(self, method.__name__, method(obj))

    return tuple(output_fields.values()), __init__
Пример #13
0
 def error_type(self) -> AnyType:
     assert self.error_handler is not None
     types = get_type_hints(self.error_handler, include_extras=True)
     if "return" not in types:
         raise TypeError("Error handler must be typed")
     return types["return"]
Пример #14
0
def type_hints_cache(obj) -> Mapping[str, AnyType]:
    # Use immutable return because of cache
    return MappingProxyType(get_type_hints(obj, include_extras=True))
Пример #15
0
 def types(self) -> Mapping[str, AnyType]:
     return get_type_hints(self.func, include_extras=True)
Пример #16
0
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if not hasattr(cls, "mutate"):
            return
        if not isinstance(cls.__dict__["mutate"], (classmethod, staticmethod)):
            raise TypeError(
                f"{cls.__name__}.mutate must be a classmethod/staticmethod")
        mutate = getattr(cls, "mutate")
        type_name(f"{cls.__name__}Payload")(cls)
        types = get_type_hints(mutate,
                               localns={cls.__name__: cls},
                               include_extras=True)
        async_mutate = is_async(mutate, types)
        fields: List[Tuple[str, AnyType, Field]] = []
        cmi_param = None
        for param_name, param in signature(mutate).parameters.items():
            if param.kind is Parameter.POSITIONAL_ONLY:
                raise TypeError("Positional only parameters are not supported")
            if param.kind in {
                    Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY
            }:
                if param_name not in types:
                    raise TypeError("Mutation parameters must be typed")
                field_type = types[param_name]
                field_ = MISSING if param.default is Parameter.empty else param.default
                if is_union_of(field_type, ClientMutationId):
                    cmi_param = param_name
                    if cls._client_mutation_id is False:
                        if field_ is MISSING:
                            raise TypeError(
                                "Cannot have a ClientMutationId parameter"
                                " when _client_mutation_id = False")
                        continue
                    elif cls._client_mutation_id is True:
                        field_ = MISSING
                    field_ = field(default=field_,
                                   metadata=alias(CLIENT_MUTATION_ID))
                fields.append((param_name, field_type, field_))
        field_names = [name for (name, _, _) in fields]
        if cmi_param is None and cls._client_mutation_id is not False:
            fields.append((
                CLIENT_MUTATION_ID,
                ClientMutationId
                if cls._client_mutation_id else Optional[ClientMutationId],
                MISSING if cls._client_mutation_id else None,
            ))
            cmi_param = CLIENT_MUTATION_ID
        input_cls = make_dataclass(f"{cls.__name__}Input", fields)

        def wrapper(input):
            return mutate(
                **{name: getattr(input, name)
                   for name in field_names})

        wrapper.__annotations__["input"] = input_cls
        wrapper.__annotations__[
            "return"] = Awaitable[cls] if async_mutate else cls
        if cls._client_mutation_id is not False:
            cls.__annotations__[
                CLIENT_MUTATION_ID] = input_cls.__annotations__[cmi_param]
            setattr(cls, CLIENT_MUTATION_ID, field(init=False))
            wrapped = wrapper

            if async_mutate:

                async def wrapper(input):
                    result = await wrapped(input)
                    setattr(result, CLIENT_MUTATION_ID,
                            getattr(input, cmi_param))
                    return result

            else:

                def wrapper(input):
                    result = wrapped(input)
                    setattr(result, CLIENT_MUTATION_ID,
                            getattr(input, cmi_param))
                    return result

            wrapper = wraps(wrapped)(wrapper)

        cls._mutation = Mutation_(
            function=wrapper,
            alias=camel_to_snake(cls.__name__),
            schema=cls._schema,
            error_handler=cls._error_handler,
        )