Exemplo n.º 1
0
 def _has_conversion(
     tp: AnyType, conversion: Optional[AnyConversion]
 ) -> Tuple[bool, Optional[Deserialization]]:
     identity_conv, result = False, []
     for conv in resolve_any_conversion(conversion):
         conv = handle_identity_conversion(conv, tp)
         if is_subclass(conv.target, tp):
             if is_identity(conv):
                 if identity_conv:
                     continue
                 identity_conv = True
                 wrapper: AnyType = self_deserialization_wrapper(
                     get_origin_or_type(tp))
                 if get_args(tp):
                     wrapper = wrapper[get_args(tp)]
                 conv = ResolvedConversion(replace(conv, source=wrapper))
             conv = handle_dataclass_model(conv)
             _, substitution = subtyping_substitution(tp, conv.target)
             source = substitute_type_vars(conv.source, substitution)
             result.append(
                 ResolvedConversion(replace(conv, source=source,
                                            target=tp)))
     if identity_conv and len(result) == 1:
         return True, None
     else:
         return bool(result), tuple(result) or None
Exemplo n.º 2
0
def handle_generic_field_type(field_type: AnyType, base: AnyType,
                              other: AnyType, covariant: bool) -> AnyType:
    contravariant = not covariant
    type_vars = None
    if is_type_var(base):
        type_vars = {base: field_type}
    if (get_origin(base) is not None and getattr(base, "__parameters__", ())
            and len(get_args(base)) == len(get_args(field_type))
            and not any(map(get_origin, get_args(base)))
            and not any(map(get_origin, get_args(field_type)))
            and not any(not is_type_var(base_arg) and base_arg != field_arg
                        for base_arg, field_arg in zip(get_args(base),
                                                       get_args(field_type)))):
        type_vars = {}
        for base_arg, field_arg in zip(get_args(base), get_args(field_type)):
            if base_arg in type_vars and type_vars[base_arg] != field_arg:
                type_vars = None
                break
            type_vars[base_arg] = field_arg
        field_type_origin, base_origin = get_origin(field_type), get_origin(
            base)
        assert field_type_origin is not None and base_origin is not None
        if base_origin != field_type_origin:
            if covariant and not issubclass(base_origin, field_type_origin):
                type_vars = None
            if contravariant and not issubclass(field_type_origin,
                                                base_origin):
                type_vars = None
    return resolve_type_vars(other, type_vars)
Exemplo n.º 3
0
    def unsupported(self, tp: AnyType) -> Result:
        from apischema import settings

        origin = get_origin_or_type(tp)
        if isinstance(origin, type):
            fields = settings.default_object_fields(origin)
            if fields is not None:
                if get_args(tp):
                    sub = dict(zip(get_parameters(origin), get_args(tp)))
                    fields = [
                        replace(f, type=substitute_type_vars(f.type, sub))
                        for f in fields
                    ]
                return self._object(origin, fields)
        return super().unsupported(tp)
Exemplo n.º 4
0
def _get_methods(
    tp: AnyType, all_methods: Mapping[Type, Mapping[str, S]]
) -> Mapping[str, Tuple[S, Mapping[str, AnyType]]]:
    result = {}
    for base in reversed(generic_mro(tp)):
        for name, method in all_methods[get_origin_or_type(base)].items():
            result[name] = (method, method.types(base))
    if has_model_origin(tp):
        origin = get_model_origin(tp)
        if get_args(tp):
            substitution = dict(
                zip(get_parameters(get_origin(tp)), get_args(tp)))
            origin = substitute_type_vars(origin, substitution)
        result.update(_get_methods(origin, all_methods))
    return result
Exemplo n.º 5
0
def gather_errors(func: Callable[..., ValidatorResult[T]]) -> Callable[..., T]:
    if not isgeneratorfunction(func):
        raise TypeError("func must be a generator returning a ValidatorResult")

    @wraps(func)
    def wrapper(*args, **kwargs):
        result, errors = func(*args, **kwargs), []
        while True:
            try:
                errors.append(next(result))
            except StopIteration as stop:
                if errors:
                    raise build_validation_error(errors)
                return stop.value

    if "return" in func.__annotations__:
        ret = func.__annotations__["return"]
        if isinstance(ret, str):
            match = re.match(r"ValidatorResult\[(?P<ret>.*)\]", ret)
            if match is not None:
                ret = match.groupdict("ret")
        else:
            annotations = get_args(ret)[1:] if is_annotated(ret) else ()
            if get_origin2(ret) == GeneratorOrigin:
                ret = get_args2(ret)[2]
                if annotations:
                    ret = Annotated[(ret, *annotations)]
        wrapper.__annotations__["return"] = ret
    return wrapper
