def from_union(self, union: StrawberryUnion) -> GraphQLUnionType: # Don't reevaluate known types if union.name in self.type_map: graphql_union = self.type_map[union.name].implementation assert isinstance(graphql_union, GraphQLUnionType) # For mypy return graphql_union graphql_types: List[GraphQLObjectType] = [] for type_ in union.types: graphql_type = self.from_type(type_) assert isinstance(graphql_type, GraphQLObjectType) graphql_types.append(graphql_type) graphql_union = GraphQLUnionType( name=union.name, types=graphql_types, description=union.description, resolve_type=union.get_type_resolver(self.type_map), ) self.type_map[union.name] = ConcreteType(definition=union, implementation=graphql_union) return graphql_union
def create_union(self, evaled_type) -> "StrawberryUnion": # Prevent import cycles from strawberry.union import StrawberryUnion # TODO: Deal with Forward References/origin if isinstance(evaled_type, StrawberryUnion): return evaled_type types = evaled_type.__args__ union = StrawberryUnion(type_annotations=tuple( StrawberryAnnotation(type_) for type_ in types), ) return union
def test_union(): @strawberry.type class Un: fi: int @strawberry.type class Ion: eld: float union = StrawberryUnion( name="UnionName", type_annotations=(StrawberryAnnotation(Un), StrawberryAnnotation(Ion)), ) annotation = StrawberryAnnotation(union) field = StrawberryField(type_annotation=annotation) assert field.type is union
def test_python_union_short_syntax(): @strawberry.type class User: name: str @strawberry.type class Error: name: str annotation = StrawberryAnnotation(User | Error) resolved = annotation.resolve() assert isinstance(resolved, StrawberryUnion) assert resolved.types == (User, Error) assert resolved == StrawberryUnion( name="UserError", type_annotations=(StrawberryAnnotation(User), StrawberryAnnotation(Error)), ) assert resolved == Union[User, Error]
def test_strawberry_union(): @strawberry.type class User: name: str @strawberry.type class Error: name: str cool_union = union(name="CoolUnion", types=(User, Error)) annotation = StrawberryAnnotation(cool_union) resolved = annotation.resolve() assert isinstance(resolved, StrawberryUnion) assert resolved.types == (User, Error) assert resolved == StrawberryUnion( name="CoolUnion", type_annotations=(StrawberryAnnotation(User), StrawberryAnnotation(Error)), ) assert resolved != Union[User, Error] # Name will be different
@strawberry.type class TypeB: age: int @pytest.mark.parametrize( "types,expected_name", [ ([StrawberryList(str)], "StrListExample"), ([StrawberryList(StrawberryList(str))], "StrListListExample"), ([StrawberryOptional(StrawberryList(str))], "StrListOptionalExample"), ([StrawberryList(StrawberryOptional(str))], "StrOptionalListExample"), ([StrawberryList(Enum)], "EnumListExample"), ([StrawberryUnion("Union", (TypeA, TypeB))], "UnionExample"), # type: ignore ([TypeA], "TypeAExample"), ([CustomInt], "CustomIntExample"), ([TypeA, TypeB], "TypeATypeBExample"), ([TypeA, LazyType["TypeB", "test_names"] ], "TypeATypeBExample"), # type: ignore ], ) def test_name_generation(types, expected_name): config = StrawberryConfig() @strawberry.type class Example(Generic[T]): a: T type_definition = Example._type_definition # type: ignore