Exemplo n.º 1
0
def find_field(field, fields_dict):
    temp = fields_dict.get(
        field.name.value,
        fields_dict.get(
            to_snake_case(field.name.value),
            None)
    )

    return temp
Exemplo n.º 2
0
def to_snake_case_plus(name: str) -> str:
    """Extends the `to_snake_case` function to account for numbers in the name
    separated by underscores.

    Args:
        name (str): The camel-case name to be converted to snake-case.

    Returns
        str: The converted name.
    """

    name_new = to_snake_case(name=name)
    for entry in list(set(re.findall("\d+", name_new))):
        name_new = name_new.replace(entry, "_{}".format(entry))
    return name_new
Exemplo n.º 3
0
    def get_queryset(cls, info, *_, **__):
        queryset = cls._meta.model.objects.all()
        fields = cls.select_foreign_keys() + cls.select_o2o_related_objects()
        fields_m2m = cls.select_m2m_fields()
        selections = cls.get_selections(info)
        fields_to_select = cls.convert_selections_to_fields(selections, info)

        for field_to_select in fields_to_select:
            field_to_select = to_snake_case(field_to_select)
            if field_to_select in fields:
                queryset = queryset.select_related(field_to_select)
            if field_to_select in fields_m2m:
                queryset = queryset.prefetch_related(field_to_select)

        return queryset
Exemplo n.º 4
0
def recursive_params(selection_set, fragments, available_related_fields,
                     select_related, prefetch_related):

    for field in selection_set.selections:

        if isinstance(field, FragmentSpread) and fragments:
            a, b = recursive_params(fragments[field.name.value].selection_set,
                                    fragments, available_related_fields,
                                    select_related, prefetch_related)
            [select_related.append(x) for x in a if x not in select_related]
            [
                prefetch_related.append(x) for x in b
                if x not in prefetch_related
            ]
            continue

        if isinstance(field, InlineFragment):
            a, b = recursive_params(field.selection_set, fragments,
                                    available_related_fields, select_related,
                                    prefetch_related)
            [select_related.append(x) for x in a if x not in select_related]
            [
                prefetch_related.append(x) for x in b
                if x not in prefetch_related
            ]
            continue

        temp = available_related_fields.get(
            field.name.value,
            available_related_fields.get(to_snake_case(field.name.value),
                                         None))

        if temp and temp.name not in [prefetch_related + select_related]:
            if temp.many_to_many or temp.one_to_many:
                prefetch_related.append(temp.name)
            else:
                select_related.append(temp.name)
        elif getattr(field, 'selection_set', None):
            a, b = recursive_params(field.selection_set, fragments,
                                    available_related_fields, select_related,
                                    prefetch_related)
            [select_related.append(x) for x in a if x not in select_related]
            [
                prefetch_related.append(x) for x in b
                if x not in prefetch_related
            ]

    return select_related, prefetch_related
Exemplo n.º 5
0
def test_create_comment(schema, author, success):
    context = MagicMock()
    comment = next(fake_comment())
    error = Err(fake.pystr())
    query = """
    mutation createComment($meetingId: UUID!, $contentText: String!) {
        createComment(meetingId: $meetingId, contentText: $contentText) {
            commentId
            meetingId
            author {
                ... on Instructor {
                    userName
                }
                ... on Student {
                    studentNumber
                }
            }
            timeStamp
            contentText
        }
    }
    """
    context.api.meeting_api.create_comment.return_value = (Ok(comment) if
                                                           success else error)
    variables = {
        "meetingId": str(comment.meeting_id),
        "contentText": comment.content_text,
    }
    result = schema.execute(query, context=context, variables=variables)
    if success:
        assert not result.errors
        for key, val in result.data["createComment"].items():
            if key == "author":
                if isinstance(comment.author, Instructor):
                    assert val["userName"] == comment.author.user_name
                else:
                    assert val[
                        "studentNumber"] == comment.author.student_number
            else:
                expected = getattr(comment, to_snake_case(key))
                if isinstance(expected, UUID):
                    expected = str(expected)
                assert expected == val
    else:
        assert result.errors
        assert error.unwrap_err() in str(result.errors)
    context.api.meeting_api.create_comment.assert_called_once_with(
        comment.meeting_id, context.user, comment.content_text)
