Ejemplo n.º 1
0
async def with_portal_run_sync(sync_fn: Callable[..., T], *args: Any, **kwds: Any) -> T:
    """Execute ``sync_fn(*args, **kwds)`` in a context that is able
    to use :func:`greenback.await_`.

    If the current task already has a greenback portal set up via a
    call to one of the other ``greenback.*_portal()`` functions, then
    :func:`with_portal_run` simply calls *sync_fn*.  If *sync_fn*
    uses :func:`greenback.await_`, the existing portal will take care
    of it.

    Otherwise (if there is no portal already available to the current task),
    :func:`with_portal_run_sync` creates a new portal which lasts only for the
    duration of the call to *sync_fn*.

    This function does *not* add any cancellation point or schedule point
    beyond those that already exist due to any :func:`await_`\\s inside *sync_fn*.
    """

    this_task = current_task()
    if this_task in task_has_portal:
        return sync_fn(*args, **kwds)
    task_has_portal.add(this_task)
    try:
        res: T = await _greenback_shim_sync(partial(sync_fn, *args, **kwds))
        return res
    finally:
        task_has_portal.remove(this_task)
Ejemplo n.º 2
0
 async def lock_entry_id(self, entry_id: EntryID) -> AsyncIterator[EntryID]:
     async with self.entry_locks[entry_id]:
         try:
             self.locking_tasks[entry_id] = lowlevel.current_task()
             yield entry_id
         finally:
             del self.locking_tasks[entry_id]
Ejemplo n.º 3
0
async def with_portal_run(
    async_fn: Callable[..., Awaitable[T]], *args: Any, **kwds: Any
) -> T:
    """Execute ``await async_fn(*args, **kwds)`` in a context that is able
    to use :func:`greenback.await_`.

    If the current task already has a greenback portal set up via a
    call to one of the other ``greenback.*_portal()`` functions, then
    :func:`with_portal_run` simply calls *async_fn*.  If *async_fn*
    uses :func:`greenback.await_`, the existing portal will take care
    of it.

    Otherwise (if there is no portal already available to the current task),
    :func:`with_portal_run` creates a new portal which lasts only for the
    duration of the call to *async_fn*. If *async_fn* then calls
    :func:`ensure_portal`, an additional portal will **not** be created:
    the task will still have just the portal installed by
    :func:`with_portal_run`, which will be removed when *async_fn* returns.

    This function does *not* add any cancellation point or schedule point
    beyond those that already exist inside *async_fn*.
    """

    this_task = current_task()
    if this_task in task_has_portal:
        return await async_fn(*args, **kwds)
    shim_coro = _greenback_shim(async_fn(*args, **kwds))  # type: ignore
    assert shim_coro.send(None) == "ready"
    task_has_portal.add(this_task)
    try:
        res: T = await shim_coro
        return res
    finally:
        task_has_portal.remove(this_task)
Ejemplo n.º 4
0
async def ensure_portal() -> None:
    """Ensure that the current async task is able to use :func:`greenback.await_`.

    If the current task has called :func:`ensure_portal` previously, calling
    it again is a no-op. Otherwise, :func:`ensure_portal` interposes a
    "coroutine shim" provided by `greenback` in between the event
    loop and the coroutine being used to run the task. For example,
    when running under Trio, `trio.lowlevel.Task.coro` is replaced with
    a wrapper around the coroutine it previously referred to. (The
    same thing happens under asyncio, but asyncio doesn't expose the
    coroutine field publicly, so some additional trickery is required
    in that case.)

    After installation of the coroutine shim, each task step passes
    through `greenback` on its way into and out of your code. At
    some performance cost, this effectively provides a **portal** that
    allows later calls to :func:`greenback.await_` in the same task to
    access an async environment, even if the function that calls
    :func:`await_` is a synchronous function.

    This function is a cancellation point and a schedule point (a checkpoint,
    in Trio terms) even if the calling task already had a portal set up.
    """

    this_task = current_task()
    if this_task not in task_has_portal:
        bestow_portal(this_task)

    # Execute a checkpoint so that we're now running inside the shim coroutine.
    # This is necessary in case the caller immediately invokes greenback.await_()
    # without any further checkpoints.
    library = sniffio.current_async_library()
    await sys.modules[library].sleep(0)
Ejemplo n.º 5
0
async def get_current_task() -> TaskInfo:
    task = trio_lowlevel.current_task()

    parent_id = None
    if task.parent_nursery and task.parent_nursery.parent_task:
        parent_id = id(task.parent_nursery.parent_task)

    return TaskInfo(id(task), parent_id, task.name, task.coro)
