Ejemplo n.º 1
0
def handle_generic_field_type(field_type: AnyType, base: AnyType,
                              other: AnyType, covariant: bool) -> AnyType:
    contravariant = not covariant
    type_vars = None
    if is_type_var(base):
        type_vars = {base: field_type}
    if (get_origin(base) is not None and getattr(base, "__parameters__", ())
            and len(get_args(base)) == len(get_args(field_type))
            and not any(map(get_origin, get_args(base)))
            and not any(map(get_origin, get_args(field_type)))
            and not any(not is_type_var(base_arg) and base_arg != field_arg
                        for base_arg, field_arg in zip(get_args(base),
                                                       get_args(field_type)))):
        type_vars = {}
        for base_arg, field_arg in zip(get_args(base), get_args(field_type)):
            if base_arg in type_vars and type_vars[base_arg] != field_arg:
                type_vars = None
                break
            type_vars[base_arg] = field_arg
        field_type_origin, base_origin = get_origin(field_type), get_origin(
            base)
        assert field_type_origin is not None and base_origin is not None
        if base_origin != field_type_origin:
            if covariant and not issubclass(base_origin, field_type_origin):
                type_vars = None
            if contravariant and not issubclass(field_type_origin,
                                                base_origin):
                type_vars = None
    return resolve_type_vars(other, type_vars)
Ejemplo n.º 2
0
 def generic(self, cls: AnyType) -> Return:
     type_vars = self._type_vars
     try:
         self._type_vars = type_var_context(cls, self._type_vars)
         return self._visit(get_origin(cls))
     finally:
         self._type_vars = type_vars
Ejemplo n.º 3
0
 def generic(self, cls: AnyType) -> JsonSchema:
     origin = get_origin(cls)
     if is_hashable(origin) and self.is_extra_conversions(origin):
         self._schema = None
     else:
         self._merge_schema(get_schema(origin))
     return super().generic(cls)
Ejemplo n.º 4
0
 def subscriptable_origin(cls: AnyType) -> AnyType:
     if (
         type(cls) == type(List[int])  # noqa: E721
         and cls.__module__ == "typing"
         and hasattr(cls, "_name")
     ):
         return getattr(typing, cls._name)
     else:
         return get_origin(cls)
Ejemplo n.º 5
0
 def visit_conversion(
     self,
     tp: AnyType,
     conversion: Optional[Deserialization],
     dynamic: bool,
     next_conversion: Optional[AnyConversion] = None,
 ) -> DeserializationMethodFactory:
     factory = super().visit_conversion(tp, conversion, dynamic,
                                        next_conversion)
     if factory.coercer is None and self._coerce:
         factory = replace(factory, coercer=self._coercer)
     if not dynamic:
         factory = factory.merge(get_constraints(get_schema(tp)),
                                 get_validators(tp))
         if get_args(tp):
             factory = factory.merge(
                 get_constraints(get_schema(get_origin(tp))),
                 get_validators(get_origin(tp)),
             )
     return factory
Ejemplo n.º 6
0
def get_type_name(tp: AnyType) -> TypeName:
    from apischema import settings

    tp = replace_builtins(tp)
    with suppress(KeyError, TypeError):
        return _type_names[tp].to_type_name(tp)
    origin, args = get_origin(tp), get_args(tp)
    if args and not has_type_vars(tp):
        with suppress(KeyError, TypeError):
            return _type_names[origin].to_type_name(origin, *args)
    return settings.default_type_name(tp) or TypeName()
Ejemplo n.º 7
0
def handle_generic_conversions(base: AnyType,
                               other: AnyType) -> Tuple[AnyType, AnyType]:
    origin = get_origin(base)
    if origin is None:
        return base, other
    args = get_args(base)
    if not all(map(is_type_var, args)):
        raise TypeError(
            f"Generic conversion doesn't support specialization,"
            f" aka {type_name(base)}[{','.join(map(type_name, args))}]")
    return origin, resolve_type_vars(other,
                                     dict(zip(args, get_parameters(origin))))
Ejemplo n.º 8
0
 def visit(self, cls: AnyType) -> Return:
     if get_origin(cls) is not None:
         return self._visit_generic(cls)
     if is_type_var(cls):
         return self.visit(self._resolve_type_vars(cls))
     else:
         type_vars = self._type_vars
         self._type_vars = None
         try:
             return self._visit(cls)
         finally:
             self._type_vars = type_vars
