Ejemplo n.º 1
0
def _get_entity_type(type_map: TypeMap):
    # https://www.apollographql.com/docs/apollo-server/federation/federation-spec/#resolve-requests-for-entities

    # To implement the _Entity union, each type annotated with @key
    # should be added to the _Entity union.

    federation_key_types = [
        type.implementation for type in type_map.values()
        if _has_federation_keys(type.definition)
    ]

    # If no types are annotated with the key directive, then the _Entity
    # union and Query._entities field should be removed from the schema.
    if not federation_key_types:
        return None

    entity_type = GraphQLUnionType("_Entity",
                                   federation_key_types)  # type: ignore

    def _resolve_type(self, value, _type):
        return type_map[self._type_definition.name].implementation

    entity_type.resolve_type = _resolve_type

    return entity_type
Ejemplo n.º 2
0
def union(name: str, types: typing.Tuple[typing.Type], *, description=None):
    """Creates a new named Union type.

    Example usages:

    >>> strawberry.union(
    >>>     "Name",
    >>>     (A, B),
    >>> )

    >>> strawberry.union(
    >>>     "Name",
    >>>     (A, B),
    >>> )
    """

    from .type_converter import get_graphql_type_for_annotation

    def _resolve_type(root, info, _type):
        if not hasattr(root, "graphql_type"):
            raise WrongReturnTypeForUnion(info.field_name, str(type(root)))

        if is_generic(type(root)):
            return _find_type_for_generic_union(root)

        if root.graphql_type not in _type.types:
            raise UnallowedReturnTypeForUnion(info.field_name, str(type(root)),
                                              _type.types)

        return root.graphql_type

    # TODO: union types don't work with scalar types
    # so we want to return a nice error
    # also we want to make sure we have been passed
    # strawberry types
    graphql_type = GraphQLUnionType(
        name,
        [
            get_graphql_type_for_annotation(type, name, force_optional=True)
            for type in types
        ],
        description=description,
    )
    graphql_type.resolve_type = _resolve_type

    # This is currently a temporary solution, this is ok for now
    # But in future we might want to change this so that it works
    # properly with mypy, but there's no way to return a type like NewType does
    # so we return this class instance as it allows us to reuse the rest of
    # our code without doing too many changes

    class X:
        def __init__(self, graphql_type):
            self.graphql_type = graphql_type

        def __call__(self):
            raise ValueError("Cannot use union type directly")

    return X(graphql_type)
Ejemplo n.º 3
0
 def graphql_type(cls):
     return GraphQLUnionType(
         name=getattr(cls.Meta, "name", cls.__name__),
         description=getattr(cls.Meta, "description", cls.__doc__),
         types=cls.graphql_types,
         resolve_type=cls.resolve_type,
     )
Ejemplo n.º 4
0
    def from_union(self, union: StrawberryUnion) -> GraphQLUnionType:

        # Don't reevaluate known types
        if union.name in self.type_map:
            graphql_union = self.type_map[union.name].implementation
            assert isinstance(graphql_union, GraphQLUnionType)  # For mypy
            return graphql_union

        graphql_types: List[GraphQLObjectType] = []
        for type_ in union.types:
            graphql_type = self.from_type(type_)

            assert isinstance(graphql_type, GraphQLObjectType)

            graphql_types.append(graphql_type)

        graphql_union = GraphQLUnionType(
            name=union.name,
            types=graphql_types,
            description=union.description,
            resolve_type=union.get_type_resolver(self.type_map),
        )

        self.type_map[union.name] = ConcreteType(definition=union,
                                                 implementation=graphql_union)

        return graphql_union
Ejemplo n.º 5
0
def get_union_type(
    union_definition: UnionDefinition,
    type_map: TypeMap,
) -> GraphQLUnionType:
    from .object_type import get_object_type

    def _resolve_type(root, info, _type):
        if not hasattr(root, "_type_definition"):
            raise WrongReturnTypeForUnion(info.field_name, str(type(root)))

        type_definition = root._type_definition

        if is_generic(type(root)):
            type_definition = _find_type_for_generic_union(root)

        returned_type = type_map[type_definition.name].implementation

        if returned_type not in _type.types:
            raise UnallowedReturnTypeForUnion(info.field_name, str(type(root)),
                                              _type.types)

        return returned_type

    types = union_definition.types  # type: ignore

    return GraphQLUnionType(
        union_definition.name,
        [get_object_type(type, type_map) for type in types],
        description=union_definition.description,
        resolve_type=_resolve_type,
    )
Ejemplo n.º 6
0
 def compile_union(self, union: Union) -> GraphQLUnionType:
     assert isinstance(union, Union)
     return GraphQLUnionType(
         name=union.name,
         types=tuple(self.get_graphql_type(t) for t in union.types),
         resolve_type=union.resolve_type,
         description=union.description,
     )