Ejemplo n.º 6
0
    async def park(self):
        task = lowlevel.current_task()
        self.tasks.add(task)

        def abort_fn(_):
            self.tasks.remove(task)
            return lowlevel.Abort.SUCCEEDED

        await lowlevel.wait_task_rescheduled(abort_fn)
Ejemplo n.º 7
0
 async def _tracked_child():
     # calling get_elapsed_descheduled_time() initiates tracking
     task = trio_lowlevel.current_task()
     assert instrument.get_elapsed_descheduled_time(task) == 0
     await trio.sleep(0)
     assert instrument.get_elapsed_descheduled_time(task) == 10 - 5
     await trio.sleep(0)
     assert instrument.get_elapsed_descheduled_time(task) == 20 - 5
     # time function is called twice for each deschedule
     assert time_fn.call_count == 4
Ejemplo n.º 8
0
    async def wait(self):
        if self._counter == 0:
            await _trio.checkpoint()
        else:
            task = _trio.current_task()
            self._tasks.add(task)

            def abort_fn(_):
                self._tasks.remove(task)
                return _trio.Abort.SUCCEEDED

            await _trio.wait_task_rescheduled(abort_fn)
Ejemplo n.º 9
0
def has_portal(
    task: Optional[Union["trio.lowlevel.Task", "asyncio.Task[Any]"]] = None
) -> bool:
    """Return true if the given *task* is currently able to use
    :func:`greenback.await_`, false otherwise. If no *task* is
    specified, query the currently executing task.
    """
    if task is None:
        try:
            task = current_task()
        except sniffio.AsyncLibraryNotFoundError:
            return False
    return task in task_has_portal
Ejemplo n.º 10
0
    async def get(self):
        if self._v is not PENDING:
            await _trio.checkpoint()
        else:
            task = _trio.current_task()
            self._tasks.add(task)

            def abort_fn(_):
                self._tasks.remove(task)
                return _trio.Abort.SUCCEEDED

            await _trio.wait_task_rescheduled(abort_fn)
        return self._v
Ejemplo n.º 11
0
async def test_serve_handler_nursery(nursery):
    task = current_task()
    async with trio.open_nursery() as handler_nursery:
        serve_with_nursery = partial(serve_websocket, echo_request_handler,
            HOST, 0, None, handler_nursery=handler_nursery)
        server = await nursery.start(serve_with_nursery)
        port = server.port
        # The server nursery begins with one task (server.listen).
        assert len(nursery.child_tasks) == 1
        no_clients_nursery_count = len(task.child_nurseries)
        async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as conn:
            # The handler nursery should have one task in it
            # (conn._reader_task).
            assert len(handler_nursery.child_tasks) == 1
Ejemplo n.º 12
0
async def test_serve(nursery):
    task = current_task()
    server = await nursery.start(serve_websocket, echo_request_handler, HOST,
                                 0, None)
    port = server.port
    assert server.port != 0
    # The server nursery begins with one task (server.listen).
    assert len(nursery.child_tasks) == 1
    no_clients_nursery_count = len(task.child_nurseries)
    async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as conn:
        # The server nursery has the same number of tasks, but there is now
        # one additional nested nursery.
        assert len(nursery.child_tasks) == 1
        assert len(task.child_nurseries) == no_clients_nursery_count + 1
Ejemplo n.º 13
0
            async def sub_and_print(
                delay: float,
            ) -> None:

                task = current_task()
                start = time.time()

                async with brx.subscribe() as lbrx:
                    while True:
                        print(f'{task.name}: starting consume loop')
                        try:
                            async for value in lbrx:
                                print(f'{task.name}: {value}')
                                await trio.sleep(delay)

                            if task.name == 'sub_1':
                                # trigger checkpoint to clean out other subs
                                await trio.sleep(0.01)

                                # the non-lagger got
                                # a ``trio.EndOfChannel``
                                # because the ``tx`` below was closed
                                assert len(lbrx._state.subs) == 1

                                await lbrx.aclose()

                                assert len(lbrx._state.subs) == 0

                        except trio.ClosedResourceError:
                            # only the fast sub will try to re-enter
                            # iteration on the now closed bcaster
                            assert task.name == 'sub_1'
                            return

                        except Lagged:
                            lag_time = time.time() - start
                            lags = laggers[task.name]
                            print(
                                f'restarting slow task {task.name} '
                                f'that bailed out on {lags}:{value} '
                                f'after {lag_time:.3f}')
                            if lags <= retries:
                                laggers[task.name] += 1
                                continue
                            else:
                                print(
                                    f'{task.name} was too slow and terminated '
                                    f'on {lags}:{value}')
                                return
