async def test_when_raising_error_in_loader():
    async def idx(keys):
        raise ValueError()

    loader = DataLoader(load_fn=idx)

    with pytest.raises(ValueError):
        await loader.load(1)

    with pytest.raises(ValueError):
        await asyncio.gather(
            loader.load(1),
            loader.load(2),
            loader.load(3),
        )
async def test_caches_by_id_when_loading_many(mocker):
    async def idx(keys):
        return keys

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader, cache=True)

    a = loader.load(1)
    b = loader.load(1)

    assert a == b

    assert [1, 1] == await asyncio.gather(a, b)

    mock_loader.assert_called_once_with([1])
async def test_fetch_data_from_db(mocker):
    def _sync_batch_load(keys):
        data = Example.objects.filter(id__in=keys)

        return list(data)

    prepare_db = sync_to_async(_prepare_db)
    batch_load = sync_to_async(_sync_batch_load)

    ids = await prepare_db()

    async def idx(keys) -> List[Example]:
        return await batch_load(keys)

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader)

    @strawberry.type
    class Query:
        hello: str = "strawberry"

        @strawberry.field
        async def get_example(self, id: strawberry.ID) -> str:
            example = await loader.load(id)

            return example.name

    schema = strawberry.Schema(query=Query)

    query = f"""{{
        a: getExample(id: "{ids[0]}")
        b: getExample(id: "{ids[1]}")
        c: getExample(id: "{ids[2]}")
        d: getExample(id: "{ids[3]}")
        e: getExample(id: "{ids[4]}")
    }}"""

    factory = RequestFactory()
    request = factory.post(
        "/graphql/", {"query": query}, content_type="application/json"
    )

    response = await AsyncGraphQLView.as_view(schema=schema)(request)
    data = json.loads(response.content.decode())

    assert not data.get("errors")
    assert data["data"] == {
        "a": "This is a demo async 0",
        "b": "This is a demo async 1",
        "c": "This is a demo async 2",
        "d": "This is a demo async 3",
        "e": "This is a demo async 4",
    }

    reset_db = sync_to_async(lambda: Example.objects.all().delete())
    await reset_db()

    mock_loader.assert_called_once_with([str(id_) for id_ in ids])
async def test_error():
    async def idx(keys):
        return [ValueError()]

    loader = DataLoader(load_fn=idx)

    with pytest.raises(ValueError):
        await loader.load(1)
async def test_cache_disabled(mocker):
    async def idx(keys):
        return keys

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader, cache=False)

    a = loader.load(1)
    b = loader.load(1)

    assert a != b

    assert await a == 1
    assert await b == 1

    mock_loader.assert_has_calls([mocker.call([1, 1])])
async def test_caches_by_id(mocker):
    async def idx(keys):
        return keys

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader, cache=True)

    a = loader.load(1)
    b = loader.load(1)

    assert a == b

    assert await a == 1
    assert await b == 1

    mock_loader.assert_called_once_with([1])
async def test_gathering(mocker):
    async def idx(keys):
        return keys

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader)

    [value_a, value_b, value_c] = await asyncio.gather(
        loader.load(1),
        loader.load(2),
        loader.load(3),
    )

    mock_loader.assert_called_once_with([1, 2, 3])

    assert value_a == 1
    assert value_b == 2
    assert value_c == 3
async def test_max_batch_size(mocker):
    async def idx(keys):
        return keys

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader, max_batch_size=2)

    [value_a, value_b, value_c] = await asyncio.gather(
        loader.load(1),
        loader.load(2),
        loader.load(3),
    )

    mock_loader.assert_has_calls([mocker.call([1, 2]), mocker.call([3])])

    assert value_a == 1
    assert value_b == 2
    assert value_c == 3
