def fake_remote_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], ) -> TaskOutputData: # get the task data worker = get_worker() task = worker.tasks.get(worker.get_current_task()) assert task is not None print(f"--> task {task=} started") cancel_event = Event(TaskCancelEventName.format(task.key)) # tell the client we are started start_event = Event(_DASK_EVENT_NAME) start_event.set() # sleep a bit in case someone is aborting us print("--> waiting for task to be aborted...") cancel_event.wait(timeout=10) if cancel_event.is_set(): # NOTE: asyncio.CancelledError is not propagated back to the client... print("--> raising cancellation error now") raise TaskCancelledError return TaskOutputData.parse_obj({"some_output_key": 123})
def _maybe_run(event_name: str, fn: Callable, *args: Any, **kwargs: Any) -> Any: """Check if the task should run against a `distributed.Event` before starting the task. This offers stronger guarantees than distributed's current cancellation mechanism, which only cancels pending tasks.""" import dask from distributed import Event, get_client try: # Explicitly pass in the timeout from dask's config_dict. Some versions of # distributed hardcode this rather than using the value from the # config_dict. Can be removed once we bump our min requirements for # distributed to >= 2.31.0. timeout = dask.config.get("distributed.comm.timeouts.connect") event = Event(event_name, client=get_client(timeout=timeout)) should_run = event.is_set() except Exception: # Failure to create an event is usually due to connection errors. These # are either due to flaky behavior in distributed's comms under high # loads, or due to the scheduler shutting down. Either way, the safest # course here is to assume we *should* run the task still. If we guess # wrong, we're either doing a bit of unnecessary work, or the cluster # is shutting down and the task will be cancelled anyway. should_run = True if should_run: return fn(*args, **kwargs)
async def test_pause_executor_with_memory_monitor(c, s, a): assert memory_monitor_running(a) mocked_rss = 0 a.monitor.get_process_memory = lambda: mocked_rss # Task that is running when the worker pauses ev_x = Event() def f(ev): ev.wait() return 1 # Task that is running on the worker when the worker pauses x = c.submit(f, ev_x, key="x") while a.executing_count != 1: await asyncio.sleep(0.01) with captured_logger( logging.getLogger("distributed.worker_memory")) as logger: # Task that is queued on the worker when the worker pauses y = c.submit(inc, 1, key="y") while "y" not in a.tasks: await asyncio.sleep(0.01) # Hog the worker with 900MB unmanaged memory mocked_rss = 900_000_000 while s.workers[a.address].status != Status.paused: await asyncio.sleep(0.01) assert "Pausing worker" in logger.getvalue() # Task that is queued on the scheduler when the worker pauses. # It is not sent to the worker. z = c.submit(inc, 2, key="z") while "z" not in s.tasks or s.tasks["z"].state != "no-worker": await asyncio.sleep(0.01) assert s.unrunnable == {s.tasks["z"]} # Test that a task that already started when the worker paused can complete # and its output can be retrieved. Also test that the now free slot won't be # used by other tasks. await ev_x.set() assert await x == 1 await asyncio.sleep(0.05) assert a.executing_count == 0 assert len(a.ready) == 1 assert a.tasks["y"].state == "ready" assert "z" not in a.tasks # Release the memory. Tasks that were queued on the worker are executed. # Tasks that were stuck on the scheduler are sent to the worker and executed. mocked_rss = 0 assert await y == 2 assert await z == 3 assert a.status == Status.running assert "Resuming worker" in logger.getvalue()
async def test_timeout(c, s, a, b): # The event should not be set and the timeout should happen event = Event("x") assert not await Event("x").wait(timeout=0.1) await event.set() assert await Event("x").wait(timeout=0.1) await event.clear() assert not await Event("x").wait(timeout=0.1)
async def test_default_event(c, s, a, b): # The default flag for events should be false event = Event("x") assert not await event.is_set() await event.clear() # Cleanup should have happened assert not s.extensions["events"]._events assert not s.extensions["events"]._waiter_count
async def test_event_types(c, s, a, b): # Event names could be strings, numbers or tuples for name in [1, ("a", 1), ["a", 1], b"123", "123"]: event = Event(name) assert event.name == name await event.set() await event.clear() result = await event.is_set() assert not result assert not s.extensions["events"]._events assert not s.extensions["events"]._waiter_count
async def test_serializable(c, s, a, b): # Pickling an event should work properly def f(x, event=None): assert event.name == "x" return x + 1 event = Event("x") futures = c.map(f, range(10), event=event) await c.gather(futures) event2 = pickle.loads(pickle.dumps(event)) assert event2.name == event.name assert event2.client is event.client
def _pre_start_yield(self) -> None: from distributed import Event is_inproc = self.client.scheduler.address.startswith("inproc") # type: ignore if self.address is not None or is_inproc: self._futures = weakref.WeakSet() self._should_run_event = Event( f"prefect-{uuid.uuid4().hex}", client=self.client ) self._should_run_event.set() self._watch_dask_events_task = asyncio.run_coroutine_threadsafe( self._watch_dask_events(), self.client.loop.asyncio_loop # type: ignore )
async def test_pause_executor_manual(c, s, a): assert not memory_monitor_running(a) # Task that is running when the worker pauses ev_x = Event() def f(ev): ev.wait() return 1 # Task that is running on the worker when the worker pauses x = c.submit(f, ev_x, key="x") while a.executing_count != 1: await asyncio.sleep(0.01) # Task that is queued on the worker when the worker pauses y = c.submit(inc, 1, key="y") while "y" not in a.tasks: await asyncio.sleep(0.01) a.status = Status.paused # Wait for sync to scheduler while s.workers[a.address].status != Status.paused: await asyncio.sleep(0.01) # Task that is queued on the scheduler when the worker pauses. # It is not sent to the worker. z = c.submit(inc, 2, key="z") while "z" not in s.tasks or s.tasks["z"].state != "no-worker": await asyncio.sleep(0.01) assert s.unrunnable == {s.tasks["z"]} # Test that a task that already started when the worker paused can complete # and its output can be retrieved. Also test that the now free slot won't be # used by other tasks. await ev_x.set() assert await x == 1 await asyncio.sleep(0.05) assert a.executing_count == 0 assert len(a.ready) == 1 assert a.tasks["y"].state == "ready" assert "z" not in a.tasks # Unpause. Tasks that were queued on the worker are executed. # Tasks that were stuck on the scheduler are sent to the worker and executed. a.status = Status.running assert await y == 2 assert await z == 3
def fake_remote_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], ) -> TaskOutputData: # wait here until the client allows us to continue start_event = Event(_DASK_EVENT_NAME) start_event.wait(timeout=5) if fail_remote_fct: raise ValueError("We fail because we're told to!") return TaskOutputData.parse_obj({"some_output_key": 123})
def _pre_start_yield(self) -> None: from distributed import Event is_inproc = self.client.scheduler.address.startswith( "inproc") # type: ignore if (self.address is not None or is_inproc) and not self.disable_cancellation_event: self._futures = weakref.WeakSet() self._should_run_event = Event(f"prefect-{uuid.uuid4().hex}", client=self.client) self._should_run_event.set() if self.watch_worker_status is True or ( self.watch_worker_status is None and not self.adapt_kwargs): self._watch_dask_events_task = asyncio.run_coroutine_threadsafe( self._watch_dask_events(), self.client.loop.asyncio_loop # type: ignore )
async def test_set_not_set(c, s, a, b): # Set and unset the event and check if the flag is # propagated correctly event = Event("x") await event.clear() assert not await event.is_set() await event.set() assert await event.is_set() await event.set() assert await event.is_set() await event.clear() assert not await event.is_set() # Cleanup should have happened assert not s.extensions["events"]._events assert not s.extensions["events"]._waiter_count
def fake_remote_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], ) -> TaskOutputData: state_pub = distributed.Pub(TaskStateEvent.topic_name()) progress_pub = distributed.Pub(TaskProgressEvent.topic_name()) logs_pub = distributed.Pub(TaskLogEvent.topic_name()) state_pub.put("my name is state") progress_pub.put("my name is progress") logs_pub.put("my name is logs") # tell the client we are done published_event = Event(name=_DASK_START_EVENT) published_event.set() return TaskOutputData.parse_obj({"some_output_key": 123})
async def test_set_not_set_many_events(c, s, a, b): # Set and unset the event and check if the flag is # propagated correctly with many events events = [Event(name) for name in range(100)] for event in events: await event.clear() assert not await event.is_set() for i, event in enumerate(events): if i % 2 == 0: await event.set() assert await event.is_set() else: assert not await event.is_set() for event in events: await event.clear() assert not await event.is_set() # Cleanup should have happened assert not s.extensions["events"]._events assert not s.extensions["events"]._waiter_count
def event_not_set(event_name): assert not Event(event_name).wait(timeout=0.05)
def event_is_set(event_name): assert Event(event_name).wait(timeout=0.5)
async def test_get_tasks_status( dask_client: DaskClient, user_id: UserID, project_id: ProjectID, cluster_id: ClusterID, cpu_image: ImageParams, mocked_node_ports: None, mocked_user_completed_cb: mock.AsyncMock, faker: Faker, fail_remote_fct: bool, ): # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function _DASK_EVENT_NAME = faker.pystr() def fake_remote_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], ) -> TaskOutputData: # wait here until the client allows us to continue start_event = Event(_DASK_EVENT_NAME) start_event.wait(timeout=5) if fail_remote_fct: raise ValueError("We fail because we're told to!") return TaskOutputData.parse_obj({"some_output_key": 123}) node_id_to_job_ids = await dask_client.send_computation_tasks( user_id=user_id, project_id=project_id, cluster_id=cluster_id, tasks=cpu_image.fake_tasks, callback=mocked_user_completed_cb, remote_fct=fake_remote_fct, ) assert node_id_to_job_ids assert len(node_id_to_job_ids) == 1 node_id, job_id = node_id_to_job_ids[0] assert node_id in cpu_image.fake_tasks # let's get a dask future for the task here so dask will not remove the task from the scheduler at the end computation_future = distributed.Future(key=job_id) assert computation_future await _assert_wait_for_task_status(job_id, dask_client, RunningState.STARTED) # let the remote fct run through now start_event = Event(_DASK_EVENT_NAME, dask_client.dask_subsystem.client) await start_event.set() # type: ignore # it will become successful hopefuly await _assert_wait_for_task_status( job_id, dask_client, RunningState.FAILED if fail_remote_fct else RunningState.SUCCESS, ) # release the task results await dask_client.release_task_result(job_id) # the task is still present since we hold a future here await _assert_wait_for_task_status( job_id, dask_client, RunningState.FAILED if fail_remote_fct else RunningState.SUCCESS, ) # removing the future will let dask eventually delete the task from its memory, so its status becomes undefined del computation_future await _assert_wait_for_task_status( job_id, dask_client, RunningState.UNKNOWN, timeout=60 )
def wait_for_it_failing(x): event = Event("x") # Event is not set in another task so far assert not event.wait(timeout=0.05) assert not event.is_set()
async def test_abort_computation_tasks( dask_client: DaskClient, user_id: UserID, project_id: ProjectID, cluster_id: ClusterID, image_params: ImageParams, mocked_node_ports: None, mocked_user_completed_cb: mock.AsyncMock, faker: Faker, ): _DASK_EVENT_NAME = faker.pystr() # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_remote_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], ) -> TaskOutputData: # get the task data worker = get_worker() task = worker.tasks.get(worker.get_current_task()) assert task is not None print(f"--> task {task=} started") cancel_event = Event(TaskCancelEventName.format(task.key)) # tell the client we are started start_event = Event(_DASK_EVENT_NAME) start_event.set() # sleep a bit in case someone is aborting us print("--> waiting for task to be aborted...") cancel_event.wait(timeout=10) if cancel_event.is_set(): # NOTE: asyncio.CancelledError is not propagated back to the client... print("--> raising cancellation error now") raise TaskCancelledError return TaskOutputData.parse_obj({"some_output_key": 123}) node_id_to_job_ids = await dask_client.send_computation_tasks( user_id=user_id, project_id=project_id, cluster_id=cluster_id, tasks=image_params.fake_tasks, callback=mocked_user_completed_cb, remote_fct=fake_remote_fct, ) assert node_id_to_job_ids assert len(node_id_to_job_ids) == 1 node_id, job_id = node_id_to_job_ids[0] assert node_id in image_params.fake_tasks await _assert_wait_for_task_status(job_id, dask_client, RunningState.STARTED) # we wait to be sure the remote fct is started start_event = Event(_DASK_EVENT_NAME) await start_event.wait(timeout=10) # type: ignore # now let's abort the computation await dask_client.abort_computation_task(job_id) await _assert_wait_for_cb_call(mocked_user_completed_cb) await _assert_wait_for_task_status(job_id, dask_client, RunningState.ABORTED) # getting the results should throw the cancellation error with pytest.raises(TaskCancelledError): await dask_client.get_task_result(job_id) # after releasing the results, the task shall be UNKNOWN await dask_client.release_task_result(job_id) await _assert_wait_for_task_status( job_id, dask_client, RunningState.UNKNOWN, timeout=120 )
def set_it(): event = Event("x") event.set()
def wait_for_it_ok(x): event = Event("x") # Event is set in another task assert event.wait(timeout=0.5) assert event.is_set()
def clear_it(): event = Event("x") event.clear()