def _has_conversion( tp: AnyType, conversion: Optional[AnyConversion] ) -> Tuple[bool, Optional[Deserialization]]: identity_conv, result = False, [] for conv in resolve_any_conversion(conversion): conv = handle_identity_conversion(conv, tp) if is_subclass(conv.target, tp): if is_identity(conv): if identity_conv: continue identity_conv = True wrapper: AnyType = self_deserialization_wrapper( get_origin_or_type(tp)) if get_args(tp): wrapper = wrapper[get_args(tp)] conv = ResolvedConversion(replace(conv, source=wrapper)) conv = handle_dataclass_model(conv) _, substitution = subtyping_substitution(tp, conv.target) source = substitute_type_vars(conv.source, substitution) result.append( ResolvedConversion(replace(conv, source=source, target=tp))) if identity_conv and len(result) == 1: return True, None else: return bool(result), tuple(result) or None
def handle_generic_field_type(field_type: AnyType, base: AnyType, other: AnyType, covariant: bool) -> AnyType: contravariant = not covariant type_vars = None if is_type_var(base): type_vars = {base: field_type} if (get_origin(base) is not None and getattr(base, "__parameters__", ()) and len(get_args(base)) == len(get_args(field_type)) and not any(map(get_origin, get_args(base))) and not any(map(get_origin, get_args(field_type))) and not any(not is_type_var(base_arg) and base_arg != field_arg for base_arg, field_arg in zip(get_args(base), get_args(field_type)))): type_vars = {} for base_arg, field_arg in zip(get_args(base), get_args(field_type)): if base_arg in type_vars and type_vars[base_arg] != field_arg: type_vars = None break type_vars[base_arg] = field_arg field_type_origin, base_origin = get_origin(field_type), get_origin( base) assert field_type_origin is not None and base_origin is not None if base_origin != field_type_origin: if covariant and not issubclass(base_origin, field_type_origin): type_vars = None if contravariant and not issubclass(field_type_origin, base_origin): type_vars = None return resolve_type_vars(other, type_vars)
def unsupported(self, tp: AnyType) -> Result: from apischema import settings origin = get_origin_or_type(tp) if isinstance(origin, type): fields = settings.default_object_fields(origin) if fields is not None: if get_args(tp): sub = dict(zip(get_parameters(origin), get_args(tp))) fields = [ replace(f, type=substitute_type_vars(f.type, sub)) for f in fields ] return self._object(origin, fields) return super().unsupported(tp)
def _get_methods( tp: AnyType, all_methods: Mapping[Type, Mapping[str, S]] ) -> Mapping[str, Tuple[S, Mapping[str, AnyType]]]: result = {} for base in reversed(generic_mro(tp)): for name, method in all_methods[get_origin_or_type(base)].items(): result[name] = (method, method.types(base)) if has_model_origin(tp): origin = get_model_origin(tp) if get_args(tp): substitution = dict( zip(get_parameters(get_origin(tp)), get_args(tp))) origin = substitute_type_vars(origin, substitution) result.update(_get_methods(origin, all_methods)) return result
def gather_errors(func: Callable[..., ValidatorResult[T]]) -> Callable[..., T]: if not isgeneratorfunction(func): raise TypeError("func must be a generator returning a ValidatorResult") @wraps(func) def wrapper(*args, **kwargs): result, errors = func(*args, **kwargs), [] while True: try: errors.append(next(result)) except StopIteration as stop: if errors: raise build_validation_error(errors) return stop.value if "return" in func.__annotations__: ret = func.__annotations__["return"] if isinstance(ret, str): match = re.match(r"ValidatorResult\[(?P<ret>.*)\]", ret) if match is not None: ret = match.groupdict("ret") else: annotations = get_args(ret)[1:] if is_annotated(ret) else () if get_origin2(ret) == GeneratorOrigin: ret = get_args2(ret)[2] if annotations: ret = Annotated[(ret, *annotations)] wrapper.__annotations__["return"] = ret return wrapper
def full_metadata(self) -> Mapping[str, Any]: if not is_annotated(self.type): return self.metadata return ChainMap( self.metadata, *(arg for arg in reversed(get_args(self.type)[1:]) if isinstance(arg, Mapping)), )
def subtyping_substitution( supertype: AnyType, subtype: AnyType ) -> Tuple[Mapping[AnyType, AnyType], Mapping[AnyType, AnyType]]: supertype, subtype = with_parameters(supertype), with_parameters(subtype) supertype_to_subtype, subtype_to_supertype = {}, {} super_origin = get_origin_or_type2(supertype) for base in generic_mro(subtype): base_origin = get_origin_or_type2(base) if base_origin == super_origin or (base_origin in ITERABLE_TYPES and super_origin in ITERABLE_TYPES): for base_arg, super_arg in zip(get_args(base), get_args(supertype)): if is_type_var(super_arg): supertype_to_subtype[super_arg] = base_arg if is_type_var(base_arg): subtype_to_supertype[base_arg] = super_arg break return supertype_to_subtype, subtype_to_supertype
def default_type_name(tp: AnyType) -> Optional[TypeName]: if ( hasattr(tp, "__name__") and not get_args(tp) and not has_type_vars(tp) and tp not in PRIMITIVE_TYPES ): return TypeName(tp.__name__, tp.__name__) else: return None
def check_type(self, tp: AnyType): if is_type_var(tp): raise TypeError("TypeVar cannot have a type_name") if has_type_vars(tp): if get_args(tp): raise TypeError("Generic alias cannot have a type_name") elif isinstance(self.json_schema, str) or isinstance(self.graphql, str): raise TypeError( "Unspecialized generic type must used factory type_name" )
def get_type_name(tp: AnyType) -> TypeName: from apischema import settings tp = replace_builtins(tp) with suppress(KeyError, TypeError): return _type_names[tp].to_type_name(tp) origin, args = get_origin(tp), get_args(tp) if args and not has_type_vars(tp): with suppress(KeyError, TypeError): return _type_names[origin].to_type_name(origin, *args) return settings.default_type_name(tp) or TypeName()
def handle_generic_conversions(base: AnyType, other: AnyType) -> Tuple[AnyType, AnyType]: origin = get_origin(base) if origin is None: return base, other args = get_args(base) if not all(map(is_type_var, args)): raise TypeError( f"Generic conversion doesn't support specialization," f" aka {type_name(base)}[{','.join(map(type_name, args))}]") return origin, resolve_type_vars(other, dict(zip(args, get_parameters(origin))))
def merged_schema( schema: Optional[Schema], tp: Optional[AnyType]) -> Tuple[Optional[Schema], Mapping[str, Any]]: if is_annotated(tp): for annotation in reversed(get_args(tp)[1:]): if isinstance(annotation, TypeNameFactory): break elif isinstance(annotation, Mapping) and SCHEMA_METADATA in annotation: schema = merge_schema(annotation[SCHEMA_METADATA], schema) schema_dict: Dict[str, Any] = {} if schema is not None: schema.merge_into(schema_dict) return schema, schema_dict
def with_validation_error( func: Callable[..., ValidatorResult[T]]) -> Callable[..., T]: if not isgeneratorfunction(func): raise TypeError("func must be a generator returning a ValidatorResult") wrapper = yield_to_raise(func) if "return" in func.__annotations__: ret = func.__annotations__["return"] if isinstance(ret, str): match = re.match(r"ValidatorResult\[(?P<ret>.*)\]", ret) if match is not None: ret = match.groupdict("ret") elif get_origin(ret) == GeneratorOrigin: ret = get_args(ret)[2] wrapper.__annotations__["return"] = ret return wrapper
def visit_conversion( self, tp: AnyType, conversion: Optional[Conv], dynamic: bool, next_conversion: Optional[AnyConversion] = None, ) -> TypeFactory[GraphQLTp]: if not dynamic and self.is_id(tp) or tp == ID: return TypeFactory(lambda *_: graphql.GraphQLNonNull(self.id_type)) factory = super().visit_conversion(tp, conversion, dynamic, next_conversion) if not dynamic: factory = factory.merge(get_type_name(tp), get_schema(tp)) if get_args(tp): factory = factory.merge(schema=get_schema(get_origin(tp))) return factory # type: ignore
def _visit_generic(self, cls: AnyType) -> Return: origin, args = get_origin(cls), get_args(cls) assert origin is not None if origin is Annotated: return self.annotated(args[0], args[1:]) if origin is Union: return self.union(args) if origin is TUPLE_TYPE: if len(args) < 2 or args[1] is not ...: return self.tuple(args) if origin in COLLECTION_TYPES: return self.collection(origin, args[0]) if origin in MAPPING_TYPES: return self.mapping(origin, args[0], args[1]) if origin is Literal: # pragma: no cover py37+ return self.literal(args) return self.generic(cls)
def visit_conversion( self, tp: AnyType, conversion: Optional[Conv], dynamic: bool, next_conversion: Optional[AnyConversion] = None, ) -> JsonSchema: schemas = [] if not dynamic: for ref_tp in self.resolve_conversion(tp): ref_schema = self.ref_schema(get_type_name(ref_tp).json_schema) if ref_schema is not None: return ref_schema if get_args(tp): schemas.append(get_schema(get_origin_or_type(tp))) schemas.append(get_schema(tp)) result = super().visit_conversion(tp, conversion, dynamic, next_conversion) return reduce(full_schema, schemas, result)
def visit_conversion( self, tp: AnyType, conversion: Optional[Deserialization], dynamic: bool, next_conversion: Optional[AnyConversion] = None, ) -> DeserializationMethodFactory: factory = super().visit_conversion(tp, conversion, dynamic, next_conversion) if factory.coercer is None and self._coerce: factory = replace(factory, coercer=self._coercer) if not dynamic: factory = factory.merge(get_constraints(get_schema(tp)), get_validators(tp)) if get_args(tp): factory = factory.merge( get_constraints(get_schema(get_origin(tp))), get_validators(get_origin(tp)), ) return factory
def process_node(node_cls: Type[Node]): if has_type_vars(node_cls) or node_cls.get_by_id is Node.get_by_id: return for base in node_cls.__mro__: if base != Node and Node.get_by_id.__name__ in base.__dict__: if not isinstance( base.__dict__[Node.get_by_id.__name__], (classmethod, staticmethod) ): raise TypeError( f"{node_cls.__name__}.get_by_id must be a" f" classmethod/staticmethod" ) break for base in generic_mro(node_cls): if get_origin(base) == Node: setattr(node_cls, ID_TYPE_ATTR, get_args(base)[0]) _nodes[node_cls._node_key()] = node_cls break else: raise TypeError("Node type parameter Id must be specialized")
def graphql_schema( *, query: Iterable[Callable] = (), mutation: Iterable[Callable] = (), subscription: Iterable[Union[Subscribe, Tuple[Subscribe, Callable]]] = (), types: Iterable[Type] = (), aliaser: Aliaser = to_camel_case, id_types: Union[Collection[AnyType], IdPredicate] = None, error_as_null: bool = True, generic_ref_factory: GenericRefFactory = None, union_ref_factory: UnionRefFactory = None, directives: Optional[Collection[graphql.GraphQLDirective]] = None, description: Optional[str] = None, extensions: Optional[Dict[str, Any]] = None, ) -> graphql.GraphQLSchema: def operation_resolver(operation: Callable, *, skip_first=False) -> Resolver: if skip_first: wrapper = operation else: def wrapper(_, *args, **kwargs): return operation(*args, **kwargs) parameters = resolver_parameters(operation, skip_first=skip_first) return Resolver(operation, wrapper, parameters) query_fields: List[ObjectField] = [] mutation_fields: List[ObjectField] = [] subscription_fields: List[ObjectField] = [] for operations, fields in [(query, query_fields), (mutation, mutation_fields)]: for operation in operations: resolver = operation_resolver(operation) fields.append( ObjectField( operation.__name__, wrap_return_type(resolver.return_type, error_as_null), resolve=resolver_resolve(resolver, aliaser, error_as_null), parameters=field_parameters(resolver), schema=get_schema(operation), )) for operation in subscription: # type: ignore resolve: Callable if isinstance(operation, tuple): operation, event_handler = operation name, schema = event_handler.__name__, get_schema(event_handler) try: resolver = operation_resolver(event_handler, skip_first=True) except MissingFirstParameter: raise TypeError( "Subscription resolver must have at least one parameter" ) from None return_type = resolver.return_type subscribe = resolver_resolve( operation_resolver(operation), aliaser, error_as_null, serialized=False, ) resolve = resolver_resolve(resolver, aliaser, error_as_null) else: name, schema = operation.__name__, get_schema(operation) resolver = operation_resolver(operation) if get_origin(resolver.return_type) not in async_iterable_origins: raise TypeError( "Subscriptions must return an AsyncIterable/AsyncIterator") return_type = get_args(resolver.return_type)[0] subscribe = resolver_resolve(resolver, aliaser, error_as_null, serialized=False) def resolve(_, *args, **kwargs): return _ subscription_fields.append( ObjectField( name, wrap_return_type(return_type, error_as_null), parameters=field_parameters(resolver), resolve=resolve, subscribe=subscribe, schema=schema, )) is_id = id_types.__contains__ if isinstance(id_types, Collection) else id_types builder = OutputSchemaBuilder(aliaser, is_id, error_as_null, generic_ref_factory, union_ref_factory) def root_type( name: str, fields: Collection[ObjectField] ) -> Optional[graphql.GraphQLObjectType]: if not fields: return None return exec_thunk(builder.object(type(name, (), {}), fields), non_null=False) return graphql.GraphQLSchema( root_type("Query", query_fields), root_type("Mutation", mutation_fields), root_type("Subscription", subscription_fields), [exec_thunk(builder.visit(cls), non_null=False) for cls in types], directives, description, extensions, )
def is_skipped(cls: AnyType, *, schema_only) -> bool: return cls is UndefinedType or ( get_origin(cls) is Annotated and (Skip in get_args(cls)[1:] or (schema_only and SkipSchema in get_args(cls)[1:])))
def _annotated(tp: AnyType) -> AnyType: return get_args(tp)[0] if is_annotated(tp) else tp
def type_var_context(cls: AnyType, type_vars: TypeVarContext = None) -> TypeVarContext: cls = resolve_type_vars(cls, type_vars) origin = get_origin(cls) assert origin is not None return dict(zip(get_parameters(origin), get_args(cls)))
def get_args2(tp: AnyType) -> Tuple[AnyType, ...]: return get_args(_annotated(tp))
def __init__(self, cls): self.cls = cls self.implem = cls.__origin__ or cls.__extra__ # extra in 3.6 def __getitem__(self, item): return self.cls[item] def __call__(self, *args, **kwargs): return self.implem(*args, **kwargs) def __instancecheck__(self, instance): return isinstance(instance, self.implem) def __subclasscheck__(self, subclass): return issubclass(subclass, self.implem) for cls in (Dict, List, Set, FrozenSet, Tuple, Type): # noqa wrapper = Wrapper(cls) globals()[wrapper.implem.__name__] = wrapper Set = AbstractSet del Wrapper if sys.version_info < (3, 7): asyncio.run = lambda coro: asyncio.get_event_loop().run_until_complete(coro ) inspect.isclass = lambda tp: isinstance(tp, type) and not get_args(tp)
def return_type(self) -> AnyType: ret = self.types["return"] return get_args(ret)[0] if get_origin(ret) == awaitable_origin else ret
def keep_annotations(tp: AnyType, annotated: AnyType) -> AnyType: return Annotated[( tp, *get_args(annotated)[1:])] if is_annotated(annotated) else tp