Exemplo n.º 1
0
 def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema:
     result = super().object(tp, fields)
     name_by_aliases = {f.alias: f.name for f in fields}
     properties = {}
     required = []
     for alias, (serialized, types) in get_serialized_methods(tp).items():
         return_type = types["return"]
         properties[self.aliaser(alias)] = full_schema(
             self.visit_with_conv(return_type, serialized.conversion),
             serialized.schema,
         )
         if not is_union_of(return_type, UndefinedType):
             required.append(alias)
         name_by_aliases[alias] = serialized.func.__name__
     if "allOf" not in result:
         to_update = result
     else:
         to_update = result["allOf"][0]
     if required:
         required.extend(to_update.get("required", ()))
         to_update["required"] = sorted(required)
     if properties:
         properties.update(to_update.get("properties", {}))
         props = sort_by_annotations_position(get_origin_or_type(tp),
                                              properties,
                                              lambda p: name_by_aliases[p])
         to_update["properties"] = {p: properties[p] for p in props}
     return result
Exemplo n.º 2
0
def get_node_name(tp):
    if is_union_of(tp, NoneType) and len(get_args2(tp)):
        tp = next(arg for arg in get_args2(tp) if arg is not NoneType)
    ref = get_type_name(tp).graphql
    if ref is None:
        raise TypeError(
            f"Node {tp} must have a ref registered to be used with connection")
    return ref
Exemplo n.º 3
0
 def skip_if(self) -> Optional[Callable[[Any], Any]]:
     skip_if = self.skip.serialization_if
     if self.default_factory is not None and self.skip.serialization_default:
         default = self.default_factory()  # type: ignore
         skip_if = merge_skip_if(skip_if, lambda obj: obj == default)
     if is_union_of(self.type, UndefinedType):
         skip_if = merge_skip_if(skip_if, lambda obj: obj is Undefined)
     if self.none_as_undefined:
         skip_if = merge_skip_if(skip_if, lambda obj: obj is None)
     return skip_if
Exemplo n.º 4
0
 def __post_init__(self, default: Any):
     if REQUIRED_METADATA in self.full_metadata:
         object.__setattr__(self, "required", True)
     if self.default_factory is MISSING:
         object.__setattr__(self, "default_factory", None)
     if not self.required and self.default_factory is None:
         if default is MISSING_DEFAULT:
             raise ValueError(
                 "Missing default for non-required ObjectField")
         object.__setattr__(self, "default_factory", LazyValue(default))
     if self.none_as_undefined and is_union_of(self.type, NoneType):
         new_type = Union[tuple(a for a in get_args2(self.type)
                                if a != NoneType)]  # type: ignore
         object.__setattr__(self, "type",
                            keep_annotations(new_type, self.type))
Exemplo n.º 5
0
 def _field_required(field: ObjectField):
     return field.required and not is_union_of(field.type, UndefinedType)
Exemplo n.º 6
0
    def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]:
        resolve = self._wrap_resolve(
            resolver_resolve(
                field.resolver,
                field.types,
                self.aliaser,
                self.input_builder.default_conversion,
                self.default_conversion,
            ))
        args = None
        if field.parameters is not None:
            args = {}
            for param in field.parameters:
                default: Any = graphql.Undefined
                param_type = field.types[param.name]
                if is_union_of(param_type, graphql.GraphQLResolveInfo):
                    break
                param_field = ObjectField(
                    param.name,
                    param_type,
                    param.default is Parameter.empty,
                    field.metadata.get(param.name, empty_dict),
                    default=param.default,
                )
                if param_field.required:
                    pass
                # Don't put `null` default + handle Undefined as None
                # also https://github.com/python/typing/issues/775
                elif param.default in {None, Undefined}:
                    param_type = Optional[param_type]
                # param.default == graphql.Undefined means the parameter is required
                # even if it has a default
                elif param.default not in {Parameter.empty, graphql.Undefined}:
                    try:
                        default = serialize(
                            param_type,
                            param.default,
                            fall_back_on_any=False,
                            check_type=True,
                        )
                    except Exception:
                        param_type = Optional[param_type]
                arg_factory = self.input_builder.visit_with_conv(
                    param_type, param_field.deserialization)
                description = get_description(param_field.schema,
                                              param_field.type)

                def arg_thunk(
                        arg_factory=arg_factory,
                        default=default,
                        description=description) -> graphql.GraphQLArgument:
                    return graphql.GraphQLArgument(arg_factory.type, default,
                                                   description)

                args[self.aliaser(param_field.alias)] = arg_thunk
        factory = self.visit_with_conv(field.type, field.resolver.conversion)
        return lambda: graphql.GraphQLField(
            factory.type,  # type: ignore
            {name: arg()
             for name, arg in args.items()} if args else None,
            resolve,
            field.subscribe,
            field.description,
            field.deprecated,
        )