Ejemplo n.º 14
0
def trio_perf_counter():
    """Trio task-local equivalent of time.perf_counter().

    For the current Trio task, return the value (in fractional seconds) of a
    performance counter, i.e. a clock with the highest available resolution to
    measure a short duration.  It includes time elapsed during time.sleep,
    but not trio.sleep.  The reference point of the returned value is
    undefined, so that only the difference between the results of consecutive
    calls is valid.

    Performance note: calling this function installs instrumentation on the
    Trio scheduler which may affect application performance.  The
    instrumentation is automatically removed when the corresponding tasks
    have exited.
    """
    trio_lowlevel.add_instrument(_instrument)
    task = trio_lowlevel.current_task()
    return perf_counter() - _instrument.get_elapsed_descheduled_time(task)
Ejemplo n.º 15
0
async def test_descheduled_time_instrument_exclude_children():
    time_fn = Mock(side_effect=[5, 10])
    instrument = _trio._DescheduledTimeInstrument(time_fn=time_fn)
    trio_lowlevel.add_instrument(instrument)

    task = trio_lowlevel.current_task()
    assert instrument.get_elapsed_descheduled_time(task) == 0

    async with trio.open_nursery() as nursery:

        @nursery.start_soon
        async def _untracked_child():
            await trio.sleep(0)

    assert instrument.get_elapsed_descheduled_time(task) == 10 - 5
    assert time_fn.call_count == 2  # 2 x 1 deschedule (due to nursery)

    # our task is still alive, so instrument remains active
    trio_lowlevel.remove_instrument(instrument)
Ejemplo n.º 16
0
async def ensure_sequence(

    stream: tractor.ReceiveMsgStream,
    sequence: list,
    delay: Optional[float] = None,

) -> None:

    name = current_task().name
    async with stream.subscribe() as bcaster:
        assert not isinstance(bcaster, type(stream))
        async for value in bcaster:
            print(f'{name} rx: {value}')
            assert value == sequence[0]
            sequence.remove(value)

            if delay:
                await trio.sleep(delay)

            if not sequence:
                # fully consumed
                break
Ejemplo n.º 17
0
def current_task() -> Union["trio.lowlevel.Task", "asyncio.Task[Any]"]:
    library = sniffio.current_async_library()
    if library == "trio":
        try:
            from trio.lowlevel import current_task
        except ImportError:  # pragma: no cover
            if not TYPE_CHECKING:
                from trio.hazmat import current_task

        return current_task()
    elif library == "asyncio":
        import asyncio

        if sys.version_info >= (3, 7):
            task = asyncio.current_task()
        else:
            task = asyncio.Task.current_task()
        if task is None:  # pragma: no cover
            # typeshed says this is possible, but I haven't been able to induce it
            raise RuntimeError("No asyncio task is running")
        return task
    else:
        raise RuntimeError(f"greenback does not support {library}")
