def gather_errors(func: Callable[..., ValidatorResult[T]]) -> Callable[..., T]: if not isgeneratorfunction(func): raise TypeError("func must be a generator returning a ValidatorResult") @wraps(func) def wrapper(*args, **kwargs): result, errors = func(*args, **kwargs), [] while True: try: errors.append(next(result)) except StopIteration as stop: if errors: raise build_validation_error(errors) return stop.value if "return" in func.__annotations__: ret = func.__annotations__["return"] if isinstance(ret, str): match = re.match(r"ValidatorResult\[(?P<ret>.*)\]", ret) if match is not None: ret = match.groupdict("ret") else: annotations = get_args(ret)[1:] if is_annotated(ret) else () if get_origin2(ret) == GeneratorOrigin: ret = get_args2(ret)[2] if annotations: ret = Annotated[(ret, *annotations)] wrapper.__annotations__["return"] = ret return wrapper
def __init_subclass__(cls, **kwargs): tags = set(getattr(cls, TAGS_ATTR, ())) types = get_type_hints(cls, include_extras=True) for tag, tp in types.items(): if get_origin2(tp) == Tagged: tagged = cls.__dict__.get(tag, Tagged()) setattr(cls, tag, field(default=Undefined, metadata=tagged.metadata)) cls.__annotations__[tag] = Union[get_args2(types[tag])[0], UndefinedType] tags.add(tag) elif tag not in tags: if get_origin2(tp) != ClassVar: cls.__annotations__[tag] = ClassVar[tp] else: raise TypeError( "Only Tagged or ClassVar fields are allowed in TaggedUnion" ) setattr(cls, TAGS_ATTR, tags) schema(min_props=1, max_props=1)(dataclass(init=False, repr=False)(cls)) for tag in tags: setattr(cls, tag, Tag(tag, cls))
def graphql_schema( *, query: Iterable[Union[Callable, Query]] = (), mutation: Iterable[Union[Callable, Mutation]] = (), subscription: Iterable[Union[Callable[..., AsyncIterable], Subscription]] = (), types: Iterable[Type] = (), directives: Optional[Collection[graphql.GraphQLDirective]] = None, description: Optional[str] = None, extensions: Optional[Dict[str, Any]] = None, aliaser: Optional[Aliaser] = to_camel_case, enum_aliaser: Optional[Aliaser] = str.upper, enum_schemas: Optional[Mapping[Enum, Schema]] = None, id_types: Union[Collection[AnyType], IdPredicate] = (), id_encoding: Tuple[Optional[Callable[[str], Any]], Optional[Callable[[Any], str]]] = (None, None), # TODO deprecate union_ref parameter union_ref: UnionNameFactory = "Or".join, union_name: UnionNameFactory = "Or".join, default_deserialization: DefaultConversion = None, default_serialization: DefaultConversion = None, ) -> graphql.GraphQLSchema: if aliaser is None: aliaser = settings.aliaser if enum_aliaser is None: enum_aliaser = lambda s: s if default_deserialization is None: default_deserialization = settings.deserialization.default_conversion if default_serialization is None: default_serialization = settings.serialization.default_conversion query_fields: List[ResolverField] = [] mutation_fields: List[ResolverField] = [] subscription_fields: List[ResolverField] = [] for operations, op_class, fields in [ (query, Query, query_fields), (mutation, Mutation, mutation_fields), ]: for operation in operations: # type: ignore alias, resolver = operation_resolver(operation, op_class) resolver_field = ResolverField( alias, resolver, resolver.types(), resolver.parameters, resolver.parameters_metadata, ) fields.append(resolver_field) for sub_op in subscription: # type: ignore if not isinstance(sub_op, Subscription): sub_op = Subscription(sub_op) # type: ignore sub_parameters: Sequence[Parameter] if sub_op.resolver is not None: alias = sub_op.alias or sub_op.resolver.__name__ _, subscriber2 = operation_resolver(sub_op, Subscription) _, *sub_parameters = resolver_parameters(sub_op.resolver, check_first=False) resolver = Resolver( sub_op.resolver, sub_op.conversion, sub_op.schema, subscriber2.error_handler, sub_parameters, sub_op.parameters_metadata, ) sub_types = resolver.types() subscriber = replace(subscriber2, error_handler=None) subscribe = resolver_resolve( subscriber, subscriber.types(), aliaser, default_deserialization, default_serialization, serialized=False, ) else: alias, subscriber2 = operation_resolver(sub_op, Subscription) resolver = Resolver( lambda _: _, sub_op.conversion, sub_op.schema, subscriber2.error_handler, (), {}, ) subscriber = replace(subscriber2, error_handler=None) sub_parameters = subscriber.parameters sub_types = subscriber.types() if get_origin2(sub_types["return"]) not in async_iterable_origins: raise TypeError( "Subscriptions must return an AsyncIterable/AsyncIterator") event_type = get_args2(sub_types["return"])[0] subscribe = resolver_resolve( subscriber, sub_types, aliaser, default_deserialization, default_serialization, serialized=False, ) sub_types = { **sub_types, "return": resolver.return_type(event_type) } resolver_field = ResolverField( alias, resolver, sub_types, sub_parameters, sub_op.parameters_metadata, subscribe, ) subscription_fields.append(resolver_field) is_id = id_types.__contains__ if isinstance(id_types, Collection) else id_types if id_encoding == (None, None): id_type: graphql.GraphQLScalarType = graphql.GraphQLID else: id_deserializer, id_serializer = id_encoding id_type = graphql.GraphQLScalarType( name="ID", serialize=id_serializer or graphql.GraphQLID.serialize, parse_value=id_deserializer or graphql.GraphQLID.parse_value, parse_literal=graphql.GraphQLID.parse_literal, description=graphql.GraphQLID.description, ) output_builder = OutputSchemaBuilder( aliaser, enum_aliaser, enum_schemas or {}, default_serialization, id_type, is_id, union_name or union_ref, default_deserialization, ) def root_type( name: str, fields: Sequence[ResolverField] ) -> Optional[graphql.GraphQLObjectType]: if not fields: return None tp, type_name = type(name, (), {}), TypeName(graphql=name) return output_builder.object(tp, (), fields).merge( type_name, None).raw_type # type: ignore return graphql.GraphQLSchema( query=root_type("Query", query_fields), mutation=root_type("Mutation", mutation_fields), subscription=root_type("Subscription", subscription_fields), types=[output_builder.visit(cls).raw_type for cls in types], # type: ignore directives=directives, description=description, extensions=extensions, )