Exemplo n.º 6
0
 def set_fields_and_attrs(
     klazz: Type[ObjectType],
     node_model: Type[SQLAlchemyInterface],
     field_dict: Mapping[str, Field],
 ):
     _name = to_snake_case(node_model.__name__)
     field_dict[
         f"all_{(pluralize_name(_name))}"] = SQLAlchemyFilteredConnectionField(
             node_model)
     field_dict[_name] = node_model.Field()
     setattr(klazz, _name, node_model.Field())
     setattr(
         klazz,
         "all_{}".format(pluralize_name(_name)),
         SQLAlchemyFilteredConnectionField(node_model),
     )
Exemplo n.º 7
0
 def get_reference_objects(*args, **kwargs):
     if args[0][1]:
         document = get_document(args[0][0])
         document_field = mongoengine.ReferenceField(document)
         document_field = convert_mongoengine_field(
             document_field, registry)
         document_field_type = document_field.get_type(
         ).type._meta.name
         only_fields = [
             to_snake_case(i) for i in get_query_fields(
                 args[0][3][0])[document_field_type].keys()
         ]
         return document.objects().no_dereference().only(
             *only_fields).filter(pk__in=args[0][1])
     else:
         return []
Exemplo n.º 8
0
 def reference_resolver(root, *args, **kwargs):
     dereferenced = getattr(root, field.name or field.db_name)
     if dereferenced:
         document = get_document(dereferenced["_cls"])
         document_field = mongoengine.ReferenceField(document)
         document_field = convert_mongoengine_field(document_field,
                                                    registry)
         _type = document_field.get_type().type
         only_fields = _type._meta.only_fields.split(",") if isinstance(
             _type._meta.only_fields, str) else list()
         return document.objects().no_dereference().only(*list(
             set(only_fields + [
                 to_snake_case(i) for i in get_query_fields(args[0])[
                     _type._meta.name].keys()
             ]))).get(pk=dereferenced["_ref"].id)
     return None
def resolver_query(g_object, m_object, args, info, is_list=False, validator=None):
    fields = [to_snake_case(f) for f in get_fields(info)]
    query = parse_operators(args)
    special_params = {k: v for k, v in args.items() if k in ['skip', 'limit']}

    if validator:
        validator(m_object, fields, query, special_params)

    result = do_query(m_object, query, fields, special_params, is_list)

    if result and is_list:
        return [mongo_to_graphene(obj, g_object, fields) for obj in result]
    elif result:
        return mongo_to_graphene(result.first(), g_object, fields)
    else:
        return [] if is_list else None
Exemplo n.º 10
0
    def resolve_queryset(cls, connection, iterable, info, args, filtering_args,
                         filterset_class):
        qs = super().resolve_queryset(
            connection,
            iterable,
            info,
            args,
            filtering_args,
            filterset_class,
        )

        order = args.pop("orderby", None) or []
        if order:
            qs = qs.order_by(*[to_snake_case(o) for o in order])

        return qs
Exemplo n.º 11
0
 def get_reference_objects(*args, **kwargs):
     document = get_document(args[0][0])
     document_field = mongoengine.ReferenceField(document)
     document_field = convert_mongoengine_field(document_field, registry)
     document_field_type = document_field.get_type().type
     queried_fields = list()
     filter_args = list()
     if document_field_type._meta.filter_fields:
         for key, values in document_field_type._meta.filter_fields.items():
             for each in values:
                 filter_args.append(key + "__" + each)
     for each in get_query_fields(args[0][3][0])[document_field_type._meta.name].keys():
         item = to_snake_case(each)
         if item in document._fields_ordered + tuple(filter_args):
             queried_fields.append(item)
     return document.objects().no_dereference().only(
         *set(list(document_field_type._meta.required_fields) + queried_fields)).filter(pk__in=args[0][1])
