def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema: result = super().object(tp, fields) name_by_aliases = {f.alias: f.name for f in fields} properties = {} required = [] for alias, (serialized, types) in get_serialized_methods(tp).items(): return_type = types["return"] properties[self.aliaser(alias)] = full_schema( self.visit_with_conv(return_type, serialized.conversion), serialized.schema, ) if not is_union_of(return_type, UndefinedType): required.append(alias) name_by_aliases[alias] = serialized.func.__name__ if "allOf" not in result: to_update = result else: to_update = result["allOf"][0] if required: required.extend(to_update.get("required", ())) to_update["required"] = sorted(required) if properties: properties.update(to_update.get("properties", {})) props = sort_by_annotations_position(get_origin_or_type(tp), properties, lambda p: name_by_aliases[p]) to_update["properties"] = {p: properties[p] for p in props} return result
def get_node_name(tp): if is_union_of(tp, NoneType) and len(get_args2(tp)): tp = next(arg for arg in get_args2(tp) if arg is not NoneType) ref = get_type_name(tp).graphql if ref is None: raise TypeError( f"Node {tp} must have a ref registered to be used with connection") return ref
def skip_if(self) -> Optional[Callable[[Any], Any]]: skip_if = self.skip.serialization_if if self.default_factory is not None and self.skip.serialization_default: default = self.default_factory() # type: ignore skip_if = merge_skip_if(skip_if, lambda obj: obj == default) if is_union_of(self.type, UndefinedType): skip_if = merge_skip_if(skip_if, lambda obj: obj is Undefined) if self.none_as_undefined: skip_if = merge_skip_if(skip_if, lambda obj: obj is None) return skip_if
def __post_init__(self, default: Any): if REQUIRED_METADATA in self.full_metadata: object.__setattr__(self, "required", True) if self.default_factory is MISSING: object.__setattr__(self, "default_factory", None) if not self.required and self.default_factory is None: if default is MISSING_DEFAULT: raise ValueError( "Missing default for non-required ObjectField") object.__setattr__(self, "default_factory", LazyValue(default)) if self.none_as_undefined and is_union_of(self.type, NoneType): new_type = Union[tuple(a for a in get_args2(self.type) if a != NoneType)] # type: ignore object.__setattr__(self, "type", keep_annotations(new_type, self.type))
def _field_required(field: ObjectField): return field.required and not is_union_of(field.type, UndefinedType)
def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]: resolve = self._wrap_resolve( resolver_resolve( field.resolver, field.types, self.aliaser, self.input_builder.default_conversion, self.default_conversion, )) args = None if field.parameters is not None: args = {} for param in field.parameters: default: Any = graphql.Undefined param_type = field.types[param.name] if is_union_of(param_type, graphql.GraphQLResolveInfo): break param_field = ObjectField( param.name, param_type, param.default is Parameter.empty, field.metadata.get(param.name, empty_dict), default=param.default, ) if param_field.required: pass # Don't put `null` default + handle Undefined as None # also https://github.com/python/typing/issues/775 elif param.default in {None, Undefined}: param_type = Optional[param_type] # param.default == graphql.Undefined means the parameter is required # even if it has a default elif param.default not in {Parameter.empty, graphql.Undefined}: try: default = serialize( param_type, param.default, fall_back_on_any=False, check_type=True, ) except Exception: param_type = Optional[param_type] arg_factory = self.input_builder.visit_with_conv( param_type, param_field.deserialization) description = get_description(param_field.schema, param_field.type) def arg_thunk( arg_factory=arg_factory, default=default, description=description) -> graphql.GraphQLArgument: return graphql.GraphQLArgument(arg_factory.type, default, description) args[self.aliaser(param_field.alias)] = arg_thunk factory = self.visit_with_conv(field.type, field.resolver.conversion) return lambda: graphql.GraphQLField( factory.type, # type: ignore {name: arg() for name, arg in args.items()} if args else None, resolve, field.subscribe, field.description, field.deprecated, )
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if not hasattr(cls, "mutate"): return if not isinstance(cls.__dict__["mutate"], (classmethod, staticmethod)): raise TypeError( f"{cls.__name__}.mutate must be a classmethod/staticmethod") mutate = getattr(cls, "mutate") type_name(f"{cls.__name__}Payload")(cls) types = get_type_hints(mutate, localns={cls.__name__: cls}, include_extras=True) async_mutate = is_async(mutate, types) fields: List[Tuple[str, AnyType, Field]] = [] cmi_param = None for param_name, param in signature(mutate).parameters.items(): if param.kind is Parameter.POSITIONAL_ONLY: raise TypeError("Positional only parameters are not supported") if param.kind in { Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY }: if param_name not in types: raise TypeError("Mutation parameters must be typed") field_type = types[param_name] field_ = MISSING if param.default is Parameter.empty else param.default if is_union_of(field_type, ClientMutationId): cmi_param = param_name if cls._client_mutation_id is False: if field_ is MISSING: raise TypeError( "Cannot have a ClientMutationId parameter" " when _client_mutation_id = False") continue elif cls._client_mutation_id is True: field_ = MISSING field_ = field(default=field_, metadata=alias(CLIENT_MUTATION_ID)) fields.append((param_name, field_type, field_)) field_names = [name for (name, _, _) in fields] if cmi_param is None and cls._client_mutation_id is not False: fields.append(( CLIENT_MUTATION_ID, ClientMutationId if cls._client_mutation_id else Optional[ClientMutationId], MISSING if cls._client_mutation_id else None, )) cmi_param = CLIENT_MUTATION_ID input_cls = make_dataclass(f"{cls.__name__}Input", fields) def wrapper(input): return mutate( **{name: getattr(input, name) for name in field_names}) wrapper.__annotations__["input"] = input_cls wrapper.__annotations__[ "return"] = Awaitable[cls] if async_mutate else cls if cls._client_mutation_id is not False: cls.__annotations__[ CLIENT_MUTATION_ID] = input_cls.__annotations__[cmi_param] setattr(cls, CLIENT_MUTATION_ID, field(init=False)) wrapped = wrapper if async_mutate: async def wrapper(input): result = await wrapped(input) setattr(result, CLIENT_MUTATION_ID, getattr(input, cmi_param)) return result else: def wrapper(input): result = wrapped(input) setattr(result, CLIENT_MUTATION_ID, getattr(input, cmi_param)) return result wrapper = wraps(wrapped)(wrapper) cls._mutation = Mutation_( function=wrapper, alias=camel_to_snake(cls.__name__), schema=cls._schema, error_handler=cls._error_handler, )
def resolver_resolve( resolver: Resolver, types: Mapping[str, AnyType], aliaser: Aliaser, default_deserialization: DefaultConversion, default_serialization: DefaultConversion, serialized: bool = True, ) -> Callable: # graphql deserialization will give Enum objects instead of strings def handle_enum(tp: AnyType) -> Optional[AnyConversion]: if is_type(tp) and issubclass(tp, Enum): return Conversion(identity, source=Any, target=tp) return default_deserialization(tp) parameters, info_parameter = [], None for param in resolver.parameters: param_type = types[param.name] if is_union_of(param_type, graphql.GraphQLResolveInfo): info_parameter = param.name else: param_field = ObjectField( param.name, param_type, param.default is Parameter.empty, resolver.parameters_metadata.get(param.name, empty_dict), param.default, ) deserializer = deserialization_method( param_type, additional_properties=False, aliaser=aliaser, coerce=False, conversion=param_field.deserialization, default_conversion=handle_enum, fall_back_on_default=False, schema=param_field.schema, ) opt_param = is_union_of(param_type, NoneType) or param.default is None parameters.append( ( aliaser(param_field.alias), param.name, deserializer, opt_param, param_field.required, ) ) func, error_handler = resolver.func, resolver.error_handler method_factory = partial_serialization_method_factory( aliaser, resolver.conversion, default_serialization ) serialize_result: Callable[[Any], Any] if not serialized: serialize_result = lambda res: res elif is_async(resolver.func): serialize_result = as_async(method_factory(types["return"])) else: serialize_result = method_factory(types["return"]) serialize_error: Optional[Callable[[Any], Any]] if error_handler is None: serialize_error = None elif is_async(error_handler): serialize_error = as_async(method_factory(resolver.error_type())) else: serialize_error = method_factory(resolver.error_type()) def resolve(__self, __info, **kwargs): values = {} errors: Dict[str, ValidationError] = {} for alias, param_name, deserializer, opt_param, required in parameters: if alias in kwargs: # It is possible for the parameter to be non-optional in Python # type hints but optional in the generated schema. In this case # we should ignore it. # See: https://github.com/wyfo/apischema/pull/130#issuecomment-845497392 if not opt_param and kwargs[alias] is None: assert not required continue try: values[param_name] = deserializer(kwargs[alias]) except ValidationError as err: errors[aliaser(param_name)] = err elif opt_param and required: values[param_name] = None if errors: raise ValueError(ValidationError(children=errors).errors) if info_parameter: values[info_parameter] = __info try: return serialize_result(func(__self, **values)) except Exception as error: if error_handler is None: raise assert serialize_error is not None return serialize_error(error_handler(error, __self, __info, **kwargs)) return resolve