Exemplo n.º 6
0
 def full_metadata(self) -> Mapping[str, Any]:
     if not is_annotated(self.type):
         return self.metadata
     return ChainMap(
         self.metadata,
         *(arg for arg in reversed(get_args(self.type)[1:])
           if isinstance(arg, Mapping)),
     )
Exemplo n.º 7
0
def subtyping_substitution(
    supertype: AnyType, subtype: AnyType
) -> Tuple[Mapping[AnyType, AnyType], Mapping[AnyType, AnyType]]:
    supertype, subtype = with_parameters(supertype), with_parameters(subtype)
    supertype_to_subtype, subtype_to_supertype = {}, {}
    super_origin = get_origin_or_type2(supertype)
    for base in generic_mro(subtype):
        base_origin = get_origin_or_type2(base)
        if base_origin == super_origin or (base_origin in ITERABLE_TYPES
                                           and super_origin in ITERABLE_TYPES):
            for base_arg, super_arg in zip(get_args(base),
                                           get_args(supertype)):
                if is_type_var(super_arg):
                    supertype_to_subtype[super_arg] = base_arg
                if is_type_var(base_arg):
                    subtype_to_supertype[base_arg] = super_arg
            break
    return supertype_to_subtype, subtype_to_supertype
Exemplo n.º 8
0
def default_type_name(tp: AnyType) -> Optional[TypeName]:
    if (
        hasattr(tp, "__name__")
        and not get_args(tp)
        and not has_type_vars(tp)
        and tp not in PRIMITIVE_TYPES
    ):
        return TypeName(tp.__name__, tp.__name__)
    else:
        return None
Exemplo n.º 9
0
 def check_type(self, tp: AnyType):
     if is_type_var(tp):
         raise TypeError("TypeVar cannot have a type_name")
     if has_type_vars(tp):
         if get_args(tp):
             raise TypeError("Generic alias cannot have a type_name")
         elif isinstance(self.json_schema, str) or isinstance(self.graphql, str):
             raise TypeError(
                 "Unspecialized generic type must used factory type_name"
             )
Exemplo n.º 10
0
def get_type_name(tp: AnyType) -> TypeName:
    from apischema import settings

    tp = replace_builtins(tp)
    with suppress(KeyError, TypeError):
        return _type_names[tp].to_type_name(tp)
    origin, args = get_origin(tp), get_args(tp)
    if args and not has_type_vars(tp):
        with suppress(KeyError, TypeError):
            return _type_names[origin].to_type_name(origin, *args)
    return settings.default_type_name(tp) or TypeName()
Exemplo n.º 11
0
def handle_generic_conversions(base: AnyType,
                               other: AnyType) -> Tuple[AnyType, AnyType]:
    origin = get_origin(base)
    if origin is None:
        return base, other
    args = get_args(base)
    if not all(map(is_type_var, args)):
        raise TypeError(
            f"Generic conversion doesn't support specialization,"
            f" aka {type_name(base)}[{','.join(map(type_name, args))}]")
    return origin, resolve_type_vars(other,
                                     dict(zip(args, get_parameters(origin))))
Exemplo n.º 12
0
def merged_schema(
        schema: Optional[Schema],
        tp: Optional[AnyType]) -> Tuple[Optional[Schema], Mapping[str, Any]]:
    if is_annotated(tp):
        for annotation in reversed(get_args(tp)[1:]):
            if isinstance(annotation, TypeNameFactory):
                break
            elif isinstance(annotation,
                            Mapping) and SCHEMA_METADATA in annotation:
                schema = merge_schema(annotation[SCHEMA_METADATA], schema)
    schema_dict: Dict[str, Any] = {}
    if schema is not None:
        schema.merge_into(schema_dict)
    return schema, schema_dict
Exemplo n.º 13
0
def with_validation_error(
        func: Callable[..., ValidatorResult[T]]) -> Callable[..., T]:
    if not isgeneratorfunction(func):
        raise TypeError("func must be a generator returning a ValidatorResult")
    wrapper = yield_to_raise(func)
    if "return" in func.__annotations__:
        ret = func.__annotations__["return"]
        if isinstance(ret, str):
            match = re.match(r"ValidatorResult\[(?P<ret>.*)\]", ret)
            if match is not None:
                ret = match.groupdict("ret")
        elif get_origin(ret) == GeneratorOrigin:
            ret = get_args(ret)[2]
        wrapper.__annotations__["return"] = ret
    return wrapper
