async def can_unset_closed_state_of_async_iterator(): items = [1, 2, 3] class Iterator: def __init__(self): self.is_closed = False def __aiter__(self): return self async def __anext__(self): if self.is_closed: raise StopAsyncIteration try: return items.pop(0) except IndexError: raise StopAsyncIteration async def aclose(self): self.is_closed = True iterator = Iterator() doubles = MapAsyncIterator(iterator, lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 assert not iterator.is_closed await doubles.aclose() assert iterator.is_closed with raises(StopAsyncIteration): await anext(iterator) with raises(StopAsyncIteration): await anext(doubles) assert doubles.is_closed iterator.is_closed = False doubles.is_closed = False assert not doubles.is_closed assert await anext(doubles) == 6 assert not doubles.is_closed assert not iterator.is_closed with raises(StopAsyncIteration): await anext(iterator) with raises(StopAsyncIteration): await anext(doubles) assert not doubles.is_closed assert not iterator.is_closed
async def allows_returning_early_from_mapped_async_iterator(): items = [1, 2, 3] class Iterator: def __aiter__(self): return self async def __anext__(self): try: return items.pop(0) except IndexError: # pragma: no cover raise StopAsyncIteration doubles = MapAsyncIterator(Iterator(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Early return await doubles.aclose() # Subsequent next calls with raises(StopAsyncIteration): await anext(doubles) with raises(StopAsyncIteration): await anext(doubles)
async def allows_throwing_errors_through_async_iterators(): items = [1, 2, 3] class Iterator: def __aiter__(self): return self async def __anext__(self): try: return items.pop(0) except IndexError: # pragma: no cover raise StopAsyncIteration doubles = MapAsyncIterator(Iterator(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Throw error with raises(RuntimeError, match="Ouch") as exc_info: await doubles.athrow(RuntimeError("Ouch")) assert str(exc_info.value) == "Ouch" with raises(StopAsyncIteration): await anext(doubles) with raises(StopAsyncIteration): await anext(doubles)
async def cancellable_aiter(async_iterator: MapAsyncIterator, cancellation_event: Event, *, cancel_pending: bool = True, timeout: Optional[float] = None) -> AsyncIterator: """[summary] Args: async_iterator (MapAsyncIterator): The iterator to use cancellation_event (Event): A cancellable event cancel_pending (bool, optional): If True cancel pendings. Defaults to True. timeout (Optional[float], optional): A timeout. Defaults to None. Returns: AsyncIterator: The async iterator """ result_iter = async_iterator.__aiter__() cancellation_task = asyncio.create_task(cancellation_event.wait()) pending: Set["Future[Any]"] = { cancellation_task, asyncio.create_task(result_iter.__anext__()) } if timeout is None: sleep_task: "Optional[Future[Any]]" = None else: sleep_task = asyncio.create_task(asyncio.sleep(timeout)) pending.add(sleep_task) while not cancellation_event.is_set(): try: done, pending = await asyncio.wait( pending, return_when=asyncio.FIRST_COMPLETED) except asyncio.CancelledError: for pending_task in pending: pending_task.cancel() raise for done_task in done: if done_task == cancellation_task: for pending_task in pending: if cancel_pending: pending_task.cancel() else: await pending_task yield pending_task.result() break elif done_task == sleep_task: yield None else: yield done_task.result() pending.add(asyncio.create_task(result_iter.__anext__())) else: if timeout is not None: if sleep_task in pending: sleep_task.cancel() pending.discard(sleep_task) sleep_task = asyncio.create_task(asyncio.sleep(timeout)) pending.add(sleep_task)
async def stops_async_iteration_on_close(): async def source(): yield 1 await Event().wait() # Block forever yield 2 # pragma: no cover yield 3 # pragma: no cover singles = source() doubles = MapAsyncIterator(singles, lambda x: x * 2) result = await anext(doubles) assert result == 2 # Make sure it is blocked doubles_future = ensure_future(anext(doubles)) await sleep(0.05) assert not doubles_future.done() # Unblock and watch StopAsyncIteration propagate await doubles.aclose() await sleep(0.05) assert doubles_future.done() assert isinstance(doubles_future.exception(), StopAsyncIteration) with raises(StopAsyncIteration): await anext(singles)
async def test(): queue = asyncio.Queue() agen_closed = False async def make_agen(): try: while True: yield await queue.get() finally: nonlocal agen_closed agen_closed = True agen = make_agen() mai = MapAsyncIterator(agen, lambda v: v) received = [] async def run(): async for i in mai: received.append(i) task = asyncio.ensure_future(run()) await asyncio.sleep(.01) for i in range(5): await queue.put(i) await asyncio.sleep(.01) assert received == list(range(5)) await mai.aclose() await asyncio.sleep(.01) assert agen_closed assert task.done()
async def maps_over_async_values(): async def source(): yield 1 yield 2 yield 3 doubles = MapAsyncIterator(source(), lambda x: x + x) assert [value async for value in doubles] == [2, 4, 6]
async def compatible_with_async_for(): async def source(): yield 1 yield 2 yield 3 doubles = MapAsyncIterator(source(), lambda x: x + x) assert [value async for value in doubles] == [2, 4, 6]
async def does_not_normally_map_over_externally_thrown_errors(): async def source(): yield "Hello" doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == "HelloHello" with raises(RuntimeError): await doubles.athrow(RuntimeError("Goodbye"))
async def does_not_normally_map_over_thrown_errors(): async def source(): yield 'Hello' raise RuntimeError('Goodbye') doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 'HelloHello' with raises(RuntimeError): await anext(doubles)
async def maps_over_async_values_with_async_function(): async def source(): yield 1 yield 2 yield 3 async def double(x): return x + x doubles = MapAsyncIterator(source(), double) assert [value async for value in doubles] == [2, 4, 6]
async def maps_over_async_generator(): async def source(): yield 1 yield 2 yield 3 doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 assert await anext(doubles) == 6 with raises(StopAsyncIteration): assert await anext(doubles)
async def does_not_normally_map_over_thrown_errors(): async def source(): yield "Hello" raise RuntimeError("Goodbye") doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == "HelloHello" with raises(RuntimeError) as exc_info: await anext(doubles) assert str(exc_info.value) == "Goodbye"
async def maps_over_thrown_errors_if_second_callback_provided(): async def source(): yield "Hello" raise RuntimeError("Goodbye") doubles = MapAsyncIterator(source(), lambda x: x + x, lambda error: error) assert await anext(doubles) == "HelloHello" result = await anext(doubles) assert isinstance(result, RuntimeError) assert str(result) == "Goodbye" with raises(StopAsyncIteration): await anext(doubles)
async def maps_over_async_iterator(): items = [1, 2, 3] class Iterator: def __aiter__(self): return self async def __anext__(self): try: return items.pop(0) except IndexError: raise StopAsyncIteration doubles = MapAsyncIterator(Iterator(), lambda x: x + x) assert [value async for value in doubles] == [2, 4, 6]
async def can_cancel_async_iterator_while_waiting(): class Iterator: def __init__(self): self.is_closed = False self.value = 1 def __aiter__(self): return self async def __anext__(self): try: await sleep(0.5) return self.value # pragma: no cover except CancelledError: self.value = -1 raise async def aclose(self): self.is_closed = True iterator = Iterator() doubles = MapAsyncIterator(iterator, lambda x: x + x) # pragma: no cover exit cancelled = False async def iterator_task(): nonlocal cancelled try: async for _ in doubles: assert False # pragma: no cover except CancelledError: cancelled = True task = ensure_future(iterator_task()) await sleep(0.05) assert not cancelled assert not doubles.is_closed assert iterator.value == 1 assert not iterator.is_closed task.cancel() await sleep(0.05) assert cancelled assert iterator.value == -1 assert doubles.is_closed assert iterator.is_closed
async def allows_returning_early_from_mapped_async_generator(): async def source(): yield 1 yield 2 yield 3 # pragma: no cover doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Early return await doubles.aclose() # Subsequent next calls with raises(StopAsyncIteration): await anext(doubles) with raises(StopAsyncIteration): await anext(doubles)
async def allows_returning_early_from_async_values(): async def source(): yield 1 yield 2 yield 3 doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Early return await doubles.aclose() # Subsequent nexts with raises(StopAsyncIteration): await anext(doubles) with raises(StopAsyncIteration): await anext(doubles)
async def passes_through_caught_errors_through_async_generators(): async def source(): try: yield 1 yield 2 yield 3 # pragma: no cover except Exception as e: yield e doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Throw error await doubles.athrow(RuntimeError("ouch")) with raises(StopAsyncIteration): await anext(doubles) with raises(StopAsyncIteration): await anext(doubles)
async def allows_throwing_errors_through_async_generators(): async def source(): yield 1 yield 2 yield 3 doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Throw error with raises(RuntimeError) as exc_info: await doubles.athrow(RuntimeError("ouch")) assert str(exc_info.value) == "ouch" with raises(StopAsyncIteration): await anext(doubles) with raises(StopAsyncIteration): await anext(doubles)
async def stops_async_iteration_on_close(): async def source(): yield 1 await Event().wait() # Block forever yield 2 yield 3 doubles = MapAsyncIterator(source(), lambda x: x * 2) result = await anext(doubles) assert result == 2 # Block at event.wait() fut = ensure_future(anext(doubles)) await sleep(.01) assert not fut.done() # Trigger cancellation and watch StopAsyncIteration propogate await doubles.aclose() await sleep(.01) assert fut.done() assert isinstance(fut.exception(), StopAsyncIteration)
async def passes_through_early_return_from_async_values(): async def source(): try: yield 1 yield 2 yield 3 finally: yield 'done' yield 'last' doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Early return await doubles.aclose() # Subsequent nexts may yield from finally block assert await anext(doubles) == 'lastlast' with raises(GeneratorExit): assert await anext(doubles)
async def passes_through_early_return_from_async_values(): async def source(): try: yield 1 yield 2 yield 3 # pragma: no cover finally: yield "Done" yield "Last" doubles = MapAsyncIterator(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Early return await doubles.aclose() # Subsequent next calls may yield from finally block assert await anext(doubles) == "LastLast" with raises(GeneratorExit): assert await anext(doubles)
async def allows_throwing_errors_with_traceback_through_async_iterators(): class Iterator: def __aiter__(self): return self async def __anext__(self): return 1 one = MapAsyncIterator(Iterator(), lambda x: x) assert await anext(one) == 1 # Throw error with traceback passed separately try: raise RuntimeError("Ouch") except RuntimeError as error: with raises(RuntimeError) as exc_info: await one.athrow(error.__class__, None, error.__traceback__) assert exc_info.tb and error.__traceback__ assert exc_info.tb.tb_frame is error.__traceback__.tb_frame with raises(StopAsyncIteration): await anext(one)
async def can_use_simple_iterator_instead_of_generator(): async def source(): yield 1 yield 2 yield 3 class Source: def __init__(self): self.counter = 0 def __aiter__(self): return self async def __anext__(self): self.counter += 1 if self.counter > 3: raise StopAsyncIteration return self.counter def double(x): return x + x for iterator in source, Source: doubles = MapAsyncIterator(iterator(), double) await doubles.aclose() with raises(StopAsyncIteration): await anext(doubles) doubles = MapAsyncIterator(iterator(), double) assert await anext(doubles) == 2 assert await anext(doubles) == 4 assert await anext(doubles) == 6 with raises(StopAsyncIteration): await anext(doubles) doubles = MapAsyncIterator(iterator(), double) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Throw error with raises(RuntimeError) as exc_info: await doubles.athrow(RuntimeError("ouch")) assert str(exc_info.value) == "ouch" with raises(StopAsyncIteration): await anext(doubles) with raises(StopAsyncIteration): await anext(doubles) await doubles.athrow(RuntimeError("no more ouch")) with raises(StopAsyncIteration): await anext(doubles) await doubles.aclose() doubles = MapAsyncIterator(iterator(), double) assert await anext(doubles) == 2 assert await anext(doubles) == 4 try: raise ValueError("bad") except ValueError: tb = sys.exc_info()[2] # Throw error with raises(ValueError): await doubles.athrow(ValueError, None, tb) await sleep(0)