예제 #1
0
파일: type.py 프로젝트: tbarnier/strawberry
    def wrap(cls):
        if not fields:
            raise MissingFieldsListError(model)

        model_fields = model.__fields__
        fields_set = set(fields)

        all_fields = [(
            name,
            get_type_for_field(field),
            dataclasses.field(default=strawberry.field(
                name=to_camel_case(field.alias))),
        ) for name, field in model_fields.items() if name in fields_set]

        cls_annotations = getattr(cls, "__annotations__", {})
        all_fields.extend(((
            name,
            type_,
            dataclasses.field(default=strawberry.field(
                name=to_camel_case(name))),
        ) for name, type_ in cls_annotations.items()))

        cls = dataclasses.make_dataclass(
            cls.__name__,
            all_fields,
        )

        _process_type(
            cls,
            name=name,
            is_input=is_input,
            is_interface=is_interface,
            description=description,
            federation=federation,
        )

        model._strawberry_type = cls  # type: ignore

        def from_pydantic(instance: Any, extra: Dict[str, Any] = None) -> Any:
            return convert_pydantic_model_to_strawberry_class(
                cls=cls, model_instance=instance, extra=extra)

        def to_pydantic(self) -> Any:
            instance_kwargs = dataclasses.asdict(self)

            return model(**instance_kwargs)

        cls.from_pydantic = staticmethod(from_pydantic)
        cls.to_pydantic = to_pydantic

        return cls
예제 #2
0
def get_field_name(field_name):
    """ Check if name attribute Extract field name """
    snake_name = str_converters.to_snake_case(field_name)
    camel_name = str_converters.to_camel_case(field_name)
    if camel_name == snake_name or camel_name == field_name:
        return ""
    else:
        return field_name
예제 #3
0
def _get_fields(cls: Type) -> List[FieldDefinition]:
    fields = []

    # get all the fields from the dataclass
    dataclass_fields = dataclasses.fields(cls)

    # plus the fields that are defined with the resolvers, using
    # the @strawberry.field decorator
    dataclass_fields += tuple(
        field for field in cls.__dict__.values() if hasattr(field, "_field_definition")
    )

    seen_fields = set()

    for field in dataclass_fields:
        if hasattr(field, "_field_definition"):
            field_definition = field._field_definition  # type: ignore

            # we make sure that the origin is either the field's resolver
            # when called as:
            # >>> @strawberry.field
            # >>> def x(self): ...
            # or the class where this field was defined, so we always have
            # the correct origin for determining field types when resolving
            # the types.

            field_definition.origin = field_definition.origin or cls
        else:
            # for fields that don't have a field definition, we create one
            # based on the dataclass field

            field_definition = FieldDefinition(
                origin_name=field.name,
                name=to_camel_case(field.name),
                type=field.type,
                origin=cls,
            )

        fields.append(field_definition)
        seen_fields.add(field_definition.origin_name)

    # let's also add fields that are declared with @strawberry.field in
    # parent classes, we do this by checking if parents have a type definition
    # and we haven't seen a field already

    # TODO: maybe we want to add a warning when overriding a field, as it might be
    # a mistake

    for base in cls.__bases__:
        if hasattr(base, "_type_definition"):
            fields += [
                field
                for field in base._type_definition.fields  # type: ignore
                if field.origin_name not in seen_fields
            ]

    return fields
예제 #4
0
    def get_graphql_name(self, auto_camel_case: bool) -> str:
        if self.graphql_name is not None:
            return self.graphql_name

        assert self.python_name

        if auto_camel_case:
            return to_camel_case(self.python_name)

        return self.python_name
예제 #5
0
    def get_graphql_name(self, obj: HasGraphQLName) -> str:
        if obj.graphql_name is not None:
            return obj.graphql_name

        assert obj.python_name

        if self.auto_camel_case:
            return to_camel_case(obj.python_name)

        return obj.python_name
예제 #6
0
    def _wrap(f):
        directive_name = name or to_camel_case(f.__name__)

        f.directive_definition = DirectiveDefinition(
            name=directive_name,
            locations=locations,
            description=description,
            resolver=f,
        )

        return f