async def test_returning_wrong_number_of_results():
    async def idx(keys):
        return [1, 2]

    loader = DataLoader(load_fn=idx)

    with pytest.raises(
            WrongNumberOfResultsReturned,
            match=("Received wrong number of results in dataloader, "
                   "expected: 1, received: 2"),
    ):
        await loader.load(1)
async def test_error_and_values():
    async def idx(keys):
        if keys == [2]:
            return [2]

        return [ValueError()]

    loader = DataLoader(load_fn=idx)

    with pytest.raises(ValueError):
        await loader.load(1)

    assert await loader.load(2) == 2
async def test_loading():
    async def idx(keys):
        return keys

    loader = DataLoader(load_fn=idx)

    value_a = await loader.load(1)
    value_b = await loader.load(2)
    value_c = await loader.load(3)

    assert value_a == 1
    assert value_b == 2
    assert value_c == 3
Exemple #12
0
    def get_group_dataloader(
            self, key_field: str) -> DataLoader[int, List[LoaderType]]:
        async def load_fn(keys: List[int]) -> List[List[LoaderType]]:
            matching_rows = await self.model.load_all(**{
                key_field: keys
            })  # where `<key_field>` in `<keys>`
            return [[
                self.constructor(**row._asdict()) for row in group
            ] for group in Loader.groupBy(
                keys, matching_rows, lambda row: getattr(row, key_field, None))
                    ]

        return DataLoader(load_fn)
Exemple #13
0
    def get_dataloader(
            self, key_field: str) -> DataLoader[int, Optional[LoaderType]]:
        async def load_fn(keys: List[int]) -> List[Optional[LoaderType]]:
            matching_rows = await self.model.load_all(**{
                key_field: keys
            })  # where `<key_field>` in `<keys>`
            return [
                self.constructor(**row._asdict()) if row else None
                for row in Loader.fillBy(
                    keys, matching_rows,
                    lambda row: getattr(row, key_field, None))
            ]

        return DataLoader(load_fn)
async def test_cache_disabled_immediate_await(mocker):
    async def idx(keys):
        return keys

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader, cache=False)

    a = await loader.load(1)
    b = await loader.load(1)

    assert a == b

    mock_loader.assert_has_calls([mocker.call([1]), mocker.call([1])])
Exemple #15
0
def test_works_when_created_in_a_different_loop(mocker):
    async def idx(keys):
        return keys

    mock_loader = mocker.Mock(side_effect=idx)
    loader = DataLoader(load_fn=mock_loader, cache=False)

    loop = asyncio.new_event_loop()

    async def run():
        return await loader.load(1)

    data = loop.run_until_complete(run())

    assert data == 1

    mock_loader.assert_called_once_with([1])
Exemple #16
0
async def test_can_use_dataloaders(mocker):
    @dataclass
    class User:
        id: str

    async def idx(keys) -> List[User]:
        return [User(key) for key in keys]

    mock_loader = mocker.Mock(side_effect=idx)

    loader = DataLoader(load_fn=mock_loader)

    @strawberry.type
    class Query:
        @strawberry.field
        async def get_user(self, id: strawberry.ID) -> str:
            user = await loader.load(id)

            return user.id

    schema = strawberry.Schema(query=Query)

    query = """{
        a: getUser(id: "1")
        b: getUser(id: "2")
    }"""

    result = await schema.execute(query)

    assert not result.errors
    assert result.data == {
        "a": "1",
        "b": "2",
    }

    mock_loader.assert_called_once_with(["1", "2"])
Exemple #17
0
from strawberry.dataloader import DataLoader

from users.db import get_engine, get_session
from users.domain.entities import User
from users.domain.repository import UsersRepository


async def load_users(ids: list[int]) -> list[User]:
    try:
        engine = get_engine()
        async with get_session(engine) as session:
            users = await UsersRepository(session).get_batch_by_ids(ids)
            users_by_id = {user.id: user for user in users}
            return [users_by_id.get(id) for id in ids]
    finally:
        await engine.dispose()


users_dataloader = DataLoader(load_fn=load_users)