def __post_init__(self) -> None:
     self.state_sub = distributed.Sub(TaskStateEvent.topic_name(),
                                      client=self.client)
     self.progress_sub = distributed.Sub(TaskProgressEvent.topic_name(),
                                         client=self.client)
     self.logs_sub = distributed.Sub(TaskLogEvent.topic_name(),
                                     client=self.client)
async def test_publish_event(dask_client: distributed.Client):
    dask_pub = distributed.Pub("some_topic")
    dask_sub = distributed.Sub("some_topic")
    async for attempt in AsyncRetrying(
            reraise=True,
            retry=retry_if_exception_type(AssertionError),
            wait=wait_fixed(0.01),
            stop=stop_after_delay(60),
    ):
        with attempt:
            print(
                f"waiting for subscribers... attempt={attempt.retry_state.attempt_number}"
            )
            assert dask_pub.subscribers
            print("we do have subscribers!")

    event_to_publish = TaskLogEvent(job_id="some_fake_job_id", log="the log")
    publish_event(dask_pub=dask_pub, event=event_to_publish)
    # NOTE: this tests runs a sync dask client,
    # and the CI seems to have sometimes difficulties having this run in a reasonable time
    # hence the long time out
    message = dask_sub.get(timeout=1)
    assert message is not None
    received_task_log_event = TaskLogEvent.parse_raw(message)  # type: ignore
    assert received_task_log_event == event_to_publish
def _some_long_running_task() -> int:
    dask_sub = distributed.Sub(TaskCancelEvent.topic_name())
    assert is_current_task_aborted(dask_sub) == False
    for i in range(300):
        print("running iteration", i)
        time.sleep(0.1)
        if is_current_task_aborted(dask_sub):
            print("task is aborted")
            return -1
    assert is_current_task_aborted(dask_sub)
    return 12
Example #4
0
async def dask_sub_consumer(
    task_event: DaskTaskEvents,
    handler: Callable[[str], Awaitable[None]],
    dask_client: distributed.Client,
):
    dask_sub = distributed.Sub(task_event.topic_name(), client=dask_client)
    async for dask_event in dask_sub:
        logger.debug(
            "received dask event '%s' of topic %s",
            dask_event,
            task_event.topic_name(),
        )
        await handler(dask_event)
Example #5
0
 async def periodicaly_check_if_aborted(task_name: str) -> None:
     try:
         logger.debug(
             "starting task to check for task cancellation for '%s'",
             f"{task_name=}")
         sub = distributed.Sub(TaskCancelEvent.topic_name())
         while await asyncio.sleep(_TASK_ABORTION_INTERVAL_CHECK_S,
                                   result=True):
             logger.debug("checking if task should be cancelled")
             if is_current_task_aborted(sub):
                 logger.debug("Task was aborted. Cancelling fct [%s]...",
                              f"{task_name=}")
                 await cancel_task(task_name)
     except asyncio.CancelledError:
         pass