예제 #7
0
def _get_fields(cls: Type) -> List[FieldDefinition]:
    """Get all the strawberry field definitions off a strawberry.type cls

    This function returns a list of FieldDefinitions (one for each field item),
    without duplicates, while also paying attention the name and typing of the
    field.

    Strawberry fields can be defined on a strawberry.type class in 4 different
    ways:

    >>> import strawberry
    >>> @strawberry.type
    ... class Query:
    ...     type_1: int = 5
    ...     type_2a: int = strawberry.field(...)
    ...     type_2b: int = strawberry.field(resolver=...)
    ...     @strawberry.field
    ...     def type_2c(self) -> int:
    ...         ...

    Type #1:
        A pure dataclass-style field. Will not have a FieldDefinition; one will
        need to be created in this function. Type annotation is required.

    Type #2a:
        A field defined using strawberry.field as a function, but without
        supplying a resolver. Again, a FieldDefinition will need to be created
        in this function. Type annotation is required.

    Type #2b:
        A field defined using strawberry.field as a function, with a supplied
        resolver.

        The type hint is optional, but if supplied, it must match the return
        type of the resolver. Type annnotation is required if resolver is not
        type annotated; if both are annotated, they must match.

        Implementation note: If both type annotations are provided, there will
        be a redundant Type #1-style entry in the dataclass' field list. If the
        strawberry.field call does not specify a `name`, the name of the field
        on the class will be used.

    Type #2c:
        A field defined using @strawberry.field as a decorator around the
        resolver. The resolver must be type-annotated.

    Final `name` attribute priority:
    1. Type #2 `name` attribute. This will be defined with an explicit
       strawberry.field(name=...). No camelcase-ification will be done, as the
       user has explicitly stated the field name.
    2. Field name on the cls. Will exist for all fields other than Type #2c.
       Field names will be converted to camelcase.
    """
    field_definitions: Dict[str, FieldDefinition] = {}

    # before trying to find any fields, let's first add the fields defined in
    # parent classes, we do this by checking if parents have a type definition
    for base in cls.__bases__:
        if hasattr(base, "_type_definition"):
            base_field_definitions = {
                field.origin_name: field
                # TODO: we need to rename _fields to something else
                for field in base._type_definition._fields  # type: ignore
            }

            # Add base's field definitions to cls' field definitions
            field_definitions = {**field_definitions, **base_field_definitions}

    # then we can proceed with finding the fields for the current class

    # type #1 fields
    type_1_fields: Dict[str, dataclasses.Field] = {
        field.name: field
        for field in dataclasses.fields(cls)
    }

    # type #2 fields
    type_2_fields: Dict[str, dataclasses.Field] = {}
    for field_name, field in cls.__dict__.items():
        if hasattr(field, "_field_definition"):
            type_2_fields[field_name] = field

    for field_name, field in type_2_fields.items():
        field_definition: FieldDefinition = field._field_definition

        # Check if there is a matching type #1 field:
        if field_name in type_1_fields:
            # Make sure field and resolver types are the same if both are
            # defined
            # TODO: https://github.com/strawberry-graphql/strawberry/issues/396
            # >>> assert field.type == resolver.type

            # Grab the type from the field if the resolver has no type
            if field_definition.type is None:
                field_type = type_1_fields[field_name].type
                field_definition.type = field_type

            # Stop tracking the type #1 field, an explicit strawberry.field was
            # defined
            type_1_fields.pop(field_name)

        # Otherwise, ensure that a resolver has been specified
        else:
            if field_definition.base_resolver is None:
                # This should be caught by _wrap_dataclass in type.py, but just
                # in case, we'll check again
                raise MissingFieldAnnotationError(field_name)

            # resolver with @strawberry.field decorator must be typed
            if field_definition.type is None:
                resolver_name = field_definition.base_resolver.__name__
                raise MissingReturnAnnotationError(resolver_name)

    all_fields = {**type_1_fields, **type_2_fields}

    for field_name, field in all_fields.items():
        if hasattr(field, "_field_definition"):
            # Use the existing FieldDefinition
            field_definition = field._field_definition

            # Check that the field type is not Private
            if isinstance(field_definition.type, Private):
                raise PrivateStrawberryFieldError(field.name, cls.__name__)

            if not field_definition.name:
                field_definition.name = to_camel_case(field_name)
                field_definition.origin_name = field_name

            # we make sure that the origin is either the field's resolver when
            # called as:
            #
            # >>> @strawberry.field
            # ... def x(self): ...
            #
            # or the class where this field was defined, so we always have
            # the correct origin for determining field types when resolving
            # the types.
            field_definition.origin = field_definition.origin or cls

        else:
            # if the field doesn't have a field definition and has already been
            # process we skip the creation of the field definition, as it seems
            # dataclasses recreates the field in some cases when extending other
            # dataclasses.
            if field_name in field_definitions:
                continue

            if isinstance(field.type, Private):
                continue

            # Create a FieldDefinition, for fields of Types #1 and #2a
            field_definition = FieldDefinition(
                origin_name=field.name,
                name=to_camel_case(field.name),
                type=field.type,
                origin=cls,
                default_value=getattr(cls, field.name, undefined),
            )

        field_name = cast(str, field_definition.origin_name)
        field_definitions[field_name] = field_definition

    return list(field_definitions.values())