Ejemplo n.º 9
0
 def check_type(self, cls: AnyType):
     if hasattr(cls, "__supertype__") and not is_builtin(cls):
         # NewType of non-builtin types cannot have a ref because their serialization
         # could be customized, but the NewType ref would then erase this
         # customization in the schema.
         raise TypeError("NewType of non-builtin type can not have a ref")
     if is_type_var(cls):
         raise TypeError("TypeVar cannot have a ref")
     elif getattr(cls, "__parameters__", ()):
         raise TypeError("Unspecialized generic types cannot have a ref")
     if get_origin(cls) is not None and self.ref is ...:
         raise TypeError(f"Generic alias {cls} cannot have ... ref")
Ejemplo n.º 10
0
def _get_methods(
    tp: AnyType, all_methods: Mapping[Type, Mapping[str, S]]
) -> Mapping[str, Tuple[S, Mapping[str, AnyType]]]:
    result = {}
    for base in reversed(generic_mro(tp)):
        for name, method in all_methods[get_origin_or_type(base)].items():
            result[name] = (method, method.types(base))
    if has_model_origin(tp):
        origin = get_model_origin(tp)
        if get_args(tp):
            substitution = dict(
                zip(get_parameters(get_origin(tp)), get_args(tp)))
            origin = substitute_type_vars(origin, substitution)
        result.update(_get_methods(origin, all_methods))
    return result
Ejemplo n.º 11
0
def with_validation_error(
        func: Callable[..., ValidatorResult[T]]) -> Callable[..., T]:
    if not isgeneratorfunction(func):
        raise TypeError("func must be a generator returning a ValidatorResult")
    wrapper = yield_to_raise(func)
    if "return" in func.__annotations__:
        ret = func.__annotations__["return"]
        if isinstance(ret, str):
            match = re.match(r"ValidatorResult\[(?P<ret>.*)\]", ret)
            if match is not None:
                ret = match.groupdict("ret")
        elif get_origin(ret) == GeneratorOrigin:
            ret = get_args(ret)[2]
        wrapper.__annotations__["return"] = ret
    return wrapper
Ejemplo n.º 12
0
 def visit_conversion(
     self,
     tp: AnyType,
     conversion: Optional[Conv],
     dynamic: bool,
     next_conversion: Optional[AnyConversion] = None,
 ) -> TypeFactory[GraphQLTp]:
     if not dynamic and self.is_id(tp) or tp == ID:
         return TypeFactory(lambda *_: graphql.GraphQLNonNull(self.id_type))
     factory = super().visit_conversion(tp, conversion, dynamic,
                                        next_conversion)
     if not dynamic:
         factory = factory.merge(get_type_name(tp), get_schema(tp))
         if get_args(tp):
             factory = factory.merge(schema=get_schema(get_origin(tp)))
     return factory  # type: ignore
Ejemplo n.º 13
0
 def _visit_generic(self, cls: AnyType) -> Return:
     origin, args = get_origin(cls), get_args(cls)
     assert origin is not None
     if origin is Annotated:
         return self.annotated(args[0], args[1:])
     if origin is Union:
         return self.union(args)
     if origin is TUPLE_TYPE:
         if len(args) < 2 or args[1] is not ...:
             return self.tuple(args)
     if origin in COLLECTION_TYPES:
         return self.collection(origin, args[0])
     if origin in MAPPING_TYPES:
         return self.mapping(origin, args[0], args[1])
     if origin is Literal:  # pragma: no cover py37+
         return self.literal(args)
     return self.generic(cls)
Ejemplo n.º 14
0
def process_node(node_cls: Type[Node]):
    if has_type_vars(node_cls) or node_cls.get_by_id is Node.get_by_id:
        return
    for base in node_cls.__mro__:
        if base != Node and Node.get_by_id.__name__ in base.__dict__:
            if not isinstance(
                base.__dict__[Node.get_by_id.__name__], (classmethod, staticmethod)
            ):
                raise TypeError(
                    f"{node_cls.__name__}.get_by_id must be a"
                    f" classmethod/staticmethod"
                )
            break
    for base in generic_mro(node_cls):
        if get_origin(base) == Node:
            setattr(node_cls, ID_TYPE_ATTR, get_args(base)[0])
            _nodes[node_cls._node_key()] = node_cls
            break
    else:
        raise TypeError("Node type parameter Id must be specialized")
Ejemplo n.º 15
0
)
from apischema.typing import (
    _LiteralMeta,
    _TypedDictMeta,
    get_args,
    get_origin,
    get_type_hints,
)
from apischema.utils import is_type_var