Ejemplo n.º 18
0
async def with_portal_run_tree(
    async_fn: Callable[..., Awaitable[T]], *args: Any, **kwds: Any
) -> T:
    """Execute ``await async_fn(*args, **kwds)`` in a context that allows use
    of :func:`greenback.await_` both in *async_fn* itself and in any tasks
    that are spawned into child nurseries of *async_fn*, recursively.

    You can use this to create an entire Trio run (except system
    tasks) that runs with :func:`greenback.await_` available: say
    ``trio.run(with_portal_run_tree, main)``.

    This function does *not* add any cancellation point or schedule point
    beyond those that already exist inside *async_fn*.

    Availability: Trio only.

    .. note:: The automatic "portalization" of child tasks is
       implemented using a Trio `instrument <trio.abc.Instrument>`,
       which has a small performance impact on task spawning for the
       entire Trio run. To minimize this impact, a single instrument
       is used even if you have multiple :func:`with_portal_run_tree`
       calls running simultaneously, and the instrument will be
       removed as soon as all such calls have completed.

    """
    try:
        import trio

        try:
            from trio import lowlevel as trio_lowlevel
        except ImportError:  # pragma: no cover
            if not TYPE_CHECKING:
                from trio import hazmat as trio_hazmat

        this_task = trio_lowlevel.current_task()
    except Exception:
        raise RuntimeError("This function is only supported when running under Trio")

    global instrument_holder
    if instrument_holder is None:
        instrument_holder = trio_lowlevel.RunVar("greenback_instrument", default=None)
    instrument = instrument_holder.get()
    if instrument is None:
        # We're the only with_portal_run_tree() in this Trio run at the moment -->
        # set up the instrument and store it in the RunVar for other calls to find
        instrument = AutoPortalInstrument()
        trio_lowlevel.add_instrument(instrument)
        instrument_holder.set(instrument)
    elif this_task in instrument.tasks:
        # We're already inside another call to with_portal_run_tree(), so nothing
        # more needs to be done
        assert this_task in task_has_portal
        return await async_fn(*args, **kwds)

    # Store our current nursery depth. This allows the instrument to
    # distinguish new tasks spawned in child nurseries of async_fn()
    # (which should get auto-portalized) from new tasks spawned in
    # nurseries that enclose this call (which shouldn't, even if they
    # have the same parent task).
    instrument.tasks[this_task] = len(this_task.child_nurseries)
    instrument.refs += 1
    try:
        return await with_portal_run(async_fn, *args, **kwds)
    finally:
        del instrument.tasks[this_task]
        instrument.refs -= 1
        if instrument.refs == 0:
            # There are no more with_portal_run_tree() calls executing
            # in this run, so clean up the instrument.
            instrument_holder.set(None)
            trio_lowlevel.remove_instrument(instrument)
Ejemplo n.º 19
0
async def test_serve_with_zero_listeners(nursery):
    task = current_task()
    with pytest.raises(ValueError):
        server = WebSocketServer(echo_request_handler, [])
Ejemplo n.º 20
0
def await_(aw: Awaitable[T]) -> T:
    """Run an async function or await an awaitable from a synchronous function,
    using the portal set up for the current async task by :func:`ensure_portal`,
    :func:`bestow_portal`, :func:`with_portal_run`, or :func:`with_portal_run_sync`.

    ``greenback.await_(foo())`` is equivalent to ``await foo()``, except that
    the `greenback` version can be written in a synchronous function while
    the native version cannot.
    """
    try:
        task = current_task()
        if task not in task_has_portal:
            raise RuntimeError(
                "you must 'await greenback.ensure_portal()' in this task first"
            ) from None
        gr = greenlet.getcurrent().parent
    except BaseException:
        if isinstance(aw, collections.abc.Coroutine):
            # Suppress the "coroutine was never awaited" warning
            aw.close()
        raise

    # If this is a non-coroutine awaitable, turn it into a coroutine
    if isinstance(aw, collections.abc.Coroutine):
        coro: Coroutine[Any, Any, T] = aw
        trim_tb_frames = 2
    else:
        coro = adapt_awaitable(aw)
        trim_tb_frames = 3

    # Step through the coroutine until it's exhausted, sending each trap
    # into the portal for the event loop to process.
    next_send: outcome.Outcome[Any] = outcome.Value(None)
    while True:
        try:
            # next_yield is a Future (under asyncio) or a checkpoint
            # or WaitTaskRescheduled marker (under Trio)
            next_yield: Any = next_send.send(coro)  # type: ignore
        except StopIteration as ex:
            return ex.value  # type: ignore
        except BaseException as ex:
            # Trim internal frames for a nicer traceback.
            # ex.__traceback__ covers the next_send.send(coro) line above;
            # its tb_next is in Value.send() or Error.send();
            # and tb_next of that covers the outermost frame in the user's
            # coroutine, which is what interests us.
            tb = ex.__traceback__
            assert tb is not None
            for _ in range(trim_tb_frames):
                if tb.tb_next is None:
                    # If we get here, there were fewer traceback frames
                    # than we expected, meaning we probably didn't
                    # even make it to the user's code. Don't do any
                    # trimming.
                    raise
                tb = tb.tb_next
            exception_from_greenbacked_function = ex.with_traceback(tb)
            # This line shows up in tracebacks, so give the variable a good name
            raise exception_from_greenbacked_function

        # next_send is an outcome.Outcome representing the value or error
        # with which the event loop wants to resume the task
        next_send = gr.switch(next_yield)
Ejemplo n.º 21
0
 def _check_lock_status(self, entry_id: EntryID) -> None:
     task = self.locking_tasks.get(entry_id)
     if task != lowlevel.current_task():
         raise RuntimeError(
             f"Entry `{entry_id}` modified without beeing locked")
