コード例 #1
0
async def test_cancel_worker_thread(cancellable, expected_last_active):
    """
    Test that when a task running a worker thread is cancelled, the cancellation is not acted on
    until the thread finishes.

    """
    def thread_worker():
        nonlocal last_active
        from_thread.run_sync(sleep_event.set)
        time.sleep(0.2)
        last_active = 'thread'
        from_thread.run_sync(finish_event.set)

    async def task_worker():
        nonlocal last_active
        try:
            await to_thread.run_sync(thread_worker, cancellable=cancellable)
        finally:
            last_active = 'task'

    sleep_event = Event()
    finish_event = Event()
    last_active = None
    async with create_task_group() as tg:
        tg.start_soon(task_worker)
        await sleep_event.wait()
        tg.cancel_scope.cancel()

    await finish_event.wait()
    assert last_active == expected_last_active
コード例 #2
0
    async def test_event(self):
        async def setter():
            assert not event.is_set()
            event.set()

        event = Event()
        async with create_task_group() as tg:
            tg.start_soon(setter)
            await event.wait()

        assert event.is_set()
コード例 #3
0
    async def test_event_cancel(self):
        async def task():
            nonlocal task_started, event_set
            task_started = True
            await event.wait()
            event_set = True

        task_started = event_set = False
        event = Event()
        async with create_task_group() as tg:
            tg.start_soon(task)
            tg.cancel_scope.cancel()
            event.set()

        assert task_started
        assert not event_set
コード例 #4
0
    async def test_handshake_fail(self, server_context):
        def handler(stream):
            pytest.fail('This function should never be called in this scenario')

        class CustomTLSListener(TLSListener):
            async def handle_handshake_error(self, exc: BaseException,
                                             stream: AnyByteStream) -> None:
                nonlocal exception
                await super().handle_handshake_error(exc, stream)
                assert isinstance(stream, SocketStream)
                exception = exc
                event.set()

        exception = None
        event = Event()
        listener = await create_tcp_listener(local_host='127.0.0.1')
        tls_listener = CustomTLSListener(listener, server_context)
        async with tls_listener, create_task_group() as tg:
            tg.start_soon(tls_listener.serve, handler)
            sock = socket.socket()
            sock.connect(listener.extra(SocketAttribute.local_address))
            sock.close()
            await event.wait()
            tg.cancel_scope.cancel()

        assert isinstance(exception, BrokenResourceError)
コード例 #5
0
async def test_get_running_tasks() -> None:
    async def inspect() -> None:
        await wait_all_tasks_blocked()
        new_tasks = set(get_running_tasks()) - existing_tasks
        task_infos[:] = sorted(new_tasks, key=lambda info: info.name or "")
        event.set()

    event = Event()
    task_infos: List[TaskInfo] = []
    host_task = get_current_task()
    async with create_task_group() as tg:
        existing_tasks = set(get_running_tasks())
        tg.start_soon(event.wait, name="task1")
        tg.start_soon(event.wait, name="task2")
        tg.start_soon(inspect)

    assert len(task_infos) == 3
    expected_names = [
        "task1",
        "task2",
        "tests.test_debugging.test_get_running_tasks.<locals>.inspect",
    ]
    for task, expected_name in zip(task_infos, expected_names):
        assert task.parent_id == host_task.id
        assert task.name == expected_name
        assert repr(task) == f"TaskInfo(id={task.id}, name={expected_name!r})"
コード例 #6
0
def test_wait_generator_based_task_blocked():
    from asyncio import DefaultEventLoopPolicy, Event, coroutine, set_event_loop

    async def native_coro_part():
        await wait_all_tasks_blocked()
        assert not gen_task._coro.gi_running
        if sys.version_info < (3, 7):
            assert gen_task._coro.gi_yieldfrom.gi_code.co_name == 'wait'
        else:
            assert gen_task._coro.gi_yieldfrom.cr_code.co_name == 'wait'

        event.set()

    @coroutine
    def generator_part():
        yield from event.wait()

    loop = DefaultEventLoopPolicy().new_event_loop()
    try:
        set_event_loop(loop)
        event = Event()
        gen_task = loop.create_task(generator_part())
        loop.run_until_complete(native_coro_part())
    finally:
        set_event_loop(None)
        loop.close()
コード例 #7
0
async def test_multi_listener(tmp_path_factory):
    async def handle(stream):
        client_addresses.append(stream.extra(SocketAttribute.remote_address))
        event.set()
        await stream.aclose()

    client_addresses = []
    listeners = [await create_tcp_listener(local_host='localhost')]
    if sys.platform != 'win32':
        socket_path = tmp_path_factory.mktemp('unix').joinpath('socket')
        listeners.append(await create_unix_listener(socket_path))

    expected_addresses = []
    async with MultiListener(listeners) as multi_listener:
        async with create_task_group() as tg:
            tg.start_soon(multi_listener.serve, handle)
            for listener in multi_listener.listeners:
                event = Event()
                local_address = listener.extra(SocketAttribute.local_address)
                if sys.platform != 'win32' and listener.extra(SocketAttribute.family) == \
                        socket.AddressFamily.AF_UNIX:
                    stream = await connect_unix(local_address)
                else:
                    stream = await connect_tcp(*local_address)

                expected_addresses.append(
                    stream.extra(SocketAttribute.local_address))
                await event.wait()
                await stream.aclose()

            tg.cancel_scope.cancel()

    assert client_addresses == expected_addresses