Exemplo n.º 14
0
 def visit_conversion(
     self,
     tp: AnyType,
     conversion: Optional[Conv],
     dynamic: bool,
     next_conversion: Optional[AnyConversion] = None,
 ) -> TypeFactory[GraphQLTp]:
     if not dynamic and self.is_id(tp) or tp == ID:
         return TypeFactory(lambda *_: graphql.GraphQLNonNull(self.id_type))
     factory = super().visit_conversion(tp, conversion, dynamic,
                                        next_conversion)
     if not dynamic:
         factory = factory.merge(get_type_name(tp), get_schema(tp))
         if get_args(tp):
             factory = factory.merge(schema=get_schema(get_origin(tp)))
     return factory  # type: ignore
Exemplo n.º 15
0
 def _visit_generic(self, cls: AnyType) -> Return:
     origin, args = get_origin(cls), get_args(cls)
     assert origin is not None
     if origin is Annotated:
         return self.annotated(args[0], args[1:])
     if origin is Union:
         return self.union(args)
     if origin is TUPLE_TYPE:
         if len(args) < 2 or args[1] is not ...:
             return self.tuple(args)
     if origin in COLLECTION_TYPES:
         return self.collection(origin, args[0])
     if origin in MAPPING_TYPES:
         return self.mapping(origin, args[0], args[1])
     if origin is Literal:  # pragma: no cover py37+
         return self.literal(args)
     return self.generic(cls)
Exemplo n.º 16
0
 def visit_conversion(
     self,
     tp: AnyType,
     conversion: Optional[Conv],
     dynamic: bool,
     next_conversion: Optional[AnyConversion] = None,
 ) -> JsonSchema:
     schemas = []
     if not dynamic:
         for ref_tp in self.resolve_conversion(tp):
             ref_schema = self.ref_schema(get_type_name(ref_tp).json_schema)
             if ref_schema is not None:
                 return ref_schema
         if get_args(tp):
             schemas.append(get_schema(get_origin_or_type(tp)))
         schemas.append(get_schema(tp))
     result = super().visit_conversion(tp, conversion, dynamic,
                                       next_conversion)
     return reduce(full_schema, schemas, result)
Exemplo n.º 17
0
 def visit_conversion(
     self,
     tp: AnyType,
     conversion: Optional[Deserialization],
     dynamic: bool,
     next_conversion: Optional[AnyConversion] = None,
 ) -> DeserializationMethodFactory:
     factory = super().visit_conversion(tp, conversion, dynamic,
                                        next_conversion)
     if factory.coercer is None and self._coerce:
         factory = replace(factory, coercer=self._coercer)
     if not dynamic:
         factory = factory.merge(get_constraints(get_schema(tp)),
                                 get_validators(tp))
         if get_args(tp):
             factory = factory.merge(
                 get_constraints(get_schema(get_origin(tp))),
                 get_validators(get_origin(tp)),
             )
     return factory
Exemplo n.º 18
0
def process_node(node_cls: Type[Node]):
    if has_type_vars(node_cls) or node_cls.get_by_id is Node.get_by_id:
        return
    for base in node_cls.__mro__:
        if base != Node and Node.get_by_id.__name__ in base.__dict__:
            if not isinstance(
                base.__dict__[Node.get_by_id.__name__], (classmethod, staticmethod)
            ):
                raise TypeError(
                    f"{node_cls.__name__}.get_by_id must be a"
                    f" classmethod/staticmethod"
                )
            break
    for base in generic_mro(node_cls):
        if get_origin(base) == Node:
            setattr(node_cls, ID_TYPE_ATTR, get_args(base)[0])
            _nodes[node_cls._node_key()] = node_cls
            break
    else:
        raise TypeError("Node type parameter Id must be specialized")