Exemplo n.º 12
0
def gen_mutation(model, graphene_schema, operators_mutation, fields_mutation, mutate_func, validator):
    """ We need to create a class that seems as follows (http://docs.graphene-python.org/en/latest/types/mutations/):

    class CreatePerson(graphene.Mutation):
        class Input:
            name = graphene.String()
    
        ok = graphene.Boolean()
        person = graphene.Field(lambda: Person)
    
        @staticmethod
        def mutate(root, args, context, info):
            person = Person(name=args.get('name'))
            ok = True
            return CreatePerson(person=person, ok=ok) 
    """

    def user_mutate(root, info, **kwargs):
        if validator:
            validator(model, kwargs, {}, {})

        obj = mutate_func(kwargs, info.context)
        if not isinstance(obj, model):
            raise TypeError('Failed to resolve mutation of the schema {}'
                            ' because mutate function must return a instance of {}, and the return type was {}.'
                            .format(graphene_schema.__name__, model.__name__, type(obj)))

        graphene_obj = mongo_to_graphene(obj, graphene_schema, fields_mutation)
        return Create(**{to_snake_case(model.__name__): graphene_obj})

    def generic_mutate(root, info, **kwargs):
        if validator:
            validator(model, kwargs, {}, {})

        obj = model(**kwargs)
        obj.save()
        graphene_obj = mongo_to_graphene(obj, graphene_schema, fields_mutation)
        return Create(**{to_snake_case(model.__name__): graphene_obj})

    Create = type('Create' + model.__name__, (graphene.Mutation,), {
        'Arguments': type('Arguments', (), operators_mutation),
        to_snake_case(model.__name__): graphene.Field(lambda: graphene_schema),
        'mutate': staticmethod(generic_mutate) if not mutate_func else staticmethod(user_mutate)
    })

    return Create
Exemplo n.º 13
0
def order_queryset(qs, node_class, order_by):
    if order_by is None:
        return qs

    orderable_fields = node_class.ORDERABLE_FIELDS
    if order_by[0] == '-':
        desc = '-'
        order_by = order_by[1:]
    else:
        desc = ''
    order_by = to_snake_case(order_by)
    if order_by not in orderable_fields:
        raise ValueError(
            'Only orderable fields are: %s' %
            ', '.join([to_camel_case(x) for x in orderable_fields]))
    qs = qs.order_by(desc + order_by)
    return qs
Exemplo n.º 14
0
    def get_queryset(self, root, field_name, field_asts, fragments, **kwargs):
        prefetched = get_prefetched_attr(root, to_snake_case(field_name))
        if prefetched:
            return prefetched

        filter_kwargs = {
            k: v
            for k, v in kwargs.items() if k in self.filtering_args
        }
        qs = queryset_factory(
            registry.get_type_for_model(self.type._meta.model), field_asts,
            fragments, **kwargs)
        qs = self.filterset_class(data=filter_kwargs, queryset=qs).qs

        if self.post_optimize:
            qs = self.post_optimize(qs, **kwargs)

        return maybe_queryset(qs)
Exemplo n.º 15
0
 def lazy_resolver(root, *args, **kwargs):
     document = getattr(root, field.name or field.db_name)
     if document:
         queried_fields = list()
         _type = registry.get_type_for_model(document.document_type)
         filter_args = list()
         if _type._meta.filter_fields:
             for key, values in _type._meta.filter_fields.items():
                 for each in values:
                     filter_args.append(key + "__" + each)
         for each in get_query_fields(args[0]).keys():
             item = to_snake_case(each)
             if item in document.document_type._fields_ordered + tuple(filter_args):
                 queried_fields.append(item)
         return document.document_type.objects().no_dereference().only(
             *(set((list(_type._meta.required_fields) + queried_fields)))).get(
             pk=document.pk)
     return None
Exemplo n.º 16
0
    def default_resolver(self, _root, info, only_fields=list(), **args):
        args = args or {}

        if _root is not None:
            field_name = to_snake_case(info.field_name)
            if getattr(_root, field_name, []) is not None:
                args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]

        connection_args = {
            "first": args.pop("first", None),
            "last": args.pop("last", None),
            "before": args.pop("before", None),
            "after": args.pop("after", None),
        }

        _id = args.pop('id', None)

        if _id is not None:
            args['pk'] = from_global_id(_id)[-1]

        if callable(getattr(self.model, "objects", None)):
            iterables = self.get_queryset(self.model, info, only_fields,
                                          **args)
            if isinstance(info, ResolveInfo):
                if not info.context:
                    info.context = Context()
                info.context.queryset = iterables
            list_length = iterables.count()
        else:
            iterables = []
            list_length = 0

        connection = connection_from_list_slice(
            list_slice=iterables,
            args=connection_args,
            list_length=list_length,
            list_slice_length=list_length,
            connection_type=self.type,
            edge_type=self.type.Edge,
            pageinfo_type=graphene.PageInfo,
        )
        connection.iterable = iterables
        connection.list_length = list_length
        return connection
