Пример #1
0
def test_dataloader_thread_safety():
    """
    Dataloader should only batch `load` calls that happened on the same thread.
    
    Here we assert that `load` calls on thread 2 are not batched on thread 1 as
    thread 1 batches its own `load` calls.
    """
    def load_many(keys):
        thead_name = threading.current_thread().getName()
        return Promise.resolve([thead_name for key in keys])

    thread_name_loader = DataLoader(load_many)

    event_1 = threading.Event()
    event_2 = threading.Event()
    event_3 = threading.Event()

    assert_object = {
        'is_same_thread_1': True,
        'is_same_thread_2': True,
    }

    def task_1():
        @Promise.safe
        def do():
            promise = thread_name_loader.load(1)
            event_1.set()
            event_2.wait()  # Wait for thread 2 to call `load`
            assert_object['is_same_thread_1'] = (
                promise.get() == threading.current_thread().getName())
            event_3.set()  # Unblock thread 2

        do().get()

    def task_2():
        @Promise.safe
        def do():
            promise = thread_name_loader.load(2)
            event_2.set()
            event_3.wait()  # Wait for thread 1 to run `dispatch_queue_batch`
            assert_object['is_same_thread_2'] = (
                promise.get() == threading.current_thread().getName())

        do().get()

    thread_1 = threading.Thread(target=task_1)
    thread_1.start()

    event_1.wait()  # Wait for thread 1 to call `load`

    thread_2 = threading.Thread(target=task_2)
    thread_2.start()

    for thread in (thread_1, thread_2):
        thread.join()

    assert assert_object['is_same_thread_1']
    assert assert_object['is_same_thread_2']
def id_loader(**options):
    load_calls = []

    resolve = options.pop("resolve", Promise.resolve)

    def fn(keys):
        load_calls.append(keys)
        return resolve(keys)

    identity_loader = DataLoader(fn, **options)
    return identity_loader, load_calls
Пример #3
0
def test_build_a_simple_data_loader():
    def call_fn(keys):
        return Promise.resolve(keys)

    identity_loader = DataLoader(call_fn)

    promise1 = identity_loader.load(1)
    assert isinstance(promise1, Promise)

    value1 = promise1.get()
    assert value1 == 1
Пример #4
0
def test_supports_loading_multiple_keys_in_one_call():
    def call_fn(keys):
        return Promise.resolve(keys)

    identity_loader = DataLoader(call_fn)

    promise_all = identity_loader.load_many([1, 2])
    assert isinstance(promise_all, Promise)

    values = promise_all.get()
    assert values == [1, 2]

    promise_all = identity_loader.load_many([])
    assert isinstance(promise_all, Promise)

    values = promise_all.get()
    assert values == []
Пример #5
0
def GeoFromCoordsLoaderFactory():
    '''
    Get geographic information for this coordinate.
    GeoLoaderFactory({...})
        .load({ latitude: y, longitude: x})
    GeoLoaderFactory({...})
        .load_many([{ latitude: y, longitude: y}, { latitude: y, longitude: x}])
    '''

    def get_cache_key(key):
        try:
            return f"{key.get('latitude')}::{key.get('longitude')}"
        except:
            return key

    def initialise_graphql_type(row):
        if row is None:
            return row
        return create_graphql_instance(
            GeocodeResult,
            **row
        )

    def batch_load_fields(coordinates):
        geos = bulk_coordinate_geo(coordinates)

        if geos is None or len(geos) == 0:
            return Promise.resolve([None for key in coordinates])

        key_map = [
            next((
                initialise_graphql_type(geo.get('result')) if geo.get('result') else None
                for geo in geos
                if geo['query'].get('latitude') == coord.get('latitude')
                and geo['query'].get('longitude') == coord.get('longitude')
            ), None)
            for coord in coordinates
        ]

        return Promise.resolve(key_map)

    return DataLoader(batch_load_fields, get_cache_key=get_cache_key)
