async def test_publish_event(dask_client: distributed.Client): dask_pub = distributed.Pub("some_topic") dask_sub = distributed.Sub("some_topic") async for attempt in AsyncRetrying( reraise=True, retry=retry_if_exception_type(AssertionError), wait=wait_fixed(0.01), stop=stop_after_delay(60), ): with attempt: print( f"waiting for subscribers... attempt={attempt.retry_state.attempt_number}" ) assert dask_pub.subscribers print("we do have subscribers!") event_to_publish = TaskLogEvent(job_id="some_fake_job_id", log="the log") publish_event(dask_pub=dask_pub, event=event_to_publish) # NOTE: this tests runs a sync dask client, # and the CI seems to have sometimes difficulties having this run in a reasonable time # hence the long time out message = dask_sub.get(timeout=1) assert message is not None received_task_log_event = TaskLogEvent.parse_raw(message) # type: ignore assert received_task_log_event == event_to_publish
def test_task_is_aborted_using_pub(dask_client: distributed.Client): job_id = "myfake_job_id" future = dask_client.submit(_some_long_running_task, key=job_id) time.sleep(1) dask_pub = distributed.Pub(TaskCancelEvent.topic_name()) dask_pub.put(TaskCancelEvent(job_id=job_id).json()) result = future.result(timeout=2) assert result == -1
def test_monitor_task_abortion(dask_client: distributed.Client): job_id = "myfake_job_id" future = dask_client.submit(_some_long_running_task_with_monitoring, key=job_id) time.sleep(1) # trigger cancellation dask_pub = distributed.Pub(TaskCancelEvent.topic_name()) dask_pub.put(TaskCancelEvent(job_id=job_id).json()) result = future.result(timeout=10) assert result is None
def fake_remote_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], ) -> TaskOutputData: state_pub = distributed.Pub(TaskStateEvent.topic_name()) progress_pub = distributed.Pub(TaskProgressEvent.topic_name()) logs_pub = distributed.Pub(TaskLogEvent.topic_name()) state_pub.put("my name is state") progress_pub.put("my name is progress") logs_pub.put("my name is logs") # tell the client we are done published_event = Event(name=_DASK_START_EVENT) published_event.set() return TaskOutputData.parse_obj({"some_output_key": 123})
def __post_init__(self): self.state = distributed.Pub(TaskStateEvent.topic_name()) self.progress = distributed.Pub(TaskProgressEvent.topic_name()) self.logs = distributed.Pub(TaskLogEvent.topic_name())