Exemplo n.º 17
0
async def get_node_by_id(root, info, **args):
    """Resolver for returning job, task, family node"""
    field_name = to_snake_case(info.field_name)
    if field_name == 'source_node':
        field_id = getattr(root, 'source', None)
    elif field_name == 'target_node':
        field_id = getattr(root, 'target', None)
    else:
        field_id = getattr(root, field_name, None)
    if field_id:
        args['id'] = field_id
    if args.get('id', None) is None:
        return None
    try:
        obj_type = str(info.return_type.of_type).replace('!', '')
    except AttributeError:
        obj_type = str(info.return_type)
    resolvers = info.context.get('resolvers')
    return await resolvers.get_node_by_id(NODE_MAP[obj_type], args)
Exemplo n.º 18
0
def recursive_params(selection_set, fragments, available_related_fields,
                     select_related, prefetch_related):

    for field in selection_set.selections:

        if isinstance(field, FragmentSpread) and fragments:
            a, b = recursive_params(
                fragments[field.name.value].selection_set,
                fragments,
                available_related_fields,
                select_related, prefetch_related
            )
            [select_related.append(x) for x in a if x not in select_related]
            [prefetch_related.append(x)
             for x in b if x not in prefetch_related]
            continue

        temp = available_related_fields.get(
            field.name.value,
            available_related_fields.get(
                to_snake_case(field.name.value),
                None)
        )

        if temp and temp.name not in [prefetch_related + select_related]:
            if temp.many_to_many or temp.one_to_many:
                prefetch_related.append(temp.name)
            else:
                select_related.append(temp.name)
        elif getattr(field, 'selection_set', None):
            a, b = recursive_params(
                field.selection_set,
                fragments,
                available_related_fields,
                select_related,
                prefetch_related
            )
            [select_related.append(x) for x in a if x not in select_related]
            [prefetch_related.append(x)
             for x in b if x not in prefetch_related]

    return select_related, prefetch_related
def test_gen_mutation(mock_person):
    import graphene
    from graphene.utils.str_converters import to_snake_case
    from graphene.types.field import Field

    from graphene_mongodb.mutation import gen_mutation
    from graphene_mongodb.model import ModelSchema

    model_schema = ModelSchema(mock_person, mock_person._fields, None, None)

    result = gen_mutation(mock_person, model_schema.schema, model_schema.operators_mutation,
                          model_schema.fields_mutation, None, None)

    assert issubclass(result, graphene.Mutation)
    assert hasattr(result, 'mutate')

    assert result._meta.name == 'Create' + mock_person.__name__
    assert isinstance(result._meta.fields[to_snake_case(mock_person.__name__)], Field)

    assert result._meta.arguments == model_schema.operators_mutation
Exemplo n.º 20
0
    def list_resolver(self, manager, filterset_class, filtering_args, root,
                      info, **kwargs):

        filter_kwargs = {
            k: v
            for k, v in kwargs.items() if k in filtering_args
        }
        if self.accessor:
            qs = getattr(root, self.accessor).all()
            qs = filterset_class(data=filter_kwargs,
                                 queryset=qs,
                                 request=info.context).qs
        else:
            qs = self.get_queryset(manager, info, **kwargs)
            qs = filterset_class(data=filter_kwargs,
                                 queryset=qs,
                                 request=info.context).qs
            if root and is_valid_django_model(root._meta.model):
                extra_filters = get_extra_filters(root, manager.model)
                qs = qs.filter(**extra_filters)
        count = qs.count()

        if getattr(self, "pagination", None):
            ordering = kwargs.pop(self.pagination.ordering_param,
                                  None) or self.pagination.ordering
            ordering = ','.join([
                to_snake_case(each)
                for each in ordering.strip(',').replace(' ', '').split(',')
            ])
            self.pagination.ordering = ordering
            qs = self.pagination.paginate_queryset(qs, **kwargs)

        return CustomDjangoListObjectBase(
            count=count,
            results=maybe_queryset(qs),
            results_field_name=self.type._meta.results_field_name,
            page=kwargs.get('page', 1)
            if hasattr(self.pagination, 'page') else None,
            pageSize=kwargs.get(
                'pageSize', graphql_api_settings.DEFAULT_PAGE_SIZE) if hasattr(
                    self.pagination, 'page') else None)