Пример #6
0
def GeoFromPostcodesLoaderFactory():
    '''
    Get geographic information for this postcode.

    GeoLoaderFactory({...})
        .load("LE115AG")
    GeoLoaderFactory({...})
        .load_many(["LE115AG", "WDD351"])
    '''

    def initialise_graphql_type(row):
        if row is None:
            return row
        return create_graphql_instance(
            GeocodeResult,
            **row
        )

    def batch_load_fields(postcodes):
        postcodes = [p.replace(" ", "") for p in postcodes]
        if len(postcodes) == 1:
            geo = postcode_geo(postcodes[0])
            return Promise.resolve([geo])

        geos = bulk_postcode_geo(postcodes)

        if geos is None or len(geos) == 0:
            return Promise.resolve([None for key in postcodes])

        key_map = [
            next((
                initialise_graphql_type(geo.get('result')) if geo.get('result') else None
                for geo in geos
                if geo['query'].replace(" ", "") == postcode
            ), None)
            for postcode in postcodes
        ]

        return Promise.resolve(key_map)

    return DataLoader(batch_load_fields)
Пример #7
0
    def __init__(self, rest_object_class, source_field_name='id', filter_field_name='id', is_top_level=False, many=False, *args, **kwargs):
        assert is_top_level or not (source_field_name == 'id' and filter_field_name == 'id') 
        self.source_field_name = source_field_name
        self.filter_field_name = filter_field_name
        self.rest_object_class = rest_object_class
        self.is_top_level = is_top_level

        self.request_maker = RequestMaker(
            filter_by_parent_fields=(not is_top_level),
            filter_field_name=filter_field_name
        )

        def batch_load_fn(source_values):
            self.request_maker.filter_values = source_values
            response = self.request_maker.make_request()
            return Promise.resolve(response.json()['results'])

        self.data_loader = DataLoader(batch_load_fn)

        self.many = many
        if self.many:
            super().__init__(graphene.List(rest_object_class), *args, **kwargs)
        else:
            super().__init__(rest_object_class, *args, **kwargs)
Пример #8
0
def test_should_query_dataloader_fields():
    from promise import Promise
    from promise.dataloader import DataLoader

    def article_batch_load_fn(keys):
        queryset = Article.objects.filter(reporter_id__in=keys)
        return Promise.resolve(
            [[article for article in queryset if article.reporter_id == id]
             for id in keys])

    article_loader = DataLoader(article_batch_load_fn)

    class ArticleType(DjangoObjectType):
        class Meta:
            model = Article
            interfaces = (Node, )

    class ReporterType(DjangoObjectType):
        class Meta:
            model = Reporter
            interfaces = (Node, )
            use_connection = True

        articles = DjangoConnectionField(ArticleType)

        def resolve_articles(self, info, **args):
            return article_loader.load(self.id)

    class Query(graphene.ObjectType):
        all_reporters = DjangoConnectionField(ReporterType)

    r = Reporter.objects.create(first_name="John",
                                last_name="Doe",
                                email="*****@*****.**",
                                a_choice=1)

    Article.objects.create(
        headline="Article Node 1",
        pub_date=datetime.date.today(),
        pub_date_time=datetime.datetime.now(),
        reporter=r,
        editor=r,
        lang="es",
    )
    Article.objects.create(
        headline="Article Node 2",
        pub_date=datetime.date.today(),
        pub_date_time=datetime.datetime.now(),
        reporter=r,
        editor=r,
        lang="en",
    )

    schema = graphene.Schema(query=Query)
    query = """
        query ReporterPromiseConnectionQuery {
            allReporters(first: 1) {
                edges {
                    node {
                        id
                        articles(first: 2) {
                            edges {
                                node {
                                    headline
                                }
                            }
                        }
                    }
                }
            }
        }
    """

    expected = {
        "allReporters": {
            "edges": [{
                "node": {
                    "id": "UmVwb3J0ZXJUeXBlOjE=",
                    "articles": {
                        "edges": [
                            {
                                "node": {
                                    "headline": "Article Node 1"
                                }
                            },
                            {
                                "node": {
                                    "headline": "Article Node 2"
                                }
                            },
                        ]
                    },
                }
            }]
        }
    }

    result = schema.execute(query)
    assert not result.errors
    assert result.data == expected
Пример #9
0
 def create_data_loader():
     return DataLoader(
         batch_load_fn=batch_load_fn,
         max_batch_size=500,
     )
