Exemple #1
0
    def fake_sidecar_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
        expected_annotations: Dict[str, Any],
    ) -> TaskOutputData:
        sub = Sub(TaskCancelEvent.topic_name())
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        print(f"--> task {task=} started")
        assert task.annotations == expected_annotations
        # sleep a bit in case someone is aborting us
        print("--> waiting for task to be aborted...")
        for msg in sub:
            assert msg
            print(f"--> received cancellation msg: {msg=}")
            cancel_event = TaskCancelEvent.parse_raw(msg)  # type: ignore
            assert cancel_event
            if cancel_event.job_id == task.key:
                print("--> raising cancellation error now")
                raise asyncio.CancelledError("task cancelled")

        return TaskOutputData.parse_obj({"some_output_key": 123})
def test_task_is_aborted_using_pub(dask_client: distributed.Client):
    job_id = "myfake_job_id"
    future = dask_client.submit(_some_long_running_task, key=job_id)
    time.sleep(1)
    dask_pub = distributed.Pub(TaskCancelEvent.topic_name())
    dask_pub.put(TaskCancelEvent(job_id=job_id).json())

    result = future.result(timeout=2)
    assert result == -1
def test_monitor_task_abortion(dask_client: distributed.Client):
    job_id = "myfake_job_id"
    future = dask_client.submit(_some_long_running_task_with_monitoring,
                                key=job_id)
    time.sleep(1)
    # trigger cancellation
    dask_pub = distributed.Pub(TaskCancelEvent.topic_name())
    dask_pub.put(TaskCancelEvent(job_id=job_id).json())
    result = future.result(timeout=10)
    assert result is None
def _some_long_running_task() -> int:
    dask_sub = distributed.Sub(TaskCancelEvent.topic_name())
    assert is_current_task_aborted(dask_sub) == False
    for i in range(300):
        print("running iteration", i)
        time.sleep(0.1)
        if is_current_task_aborted(dask_sub):
            print("task is aborted")
            return -1
    assert is_current_task_aborted(dask_sub)
    return 12
def is_current_task_aborted(sub: distributed.Sub) -> bool:
    task: Optional[TaskState] = _get_current_task_state()
    logger.debug("found following TaskState: %s", task)
    if task is None:
        # the task was removed from the list of tasks this worker should work on, meaning it is aborted
        # NOTE: this does not work in distributed mode, hence we need to use Variables,or PubSub
        return True

    with suppress(asyncio.TimeoutError):
        msg = sub.get(timeout="100ms")
        if msg:
            cancel_event = TaskCancelEvent.parse_raw(msg)  # type: ignore
            return bool(cancel_event.job_id == task.key)
    return False
 async def abort_computation_tasks(self, job_ids: List[str]) -> None:
     # Dask future may be cancelled, but only a future that was not already taken by
     # a sidecar can be cancelled that way.
     # If the sidecar has already taken the job, then the cancellation must be user-defined.
     # therefore the dask PUB is used, and the dask-sidecar will then let the abort
     # process, and report when it is finished and properly cancelled.
     logger.debug("cancelling tasks with job_ids: [%s]", job_ids)
     for job_id in job_ids:
         task_future = self._taskid_to_future_map.get(job_id)
         if task_future:
             self.cancellation_dask_pub.put(  # type: ignore
                 TaskCancelEvent(job_id=job_id).json())
             await task_future.cancel()
             logger.debug("Dask task %s cancelled", task_future.key)
 async def periodicaly_check_if_aborted(task_name: str) -> None:
     try:
         logger.debug(
             "starting task to check for task cancellation for '%s'",
             f"{task_name=}")
         sub = distributed.Sub(TaskCancelEvent.topic_name())
         while await asyncio.sleep(_TASK_ABORTION_INTERVAL_CHECK_S,
                                   result=True):
             logger.debug("checking if task should be cancelled")
             if is_current_task_aborted(sub):
                 logger.debug("Task was aborted. Cancelling fct [%s]...",
                              f"{task_name=}")
                 await cancel_task(task_name)
     except asyncio.CancelledError:
         pass
 async def create(
     cls,
     app: FastAPI,
     settings: DaskSchedulerSettings,
     endpoint: AnyUrl,
     authentication: ClusterAuthentication,
 ) -> "DaskClient":
     logger.info(
         "Initiating connection to %s with auth: %s",
         f"dask-scheduler/gateway at {endpoint}",
         authentication,
     )
     async for attempt in AsyncRetrying(
             reraise=True,
             before_sleep=before_sleep_log(logger, logging.WARNING),
             wait=wait_fixed(0.3),
             stop=stop_after_attempt(3),
     ):
         with attempt:
             logger.debug(
                 "Connecting to %s, attempt %s...",
                 endpoint,
                 attempt.retry_state.attempt_number,
             )
             dask_subsystem = await _create_internal_client_based_on_auth(
                 endpoint, authentication)
             check_scheduler_status(dask_subsystem.client)
             instance = cls(
                 app=app,
                 dask_subsystem=dask_subsystem,
                 settings=settings,
                 cancellation_dask_pub=distributed.Pub(
                     TaskCancelEvent.topic_name(),
                     client=dask_subsystem.client),
             )
             logger.info(
                 "Connection to %s succeeded [%s]",
                 f"dask-scheduler/gateway at {endpoint}",
                 json.dumps(attempt.retry_state.retry_object.statistics),
             )
             logger.info(
                 "Scheduler info:\n%s",
                 json.dumps(dask_subsystem.client.scheduler_info(),
                            indent=2),
             )
             return instance
     # this is to satisfy pylance
     raise ValueError("Could not create client")