Exemplo n.º 21
0
async def get_nodes_by_id(root, info, **args):
    """Resolver for returning job, task, family node"""
    field_name = to_snake_case(info.field_name)
    field_ids = getattr(root, field_name, None)
    if hasattr(args, 'id'):
        args['ids'] = [args.get('id')]
    if field_ids:
        if isinstance(field_ids, str):
            field_ids = [field_ids]
        args['native_ids'] = field_ids
    elif field_ids == []:
        return []
    try:
        obj_type = str(info.return_type.of_type).replace('!', '')
    except AttributeError:
        obj_type = str(info.return_type)
    node_type = NODE_MAP[obj_type]
    args['ids'] = [parse_node_id(n_id, node_type) for n_id in args['ids']]
    args['exids'] = [parse_node_id(n_id, node_type) for n_id in args['exids']]
    resolvers = info.context.get('resolvers')
    return await resolvers.get_nodes_by_id(node_type, args)
Exemplo n.º 22
0
        def resolve_entities(parent, info, representations):
            entities = []
            for representation in representations:
                model = custom_entities[representation["__typename"]]
                model_aguments = representation.copy()
                model_aguments.pop("__typename")
                # todo use schema to identify correct mapping for field names
                if auto_camelcase:
                    model_aguments = {to_snake_case(k): v for k, v in model_aguments.items()}
                model_instance = model(**model_aguments)

                try:
                    resolver = getattr(
                        model, "_%s__resolve_reference" % representation["__typename"])
                except AttributeError:
                    pass
                else:
                    model_instance = resolver(model_instance, info)

                entities.append(model_instance)
            return entities
Exemplo n.º 23
0
    def __init_subclass_with_meta__(
        cls,
        _meta=None,
        model=None,
        permissions=None,
        login_required=None,
        only_fields=(),
        exclude_fields=(),
        return_field_name=None,
        **kwargs,
    ):
        registry = get_global_registry()
        model_type = registry.get_type_for_model(model)

        assert model_type, f"Model type must be registered for model {model}"

        if not return_field_name:
            return_field_name = to_snake_case(model.__name__)

        arguments = OrderedDict(id=graphene.ID(required=True))

        output_fields = OrderedDict()
        output_fields["found"] = graphene.Boolean()
        output_fields["deleted_input_id"] = graphene.ID()
        output_fields["deleted_id"] = graphene.ID()
        output_fields["deleted_raw_id"] = graphene.ID()

        if _meta is None:
            _meta = DjangoDeleteMutationOptions(cls)

        _meta.model = model
        _meta.model_type = model_type
        _meta.fields = yank_fields_from_attrs(output_fields, _as=graphene.Field)
        _meta.return_field_name = return_field_name
        _meta.permissions = permissions
        _meta.login_required = login_required or (
            _meta.permissions and len(_meta.permissions) > 0
        )

        super().__init_subclass_with_meta__(arguments=arguments, _meta=_meta, **kwargs)