Пример #10
0
def get_for_model(model):
    """Create dataloader for model.  """

    return DataLoader(_get_model_batch_load_fn(model),
                      get_cache_key=_get_model_cache_key)
Пример #11
0
def test_should_query_dataloader_fields():
    from promise import Promise
    from promise.dataloader import DataLoader

    def article_batch_load_fn(keys):
        queryset = Article.objects.filter(reporter_id__in=keys)
        return Promise.resolve([
            [article for article in queryset if article.reporter_id == id]
            for id in keys
        ])

    article_loader = DataLoader(article_batch_load_fn)

    class ArticleType(DjangoObjectType):

        class Meta:
            model = Article
            interfaces = (Node, )

    class ReporterType(DjangoObjectType):

        class Meta:
            model = Reporter
            interfaces = (Node, )
            use_connection = True

        articles = DjangoConnectionField(ArticleType)

        def resolve_articles(self, info, **args):
            return article_loader.load(self.id)

    class Query(graphene.ObjectType):
        all_reporters = DjangoConnectionField(ReporterType)

    r = Reporter.objects.create(
        first_name='John',
        last_name='Doe',
        email='*****@*****.**',
        a_choice=1
    )

    Article.objects.create(
        headline='Article Node 1',
        pub_date=datetime.date.today(),
        pub_date_time=datetime.datetime.now(),
        reporter=r,
        editor=r,
        lang='es'
    )
    Article.objects.create(
        headline='Article Node 2',
        pub_date=datetime.date.today(),
        pub_date_time=datetime.datetime.now(),
        reporter=r,
        editor=r,
        lang='en'
    )

    schema = graphene.Schema(query=Query)
    query = '''
        query ReporterPromiseConnectionQuery {
            allReporters(first: 1) {
                edges {
                    node {
                        id
                        articles(first: 2) {
                            edges {
                                node {
                                    headline
                                }
                            }
                        }
                    }
                }
            }
        }
    '''

    expected = {
        'allReporters': {
            'edges': [{
                'node': {
                    'id': 'UmVwb3J0ZXJUeXBlOjE=',
                    'articles': {
                        'edges': [{
                            'node': {
                                'headline': 'Article Node 1',
                            }
                        }, {
                            'node': {
                                'headline': 'Article Node 2'
                            }
                        }]
                    }
                }
            }]
        }
    }

    result = schema.execute(query)
    assert not result.errors
    assert result.data == expected
Пример #12
0
    key = 1
    keys = [1, 2, 3]

    class InnerClass(object):
        key = 2
        keys = [4, 5, 6]

    def resolver(self):
        return "resolver method"


def batch_load_fn(keys):
    return Promise.all(keys)


data_loader = DataLoader(batch_load_fn=batch_load_fn)


class PermissionFieldTests(TestCase):
    def test_permission_field(self):
        MyType = object()
        field = DjangoField(MyType, permissions=["perm1", "perm2"], source="resolver")
        resolver = field.get_resolver(None)

        class Viewer(object):
            def has_perm(self, perm):
                return perm == "perm2"

        info = mock.Mock(context=mock.Mock(user=Viewer()))

        self.assertEqual(resolver(MyInstance(), info), MyInstance().resolver())
Пример #13
0
#             print('ConnectionClass.Edge : ', args, kwargs)
#             super().__init__(*args, **kwargs)
#         other = graphene.String()

from promise import Promise
from promise.dataloader import DataLoader


def article_batch_load_fn(keys):
    queryset = models.TodoExtra.objects.filter(pk__in=keys)
    return Promise.resolve(
        [[article for article in queryset if article.reporter_id == id]
         for id in keys])


article_loader = DataLoader(article_batch_load_fn)


@test_dec2
class TodoExtraNode(DjangoObjectType):
    class Meta:
        model = models.TodoExtra
        exclude = ()
        # fields = ['description']
        filter_fields = {}
        interfaces = (graphene.relay.Node, )

    @classmethod
    def get_queryset(cls, queryset, info):
        # TodoNode.extraからアクセスするときに呼び出されないのはなぜ・・・??
        # manyの場合にしかconnectionは張られない・・・