Ejemplo n.º 22
0
    async def receive(self) -> ReceiveType:

        key = self.key
        state = self._state

        # TODO: ideally we can make some way to "lock out" the
        # underlying receive channel in some way such that if some task
        # tries to pull from it directly (i.e. one we're unaware of)
        # then it errors out.

        # only tasks which have entered ``.subscribe()`` can
        # receive on this broadcaster.
        try:
            seq = state.subs[key]
        except KeyError:
            if self._closed:
                raise trio.ClosedResourceError

            raise RuntimeError(f'{self} is not registerd as subscriber')

        # check that task does not already have a value it can receive
        # immediately and/or that it has lagged.
        if seq > -1:
            # get the oldest value we haven't received immediately
            try:
                value = state.queue[seq]
            except IndexError:

                # adhere to ``tokio`` style "lagging":
                # "Once RecvError::Lagged is returned, the lagging
                # receiver's position is updated to the oldest value
                # contained by the channel. The next call to recv will
                # return this value."
                # https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging

                # decrement to the last value and expect
                # consumer to either handle the ``Lagged`` and come back
                # or bail out on its own (thus un-subscribing)
                state.subs[key] = state.maxlen - 1

                # this task was overrun by the producer side
                task: Task = current_task()
                raise Lagged(f'Task {task.name} was overrun')

            state.subs[key] -= 1
            return value

        # current task already has the latest value **and** is the
        # first task to begin waiting for a new one
        if state.recv_ready is None:

            if self._closed:
                raise trio.ClosedResourceError

            event = trio.Event()
            state.recv_ready = key, event

            # if we're cancelled here it should be
            # fine to bail without affecting any other consumers
            # right?
            try:
                value = await self._recv()

                # items with lower indices are "newer"
                # NOTE: ``collections.deque`` implicitly takes care of
                # trucating values outside our ``state.maxlen``. In the
                # alt-backend-array-case we'll need to make sure this is
                # implemented in similar ringer-buffer-ish style.
                state.queue.appendleft(value)

                # broadcast new value to all subscribers by increasing
                # all sequence numbers that will point in the queue to
                # their latest available value.

                # don't decrement the sequence for this task since we
                # already retreived the last value

                # XXX: which of these impls is fastest?

                # subs = state.subs.copy()
                # subs.pop(key)

                for sub_key in filter(
                        # lambda k: k != key, state.subs,
                        partial(ne, key),
                        state.subs,
                ):
                    state.subs[sub_key] += 1

                # NOTE: this should ONLY be set if the above task was *NOT*
                # cancelled on the `._recv()` call.
                event.set()
                return value

            except trio.EndOfChannel:
                # if any one consumer gets an EOC from the underlying
                # receiver we need to unblock and send that signal to
                # all other consumers.
                self._state.eoc = True
                if event.statistics().tasks_waiting:
                    event.set()
                raise

            except (trio.Cancelled, ):
                # handle cancelled specially otherwise sibling
                # consumers will be awoken with a sequence of -1
                # and will potentially try to rewait the underlying
                # receiver instead of just cancelling immediately.
                self._state.cancelled = True
                if event.statistics().tasks_waiting:
                    event.set()
                raise

            finally:

                # Reset receiver waiter task event for next blocking condition.
                # this MUST be reset even if the above ``.recv()`` call
                # was cancelled to avoid the next consumer from blocking on
                # an event that won't be set!
                state.recv_ready = None

        # This task is all caught up and ready to receive the latest
        # value, so queue sched it on the internal event.
        else:
            seq = state.subs[key]
            assert seq == -1  # sanity
            _, ev = state.recv_ready
            await ev.wait()

            # NOTE: if we ever would like the behaviour where if the
            # first task to recv on the underlying is cancelled but it
            # still DOES trigger the ``.recv_ready``, event we'll likely need
            # this logic:

            if seq > -1:
                # stuff from above..
                seq = state.subs[key]

                value = state.queue[seq]
                state.subs[key] -= 1
                return value

            elif seq == -1:
                # XXX: In the case where the first task to allocate the
                # ``.recv_ready`` event is cancelled we will be woken with
                # a non-incremented sequence number and thus will read the
                # oldest value if we use that. Instead we need to detect if
                # we have not been incremented and then receive again.
                return await self.receive()

            else:
                raise ValueError(f'Invalid sequence {seq}!?')