Exemplo n.º 19
0
def graphql_schema(
    *,
    query: Iterable[Callable] = (),
    mutation: Iterable[Callable] = (),
    subscription: Iterable[Union[Subscribe, Tuple[Subscribe, Callable]]] = (),
    types: Iterable[Type] = (),
    aliaser: Aliaser = to_camel_case,
    id_types: Union[Collection[AnyType], IdPredicate] = None,
    error_as_null: bool = True,
    generic_ref_factory: GenericRefFactory = None,
    union_ref_factory: UnionRefFactory = None,
    directives: Optional[Collection[graphql.GraphQLDirective]] = None,
    description: Optional[str] = None,
    extensions: Optional[Dict[str, Any]] = None,
) -> graphql.GraphQLSchema:
    def operation_resolver(operation: Callable,
                           *,
                           skip_first=False) -> Resolver:
        if skip_first:
            wrapper = operation
        else:

            def wrapper(_, *args, **kwargs):
                return operation(*args, **kwargs)

        parameters = resolver_parameters(operation, skip_first=skip_first)
        return Resolver(operation, wrapper, parameters)

    query_fields: List[ObjectField] = []
    mutation_fields: List[ObjectField] = []
    subscription_fields: List[ObjectField] = []
    for operations, fields in [(query, query_fields),
                               (mutation, mutation_fields)]:
        for operation in operations:
            resolver = operation_resolver(operation)
            fields.append(
                ObjectField(
                    operation.__name__,
                    wrap_return_type(resolver.return_type, error_as_null),
                    resolve=resolver_resolve(resolver, aliaser, error_as_null),
                    parameters=field_parameters(resolver),
                    schema=get_schema(operation),
                ))
    for operation in subscription:  # type: ignore
        resolve: Callable
        if isinstance(operation, tuple):
            operation, event_handler = operation
            name, schema = event_handler.__name__, get_schema(event_handler)
            try:
                resolver = operation_resolver(event_handler, skip_first=True)
            except MissingFirstParameter:
                raise TypeError(
                    "Subscription resolver must have at least one parameter"
                ) from None
            return_type = resolver.return_type
            subscribe = resolver_resolve(
                operation_resolver(operation),
                aliaser,
                error_as_null,
                serialized=False,
            )
            resolve = resolver_resolve(resolver, aliaser, error_as_null)
        else:
            name, schema = operation.__name__, get_schema(operation)
            resolver = operation_resolver(operation)
            if get_origin(resolver.return_type) not in async_iterable_origins:
                raise TypeError(
                    "Subscriptions must return an AsyncIterable/AsyncIterator")
            return_type = get_args(resolver.return_type)[0]
            subscribe = resolver_resolve(resolver,
                                         aliaser,
                                         error_as_null,
                                         serialized=False)

            def resolve(_, *args, **kwargs):
                return _

        subscription_fields.append(
            ObjectField(
                name,
                wrap_return_type(return_type, error_as_null),
                parameters=field_parameters(resolver),
                resolve=resolve,
                subscribe=subscribe,
                schema=schema,
            ))

    is_id = id_types.__contains__ if isinstance(id_types,
                                                Collection) else id_types
    builder = OutputSchemaBuilder(aliaser, is_id, error_as_null,
                                  generic_ref_factory, union_ref_factory)

    def root_type(
        name: str, fields: Collection[ObjectField]
    ) -> Optional[graphql.GraphQLObjectType]:
        if not fields:
            return None
        return exec_thunk(builder.object(type(name, (), {}), fields),
                          non_null=False)

    return graphql.GraphQLSchema(
        root_type("Query", query_fields),
        root_type("Mutation", mutation_fields),
        root_type("Subscription", subscription_fields),
        [exec_thunk(builder.visit(cls), non_null=False) for cls in types],
        directives,
        description,
        extensions,
    )
Exemplo n.º 20
0
def is_skipped(cls: AnyType, *, schema_only) -> bool:
    return cls is UndefinedType or (
        get_origin(cls) is Annotated and
        (Skip in get_args(cls)[1:] or
         (schema_only and SkipSchema in get_args(cls)[1:])))
Exemplo n.º 21
0
def _annotated(tp: AnyType) -> AnyType:
    return get_args(tp)[0] if is_annotated(tp) else tp
Exemplo n.º 22
0
def type_var_context(cls: AnyType,
                     type_vars: TypeVarContext = None) -> TypeVarContext:
    cls = resolve_type_vars(cls, type_vars)
    origin = get_origin(cls)
    assert origin is not None
    return dict(zip(get_parameters(origin), get_args(cls)))
Exemplo n.º 23
0
def get_args2(tp: AnyType) -> Tuple[AnyType, ...]:
    return get_args(_annotated(tp))
Exemplo n.º 24
0
    def __init__(self, cls):
        self.cls = cls
        self.implem = cls.__origin__ or cls.__extra__  # extra in 3.6

    def __getitem__(self, item):
        return self.cls[item]

    def __call__(self, *args, **kwargs):
        return self.implem(*args, **kwargs)

    def __instancecheck__(self, instance):
        return isinstance(instance, self.implem)

    def __subclasscheck__(self, subclass):
        return issubclass(subclass, self.implem)


for cls in (Dict, List, Set, FrozenSet, Tuple, Type):  # noqa
    wrapper = Wrapper(cls)
    globals()[wrapper.implem.__name__] = wrapper

Set = AbstractSet

del Wrapper

if sys.version_info < (3, 7):
    asyncio.run = lambda coro: asyncio.get_event_loop().run_until_complete(coro
                                                                           )

inspect.isclass = lambda tp: isinstance(tp, type) and not get_args(tp)
Exemplo n.º 25
0
 def return_type(self) -> AnyType:
     ret = self.types["return"]
     return get_args(ret)[0] if get_origin(ret) == awaitable_origin else ret
Exemplo n.º 26
0
def keep_annotations(tp: AnyType, annotated: AnyType) -> AnyType:
    return Annotated[(
        tp, *get_args(annotated)[1:])] if is_annotated(annotated) else tp