예제 #8
0
def _get_fields(cls: Type) -> List[StrawberryField]:
    """Get all the strawberry fields off a strawberry.type cls

    This function returns a list of StrawberryFields (one for each field item), while
    also paying attention the name and typing of the field.

    StrawberryFields can be defined on a strawberry.type class as either a dataclass-
    style field or using strawberry.field as a decorator.

    >>> import strawberry
    >>> @strawberry.type
    ... class Query:
    ...     type_1a: int = 5
    ...     type_1b: int = strawberry.field(...)
    ...     type_1c: int = strawberry.field(resolver=...)
    ...
    ...     @strawberry.field
    ...     def type_2(self) -> int:
    ...         ...

    Type #1:
        A pure dataclass-style field. Will not have a StrawberryField; one will need to
        be created in this function. Type annotation is required.

    Type #2:
        A field defined using @strawberry.field as a decorator around the resolver. The
        resolver must be type-annotated.

    The StrawberryField.python_name value will be assigned to the field's name on the
    class if one is not set by either using an explicit strawberry.field(name=...) or by
    passing a named function (i.e. not an anonymous lambda) to strawberry.field
    (typically as a decorator).
    """
    # Deferred import to avoid import cycles
    from strawberry.field import StrawberryField

    fields: Dict[str, StrawberryField] = {}

    # before trying to find any fields, let's first add the fields defined in
    # parent classes, we do this by checking if parents have a type definition
    for base in cls.__bases__:
        if hasattr(base, "_type_definition"):
            base_fields = {
                field.graphql_name: field
                # TODO: we need to rename _fields to something else
                for field in base._type_definition._fields  # type: ignore
            }

            # Add base's fields to cls' fields
            fields = {**fields, **base_fields}

    # then we can proceed with finding the fields for the current class
    for field in dataclasses.fields(cls):

        if isinstance(field, StrawberryField):
            # Check that the field type is not Private
            if isinstance(field.type, Private):
                raise PrivateStrawberryFieldError(field.python_name,
                                                  cls.__name__)

            # Check that default is not set if a resolver is defined
            if field.default != dataclasses.MISSING and field.base_resolver is not None:
                raise FieldWithResolverAndDefaultValueError(
                    field.python_name, cls.__name__)

            # Check that default_factory is not set if a resolver is defined
            # Note: using getattr because of this issue:
            # https://github.com/python/mypy/issues/6910
            if (getattr(field, "default_factory") !=
                    dataclasses.MISSING  # noqa
                    and field.base_resolver is not None):
                raise FieldWithResolverAndDefaultFactoryError(
                    field.python_name, cls.__name__)

            # we make sure that the origin is either the field's resolver when
            # called as:
            #
            # >>> @strawberry.field
            # ... def x(self): ...
            #
            # or the class where this field was defined, so we always have
            # the correct origin for determining field types when resolving
            # the types.
            field.origin = field.origin or cls

        # Create a StrawberryField for fields that didn't use strawberry.field
        else:
            # Only ignore Private fields that weren't defined using StrawberryFields
            if isinstance(field.type, Private):
                continue

            field_type = field.type

            # Create a StrawberryField, for fields of Types #1 and #2a
            field = StrawberryField(
                python_name=field.name,
                graphql_name=to_camel_case(field.name),
                type_=field_type,
                origin=cls,
                default_value=getattr(cls, field.name, UNSET),
            )

        field_name = field.graphql_name

        assert_message = "Field must have a name by the time the schema is generated"
        assert field_name is not None, assert_message

        # TODO: Raise exception if field_name already in fields
        fields[field_name] = field

    return list(fields.values())