Ejemplo n.º 7
0
def get_graphql_type_for_annotation(annotation,
                                    field_name: str,
                                    force_optional: bool = False):
    # TODO: this might lead to issues with types that have a field value
    is_field_optional = force_optional

    if hasattr(annotation, "field"):
        graphql_type = annotation.field
    else:
        annotation_name = getattr(annotation, "_name", None)

        if annotation_name == "List":
            list_of_type = get_graphql_type_for_annotation(
                annotation.__args__[0], field_name)

            return GraphQLList(list_of_type)

        annotation_origin = getattr(annotation, "__origin__", None)

        if annotation_origin == AsyncGenerator:
            # async generators are used in subscription, we only need the yield type
            # https://docs.python.org/3/library/typing.html#typing.AsyncGenerator
            return get_graphql_type_for_annotation(annotation.__args__[0],
                                                   field_name)

        elif is_union(annotation):
            types = annotation.__args__
            non_none_types = [x for x in types
                              if x != None.__class__]  # noqa:E721

            # optionals are represented as Union[type, None]
            if len(non_none_types) == 1:
                is_field_optional = True
                graphql_type = get_graphql_type_for_annotation(
                    non_none_types[0], field_name, force_optional=True)
            else:
                is_field_optional = None.__class__ in types

                # TODO: union types don't work with scalar types
                # so we want to return a nice error
                # also we want to make sure we have been passed
                # strawberry types
                graphql_type = GraphQLUnionType(field_name,
                                                [type.field for type in types])
        else:
            graphql_type = REGISTRY.get(annotation)

    if not graphql_type:
        raise ValueError(f"Unable to get GraphQL type for {annotation}")

    if is_field_optional:
        return graphql_type

    return GraphQLNonNull(graphql_type)
Ejemplo n.º 8
0
 def map_union(self, ann: types.AUnion) -> GraphQLUnionType:
     self.check_union_name(ann)
     return GraphQLUnionType(
         name=ann.name,
         # translate_annotation returns a NonNull, but we need the underlying for
         # our union
         types=[
             self.translate_annotation_unwrapped(ann)
             for ann in ann.of_types
         ],
     )
Ejemplo n.º 9
0
    def convert(self, type_map: t.Dict[str, GraphQLType]) -> GraphQLUnionType:
        if self.name in type_map:
            return t.cast(GraphQLUnionType, type_map[self.name])
        types: t.List[GraphQLObjectType] = []

        for enum_type in self.types:
            if isinstance(enum_type, str):
                types.append(t.cast(GraphQLObjectType, type_map[enum_type]))
            else:
                types.append(enum_type)
        type_map[self.name] = GraphQLUnionType(self.name, types, self.resolve_types)
        return t.cast(GraphQLUnionType, type_map[self.name])
Ejemplo n.º 10
0
 def __init__(self, session):
     self.type_map = {
         "Test":
         GraphQLObjectType("Test",
                           GraphQLField(GraphQLList(GraphQLString))),
         "TestEmptyObject":
         GraphQLObjectType("TestEmptyObject", {}),
         "TestNestedObjects":
         GraphQLObjectType(
             "TestNestedObjects",
             {
                 "EnumField":
                 GraphQLField(
                     GraphQLEnumType("TestEnum", {
                         "RED": 0,
                         "GREEN": 1,
                         "BLUE": 2
                     })),
                 "InputField":
                 GraphQLInputField(GraphQLNonNull(GraphQLInt)),
                 "List":
                 GraphQLList(GraphQLString),
                 "InputObjectType":
                 GraphQLInputObjectType(
                     "TestInputObject",
                     GraphQLNonNull(
                         GraphQLUnionType("TestUnion",
                                          [GraphQLString, GraphQLID])),
                 ),
                 "ArgumentType":
                 GraphQLArgument(GraphQLBoolean),
                 "Float":
                 GraphQLFloat,
             },
         ),
     }
     self.context = session
    return echo


def resolve_error(root, info):
    raise RuntimeError("Runtime Error!")


try:
    hello_field = GraphQLField(GraphQLString, resolver=resolve_hello)
    library_field = GraphQLField(
        Library,
        resolver=resolve_library,
        args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))},
    )
    search_field = GraphQLField(
        GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)),
        args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))},
    )
    echo_field = GraphQLField(
        GraphQLString,
        resolver=resolve_echo,
        args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))},
    )
    storage_field = GraphQLField(
        Storage,
        resolver=resolve_storage,
    )
    storage_add_field = GraphQLField(
        Storage,
        resolver=resolve_storage_add,
        args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))},