async def can_unset_closed_state_of_async_iterator():
        items = [1, 2, 3]

        class Iterator:
            def __init__(self):
                self.is_closed = False

            def __aiter__(self):
                return self

            async def __anext__(self):
                if self.is_closed:
                    raise StopAsyncIteration
                try:
                    return items.pop(0)
                except IndexError:
                    raise StopAsyncIteration

            async def aclose(self):
                self.is_closed = True

        iterator = Iterator()
        doubles = MapAsyncIterator(iterator, lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4
        assert not iterator.is_closed
        await doubles.aclose()
        assert iterator.is_closed
        with raises(StopAsyncIteration):
            await anext(iterator)
        with raises(StopAsyncIteration):
            await anext(doubles)
        assert doubles.is_closed

        iterator.is_closed = False
        doubles.is_closed = False
        assert not doubles.is_closed

        assert await anext(doubles) == 6
        assert not doubles.is_closed
        assert not iterator.is_closed
        with raises(StopAsyncIteration):
            await anext(iterator)
        with raises(StopAsyncIteration):
            await anext(doubles)
        assert not doubles.is_closed
        assert not iterator.is_closed
    async def allows_returning_early_from_mapped_async_iterator():
        items = [1, 2, 3]

        class Iterator:
            def __aiter__(self):
                return self

            async def __anext__(self):
                try:
                    return items.pop(0)
                except IndexError:  # pragma: no cover
                    raise StopAsyncIteration

        doubles = MapAsyncIterator(Iterator(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Early return
        await doubles.aclose()

        # Subsequent next calls
        with raises(StopAsyncIteration):
            await anext(doubles)
        with raises(StopAsyncIteration):
            await anext(doubles)
    async def allows_throwing_errors_through_async_iterators():
        items = [1, 2, 3]

        class Iterator:
            def __aiter__(self):
                return self

            async def __anext__(self):
                try:
                    return items.pop(0)
                except IndexError:  # pragma: no cover
                    raise StopAsyncIteration

        doubles = MapAsyncIterator(Iterator(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Throw error
        with raises(RuntimeError, match="Ouch") as exc_info:
            await doubles.athrow(RuntimeError("Ouch"))

        assert str(exc_info.value) == "Ouch"

        with raises(StopAsyncIteration):
            await anext(doubles)
        with raises(StopAsyncIteration):
            await anext(doubles)
async def cancellable_aiter(async_iterator: MapAsyncIterator,
                            cancellation_event: Event,
                            *,
                            cancel_pending: bool = True,
                            timeout: Optional[float] = None) -> AsyncIterator:
    """[summary]

    Args:
        async_iterator (MapAsyncIterator): The iterator to use
        cancellation_event (Event): A cancellable event
        cancel_pending (bool, optional): If True cancel pendings. Defaults to
            True.
        timeout (Optional[float], optional): A timeout. Defaults to None.

    Returns:
        AsyncIterator: The async iterator
    """
    result_iter = async_iterator.__aiter__()
    cancellation_task = asyncio.create_task(cancellation_event.wait())
    pending: Set["Future[Any]"] = {
        cancellation_task,
        asyncio.create_task(result_iter.__anext__())
    }

    if timeout is None:
        sleep_task: "Optional[Future[Any]]" = None
    else:
        sleep_task = asyncio.create_task(asyncio.sleep(timeout))
        pending.add(sleep_task)

    while not cancellation_event.is_set():
        try:
            done, pending = await asyncio.wait(
                pending, return_when=asyncio.FIRST_COMPLETED)
        except asyncio.CancelledError:
            for pending_task in pending:
                pending_task.cancel()
            raise

        for done_task in done:
            if done_task == cancellation_task:
                for pending_task in pending:
                    if cancel_pending:
                        pending_task.cancel()
                    else:
                        await pending_task
                        yield pending_task.result()
                break
            elif done_task == sleep_task:
                yield None
            else:
                yield done_task.result()
                pending.add(asyncio.create_task(result_iter.__anext__()))
        else:
            if timeout is not None:
                if sleep_task in pending:
                    sleep_task.cancel()
                    pending.discard(sleep_task)
                sleep_task = asyncio.create_task(asyncio.sleep(timeout))
                pending.add(sleep_task)
    async def stops_async_iteration_on_close():
        async def source():
            yield 1
            await Event().wait()  # Block forever
            yield 2  # pragma: no cover
            yield 3  # pragma: no cover

        singles = source()
        doubles = MapAsyncIterator(singles, lambda x: x * 2)

        result = await anext(doubles)
        assert result == 2

        # Make sure it is blocked
        doubles_future = ensure_future(anext(doubles))
        await sleep(0.05)
        assert not doubles_future.done()

        # Unblock and watch StopAsyncIteration propagate
        await doubles.aclose()
        await sleep(0.05)
        assert doubles_future.done()
        assert isinstance(doubles_future.exception(), StopAsyncIteration)

        with raises(StopAsyncIteration):
            await anext(singles)
Esempio n. 6
0
async def test():
    queue = asyncio.Queue()
    agen_closed = False

    async def make_agen():
        try:
            while True:
                yield await queue.get()
        finally:
            nonlocal agen_closed
            agen_closed = True

    agen = make_agen()
    mai = MapAsyncIterator(agen, lambda v: v)

    received = []

    async def run():
        async for i in mai:
            received.append(i)

    task = asyncio.ensure_future(run())
    await asyncio.sleep(.01)

    for i in range(5):
        await queue.put(i)
        await asyncio.sleep(.01)

    assert received == list(range(5))

    await mai.aclose()
    await asyncio.sleep(.01)
    assert agen_closed
    assert task.done()
    async def maps_over_async_values():
        async def source():
            yield 1
            yield 2
            yield 3

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert [value async for value in doubles] == [2, 4, 6]
    async def compatible_with_async_for():
        async def source():
            yield 1
            yield 2
            yield 3

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert [value async for value in doubles] == [2, 4, 6]
    async def does_not_normally_map_over_externally_thrown_errors():
        async def source():
            yield "Hello"

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == "HelloHello"

        with raises(RuntimeError):
            await doubles.athrow(RuntimeError("Goodbye"))
Esempio n. 10
0
    async def does_not_normally_map_over_thrown_errors():
        async def source():
            yield 'Hello'
            raise RuntimeError('Goodbye')

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 'HelloHello'

        with raises(RuntimeError):
            await anext(doubles)
    async def maps_over_async_values_with_async_function():
        async def source():
            yield 1
            yield 2
            yield 3

        async def double(x):
            return x + x

        doubles = MapAsyncIterator(source(), double)

        assert [value async for value in doubles] == [2, 4, 6]
    async def maps_over_async_generator():
        async def source():
            yield 1
            yield 2
            yield 3

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4
        assert await anext(doubles) == 6
        with raises(StopAsyncIteration):
            assert await anext(doubles)
    async def does_not_normally_map_over_thrown_errors():
        async def source():
            yield "Hello"
            raise RuntimeError("Goodbye")

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == "HelloHello"

        with raises(RuntimeError) as exc_info:
            await anext(doubles)

        assert str(exc_info.value) == "Goodbye"
    async def maps_over_thrown_errors_if_second_callback_provided():
        async def source():
            yield "Hello"
            raise RuntimeError("Goodbye")

        doubles = MapAsyncIterator(source(), lambda x: x + x, lambda error: error)

        assert await anext(doubles) == "HelloHello"

        result = await anext(doubles)
        assert isinstance(result, RuntimeError)
        assert str(result) == "Goodbye"

        with raises(StopAsyncIteration):
            await anext(doubles)
    async def maps_over_async_iterator():
        items = [1, 2, 3]

        class Iterator:
            def __aiter__(self):
                return self

            async def __anext__(self):
                try:
                    return items.pop(0)
                except IndexError:
                    raise StopAsyncIteration

        doubles = MapAsyncIterator(Iterator(), lambda x: x + x)

        assert [value async for value in doubles] == [2, 4, 6]
    async def can_cancel_async_iterator_while_waiting():
        class Iterator:
            def __init__(self):
                self.is_closed = False
                self.value = 1

            def __aiter__(self):
                return self

            async def __anext__(self):
                try:
                    await sleep(0.5)
                    return self.value  # pragma: no cover
                except CancelledError:
                    self.value = -1
                    raise

            async def aclose(self):
                self.is_closed = True

        iterator = Iterator()
        doubles = MapAsyncIterator(iterator,
                                   lambda x: x + x)  # pragma: no cover exit
        cancelled = False

        async def iterator_task():
            nonlocal cancelled
            try:
                async for _ in doubles:
                    assert False  # pragma: no cover
            except CancelledError:
                cancelled = True

        task = ensure_future(iterator_task())
        await sleep(0.05)
        assert not cancelled
        assert not doubles.is_closed
        assert iterator.value == 1
        assert not iterator.is_closed
        task.cancel()
        await sleep(0.05)
        assert cancelled
        assert iterator.value == -1
        assert doubles.is_closed
        assert iterator.is_closed
    async def allows_returning_early_from_mapped_async_generator():
        async def source():
            yield 1
            yield 2
            yield 3  # pragma: no cover

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Early return
        await doubles.aclose()

        # Subsequent next calls
        with raises(StopAsyncIteration):
            await anext(doubles)
        with raises(StopAsyncIteration):
            await anext(doubles)
    async def allows_returning_early_from_async_values():
        async def source():
            yield 1
            yield 2
            yield 3

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Early return
        await doubles.aclose()

        # Subsequent nexts
        with raises(StopAsyncIteration):
            await anext(doubles)
        with raises(StopAsyncIteration):
            await anext(doubles)
    async def passes_through_caught_errors_through_async_generators():
        async def source():
            try:
                yield 1
                yield 2
                yield 3  # pragma: no cover
            except Exception as e:
                yield e

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Throw error
        await doubles.athrow(RuntimeError("ouch"))

        with raises(StopAsyncIteration):
            await anext(doubles)
        with raises(StopAsyncIteration):
            await anext(doubles)
    async def allows_throwing_errors_through_async_generators():
        async def source():
            yield 1
            yield 2
            yield 3

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Throw error
        with raises(RuntimeError) as exc_info:
            await doubles.athrow(RuntimeError("ouch"))

        assert str(exc_info.value) == "ouch"

        with raises(StopAsyncIteration):
            await anext(doubles)
        with raises(StopAsyncIteration):
            await anext(doubles)
Esempio n. 21
0
    async def stops_async_iteration_on_close():
        async def source():
            yield 1
            await Event().wait()  # Block forever
            yield 2
            yield 3

        doubles = MapAsyncIterator(source(), lambda x: x * 2)

        result = await anext(doubles)
        assert result == 2

        # Block at event.wait()
        fut = ensure_future(anext(doubles))
        await sleep(.01)
        assert not fut.done()

        # Trigger cancellation and watch StopAsyncIteration propogate
        await doubles.aclose()
        await sleep(.01)
        assert fut.done()
        assert isinstance(fut.exception(), StopAsyncIteration)
Esempio n. 22
0
    async def passes_through_early_return_from_async_values():
        async def source():
            try:
                yield 1
                yield 2
                yield 3
            finally:
                yield 'done'
                yield 'last'

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Early return
        await doubles.aclose()

        # Subsequent nexts may yield from finally block
        assert await anext(doubles) == 'lastlast'
        with raises(GeneratorExit):
            assert await anext(doubles)
    async def passes_through_early_return_from_async_values():
        async def source():
            try:
                yield 1
                yield 2
                yield 3  # pragma: no cover
            finally:
                yield "Done"
                yield "Last"

        doubles = MapAsyncIterator(source(), lambda x: x + x)

        assert await anext(doubles) == 2
        assert await anext(doubles) == 4

        # Early return
        await doubles.aclose()

        # Subsequent next calls may yield from finally block
        assert await anext(doubles) == "LastLast"
        with raises(GeneratorExit):
            assert await anext(doubles)
    async def allows_throwing_errors_with_traceback_through_async_iterators():
        class Iterator:
            def __aiter__(self):
                return self

            async def __anext__(self):
                return 1

        one = MapAsyncIterator(Iterator(), lambda x: x)

        assert await anext(one) == 1

        # Throw error with traceback passed separately
        try:
            raise RuntimeError("Ouch")
        except RuntimeError as error:
            with raises(RuntimeError) as exc_info:
                await one.athrow(error.__class__, None, error.__traceback__)

            assert exc_info.tb and error.__traceback__
            assert exc_info.tb.tb_frame is error.__traceback__.tb_frame

        with raises(StopAsyncIteration):
            await anext(one)
    async def can_use_simple_iterator_instead_of_generator():
        async def source():
            yield 1
            yield 2
            yield 3

        class Source:
            def __init__(self):
                self.counter = 0

            def __aiter__(self):
                return self

            async def __anext__(self):
                self.counter += 1
                if self.counter > 3:
                    raise StopAsyncIteration
                return self.counter

        def double(x):
            return x + x

        for iterator in source, Source:
            doubles = MapAsyncIterator(iterator(), double)

            await doubles.aclose()

            with raises(StopAsyncIteration):
                await anext(doubles)

            doubles = MapAsyncIterator(iterator(), double)

            assert await anext(doubles) == 2
            assert await anext(doubles) == 4
            assert await anext(doubles) == 6

            with raises(StopAsyncIteration):
                await anext(doubles)

            doubles = MapAsyncIterator(iterator(), double)

            assert await anext(doubles) == 2
            assert await anext(doubles) == 4

            # Throw error
            with raises(RuntimeError) as exc_info:
                await doubles.athrow(RuntimeError("ouch"))

            assert str(exc_info.value) == "ouch"

            with raises(StopAsyncIteration):
                await anext(doubles)
            with raises(StopAsyncIteration):
                await anext(doubles)

            await doubles.athrow(RuntimeError("no more ouch"))

            with raises(StopAsyncIteration):
                await anext(doubles)

            await doubles.aclose()

            doubles = MapAsyncIterator(iterator(), double)

            assert await anext(doubles) == 2
            assert await anext(doubles) == 4

            try:
                raise ValueError("bad")
            except ValueError:
                tb = sys.exc_info()[2]

            # Throw error
            with raises(ValueError):
                await doubles.athrow(ValueError, None, tb)

        await sleep(0)