Exemplo n.º 24
0
def resolve_entities(_, info, *, representations):
    max_representations = settings.FEDERATED_QUERY_MAX_ENTITIES
    if max_representations and len(representations) > max_representations:
        representations_count = len(representations)
        raise GraphQLError(
            f"Federated query exceeded entity limit: {representations_count} "
            f"items requested over {max_representations}.")

    resolvers = {}
    for representation in representations:
        if representation["__typename"] not in resolvers:
            try:
                model = federated_entities[representation["__typename"]]
                resolvers[representation["__typename"]] = getattr(
                    model,
                    "_%s__resolve_references" % representation["__typename"])
            except AttributeError:
                pass

    batches = defaultdict(list)
    for representation in representations:
        model = federated_entities[representation["__typename"]]
        model_arguments = representation.copy()
        typename = model_arguments.pop("__typename")
        model_arguments = {
            to_snake_case(k): v
            for k, v in model_arguments.items()
        }
        model_instance = model(**model_arguments)
        batches[typename].append(model_instance)

    entities = []
    for typename, batch in batches.items():
        if typename not in resolvers:
            continue

        resolver = resolvers[typename]
        entities.extend(resolver(batch, info))

    return entities
    def __init__(self, parent_model: 'Any', info: 'ResolveInfo',
                 graphql_args: dict):
        """
        Dataloader for SQLAlchemy model relations.

        Args:
            parent_model: Parent SQLAlchemy model.
            info: Graphene resolve info object.
            graphql_args: Request args: filters, sort, ...

        """
        super().__init__()
        self.info: 'ResolveInfo' = info
        self.graphql_args: dict = graphql_args

        self.parent_model: 'Any' = parent_model
        self.parent_model_pk_field: str = self._get_model_pk_field_name(
            self.parent_model)

        self.model_relation_field: str = to_snake_case(self.info.field_name)

        self.relation: 'Any' = getattr(self.parent_model,
                                       self.model_relation_field)
Exemplo n.º 26
0
    def __init_subclass_with_meta__(
            cls,
            _meta=None,
            model=None,
            permissions=None,
            login_required=None,
            only_fields=(),
            exclude_fields=(),
            return_field_name=None,
            **kwargs,
    ):
        registry = get_global_registry()

        if not return_field_name:
            return_field_name = to_snake_case(model.__name__)

        arguments = OrderedDict(ids=graphene.List(graphene.ID, required=True))

        output_fields = OrderedDict()
        output_fields["deletion_count"] = graphene.Int()
        output_fields["deleted_ids"] = graphene.List(graphene.ID)
        output_fields["missed_ids"] = graphene.List(graphene.ID)

        if _meta is None:
            _meta = DjangoBatchDeleteMutationOptions(cls)

        _meta.model = model
        _meta.fields = yank_fields_from_attrs(output_fields,
                                              _as=graphene.Field)
        _meta.return_field_name = return_field_name
        _meta.permissions = permissions
        _meta.login_required = login_required or (_meta.permissions and
                                                  len(_meta.permissions) > 0)

        super().__init_subclass_with_meta__(arguments=arguments,
                                            _meta=_meta,
                                            **kwargs)
Exemplo n.º 27
0
    def list_resolver(self, filterset_class, filtering_args, root, info,
                      **kwargs):

        filter_kwargs = {
            k: v
            for k, v in kwargs.items() if k in filtering_args
        }
        qs = getattr(root, self.accessor)
        if hasattr(qs, 'all'):
            qs = qs.all()
        qs = filterset_class(data=filter_kwargs,
                             queryset=qs,
                             request=info.context).qs
        count = qs.count()

        if getattr(self, "pagination", None):
            ordering = kwargs.pop(self.pagination.ordering_param,
                                  None) or self.pagination.ordering
            ordering = ','.join([
                to_snake_case(each)
                for each in ordering.strip(',').replace(' ', '').split(',')
            ])
            'pageSize' in kwargs and kwargs['pageSize'] is None and kwargs.pop(
                'pageSize')
            kwargs[self.pagination.ordering_param] = ordering
            qs = self.pagination.paginate_queryset(qs, **kwargs)

        return CustomDjangoListObjectBase(
            count=count,
            results=maybe_queryset(qs),
            results_field_name=self.type._meta.results_field_name,
            page=kwargs.get('page', 1)
            if hasattr(self.pagination, 'page') else None,
            pageSize=kwargs.
            get(  # TODO: Need to add cutoff to send max page size instead of requested
                'pageSize', graphql_api_settings.DEFAULT_PAGE_SIZE) if hasattr(
                    self.pagination, 'page') else None)