예제 #9
0
def _get_fields(cls: Type) -> List[FieldDefinition]:
    """Get all the strawberry field definitions off a strawberry.type cls

    This function returns a list of FieldDefinitions (one for each field item), while
    also paying attention the name and typing of the field.

    Strawberry fields can be defined on a strawberry.type class as either a dataclass-
    style field or using strawberry.field as a decorator.

    >>> import strawberry
    >>> @strawberry.type
    ... class Query:
    ...     type_1a: int = 5
    ...     type_1b: int = strawberry.field(...)
    ...     type_1c: int = strawberry.field(resolver=...)
    ...
    ...     @strawberry.field
    ...     def type_2(self) -> int:
    ...         ...

    Type #1:
        A pure dataclass-style field. Will not have a FieldDefinition; one will
        need to be created in this function. Type annotation is required.

    Type #2:
        A field defined using @strawberry.field as a decorator around the
        resolver. The resolver must be type-annotated.

    The FieldDefinition.name value will be assigned to the field's name on the class if
    one is not set by either using an explicit strawberry.field(name=...) or by passing
    a named function (i.e. not an anonymous lambda) to strawberry.field (typically as a
    decorator).
    """
    field_definitions: Dict[str, FieldDefinition] = {}

    # before trying to find any fields, let's first add the fields defined in
    # parent classes, we do this by checking if parents have a type definition
    for base in cls.__bases__:
        if hasattr(base, "_type_definition"):
            base_field_definitions = {
                field.origin_name: field
                # TODO: we need to rename _fields to something else
                for field in base._type_definition._fields  # type: ignore
            }

            # Add base's field definitions to cls' field definitions
            field_definitions = {**field_definitions, **base_field_definitions}

    # Deferred import to avoid import cycles
    from strawberry.field import StrawberryField

    # then we can proceed with finding the fields for the current class
    for field in dataclasses.fields(cls):

        if isinstance(field, StrawberryField):
            # Use the existing FieldDefinition
            field_definition = field._field_definition

            # Check that the field type is not Private
            if isinstance(field_definition.type, Private):
                raise PrivateStrawberryFieldError(field.name, cls.__name__)

            # we make sure that the origin is either the field's resolver when
            # called as:
            #
            # >>> @strawberry.field
            # ... def x(self): ...
            #
            # or the class where this field was defined, so we always have
            # the correct origin for determining field types when resolving
            # the types.
            field_definition.origin = field_definition.origin or cls
            field_definition.origin_name = field.name

        # Create a FieldDefinition for fields that didn't use strawberry.field
        else:
            # Only ignore Private fields that weren't defined using StrawberryFields
            if isinstance(field.type, Private):
                continue

            # Create a FieldDefinition, for fields of Types #1 and #2a
            field_definition = FieldDefinition(
                origin_name=field.name,
                name=to_camel_case(field.name),
                type=field.type,
                origin=cls,
                default_value=getattr(cls, field.name, undefined),
            )

        field_name = field_definition.origin_name

        assert_message = "Field must have a name by the time the schema is generated"
        assert field_name is not None, assert_message

        field_definitions[field_name] = field_definition

    return list(field_definitions.values())