def test_build_a_simple_data_loader(): def call_fn(keys): return Promise.resolve(keys) identity_loader = DataLoader(call_fn) promise1 = identity_loader.load(1) assert isinstance(promise1, Promise) value1 = promise1.get() assert value1 == 1
def test_supports_loading_multiple_keys_in_one_call(): def call_fn(keys): return Promise.resolve(keys) identity_loader = DataLoader(call_fn) promise_all = identity_loader.load_many([1, 2]) assert isinstance(promise_all, Promise) values = promise_all.get() assert values == [1, 2] promise_all = identity_loader.load_many([]) assert isinstance(promise_all, Promise) values = promise_all.get() assert values == []
def id_loader(**options): load_calls = [] resolve = options.pop('resolve', Promise.resolve) def fn(keys): load_calls.append(keys) return resolve(keys) identity_loader = DataLoader(fn, **options) return identity_loader, load_calls
def __init__(self, rest_object_class, source_field_name='id', filter_field_name='id', is_top_level=False, many=False, *args, **kwargs): assert is_top_level or not (source_field_name == 'id' and filter_field_name == 'id') self.source_field_name = source_field_name self.filter_field_name = filter_field_name self.rest_object_class = rest_object_class self.is_top_level = is_top_level self.request_maker = RequestMaker( filter_by_parent_fields=(not is_top_level), filter_field_name=filter_field_name ) def batch_load_fn(source_values): self.request_maker.filter_values = source_values response = self.request_maker.make_request() return Promise.resolve(response.json()['results']) self.data_loader = DataLoader(batch_load_fn) self.many = many if self.many: super().__init__(graphene.List(rest_object_class), *args, **kwargs) else: super().__init__(rest_object_class, *args, **kwargs)
def GeoFromCoordsLoaderFactory(): ''' Get geographic information for this coordinate. GeoLoaderFactory({...}) .load({ latitude: y, longitude: x}) GeoLoaderFactory({...}) .load_many([{ latitude: y, longitude: y}, { latitude: y, longitude: x}]) ''' def get_cache_key(key): try: return f"{key.get('latitude')}::{key.get('longitude')}" except: return key def initialise_graphql_type(row): if row is None: return row return create_graphql_instance( GeocodeResult, **row ) def batch_load_fields(coordinates): geos = bulk_coordinate_geo(coordinates) if geos is None or len(geos) == 0: return Promise.resolve([None for key in coordinates]) key_map = [ next(( initialise_graphql_type(geo.get('result')) if geo.get('result') else None for geo in geos if geo['query'].get('latitude') == coord.get('latitude') and geo['query'].get('longitude') == coord.get('longitude') ), None) for coord in coordinates ] return Promise.resolve(key_map) return DataLoader(batch_load_fields, get_cache_key=get_cache_key)
def GeoFromPostcodesLoaderFactory(): ''' Get geographic information for this postcode. GeoLoaderFactory({...}) .load("LE115AG") GeoLoaderFactory({...}) .load_many(["LE115AG", "WDD351"]) ''' def initialise_graphql_type(row): if row is None: return row return create_graphql_instance( GeocodeResult, **row ) def batch_load_fields(postcodes): postcodes = [p.replace(" ", "") for p in postcodes] if len(postcodes) == 1: geo = postcode_geo(postcodes[0]) return Promise.resolve([geo]) geos = bulk_postcode_geo(postcodes) if geos is None or len(geos) == 0: return Promise.resolve([None for key in postcodes]) key_map = [ next(( initialise_graphql_type(geo.get('result')) if geo.get('result') else None for geo in geos if geo['query'].replace(" ", "") == postcode ), None) for postcode in postcodes ] return Promise.resolve(key_map) return DataLoader(batch_load_fields)
def test_should_query_dataloader_fields(): from promise import Promise from promise.dataloader import DataLoader def article_batch_load_fn(keys): queryset = Article.objects.filter(reporter_id__in=keys) return Promise.resolve( [[article for article in queryset if article.reporter_id == id] for id in keys]) article_loader = DataLoader(article_batch_load_fn) class ArticleType(DjangoObjectType): class Meta: model = Article interfaces = (Node, ) class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node, ) use_connection = True articles = DjangoConnectionField(ArticleType) def resolve_articles(self, info, **args): return article_loader.load(self.id) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create(first_name="John", last_name="Doe", email="*****@*****.**", a_choice=1) Article.objects.create( headline="Article Node 1", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, lang="es", ) Article.objects.create( headline="Article Node 2", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, lang="en", ) schema = graphene.Schema(query=Query) query = """ query ReporterPromiseConnectionQuery { allReporters(first: 1) { edges { node { id articles(first: 2) { edges { node { headline } } } } } } } """ expected = { "allReporters": { "edges": [{ "node": { "id": "UmVwb3J0ZXJUeXBlOjE=", "articles": { "edges": [ { "node": { "headline": "Article Node 1" } }, { "node": { "headline": "Article Node 2" } }, ] }, } }] } } result = schema.execute(query) assert not result.errors assert result.data == expected
def create_data_loader(): return DataLoader( batch_load_fn=batch_load_fn, max_batch_size=500, )
def get_for_model(model): """Create dataloader for model. """ return DataLoader(_get_model_batch_load_fn(model), get_cache_key=_get_model_cache_key)
def test_should_query_dataloader_fields(): from promise import Promise from promise.dataloader import DataLoader def article_batch_load_fn(keys): queryset = Article.objects.filter(reporter_id__in=keys) return Promise.resolve([ [article for article in queryset if article.reporter_id == id] for id in keys ]) article_loader = DataLoader(article_batch_load_fn) class ArticleType(DjangoObjectType): class Meta: model = Article interfaces = (Node, ) class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node, ) use_connection = True articles = DjangoConnectionField(ArticleType) def resolve_articles(self, info, **args): return article_loader.load(self.id) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( first_name='John', last_name='Doe', email='*****@*****.**', a_choice=1 ) Article.objects.create( headline='Article Node 1', pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, lang='es' ) Article.objects.create( headline='Article Node 2', pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, lang='en' ) schema = graphene.Schema(query=Query) query = ''' query ReporterPromiseConnectionQuery { allReporters(first: 1) { edges { node { id articles(first: 2) { edges { node { headline } } } } } } } ''' expected = { 'allReporters': { 'edges': [{ 'node': { 'id': 'UmVwb3J0ZXJUeXBlOjE=', 'articles': { 'edges': [{ 'node': { 'headline': 'Article Node 1', } }, { 'node': { 'headline': 'Article Node 2' } }] } } }] } } result = schema.execute(query) assert not result.errors assert result.data == expected
key = 1 keys = [1, 2, 3] class InnerClass(object): key = 2 keys = [4, 5, 6] def resolver(self): return "resolver method" def batch_load_fn(keys): return Promise.all(keys) data_loader = DataLoader(batch_load_fn=batch_load_fn) class PermissionFieldTests(TestCase): def test_permission_field(self): MyType = object() field = DjangoField(MyType, permissions=["perm1", "perm2"], source="resolver") resolver = field.get_resolver(None) class Viewer(object): def has_perm(self, perm): return perm == "perm2" info = mock.Mock(context=mock.Mock(user=Viewer())) self.assertEqual(resolver(MyInstance(), info), MyInstance().resolver())
def __init__(self, keys: List[str], labbook: LabBook, username: str): DataLoader.__init__(self) self.keys = keys self.latest_versions = dict() self.labbook = labbook self.username = username
class ExternalRESTField(graphene.Field): def __init__(self, rest_object_class, source_field_name='id', filter_field_name='id', is_top_level=False, many=False, *args, **kwargs): assert is_top_level or not (source_field_name == 'id' and filter_field_name == 'id') self.source_field_name = source_field_name self.filter_field_name = filter_field_name self.rest_object_class = rest_object_class self.is_top_level = is_top_level self.request_maker = RequestMaker( filter_by_parent_fields=(not is_top_level), filter_field_name=filter_field_name ) def batch_load_fn(source_values): self.request_maker.filter_values = source_values response = self.request_maker.make_request() return Promise.resolve(response.json()['results']) self.data_loader = DataLoader(batch_load_fn) self.many = many if self.many: super().__init__(graphene.List(rest_object_class), *args, **kwargs) else: super().__init__(rest_object_class, *args, **kwargs) def get_resolver(self, parent_resolver): if self.resolver: return self.resolver else: return self.generate_resolver(get_actual_object_class(self.rest_object_class)) def generate_resolver(self, rest_object_class, *class_args, **class_kwargs): def endpoint_resolver_promise(parent_object, results): relevant_results = list(filter(lambda h: equals_or_contains(h[self.filter_field_name], getattr(parent_object, self.source_field_name)), results)) if not self.many: assert len(relevant_results) == 1 relevant_results = relevant_results[0] obj = reduce_fields_to_objects(rest_object_class, relevant_results, is_list=self.many) return obj def endpoint_resolver(parent_object, args, context, info): # This is called for every parent object where we want nested objects. # Therefore we don't want to do unnecessary computation (ex: # processing query params/headers from the original request) # Instead, we do initial processing in request_maker.initialize_x # and final processing in request_maker.generate_x self.request_maker.headers = context.headers self.request_maker.data = context.data self.request_maker.base_url = rest_object_class.base_url self.request_maker.query_string = context.query_string self.request_maker.graphql_arguments = args if self.is_top_level: response = self.request_maker.make_request() return reduce_fields_to_objects(rest_object_class, response.json()['results']) else: source_values = getattr(parent_object, self.source_field_name) if not is_non_str_iterable(source_values): source_values = [source_values] result = self.data_loader.load_many(source_values) return result.then( functools.partial(endpoint_resolver_promise, parent_object) ) return endpoint_resolver
# print('ConnectionClass.Edge : ', args, kwargs) # super().__init__(*args, **kwargs) # other = graphene.String() from promise import Promise from promise.dataloader import DataLoader def article_batch_load_fn(keys): queryset = models.TodoExtra.objects.filter(pk__in=keys) return Promise.resolve( [[article for article in queryset if article.reporter_id == id] for id in keys]) article_loader = DataLoader(article_batch_load_fn) @test_dec2 class TodoExtraNode(DjangoObjectType): class Meta: model = models.TodoExtra exclude = () # fields = ['description'] filter_fields = {} interfaces = (graphene.relay.Node, ) @classmethod def get_queryset(cls, queryset, info): # TodoNode.extraからアクセスするときに呼び出されないのはなぜ・・・?? # manyの場合にしかconnectionは張られない・・・