def find_field(field, fields_dict): temp = fields_dict.get( field.name.value, fields_dict.get( to_snake_case(field.name.value), None) ) return temp
def to_snake_case_plus(name: str) -> str: """Extends the `to_snake_case` function to account for numbers in the name separated by underscores. Args: name (str): The camel-case name to be converted to snake-case. Returns str: The converted name. """ name_new = to_snake_case(name=name) for entry in list(set(re.findall("\d+", name_new))): name_new = name_new.replace(entry, "_{}".format(entry)) return name_new
def get_queryset(cls, info, *_, **__): queryset = cls._meta.model.objects.all() fields = cls.select_foreign_keys() + cls.select_o2o_related_objects() fields_m2m = cls.select_m2m_fields() selections = cls.get_selections(info) fields_to_select = cls.convert_selections_to_fields(selections, info) for field_to_select in fields_to_select: field_to_select = to_snake_case(field_to_select) if field_to_select in fields: queryset = queryset.select_related(field_to_select) if field_to_select in fields_m2m: queryset = queryset.prefetch_related(field_to_select) return queryset
def recursive_params(selection_set, fragments, available_related_fields, select_related, prefetch_related): for field in selection_set.selections: if isinstance(field, FragmentSpread) and fragments: a, b = recursive_params(fragments[field.name.value].selection_set, fragments, available_related_fields, select_related, prefetch_related) [select_related.append(x) for x in a if x not in select_related] [ prefetch_related.append(x) for x in b if x not in prefetch_related ] continue if isinstance(field, InlineFragment): a, b = recursive_params(field.selection_set, fragments, available_related_fields, select_related, prefetch_related) [select_related.append(x) for x in a if x not in select_related] [ prefetch_related.append(x) for x in b if x not in prefetch_related ] continue temp = available_related_fields.get( field.name.value, available_related_fields.get(to_snake_case(field.name.value), None)) if temp and temp.name not in [prefetch_related + select_related]: if temp.many_to_many or temp.one_to_many: prefetch_related.append(temp.name) else: select_related.append(temp.name) elif getattr(field, 'selection_set', None): a, b = recursive_params(field.selection_set, fragments, available_related_fields, select_related, prefetch_related) [select_related.append(x) for x in a if x not in select_related] [ prefetch_related.append(x) for x in b if x not in prefetch_related ] return select_related, prefetch_related
def test_create_comment(schema, author, success): context = MagicMock() comment = next(fake_comment()) error = Err(fake.pystr()) query = """ mutation createComment($meetingId: UUID!, $contentText: String!) { createComment(meetingId: $meetingId, contentText: $contentText) { commentId meetingId author { ... on Instructor { userName } ... on Student { studentNumber } } timeStamp contentText } } """ context.api.meeting_api.create_comment.return_value = (Ok(comment) if success else error) variables = { "meetingId": str(comment.meeting_id), "contentText": comment.content_text, } result = schema.execute(query, context=context, variables=variables) if success: assert not result.errors for key, val in result.data["createComment"].items(): if key == "author": if isinstance(comment.author, Instructor): assert val["userName"] == comment.author.user_name else: assert val[ "studentNumber"] == comment.author.student_number else: expected = getattr(comment, to_snake_case(key)) if isinstance(expected, UUID): expected = str(expected) assert expected == val else: assert result.errors assert error.unwrap_err() in str(result.errors) context.api.meeting_api.create_comment.assert_called_once_with( comment.meeting_id, context.user, comment.content_text)
def set_fields_and_attrs( klazz: Type[ObjectType], node_model: Type[SQLAlchemyInterface], field_dict: Mapping[str, Field], ): _name = to_snake_case(node_model.__name__) field_dict[ f"all_{(pluralize_name(_name))}"] = SQLAlchemyFilteredConnectionField( node_model) field_dict[_name] = node_model.Field() setattr(klazz, _name, node_model.Field()) setattr( klazz, "all_{}".format(pluralize_name(_name)), SQLAlchemyFilteredConnectionField(node_model), )
def get_reference_objects(*args, **kwargs): if args[0][1]: document = get_document(args[0][0]) document_field = mongoengine.ReferenceField(document) document_field = convert_mongoengine_field( document_field, registry) document_field_type = document_field.get_type( ).type._meta.name only_fields = [ to_snake_case(i) for i in get_query_fields( args[0][3][0])[document_field_type].keys() ] return document.objects().no_dereference().only( *only_fields).filter(pk__in=args[0][1]) else: return []
def reference_resolver(root, *args, **kwargs): dereferenced = getattr(root, field.name or field.db_name) if dereferenced: document = get_document(dereferenced["_cls"]) document_field = mongoengine.ReferenceField(document) document_field = convert_mongoengine_field(document_field, registry) _type = document_field.get_type().type only_fields = _type._meta.only_fields.split(",") if isinstance( _type._meta.only_fields, str) else list() return document.objects().no_dereference().only(*list( set(only_fields + [ to_snake_case(i) for i in get_query_fields(args[0])[ _type._meta.name].keys() ]))).get(pk=dereferenced["_ref"].id) return None
def resolver_query(g_object, m_object, args, info, is_list=False, validator=None): fields = [to_snake_case(f) for f in get_fields(info)] query = parse_operators(args) special_params = {k: v for k, v in args.items() if k in ['skip', 'limit']} if validator: validator(m_object, fields, query, special_params) result = do_query(m_object, query, fields, special_params, is_list) if result and is_list: return [mongo_to_graphene(obj, g_object, fields) for obj in result] elif result: return mongo_to_graphene(result.first(), g_object, fields) else: return [] if is_list else None
def resolve_queryset(cls, connection, iterable, info, args, filtering_args, filterset_class): qs = super().resolve_queryset( connection, iterable, info, args, filtering_args, filterset_class, ) order = args.pop("orderby", None) or [] if order: qs = qs.order_by(*[to_snake_case(o) for o in order]) return qs
def get_reference_objects(*args, **kwargs): document = get_document(args[0][0]) document_field = mongoengine.ReferenceField(document) document_field = convert_mongoengine_field(document_field, registry) document_field_type = document_field.get_type().type queried_fields = list() filter_args = list() if document_field_type._meta.filter_fields: for key, values in document_field_type._meta.filter_fields.items(): for each in values: filter_args.append(key + "__" + each) for each in get_query_fields(args[0][3][0])[document_field_type._meta.name].keys(): item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) return document.objects().no_dereference().only( *set(list(document_field_type._meta.required_fields) + queried_fields)).filter(pk__in=args[0][1])
def gen_mutation(model, graphene_schema, operators_mutation, fields_mutation, mutate_func, validator): """ We need to create a class that seems as follows (http://docs.graphene-python.org/en/latest/types/mutations/): class CreatePerson(graphene.Mutation): class Input: name = graphene.String() ok = graphene.Boolean() person = graphene.Field(lambda: Person) @staticmethod def mutate(root, args, context, info): person = Person(name=args.get('name')) ok = True return CreatePerson(person=person, ok=ok) """ def user_mutate(root, info, **kwargs): if validator: validator(model, kwargs, {}, {}) obj = mutate_func(kwargs, info.context) if not isinstance(obj, model): raise TypeError('Failed to resolve mutation of the schema {}' ' because mutate function must return a instance of {}, and the return type was {}.' .format(graphene_schema.__name__, model.__name__, type(obj))) graphene_obj = mongo_to_graphene(obj, graphene_schema, fields_mutation) return Create(**{to_snake_case(model.__name__): graphene_obj}) def generic_mutate(root, info, **kwargs): if validator: validator(model, kwargs, {}, {}) obj = model(**kwargs) obj.save() graphene_obj = mongo_to_graphene(obj, graphene_schema, fields_mutation) return Create(**{to_snake_case(model.__name__): graphene_obj}) Create = type('Create' + model.__name__, (graphene.Mutation,), { 'Arguments': type('Arguments', (), operators_mutation), to_snake_case(model.__name__): graphene.Field(lambda: graphene_schema), 'mutate': staticmethod(generic_mutate) if not mutate_func else staticmethod(user_mutate) }) return Create
def order_queryset(qs, node_class, order_by): if order_by is None: return qs orderable_fields = node_class.ORDERABLE_FIELDS if order_by[0] == '-': desc = '-' order_by = order_by[1:] else: desc = '' order_by = to_snake_case(order_by) if order_by not in orderable_fields: raise ValueError( 'Only orderable fields are: %s' % ', '.join([to_camel_case(x) for x in orderable_fields])) qs = qs.order_by(desc + order_by) return qs
def get_queryset(self, root, field_name, field_asts, fragments, **kwargs): prefetched = get_prefetched_attr(root, to_snake_case(field_name)) if prefetched: return prefetched filter_kwargs = { k: v for k, v in kwargs.items() if k in self.filtering_args } qs = queryset_factory( registry.get_type_for_model(self.type._meta.model), field_asts, fragments, **kwargs) qs = self.filterset_class(data=filter_kwargs, queryset=qs).qs if self.post_optimize: qs = self.post_optimize(qs, **kwargs) return maybe_queryset(qs)
def lazy_resolver(root, *args, **kwargs): document = getattr(root, field.name or field.db_name) if document: queried_fields = list() _type = registry.get_type_for_model(document.document_type) filter_args = list() if _type._meta.filter_fields: for key, values in _type._meta.filter_fields.items(): for each in values: filter_args.append(key + "__" + each) for each in get_query_fields(args[0]).keys(): item = to_snake_case(each) if item in document.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) return document.document_type.objects().no_dereference().only( *(set((list(_type._meta.required_fields) + queried_fields)))).get( pk=document.pk) return None
def default_resolver(self, _root, info, only_fields=list(), **args): args = args or {} if _root is not None: field_name = to_snake_case(info.field_name) if getattr(_root, field_name, []) is not None: args["pk__in"] = [r.id for r in getattr(_root, field_name, [])] connection_args = { "first": args.pop("first", None), "last": args.pop("last", None), "before": args.pop("before", None), "after": args.pop("after", None), } _id = args.pop('id', None) if _id is not None: args['pk'] = from_global_id(_id)[-1] if callable(getattr(self.model, "objects", None)): iterables = self.get_queryset(self.model, info, only_fields, **args) if isinstance(info, ResolveInfo): if not info.context: info.context = Context() info.context.queryset = iterables list_length = iterables.count() else: iterables = [] list_length = 0 connection = connection_from_list_slice( list_slice=iterables, args=connection_args, list_length=list_length, list_slice_length=list_length, connection_type=self.type, edge_type=self.type.Edge, pageinfo_type=graphene.PageInfo, ) connection.iterable = iterables connection.list_length = list_length return connection
async def get_node_by_id(root, info, **args): """Resolver for returning job, task, family node""" field_name = to_snake_case(info.field_name) if field_name == 'source_node': field_id = getattr(root, 'source', None) elif field_name == 'target_node': field_id = getattr(root, 'target', None) else: field_id = getattr(root, field_name, None) if field_id: args['id'] = field_id if args.get('id', None) is None: return None try: obj_type = str(info.return_type.of_type).replace('!', '') except AttributeError: obj_type = str(info.return_type) resolvers = info.context.get('resolvers') return await resolvers.get_node_by_id(NODE_MAP[obj_type], args)
def recursive_params(selection_set, fragments, available_related_fields, select_related, prefetch_related): for field in selection_set.selections: if isinstance(field, FragmentSpread) and fragments: a, b = recursive_params( fragments[field.name.value].selection_set, fragments, available_related_fields, select_related, prefetch_related ) [select_related.append(x) for x in a if x not in select_related] [prefetch_related.append(x) for x in b if x not in prefetch_related] continue temp = available_related_fields.get( field.name.value, available_related_fields.get( to_snake_case(field.name.value), None) ) if temp and temp.name not in [prefetch_related + select_related]: if temp.many_to_many or temp.one_to_many: prefetch_related.append(temp.name) else: select_related.append(temp.name) elif getattr(field, 'selection_set', None): a, b = recursive_params( field.selection_set, fragments, available_related_fields, select_related, prefetch_related ) [select_related.append(x) for x in a if x not in select_related] [prefetch_related.append(x) for x in b if x not in prefetch_related] return select_related, prefetch_related
def test_gen_mutation(mock_person): import graphene from graphene.utils.str_converters import to_snake_case from graphene.types.field import Field from graphene_mongodb.mutation import gen_mutation from graphene_mongodb.model import ModelSchema model_schema = ModelSchema(mock_person, mock_person._fields, None, None) result = gen_mutation(mock_person, model_schema.schema, model_schema.operators_mutation, model_schema.fields_mutation, None, None) assert issubclass(result, graphene.Mutation) assert hasattr(result, 'mutate') assert result._meta.name == 'Create' + mock_person.__name__ assert isinstance(result._meta.fields[to_snake_case(mock_person.__name__)], Field) assert result._meta.arguments == model_schema.operators_mutation
def list_resolver(self, manager, filterset_class, filtering_args, root, info, **kwargs): filter_kwargs = { k: v for k, v in kwargs.items() if k in filtering_args } if self.accessor: qs = getattr(root, self.accessor).all() qs = filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs else: qs = self.get_queryset(manager, info, **kwargs) qs = filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs if root and is_valid_django_model(root._meta.model): extra_filters = get_extra_filters(root, manager.model) qs = qs.filter(**extra_filters) count = qs.count() if getattr(self, "pagination", None): ordering = kwargs.pop(self.pagination.ordering_param, None) or self.pagination.ordering ordering = ','.join([ to_snake_case(each) for each in ordering.strip(',').replace(' ', '').split(',') ]) self.pagination.ordering = ordering qs = self.pagination.paginate_queryset(qs, **kwargs) return CustomDjangoListObjectBase( count=count, results=maybe_queryset(qs), results_field_name=self.type._meta.results_field_name, page=kwargs.get('page', 1) if hasattr(self.pagination, 'page') else None, pageSize=kwargs.get( 'pageSize', graphql_api_settings.DEFAULT_PAGE_SIZE) if hasattr( self.pagination, 'page') else None)
async def get_nodes_by_id(root, info, **args): """Resolver for returning job, task, family node""" field_name = to_snake_case(info.field_name) field_ids = getattr(root, field_name, None) if hasattr(args, 'id'): args['ids'] = [args.get('id')] if field_ids: if isinstance(field_ids, str): field_ids = [field_ids] args['native_ids'] = field_ids elif field_ids == []: return [] try: obj_type = str(info.return_type.of_type).replace('!', '') except AttributeError: obj_type = str(info.return_type) node_type = NODE_MAP[obj_type] args['ids'] = [parse_node_id(n_id, node_type) for n_id in args['ids']] args['exids'] = [parse_node_id(n_id, node_type) for n_id in args['exids']] resolvers = info.context.get('resolvers') return await resolvers.get_nodes_by_id(node_type, args)
def resolve_entities(parent, info, representations): entities = [] for representation in representations: model = custom_entities[representation["__typename"]] model_aguments = representation.copy() model_aguments.pop("__typename") # todo use schema to identify correct mapping for field names if auto_camelcase: model_aguments = {to_snake_case(k): v for k, v in model_aguments.items()} model_instance = model(**model_aguments) try: resolver = getattr( model, "_%s__resolve_reference" % representation["__typename"]) except AttributeError: pass else: model_instance = resolver(model_instance, info) entities.append(model_instance) return entities
def __init_subclass_with_meta__( cls, _meta=None, model=None, permissions=None, login_required=None, only_fields=(), exclude_fields=(), return_field_name=None, **kwargs, ): registry = get_global_registry() model_type = registry.get_type_for_model(model) assert model_type, f"Model type must be registered for model {model}" if not return_field_name: return_field_name = to_snake_case(model.__name__) arguments = OrderedDict(id=graphene.ID(required=True)) output_fields = OrderedDict() output_fields["found"] = graphene.Boolean() output_fields["deleted_input_id"] = graphene.ID() output_fields["deleted_id"] = graphene.ID() output_fields["deleted_raw_id"] = graphene.ID() if _meta is None: _meta = DjangoDeleteMutationOptions(cls) _meta.model = model _meta.model_type = model_type _meta.fields = yank_fields_from_attrs(output_fields, _as=graphene.Field) _meta.return_field_name = return_field_name _meta.permissions = permissions _meta.login_required = login_required or ( _meta.permissions and len(_meta.permissions) > 0 ) super().__init_subclass_with_meta__(arguments=arguments, _meta=_meta, **kwargs)
def resolve_entities(_, info, *, representations): max_representations = settings.FEDERATED_QUERY_MAX_ENTITIES if max_representations and len(representations) > max_representations: representations_count = len(representations) raise GraphQLError( f"Federated query exceeded entity limit: {representations_count} " f"items requested over {max_representations}.") resolvers = {} for representation in representations: if representation["__typename"] not in resolvers: try: model = federated_entities[representation["__typename"]] resolvers[representation["__typename"]] = getattr( model, "_%s__resolve_references" % representation["__typename"]) except AttributeError: pass batches = defaultdict(list) for representation in representations: model = federated_entities[representation["__typename"]] model_arguments = representation.copy() typename = model_arguments.pop("__typename") model_arguments = { to_snake_case(k): v for k, v in model_arguments.items() } model_instance = model(**model_arguments) batches[typename].append(model_instance) entities = [] for typename, batch in batches.items(): if typename not in resolvers: continue resolver = resolvers[typename] entities.extend(resolver(batch, info)) return entities
def __init__(self, parent_model: 'Any', info: 'ResolveInfo', graphql_args: dict): """ Dataloader for SQLAlchemy model relations. Args: parent_model: Parent SQLAlchemy model. info: Graphene resolve info object. graphql_args: Request args: filters, sort, ... """ super().__init__() self.info: 'ResolveInfo' = info self.graphql_args: dict = graphql_args self.parent_model: 'Any' = parent_model self.parent_model_pk_field: str = self._get_model_pk_field_name( self.parent_model) self.model_relation_field: str = to_snake_case(self.info.field_name) self.relation: 'Any' = getattr(self.parent_model, self.model_relation_field)
def __init_subclass_with_meta__( cls, _meta=None, model=None, permissions=None, login_required=None, only_fields=(), exclude_fields=(), return_field_name=None, **kwargs, ): registry = get_global_registry() if not return_field_name: return_field_name = to_snake_case(model.__name__) arguments = OrderedDict(ids=graphene.List(graphene.ID, required=True)) output_fields = OrderedDict() output_fields["deletion_count"] = graphene.Int() output_fields["deleted_ids"] = graphene.List(graphene.ID) output_fields["missed_ids"] = graphene.List(graphene.ID) if _meta is None: _meta = DjangoBatchDeleteMutationOptions(cls) _meta.model = model _meta.fields = yank_fields_from_attrs(output_fields, _as=graphene.Field) _meta.return_field_name = return_field_name _meta.permissions = permissions _meta.login_required = login_required or (_meta.permissions and len(_meta.permissions) > 0) super().__init_subclass_with_meta__(arguments=arguments, _meta=_meta, **kwargs)
def list_resolver(self, filterset_class, filtering_args, root, info, **kwargs): filter_kwargs = { k: v for k, v in kwargs.items() if k in filtering_args } qs = getattr(root, self.accessor) if hasattr(qs, 'all'): qs = qs.all() qs = filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs count = qs.count() if getattr(self, "pagination", None): ordering = kwargs.pop(self.pagination.ordering_param, None) or self.pagination.ordering ordering = ','.join([ to_snake_case(each) for each in ordering.strip(',').replace(' ', '').split(',') ]) 'pageSize' in kwargs and kwargs['pageSize'] is None and kwargs.pop( 'pageSize') kwargs[self.pagination.ordering_param] = ordering qs = self.pagination.paginate_queryset(qs, **kwargs) return CustomDjangoListObjectBase( count=count, results=maybe_queryset(qs), results_field_name=self.type._meta.results_field_name, page=kwargs.get('page', 1) if hasattr(self.pagination, 'page') else None, pageSize=kwargs. get( # TODO: Need to add cutoff to send max page size instead of requested 'pageSize', graphql_api_settings.DEFAULT_PAGE_SIZE) if hasattr( self.pagination, 'page') else None)
def reference_resolver(root, *args, **kwargs): de_referenced = getattr(root, field.name or field.db_name) if de_referenced: document = get_document(de_referenced["_cls"]) document_field = mongoengine.ReferenceField(document) document_field = convert_mongoengine_field(document_field, registry) _type = document_field.get_type().type filter_args = list() if _type._meta.filter_fields: for key, values in _type._meta.filter_fields.items(): for each in values: filter_args.append(key + "__" + each) querying_types = list(get_query_fields(args[0]).keys()) if _type.__name__ in querying_types: queried_fields = list() for each in get_query_fields(args[0])[_type._meta.name].keys(): item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) return document.objects().no_dereference().only(*list( set(list(_type._meta.required_fields) + queried_fields))).get( pk=de_referenced["_ref"].id) return document return None
def build_schema(allowed_list: List[SchemaRestriction]): allowed_queries = chain(*(query_lst for allow, query_lst in _query_registry.items() if allow in allowed_list)) allowed_mutations = list( chain(*(mutation_lst for allow, mutation_lst in _mutation_registry.items() if allow in allowed_list))) class Query(*allowed_queries): pass if allowed_mutations: Mutation = type( "Mutation", (graphene.ObjectType, ), { to_snake_case(mutation.__name__): mutation.Field() for mutation in allowed_mutations }, ) return graphene.Schema(query=Query, mutation=Mutation) return graphene.Schema(query=Query)
def test_experiments(self): user_email = "*****@*****.**" experiment = NimbusExperimentFactory.create_with_status( NimbusExperiment.Status.DRAFT) response = self.query( """ query { experiments { name slug publicDescription } } """, headers={settings.OPENIDC_EMAIL_HEADER: user_email}, ) self.assertEqual(response.status_code, 200) content = json.loads(response.content) experiments = content["data"]["experiments"] self.assertEqual(len(experiments), 1) for key in experiments[0]: self.assertEqual(experiments[0][key], str(getattr(experiment, to_snake_case(key))))
def resolve(self, next_, root, info, **args): """Middleware resolver; handles field according to operation.""" # GraphiQL introspection is 'query' but not async if getattr(info.operation.name, 'value', None) == 'IntrospectionQuery': return next_(root, info, **args) if info.operation.operation in STRIP_OPS: path_string = f'{info.path}' # Needed for child fields that resolve without args. # Store arguments of parents as leaves of schema tree from path # to respective field. # no need to regrow the tree on every subscription push/delta if args and path_string not in self.tree_paths: grow_tree(self.args_tree, info.path, args) self.tree_paths.add(path_string) if STRIP_ARG not in args: branch = self.args_tree for section in info.path: if section not in branch: break branch = branch[section] # Only set if present on branch section if 'leaves' in branch and STRIP_ARG in branch['leaves']: args[STRIP_ARG] = branch['leaves'][STRIP_ARG] # Now flag empty fields as null for stripping if args.get(STRIP_ARG, False): field_name = to_snake_case(info.field_name) # Clear field set so recreated via first child field, # as path may be a parent. # Done here as parent may be in NODE_MAP if path_string in self.field_sets: del self.field_sets[path_string] # Avoid using the protobuf default if field isn't set. if (hasattr(root, 'ListFields') and hasattr(root, field_name) and get_type_str(info.return_type) not in NODE_MAP): # Gather fields set in root parent_path_string = f'{info.path[:-1:]}' stamp = getattr(root, 'stamp', '') if (parent_path_string not in self.field_sets or self.field_sets[parent_path_string]['stamp'] != stamp): self.field_sets[parent_path_string] = { 'stamp': stamp, 'fields': {field.name for field, _ in root.ListFields()} } if (parent_path_string in self.field_sets and field_name not in self.field_sets[parent_path_string]['fields']): return None # Do not resolve subfields of an empty type # by setting as null in parent/root. elif (isinstance(root, dict) and field_name in root): field_value = root[field_name] if (field_value in EMPTY_VALUES or (hasattr(field_value, 'ListFields') and not field_value.ListFields())): return None if (info.operation.operation in self.ASYNC_OPS or iscoroutinefunction(next_)): return self.async_null_setter(next_, root, info, **args) return null_setter(next_(root, info, **args)) if (info.operation.operation in self.ASYNC_OPS or iscoroutinefunction(next_)): return self.async_resolve(next_, root, info, **args) return next_(root, info, **args)
async def subscribe_delta(self, root, info, args): """Delta subscription async generator. Async generator mapping the incoming protobuf deltas to yielded GraphQL subscription objects. """ workflow_ids = set(args.get('workflows', args.get('ids', ()))) sub_id = uuid4() info.context['sub_id'] = sub_id self.delta_store[sub_id] = {} delta_queues = self.data_store_mgr.delta_queues deltas_queue = queue.Queue() try: # Iterate over the queue yielding deltas w_ids = workflow_ids sub_resolver = SUB_RESOLVERS.get(to_snake_case(info.field_name)) interval = args['ignore_interval'] old_time = time() while True: if not workflow_ids: old_ids = w_ids w_ids = set(delta_queues.keys()) for remove_id in old_ids.difference(w_ids): if remove_id in self.delta_store[sub_id]: del self.delta_store[sub_id][remove_id] for w_id in w_ids: if w_id in self.data_store_mgr.data: if sub_id not in delta_queues[w_id]: delta_queues[w_id][sub_id] = deltas_queue # On new yield workflow data-store as added delta if args.get('initial_burst'): delta_store = create_delta_store( workflow_id=w_id) delta_store[DELTA_ADDED] = ( self.data_store_mgr.data[w_id]) self.delta_store[sub_id][w_id] = delta_store if sub_resolver is None: yield delta_store else: result = await sub_resolver( root, info, **args) if result: yield result elif w_id in self.delta_store[sub_id]: del self.delta_store[sub_id][w_id] try: w_id, topic, delta_store = deltas_queue.get(False) if topic != 'shutdown': new_time = time() elapsed = new_time - old_time # ignore deltas that are more frequent than interval. if elapsed <= interval: continue old_time = new_time else: delta_store['shutdown'] = True self.delta_store[sub_id][w_id] = delta_store if sub_resolver is None: yield delta_store else: result = await sub_resolver(root, info, **args) if result: yield result except queue.Empty: await asyncio.sleep(DELTA_SLEEP_INTERVAL) except (GeneratorExit, asyncio.CancelledError): raise except Exception: import traceback logger.warning(traceback.format_exc()) finally: for w_id in w_ids: if delta_queues.get(w_id, {}).get(sub_id): del delta_queues[w_id][sub_id] if sub_id in self.delta_store: del self.delta_store[sub_id] yield None
def get_name(cls): return to_snake_case(cls.__name__.replace('GraphQLDirective', ''))
def resolve(value, directive, root, info, **kwargs): value = value if isinstance(value, six.string_types) else str(value) return to_snake_case(value.title().replace(' ', ''))