Exemplo n.º 28
0
 def reference_resolver(root, *args, **kwargs):
     de_referenced = getattr(root, field.name or field.db_name)
     if de_referenced:
         document = get_document(de_referenced["_cls"])
         document_field = mongoengine.ReferenceField(document)
         document_field = convert_mongoengine_field(document_field, registry)
         _type = document_field.get_type().type
         filter_args = list()
         if _type._meta.filter_fields:
             for key, values in _type._meta.filter_fields.items():
                 for each in values:
                     filter_args.append(key + "__" + each)
         querying_types = list(get_query_fields(args[0]).keys())
         if _type.__name__ in querying_types:
             queried_fields = list()
             for each in get_query_fields(args[0])[_type._meta.name].keys():
                 item = to_snake_case(each)
                 if item in document._fields_ordered + tuple(filter_args):
                     queried_fields.append(item)
             return document.objects().no_dereference().only(*list(
                 set(list(_type._meta.required_fields) + queried_fields))).get(
                 pk=de_referenced["_ref"].id)
         return document
     return None
Exemplo n.º 29
0
def build_schema(allowed_list: List[SchemaRestriction]):
    allowed_queries = chain(*(query_lst
                              for allow, query_lst in _query_registry.items()
                              if allow in allowed_list))
    allowed_mutations = list(
        chain(*(mutation_lst
                for allow, mutation_lst in _mutation_registry.items()
                if allow in allowed_list)))

    class Query(*allowed_queries):
        pass

    if allowed_mutations:
        Mutation = type(
            "Mutation",
            (graphene.ObjectType, ),
            {
                to_snake_case(mutation.__name__): mutation.Field()
                for mutation in allowed_mutations
            },
        )

        return graphene.Schema(query=Query, mutation=Mutation)
    return graphene.Schema(query=Query)
Exemplo n.º 30
0
    def test_experiments(self):
        user_email = "*****@*****.**"
        experiment = NimbusExperimentFactory.create_with_status(
            NimbusExperiment.Status.DRAFT)

        response = self.query(
            """
            query {
                experiments {
                    name
                    slug
                    publicDescription
                }
            }
            """,
            headers={settings.OPENIDC_EMAIL_HEADER: user_email},
        )
        self.assertEqual(response.status_code, 200)
        content = json.loads(response.content)
        experiments = content["data"]["experiments"]
        self.assertEqual(len(experiments), 1)
        for key in experiments[0]:
            self.assertEqual(experiments[0][key],
                             str(getattr(experiment, to_snake_case(key))))
Exemplo n.º 31
0
    def resolve(self, next_, root, info, **args):
        """Middleware resolver; handles field according to operation."""
        # GraphiQL introspection is 'query' but not async
        if getattr(info.operation.name, 'value', None) == 'IntrospectionQuery':
            return next_(root, info, **args)

        if info.operation.operation in STRIP_OPS:
            path_string = f'{info.path}'
            # Needed for child fields that resolve without args.
            # Store arguments of parents as leaves of schema tree from path
            # to respective field.
            # no need to regrow the tree on every subscription push/delta
            if args and path_string not in self.tree_paths:
                grow_tree(self.args_tree, info.path, args)
                self.tree_paths.add(path_string)
            if STRIP_ARG not in args:
                branch = self.args_tree
                for section in info.path:
                    if section not in branch:
                        break
                    branch = branch[section]
                    # Only set if present on branch section
                    if 'leaves' in branch and STRIP_ARG in branch['leaves']:
                        args[STRIP_ARG] = branch['leaves'][STRIP_ARG]

            # Now flag empty fields as null for stripping
            if args.get(STRIP_ARG, False):
                field_name = to_snake_case(info.field_name)

                # Clear field set so recreated via first child field,
                # as path may be a parent.
                # Done here as parent may be in NODE_MAP
                if path_string in self.field_sets:
                    del self.field_sets[path_string]

                # Avoid using the protobuf default if field isn't set.
                if (hasattr(root, 'ListFields') and hasattr(root, field_name)
                        and get_type_str(info.return_type) not in NODE_MAP):

                    # Gather fields set in root
                    parent_path_string = f'{info.path[:-1:]}'
                    stamp = getattr(root, 'stamp', '')
                    if (parent_path_string not in self.field_sets
                            or self.field_sets[parent_path_string]['stamp'] !=
                            stamp):
                        self.field_sets[parent_path_string] = {
                            'stamp': stamp,
                            'fields':
                            {field.name
                             for field, _ in root.ListFields()}
                        }

                    if (parent_path_string in self.field_sets
                            and field_name not in
                            self.field_sets[parent_path_string]['fields']):
                        return None
                # Do not resolve subfields of an empty type
                # by setting as null in parent/root.
                elif (isinstance(root, dict) and field_name in root):
                    field_value = root[field_name]
                    if (field_value in EMPTY_VALUES
                            or (hasattr(field_value, 'ListFields')
                                and not field_value.ListFields())):
                        return None
                if (info.operation.operation in self.ASYNC_OPS
                        or iscoroutinefunction(next_)):
                    return self.async_null_setter(next_, root, info, **args)
                return null_setter(next_(root, info, **args))

        if (info.operation.operation in self.ASYNC_OPS
                or iscoroutinefunction(next_)):
            return self.async_resolve(next_, root, info, **args)
        return next_(root, info, **args)
