def test_type_var(): T = TypeVar("T") annotation = StrawberryAnnotation(T) field = StrawberryField(type_annotation=annotation) assert field.type == T
def test_object(): @strawberry.type class TypeyType: value: str annotation = StrawberryAnnotation(TypeyType) field = StrawberryField(type_annotation=annotation) assert field.type is TypeyType
def test_sort_creation_fields(): has_default = DataclassCreationFields( name="has_default", type_annotation=str, field=StrawberryField( python_name="has_default", graphql_name="has_default", default="default_str", default_factory=UNSET, type_annotation=str, description="description", ), ) has_default_factory = DataclassCreationFields( name="has_default_factory", type_annotation=str, field=StrawberryField( python_name="has_default_factory", graphql_name="has_default_factory", default=UNSET, default_factory=lambda: "default_factory_str", type_annotation=str, description="description", ), ) no_defaults = DataclassCreationFields( name="no_defaults", type_annotation=str, field=StrawberryField( python_name="no_defaults", graphql_name="no_defaults", default=UNSET, default_factory=UNSET, type_annotation=str, description="description", ), ) fields = [has_default, has_default_factory, no_defaults] # should place items with defaults last assert sort_creation_fields(fields) == [ no_defaults, has_default, has_default_factory, ]
def test_lazy_type_field(): # Module path is short and relative because of the way pytest runs the file LazierType = LazyType["LaziestType", "test_lazy_types"] annotation = StrawberryAnnotation(LazierType) field = StrawberryField(type_annotation=annotation) assert isinstance(field.type, LazyType) assert field.type is LazierType assert field.type.resolve_type() is LaziestType # type: ignore
def test_enum(): @strawberry.enum class Egnum(Enum): a = "A" b = "B" annotation = StrawberryAnnotation(Egnum) field = StrawberryField(type_annotation=annotation) # TODO: Remove reference to ._enum_definition with StrawberryEnum assert field.type is Egnum._enum_definition
def test_forward_reference(): global RefForward annotation = StrawberryAnnotation("RefForward", namespace=globals()) field = StrawberryField(type_annotation=annotation) @strawberry.type class RefForward: ref: int assert field.type is RefForward del RefForward
def test_union(): @strawberry.type class Un: fi: int @strawberry.type class Ion: eld: float union = StrawberryUnion( name="UnionName", type_annotations=(StrawberryAnnotation(Un), StrawberryAnnotation(Ion)), ) annotation = StrawberryAnnotation(union) field = StrawberryField(type_annotation=annotation) assert field.type is union
def from_resolver(self, field: StrawberryField) -> Callable: return field.get_wrapped_resolver()
def wrap(cls): if not fields: raise MissingFieldsListError(model) model_fields = model.__fields__ fields_set = set(fields) all_fields: List[Tuple[str, Any, dataclasses.Field]] = [( name, get_type_for_field(field), StrawberryField( python_name=field.name, graphql_name=field.alias if field.has_alias else None, default=field.default if not field.required else UNSET, default_factory=(field.default_factory if field.default_factory else UNSET), type_annotation=get_type_for_field(field), ), ) for name, field in model_fields.items() if name in fields_set] wrapped = _wrap_dataclass(cls) extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped)) private_fields = _get_private_fields(wrapped) all_fields.extend((( field.name, field.type, field, ) for field in extra_fields + private_fields)) # Sort fields so that fields with missing defaults go first # because dataclasses require that fields with no defaults are defined # first missing_default = [] has_default = [] for field in all_fields: if field[2].default is dataclasses.MISSING: missing_default.append(field) else: has_default.append(field) sorted_fields = missing_default + has_default cls = dataclasses.make_dataclass( cls.__name__, sorted_fields, bases=cls.__bases__, ) _process_type( cls, name=name, is_input=is_input, is_interface=is_interface, description=description, federation=federation, ) model._strawberry_type = cls # type: ignore cls._pydantic_type = model # type: ignore def from_pydantic(instance: Any, extra: Dict[str, Any] = None) -> Any: return convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra) def to_pydantic(self) -> Any: instance_kwargs = dataclasses.asdict(self) return model(**instance_kwargs) cls.from_pydantic = staticmethod(from_pydantic) cls.to_pydantic = to_pydantic return cls
def test_optional(): annotation = StrawberryAnnotation(Optional[float]) field = StrawberryField(type_annotation=annotation) assert field.type == Optional[float]
def copy_type_with( base: Type, *types: Type, params_to_type: Dict[Type, Union[Type, StrawberryUnion]] = None) -> Type: if params_to_type is None: params_to_type = {} if isinstance(base, StrawberryUnion): return copy_union_with(base.types, params_to_type=params_to_type, description=base.description) if hasattr(base, "_type_definition"): definition = cast(TypeDefinition, base._type_definition) if definition.type_params: fields = [] type_params = definition.type_params.values() for param, type_ in zip(type_params, types): if is_union(type_): params_to_type[param] = copy_union_with( type_.__args__, params_to_type=params_to_type) else: params_to_type[param] = type_ name = get_name_from_types( params_to_type.values()) + definition.name for field in definition.fields: # Copy federation information federation = FederationFieldParams(**field.federation.__dict__) new_field = StrawberryField( python_name=field.python_name, graphql_name=field.graphql_name, origin=field.origin, type_=field.type, default_value=field.default_value, base_resolver=field.base_resolver, child=field.child, is_child_optional=field.is_child_optional, is_list=field.is_list, is_optional=field.is_optional, is_subscription=field.is_subscription, is_union=field.is_union, federation=federation, permission_classes=field.permission_classes, ) if field.is_list: assert field.child is not None child_type = copy_type_with(field.child.type, params_to_type=params_to_type) new_field.child = StrawberryField( python_name=field.child.python_name, origin=field.child.origin, graphql_name=field.child.graphql_name, is_optional=field.child.is_optional, type_=child_type, ) else: new_field.type = copy_type_with( field.type, params_to_type=params_to_type) fields.append(new_field) type_definition = TypeDefinition( name=name, is_input=definition.is_input, origin=definition.origin, is_interface=definition.is_interface, is_generic=False, federation=definition.federation, interfaces=definition.interfaces, description=definition.description, _fields=fields, ) type_definition._type_params = {} copied_type = builtins.type( name, (base.__origin__, ) if hasattr(base, "__origin__") else (), {"_type_definition": type_definition}, ) if not hasattr(base, "_copies"): base._copies = {} base._copies[types] = copied_type return copied_type if is_type_var(base): # TODO: we ignore the type issue here as we'll improve how types # are represented internally (using StrawberryTypes) so we can improve # typings later return params_to_type[base] # type: ignore return base
def test_literal(): annotation = StrawberryAnnotation(bool) field = StrawberryField(type_annotation=annotation) assert field.type is bool
def test_list(): annotation = StrawberryAnnotation(List[int]) field = StrawberryField(type_annotation=annotation) assert field.type == List[int]
def wrap(cls): model_fields = model.__fields__ fields_set = set(fields) if fields else set([]) if fields: warnings.warn( "`fields` is deprecated, use `auto` type annotations instead", DeprecationWarning, ) existing_fields = getattr(cls, "__annotations__", {}) fields_set = fields_set.union( set(name for name, typ in existing_fields.items() if typ is strawberry.auto)) if all_fields: if fields_set: warnings.warn( "Using all_fields overrides any explicitly defined fields " "in the model, using both is likely a bug", stacklevel=2, ) fields_set = set(model_fields.keys()) if not fields_set: raise MissingFieldsListError(cls) all_model_fields: List[Tuple[str, Any, dataclasses.Field]] = [( name, get_type_for_field(field), StrawberryField( python_name=field.name, graphql_name=field.alias if field.has_alias else None, default=field.default if not field.required else UNSET, default_factory=(field.default_factory if field.default_factory else UNSET), type_annotation=get_type_for_field(field), description=field.field_info.description, ), ) for name, field in model_fields.items() if name in fields_set] wrapped = _wrap_dataclass(cls) extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped)) private_fields = get_private_fields(wrapped) all_model_fields.extend((( field.name, field.type, field, ) for field in extra_fields + private_fields if field.type != strawberry.auto)) # Sort fields so that fields with missing defaults go first # because dataclasses require that fields with no defaults are defined # first missing_default = [] has_default = [] for field in all_model_fields: if field[2].default is dataclasses.MISSING: missing_default.append(field) else: has_default.append(field) sorted_fields = missing_default + has_default cls = dataclasses.make_dataclass( cls.__name__, sorted_fields, bases=cls.__bases__, ) _process_type( cls, name=name, is_input=is_input, is_interface=is_interface, description=description, directives=directives, ) model._strawberry_type = cls # type: ignore cls._pydantic_type = model # type: ignore def from_pydantic(instance: Any, extra: Dict[str, Any] = None) -> Any: return convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra) def to_pydantic(self) -> Any: instance_kwargs = dataclasses.asdict(self) return model(**instance_kwargs) cls.from_pydantic = staticmethod(from_pydantic) cls.to_pydantic = to_pydantic return cls
def wrap(cls): model_fields = model.__fields__ fields_set = set(fields) if fields else set([]) if fields: warnings.warn( "`fields` is deprecated, use `auto` type annotations instead", DeprecationWarning, ) existing_fields = getattr(cls, "__annotations__", {}) fields_set = fields_set.union( set(name for name, typ in existing_fields.items() if typ is strawberry.auto) ) if all_fields: if fields_set: warnings.warn( "Using all_fields overrides any explicitly defined fields " "in the model, using both is likely a bug", stacklevel=2, ) fields_set = set(model_fields.keys()) if not fields_set: raise MissingFieldsListError(cls) ensure_all_auto_fields_in_pydantic( model=model, auto_fields=fields_set, cls_name=cls.__name__ ) all_model_fields: List[DataclassCreationFields] = [ DataclassCreationFields( name=field_name, type_annotation=get_type_for_field(field), field=StrawberryField( python_name=field.name, graphql_name=field.alias if field.has_alias else None, # always unset because we use default_factory instead default=UNSET, default_factory=get_default_factory_for_field(field), type_annotation=get_type_for_field(field), description=field.field_info.description, ), ) for field_name, field in model_fields.items() if field_name in fields_set ] wrapped = _wrap_dataclass(cls) extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped)) private_fields = get_private_fields(wrapped) all_model_fields.extend( ( DataclassCreationFields( name=field.name, type_annotation=field.type, field=field, ) for field in extra_fields + private_fields if field.type != strawberry.auto ) ) # Sort fields so that fields with missing defaults go first sorted_fields = sort_creation_fields(all_model_fields) # Implicitly define `is_type_of` to support interfaces/unions that use # pydantic objects (not the corresponding strawberry type) @classmethod # type: ignore def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: return isinstance(obj, (cls, model)) cls = dataclasses.make_dataclass( cls.__name__, [field.to_tuple() for field in sorted_fields], bases=cls.__bases__, namespace={"is_type_of": is_type_of}, ) _process_type( cls, name=name, is_input=is_input, is_interface=is_interface, description=description, directives=directives, ) model._strawberry_type = cls # type: ignore cls._pydantic_type = model # type: ignore def from_pydantic(instance: Any, extra: Dict[str, Any] = None) -> Any: return convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra ) def to_pydantic(self) -> Any: instance_kwargs = dataclasses.asdict(self) return model(**instance_kwargs) cls.from_pydantic = staticmethod(from_pydantic) cls.to_pydantic = to_pydantic return cls
def _get_fields(cls: Type) -> List[StrawberryField]: """Get all the strawberry fields off a strawberry.type cls This function returns a list of StrawberryFields (one for each field item), while also paying attention the name and typing of the field. StrawberryFields can be defined on a strawberry.type class as either a dataclass- style field or using strawberry.field as a decorator. >>> import strawberry >>> @strawberry.type ... class Query: ... type_1a: int = 5 ... type_1b: int = strawberry.field(...) ... type_1c: int = strawberry.field(resolver=...) ... ... @strawberry.field ... def type_2(self) -> int: ... ... Type #1: A pure dataclass-style field. Will not have a StrawberryField; one will need to be created in this function. Type annotation is required. Type #2: A field defined using @strawberry.field as a decorator around the resolver. The resolver must be type-annotated. The StrawberryField.python_name value will be assigned to the field's name on the class if one is not set by either using an explicit strawberry.field(name=...) or by passing a named function (i.e. not an anonymous lambda) to strawberry.field (typically as a decorator). """ # Deferred import to avoid import cycles from strawberry.field import StrawberryField fields: Dict[str, StrawberryField] = {} # before trying to find any fields, let's first add the fields defined in # parent classes, we do this by checking if parents have a type definition for base in cls.__bases__: if hasattr(base, "_type_definition"): base_fields = { field.python_name: field # TODO: we need to rename _fields to something else for field in base._type_definition._fields # type: ignore } # Add base's fields to cls' fields fields = {**fields, **base_fields} # Find the class the each field was originally defined on so we can use # that scope later when resolving the type, as it may have different names # available to it. origins: Dict[str, type] = {field_name: cls for field_name in cls.__annotations__} for base in cls.__mro__: if hasattr(base, "_type_definition"): for field in base._type_definition._fields: # type: ignore if field.python_name in base.__annotations__: origins.setdefault(field.name, base) # then we can proceed with finding the fields for the current class for field in dataclasses.fields(cls): if isinstance(field, StrawberryField): # Check that the field type is not Private if isinstance(field.type, Private): raise PrivateStrawberryFieldError(field.python_name, cls.__name__) # Check that default is not set if a resolver is defined if ( field.default is not dataclasses.MISSING and field.base_resolver is not None ): raise FieldWithResolverAndDefaultValueError( field.python_name, cls.__name__ ) # Check that default_factory is not set if a resolver is defined # Note: using getattr because of this issue: # https://github.com/python/mypy/issues/6910 if ( getattr(field, "default_factory") is not dataclasses.MISSING # noqa and field.base_resolver is not None ): raise FieldWithResolverAndDefaultFactoryError( field.python_name, cls.__name__ ) # we make sure that the origin is either the field's resolver when # called as: # # >>> @strawberry.field # ... def x(self): ... # # or the class where this field was defined, so we always have # the correct origin for determining field types when resolving # the types. field.origin = field.origin or cls # Make sure types are StrawberryAnnotations if not isinstance(field.type_annotation, StrawberryAnnotation): module = sys.modules[field.origin.__module__] field.type_annotation = StrawberryAnnotation( annotation=field.type_annotation, namespace=module.__dict__ ) # Create a StrawberryField for fields that didn't use strawberry.field else: # Only ignore Private fields that weren't defined using StrawberryFields if isinstance(field.type, Private): continue field_type = field.type origin = origins.get(field.name, cls) module = sys.modules[origin.__module__] # Create a StrawberryField, for fields of Types #1 and #2a field = StrawberryField( python_name=field.name, graphql_name=None, type_annotation=StrawberryAnnotation( annotation=field_type, namespace=module.__dict__, ), origin=origin, default=getattr(cls, field.name, UNSET), ) field_name = field.python_name assert_message = "Field must have a name by the time the schema is generated" assert field_name is not None, assert_message # TODO: Raise exception if field_name already in fields fields[field_name] = field return list(fields.values())
def wrap(cls): if not fields: raise MissingFieldsListError(model) model_fields = model.__fields__ fields_set = set(fields) all_fields = [ ( name, get_type_for_field(field), StrawberryField( python_name=field.name, graphql_name=field.alias if field.has_alias else None, default_value=field.default if not field.required else UNSET, default_factory=( field.default_factory if field.default_factory else UNSET ), type_=get_type_for_field(field), ), ) for name, field in model_fields.items() if name in fields_set ] cls_annotations = getattr(cls, "__annotations__", {}) all_fields.extend( ( ( name, type_, StrawberryField( python_name=name, graphql_name=None, type_=type_, # we need a default value when adding additional fields # on top of a type generated from Pydantic, this is because # Pydantic Optional fields always have None as default value # which breaks dataclasses generation; as we can't define # a field without a default value after one with a default value # adding fields at the beginning won't work as we will also # support default values on them (so the problem will be just # shifted around) default_value=None, ), ) for name, type_ in cls_annotations.items() ) ) cls = dataclasses.make_dataclass( cls.__name__, all_fields, ) _process_type( cls, name=name, is_input=is_input, is_interface=is_interface, description=description, federation=federation, ) model._strawberry_type = cls # type: ignore def from_pydantic(instance: Any, extra: Dict[str, Any] = None) -> Any: return convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra ) def to_pydantic(self) -> Any: instance_kwargs = dataclasses.asdict(self) return model(**instance_kwargs) cls.from_pydantic = staticmethod(from_pydantic) cls.to_pydantic = to_pydantic return cls
def resolve_type_field(field: StrawberryField) -> None: # TODO: This should be handled by StrawberryType in the future if isinstance(field.type, str): module = sys.modules[field.origin.__module__] field.type = eval(field.type, module.__dict__) if isinstance(field.type, LazyType): field.type = field.type.resolve_type() if is_forward_ref(field.type): # if the type is a forward reference we try to resolve the type by # finding it in the global namespace of the module where the field # was initially declared. This will break when the type is not declared # in the main scope, but we don't want to support that use case # see https://mail.python.org/archives/list/[email protected]/thread/SNKJB2U5S74TWGDWVD6FMXOP63WVIGDR/ # noqa: E501 type_name = field.type.__forward_arg__ module = sys.modules[field.origin.__module__] # TODO: we should probably raise an error if we can't find the type field.type = module.__dict__[type_name] return if is_async_generator(field.type): # TODO: shall we raise a warning if field is not used in a subscription? # async generators are used in subscription, we only need the yield type # https://docs.python.org/3/library/typing.html#typing.AsyncGenerator field.type = get_async_generator_annotation(field.type) return resolve_type_field(field) # check for Optional[A] which is represented as Union[A, None], we # have an additional check for proper unions below if is_optional(field.type) and len(field.type.__args__) == 2: # this logics works around List of optionals and Optional lists of Optionals: # >>> Optional[List[Str]] # >>> Optional[List[Optional[Str]]] # the field is only optional if it is not a list or if it was already optional # since we mark the child as optional when the field is a list field.is_optional = True and not field.is_list or field.is_optional field.is_child_optional = field.is_list field.type = get_optional_annotation(field.type) return resolve_type_field(field) elif is_list(field.type): child_field = StrawberryField( python_name=None, graphql_name=None, origin=field.origin, # type: ignore type_=get_list_annotation(field.type), ) resolve_type_field(child_field) field.is_list = True field.child = child_field # TODO: Fix StrawberryField.type typing field.type = typing.cast(type, None) return # case for Union[A, B, C], it also handles Optional[Union[A, B, C]] as optionals # type hints are represented as Union[..., None]. elif is_union(field.type): # Optional[Union[A, B]] is represented as Union[A, B, None] so we need # too check again if the field is optional as the check above only checks # for single Optionals field.is_optional = is_optional(field.type) types = field.type.__args__ # we use a simplified version of resolve_type since unions in GraphQL # are simpler and cannot contain lists or optionals types = tuple( _resolve_generic_type(t, field.python_name) for t in types if t is not None.__class__) field.is_union = True # TODO: Fix StrawberryField.type typing strawberry_union = typing.cast( type, union(get_name_from_types(types), types)) field.type = strawberry_union # case for Type[A], we want to convert generics to have the concrete types # when we pass them, so that we don't have to deal with generics when # generating the GraphQL types later on. elif (hasattr(field.type, "_type_definition") and field.type._type_definition.is_generic): args = get_args(field.type) # raise an error when using generics without passing any type parameter, ie: # >>> class X(Generic[T]): ... # >>> a: X # instead of # >>> a: X[str] if len(args) == 0: raise MissingTypesForGenericError(field.python_name, field.type) # we only make a copy when all the arguments are not type vars if not all(is_type_var(a) for a in args): field.type = copy_type_with(field.type, *args) if isinstance(field.type, StrawberryUnion): field.is_union = True