try:
    from apischema.typing import Annotated, Literal
except ImportError:
    Annotated, Literal = ..., ...  # type: ignore

TUPLE_TYPE = get_origin(Tuple[Any])


@lru_cache()
def type_hints_cache(obj) -> Mapping[str, AnyType]:
    # Use immutable return because of cache
    return MappingProxyType(get_type_hints(obj, include_extras=True))


class Unsupported(TypeError):
    def __init__(self, cls: Type):
        self.cls = cls


Return = TypeVar("Return", covariant=True)
Ejemplo n.º 16
0
def is_skipped(cls: AnyType, *, schema_only) -> bool:
    return cls is UndefinedType or (
        get_origin(cls) is Annotated and
        (Skip in get_args(cls)[1:] or
         (schema_only and SkipSchema in get_args(cls)[1:])))
Ejemplo n.º 17
0
        ({
            T: int
        }, int, int),
        ({
            T: int
        }, T, int),
        ({
            T: int
        }, Foo[T], Foo[int]),
    ],
)
def test_resolve_type_vars_no_context(ctx, tv, expected):
    assert resolve_type_vars(tv, ctx) == expected


T0 = next(iter(get_parameters(get_origin(Deque[Any]))))


@mark.parametrize(
    "ctx, cls, expected",
    [
        (None, Foo[int], {
            T: int
        }),
        (None, Foo[U], {
            T: Any
        }),
        ({
            T: int
        }, Foo[T], {
            T: int
Ejemplo n.º 18
0
                    supertype_to_subtype[super_arg] = base_arg
                if is_type_var(base_arg):
                    subtype_to_supertype[base_arg] = super_arg
            break
    return supertype_to_subtype, subtype_to_supertype


def literal_values(values: Sequence[Any]) -> Sequence[Any]:
    if any(
            type(v) not in PRIMITIVE_TYPES and not isinstance(v, Enum)
            for v in values):
        raise TypeError("Only primitive types are supported for Literal/Enum")
    return [v.value if isinstance(v, Enum) else v for v in values]


awaitable_origin = get_origin(Awaitable[Any])


def is_async(func: Callable, types: Mapping[str, AnyType] = None) -> bool:
    wrapped_func = func
    while hasattr(wrapped_func, "__wrapped__"):
        wrapped_func = wrapped_func.__wrapped__  # type: ignore
    if inspect.iscoroutinefunction(wrapped_func):
        return True
    if types is None:
        try:
            types = get_type_hints(func)
        except Exception:
            types = {}
    return get_origin_or_type2(types.get("return")) == awaitable_origin
Ejemplo n.º 19
0
def type_var_context(cls: AnyType,
                     type_vars: TypeVarContext = None) -> TypeVarContext:
    cls = resolve_type_vars(cls, type_vars)
    origin = get_origin(cls)
    assert origin is not None
    return dict(zip(get_parameters(origin), get_args(cls)))
Ejemplo n.º 20
0
def edge_name(tp: Type["Edge"], *args) -> str:
    for base in generic_mro(tp[tuple(args)] if args else tp):  # type: ignore
        if get_origin(base) == Edge:
            return f"{get_node_name(get_args(base)[0])}Edge"
    raise NotImplementedError
Ejemplo n.º 21
0
def get_origin_or_type(tp: AnyType) -> AnyType:
    origin = get_origin(tp)
    return origin if origin is not None else tp
Ejemplo n.º 22
0
 def return_type(self) -> AnyType:
     ret = self.types["return"]
     return get_args(ret)[0] if get_origin(ret) == awaitable_origin else ret
Ejemplo n.º 23
0
def get_origin2(tp: AnyType) -> Optional[Type]:
    return get_origin(_annotated(tp))
Ejemplo n.º 24
0
def connection_name(tp: Type["Connection"], *args) -> str:
    for base in generic_mro(tp[tuple(args)] if args else tp):  # type: ignore
        if get_origin(base) == Connection:
            return f"{get_node_name(get_args(base)[0])}Connection"
    raise NotImplementedError
Ejemplo n.º 25
0
def graphql_schema(
    *,
    query: Iterable[Callable] = (),
    mutation: Iterable[Callable] = (),
    subscription: Iterable[Union[Subscribe, Tuple[Subscribe, Callable]]] = (),
    types: Iterable[Type] = (),
    aliaser: Aliaser = to_camel_case,
    id_types: Union[Collection[AnyType], IdPredicate] = None,
    error_as_null: bool = True,
    generic_ref_factory: GenericRefFactory = None,
    union_ref_factory: UnionRefFactory = None,
    directives: Optional[Collection[graphql.GraphQLDirective]] = None,
    description: Optional[str] = None,
    extensions: Optional[Dict[str, Any]] = None,
) -> graphql.GraphQLSchema:
    def operation_resolver(operation: Callable,
                           *,
                           skip_first=False) -> Resolver:
        if skip_first:
            wrapper = operation
        else:

            def wrapper(_, *args, **kwargs):
                return operation(*args, **kwargs)

        parameters = resolver_parameters(operation, skip_first=skip_first)
        return Resolver(operation, wrapper, parameters)

    query_fields: List[ObjectField] = []
    mutation_fields: List[ObjectField] = []
    subscription_fields: List[ObjectField] = []
    for operations, fields in [(query, query_fields),
                               (mutation, mutation_fields)]:
        for operation in operations:
            resolver = operation_resolver(operation)
            fields.append(
                ObjectField(
                    operation.__name__,
                    wrap_return_type(resolver.return_type, error_as_null),
                    resolve=resolver_resolve(resolver, aliaser, error_as_null),
                    parameters=field_parameters(resolver),
                    schema=get_schema(operation),
                ))
    for operation in subscription:  # type: ignore
        resolve: Callable
        if isinstance(operation, tuple):
            operation, event_handler = operation
            name, schema = event_handler.__name__, get_schema(event_handler)
            try:
                resolver = operation_resolver(event_handler, skip_first=True)
            except MissingFirstParameter:
                raise TypeError(
                    "Subscription resolver must have at least one parameter"
                ) from None
            return_type = resolver.return_type
            subscribe = resolver_resolve(
                operation_resolver(operation),
                aliaser,
                error_as_null,
                serialized=False,
            )
            resolve = resolver_resolve(resolver, aliaser, error_as_null)
        else:
            name, schema = operation.__name__, get_schema(operation)
            resolver = operation_resolver(operation)
            if get_origin(resolver.return_type) not in async_iterable_origins:
                raise TypeError(
                    "Subscriptions must return an AsyncIterable/AsyncIterator")
            return_type = get_args(resolver.return_type)[0]
            subscribe = resolver_resolve(resolver,
                                         aliaser,
                                         error_as_null,
                                         serialized=False)

            def resolve(_, *args, **kwargs):
                return _

        subscription_fields.append(
            ObjectField(
                name,
                wrap_return_type(return_type, error_as_null),
                parameters=field_parameters(resolver),
                resolve=resolve,
                subscribe=subscribe,
                schema=schema,
            ))

    is_id = id_types.__contains__ if isinstance(id_types,
                                                Collection) else id_types
    builder = OutputSchemaBuilder(aliaser, is_id, error_as_null,
                                  generic_ref_factory, union_ref_factory)

    def root_type(
        name: str, fields: Collection[ObjectField]
    ) -> Optional[graphql.GraphQLObjectType]:
        if not fields:
            return None
        return exec_thunk(builder.object(type(name, (), {}), fields),
                          non_null=False)

    return graphql.GraphQLSchema(
        root_type("Query", query_fields),
        root_type("Mutation", mutation_fields),
        root_type("Subscription", subscription_fields),
        [exec_thunk(builder.visit(cls), non_null=False) for cls in types],
        directives,
        description,
        extensions,
    )
Ejemplo n.º 26
0
def get_origin_or_type2(tp: AnyType) -> AnyType:
    tp2 = _annotated(tp)
    origin = get_origin(tp2)
    return origin if origin is not None else tp2
Ejemplo n.º 27
0
    Lazy,
    PREFIX,
    context_setter,
    deprecate_kwargs,
    get_origin_or_type,
    identity,
    literal_values,
    opt_or,
)
from apischema.validation import get_validators
from apischema.validation.errors import ErrorKey, ValidationError, merge_errors
from apischema.validation.mock import ValidatorMock
from apischema.validation.validators import Validator, validate
from apischema.visitor import Unsupported

DICT_TYPE = get_origin(Dict[Any, Any])
LIST_TYPE = get_origin(List[Any])

MISSING_PROPERTY = ValidationError(["missing property"])
UNEXPECTED_PROPERTY = ValidationError(["unexpected property"])

NOT_NONE = object()

INIT_VARS_ATTR = f"{PREFIX}_init_vars"

T = TypeVar("T")

DeserializationMethod = Callable[[Any], T]


@dataclass(frozen=True)