Exemplo n.º 7
0
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if not hasattr(cls, "mutate"):
            return
        if not isinstance(cls.__dict__["mutate"], (classmethod, staticmethod)):
            raise TypeError(
                f"{cls.__name__}.mutate must be a classmethod/staticmethod")
        mutate = getattr(cls, "mutate")
        type_name(f"{cls.__name__}Payload")(cls)
        types = get_type_hints(mutate,
                               localns={cls.__name__: cls},
                               include_extras=True)
        async_mutate = is_async(mutate, types)
        fields: List[Tuple[str, AnyType, Field]] = []
        cmi_param = None
        for param_name, param in signature(mutate).parameters.items():
            if param.kind is Parameter.POSITIONAL_ONLY:
                raise TypeError("Positional only parameters are not supported")
            if param.kind in {
                    Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY
            }:
                if param_name not in types:
                    raise TypeError("Mutation parameters must be typed")
                field_type = types[param_name]
                field_ = MISSING if param.default is Parameter.empty else param.default
                if is_union_of(field_type, ClientMutationId):
                    cmi_param = param_name
                    if cls._client_mutation_id is False:
                        if field_ is MISSING:
                            raise TypeError(
                                "Cannot have a ClientMutationId parameter"
                                " when _client_mutation_id = False")
                        continue
                    elif cls._client_mutation_id is True:
                        field_ = MISSING
                    field_ = field(default=field_,
                                   metadata=alias(CLIENT_MUTATION_ID))
                fields.append((param_name, field_type, field_))
        field_names = [name for (name, _, _) in fields]
        if cmi_param is None and cls._client_mutation_id is not False:
            fields.append((
                CLIENT_MUTATION_ID,
                ClientMutationId
                if cls._client_mutation_id else Optional[ClientMutationId],
                MISSING if cls._client_mutation_id else None,
            ))
            cmi_param = CLIENT_MUTATION_ID
        input_cls = make_dataclass(f"{cls.__name__}Input", fields)

        def wrapper(input):
            return mutate(
                **{name: getattr(input, name)
                   for name in field_names})

        wrapper.__annotations__["input"] = input_cls
        wrapper.__annotations__[
            "return"] = Awaitable[cls] if async_mutate else cls
        if cls._client_mutation_id is not False:
            cls.__annotations__[
                CLIENT_MUTATION_ID] = input_cls.__annotations__[cmi_param]
            setattr(cls, CLIENT_MUTATION_ID, field(init=False))
            wrapped = wrapper

            if async_mutate:

                async def wrapper(input):
                    result = await wrapped(input)
                    setattr(result, CLIENT_MUTATION_ID,
                            getattr(input, cmi_param))
                    return result

            else:

                def wrapper(input):
                    result = wrapped(input)
                    setattr(result, CLIENT_MUTATION_ID,
                            getattr(input, cmi_param))
                    return result

            wrapper = wraps(wrapped)(wrapper)

        cls._mutation = Mutation_(
            function=wrapper,
            alias=camel_to_snake(cls.__name__),
            schema=cls._schema,
            error_handler=cls._error_handler,
        )
Exemplo n.º 8
0
def resolver_resolve(
    resolver: Resolver,
    types: Mapping[str, AnyType],
    aliaser: Aliaser,
    default_deserialization: DefaultConversion,
    default_serialization: DefaultConversion,
    serialized: bool = True,
) -> Callable:
    # graphql deserialization will give Enum objects instead of strings
    def handle_enum(tp: AnyType) -> Optional[AnyConversion]:
        if is_type(tp) and issubclass(tp, Enum):
            return Conversion(identity, source=Any, target=tp)
        return default_deserialization(tp)

    parameters, info_parameter = [], None
    for param in resolver.parameters:
        param_type = types[param.name]
        if is_union_of(param_type, graphql.GraphQLResolveInfo):
            info_parameter = param.name
        else:
            param_field = ObjectField(
                param.name,
                param_type,
                param.default is Parameter.empty,
                resolver.parameters_metadata.get(param.name, empty_dict),
                param.default,
            )
            deserializer = deserialization_method(
                param_type,
                additional_properties=False,
                aliaser=aliaser,
                coerce=False,
                conversion=param_field.deserialization,
                default_conversion=handle_enum,
                fall_back_on_default=False,
                schema=param_field.schema,
            )
            opt_param = is_union_of(param_type, NoneType) or param.default is None
            parameters.append(
                (
                    aliaser(param_field.alias),
                    param.name,
                    deserializer,
                    opt_param,
                    param_field.required,
                )
            )
    func, error_handler = resolver.func, resolver.error_handler
    method_factory = partial_serialization_method_factory(
        aliaser, resolver.conversion, default_serialization
    )

    serialize_result: Callable[[Any], Any]
    if not serialized:
        serialize_result = lambda res: res
    elif is_async(resolver.func):
        serialize_result = as_async(method_factory(types["return"]))
    else:
        serialize_result = method_factory(types["return"])
    serialize_error: Optional[Callable[[Any], Any]]
    if error_handler is None:
        serialize_error = None
    elif is_async(error_handler):
        serialize_error = as_async(method_factory(resolver.error_type()))
    else:
        serialize_error = method_factory(resolver.error_type())

    def resolve(__self, __info, **kwargs):
        values = {}
        errors: Dict[str, ValidationError] = {}
        for alias, param_name, deserializer, opt_param, required in parameters:
            if alias in kwargs:
                # It is possible for the parameter to be non-optional in Python
                # type hints but optional in the generated schema. In this case
                # we should ignore it.
                # See: https://github.com/wyfo/apischema/pull/130#issuecomment-845497392
                if not opt_param and kwargs[alias] is None:
                    assert not required
                    continue
                try:
                    values[param_name] = deserializer(kwargs[alias])
                except ValidationError as err:
                    errors[aliaser(param_name)] = err
            elif opt_param and required:
                values[param_name] = None

        if errors:
            raise ValueError(ValidationError(children=errors).errors)
        if info_parameter:
            values[info_parameter] = __info
        try:
            return serialize_result(func(__self, **values))
        except Exception as error:
            if error_handler is None:
                raise
            assert serialize_error is not None
            return serialize_error(error_handler(error, __self, __info, **kwargs))

    return resolve