コード例 #8
0
    async def remote_select_channel_contents(self, **kwargs):
        peers_to_query = self.get_known_subscribed_peers_for_node(
            kwargs["channel_pk"], kwargs["origin_id"])
        if not peers_to_query:
            raise NoChannelSourcesException()

        result = []
        async with create_task_group() as tg:
            got_at_least_one_response = Event()

            async def _send_remote_select(peer):
                request = self.send_remote_select(peer,
                                                  force_eva_response=True,
                                                  **kwargs)
                await request.processing_results

                # Stop execution if we already received the results from another coroutine
                if result or got_at_least_one_response.is_set():
                    return

                result.extend(request.processing_results.result())
                got_at_least_one_response.set()

            for peer in peers_to_query:
                # Before issuing another request, check if we possibly already received a response
                if got_at_least_one_response.is_set():
                    break

                # Issue a request to another peer
                tg.start_soon(_send_remote_select, peer)
                with move_on_after(happy_eyeballs_delay):
                    await got_at_least_one_response.wait()
            await got_at_least_one_response.wait()

            # Cancel the remaining requests so we don't have to wait for them to finish
            tg.cancel_scope.cancel()

        request_results = [r.md_obj.to_simple_dict() for r in result]
        return request_results
コード例 #9
0
    async def test_increase_tokens(self):
        async def setter():
            # Wait until waiter() is inside the limiter block
            await event1.wait()
            async with limiter:
                # This can only happen when total_tokens has been increased
                event2.set()

        async def waiter():
            async with limiter:
                event1.set()
                await event2.wait()

        limiter = CapacityLimiter(1)
        event1, event2 = Event(), Event()
        async with create_task_group() as tg:
            tg.start_soon(setter)
            tg.start_soon(waiter)
            await wait_all_tasks_blocked()
            assert event1.is_set()
            assert not event2.is_set()
            limiter.total_tokens = 2

        assert event2.is_set()
コード例 #10
0
    async def test_statistics(self):
        async def waiter():
            await event.wait()

        event = Event()
        async with create_task_group() as tg:
            assert event.statistics().tasks_waiting == 0
            for i in range(1, 3):
                tg.start_soon(waiter)
                await wait_all_tasks_blocked()
                assert event.statistics().tasks_waiting == i

            event.set()

        assert event.statistics().tasks_waiting == 0
コード例 #11
0
    async def prepare_mod(cls, mod_obj: Extension):
        lowest_mc_version = min(map(map_version,
                                    ServerCore.minecraft_versions))

        if mod_obj.mc_version > lowest_mc_version:
            raise Exception(
                f"Expected at least minecraft version {mod_obj.mc_version}, "
                f"got {lowest_mc_version}")

        mod_locks = []

        for dep in mod_obj.dependencies:
            if dep.dependency_id not in cls.mods:
                raise Exception(f"Missing dependency {dep.dependency_id} "
                                f"for mod {mod_obj.id}")

            lock = Event()

            dep_obj = cls.get_mod(dep.dependency_id)

            if dep_obj._instance is None:
                dep_obj._locks.append(lock)
                mod_locks.append(lock)

            dep_version = dep_obj.version

            if map_version(dep.dependency_min_version) > dep_version:
                raise Exception(
                    f"Dependency {dep.dependency_id} out of date! "
                    f"Expected at least version {dep.dependency_min_version}, got {dep_version}"
                )

            if dep.dependency_max_version is not None and map_version(
                    dep.dependency_max_version) < dep_version:
                raise Exception(
                    f"Dependency {dep.dependency_id} too new! "
                    f"Expected at most version {dep.dependency_min_version}, got {dep_version}"
                )

        for lock in mod_locks:
            await lock.wait()

        mod_obj._instance = mod_obj.cls()

        for lock in mod_obj._locks:
            await lock.set()
コード例 #12
0
    async def test_wait_cancel(self):
        async def task():
            nonlocal task_started, notified
            task_started = True
            async with condition:
                event.set()
                await condition.wait()
                notified = True

        task_started = notified = False
        event = Event()
        condition = Condition()
        async with create_task_group() as tg:
            tg.start_soon(task)
            await event.wait()
            await wait_all_tasks_blocked()
            tg.cancel_scope.cancel()

        assert task_started
        assert not notified
コード例 #13
0
ファイル: test_compat.py プロジェクト: agronholm/anyio
 async def test_event_set(self) -> None:
     event = Event()
     with pytest.deprecated_call():
         await event.set()
コード例 #14
0
 def suspend(self) -> None:
     """Temporarily suspends the network event detector."""
     self._suspended += 1
     if self._suspended and not self._resume_event:
         self._resume_event = Event()