Exemplo n.º 32
0
    async def subscribe_delta(self, root, info, args):
        """Delta subscription async generator.

        Async generator mapping the incoming protobuf deltas to
        yielded GraphQL subscription objects.

        """
        workflow_ids = set(args.get('workflows', args.get('ids', ())))
        sub_id = uuid4()
        info.context['sub_id'] = sub_id
        self.delta_store[sub_id] = {}
        delta_queues = self.data_store_mgr.delta_queues
        deltas_queue = queue.Queue()
        try:
            # Iterate over the queue yielding deltas
            w_ids = workflow_ids
            sub_resolver = SUB_RESOLVERS.get(to_snake_case(info.field_name))
            interval = args['ignore_interval']
            old_time = time()
            while True:
                if not workflow_ids:
                    old_ids = w_ids
                    w_ids = set(delta_queues.keys())
                    for remove_id in old_ids.difference(w_ids):
                        if remove_id in self.delta_store[sub_id]:
                            del self.delta_store[sub_id][remove_id]
                for w_id in w_ids:
                    if w_id in self.data_store_mgr.data:
                        if sub_id not in delta_queues[w_id]:
                            delta_queues[w_id][sub_id] = deltas_queue
                            # On new yield workflow data-store as added delta
                            if args.get('initial_burst'):
                                delta_store = create_delta_store(
                                    workflow_id=w_id)
                                delta_store[DELTA_ADDED] = (
                                    self.data_store_mgr.data[w_id])
                                self.delta_store[sub_id][w_id] = delta_store
                                if sub_resolver is None:
                                    yield delta_store
                                else:
                                    result = await sub_resolver(
                                        root, info, **args)
                                    if result:
                                        yield result
                    elif w_id in self.delta_store[sub_id]:
                        del self.delta_store[sub_id][w_id]
                try:
                    w_id, topic, delta_store = deltas_queue.get(False)
                    if topic != 'shutdown':
                        new_time = time()
                        elapsed = new_time - old_time
                        # ignore deltas that are more frequent than interval.
                        if elapsed <= interval:
                            continue
                        old_time = new_time
                    else:
                        delta_store['shutdown'] = True
                    self.delta_store[sub_id][w_id] = delta_store
                    if sub_resolver is None:
                        yield delta_store
                    else:
                        result = await sub_resolver(root, info, **args)
                        if result:
                            yield result
                except queue.Empty:
                    await asyncio.sleep(DELTA_SLEEP_INTERVAL)
        except (GeneratorExit, asyncio.CancelledError):
            raise
        except Exception:
            import traceback
            logger.warning(traceback.format_exc())
        finally:
            for w_id in w_ids:
                if delta_queues.get(w_id, {}).get(sub_id):
                    del delta_queues[w_id][sub_id]
            if sub_id in self.delta_store:
                del self.delta_store[sub_id]
            yield None
Exemplo n.º 33
0
 def get_name(cls):
     return to_snake_case(cls.__name__.replace('GraphQLDirective', ''))
Exemplo n.º 34
0
 def resolve(value, directive, root, info, **kwargs):
     value = value if isinstance(value, six.string_types) else str(value)
     return to_snake_case(value.title().replace(' ', ''))