async def test_cluster_dump_state(c, s, a, b, tmp_path): filename = tmp_path / "dump" futs = c.map(inc, range(2)) fut_keys = {f.key for f in futs} await c.gather(futs) event = distributed.Event() blocked_fut = c.submit(blocked_inc, 1, event) await asyncio.sleep(0.05) await c.dump_cluster_state(filename, format="msgpack") scheduler_tasks = list(s.tasks.values()) worker_tasks = [t for w in (a, b) for t in w.tasks.values()] smem_tasks = [t for t in scheduler_tasks if t.state == "memory"] wmem_tasks = [t for t in worker_tasks if t.state == "memory"] assert len(smem_tasks) == 2 assert len(wmem_tasks) == 2 sproc_tasks = [t for t in scheduler_tasks if t.state == "processing"] wproc_tasks = [t for t in worker_tasks if t.state == "executing"] assert len(sproc_tasks) == 1 assert len(wproc_tasks) == 1 await c.gather(event.set(), blocked_fut) dump = DumpArtefact.from_url(f"{filename}.msgpack.gz") smem_keys = {t["key"] for t in dump.scheduler_tasks_in_state("memory")} wmem_keys = {t["key"] for t in dump.worker_tasks_in_state("memory")} assert smem_keys == fut_keys assert smem_keys == {t.key for t in smem_tasks} assert wmem_keys == fut_keys assert wmem_keys == {t.key for t in wmem_tasks} sproc_keys = { t["key"] for t in dump.scheduler_tasks_in_state("processing") } wproc_keys = {t["key"] for t in dump.worker_tasks_in_state("executing")} assert sproc_keys == {t.key for t in sproc_tasks} assert wproc_keys == {t.key for t in wproc_tasks} sall_keys = {t["key"] for t in dump.scheduler_tasks_in_state()} wall_keys = {t["key"] for t in dump.worker_tasks_in_state()} assert fut_keys | {blocked_fut.key} == sall_keys assert fut_keys | {blocked_fut.key} == wall_keys # Mapping API works assert "transition_log" in dump["scheduler"] assert "log" in dump["workers"][a.address] assert len(dump) == 3
def is_current_task_aborted() -> bool: task: Optional[TaskState] = _get_current_task_state() logger.debug("found following TaskState: %s", task) if task is None: # the task was removed from the list of tasks this worker should work on, meaning it is aborted # NOTE: this does not work in distributed mode, hence we need to use Events, Variables,or PubSub logger.debug("%s shall be aborted", f"{task=}") return True # NOTE: in distributed mode an event is necessary! cancel_event = distributed.Event(name=TaskCancelEventName.format(task.key)) if cancel_event.is_set(): logger.debug("%s shall be aborted", f"{task=}") return True return False
async def test_cluster_dump_to_yamls(c, s, a, b, tmp_path): futs = c.map(inc, range(2)) await c.gather(futs) event = distributed.Event() blocked_fut = c.submit(blocked_inc, 1, event) filename = tmp_path / "dump" await asyncio.sleep(0.05) await c.dump_cluster_state(filename, format="msgpack") await event.set() await blocked_fut dump = DumpArtefact.from_url(f"{filename}.msgpack.gz") yaml_path = Path(tmp_path / "dump") dump.to_yamls(yaml_path) scheduler_files = { "events.yaml", "extensions.yaml", "general.yaml", "log.yaml", "task_groups.yaml", "tasks.yaml", "transition_log.yaml", "workers.yaml", } scheduler_yaml_path = yaml_path / "scheduler" expected = {scheduler_yaml_path / f for f in scheduler_files} assert expected == set(scheduler_yaml_path.iterdir()) worker_files = { "config.yaml", "general.yaml", "log.yaml", "logs.yaml", "tasks.yaml", } for worker in (a, b): worker_yaml_path = yaml_path / worker.id expected = {worker_yaml_path / f for f in worker_files} assert expected == set(worker_yaml_path.iterdir()) # Internal dictionary state compaction # has not been destructive of the original dictionary assert "id" in dump["scheduler"] assert "address" in dump["scheduler"]
def fake_sidecar_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], expected_annotations, ) -> TaskOutputData: # get the task data worker = get_worker() task = worker.tasks.get(worker.get_current_task()) assert task is not None assert task.annotations == expected_annotations assert command == ["run"] event = distributed.Event(_DASK_EVENT_NAME) event.wait(timeout=5) return TaskOutputData.parse_obj({"some_output_key": 123})
async def test_send_computation_task( dask_client: DaskClient, user_id: UserID, project_id: ProjectID, cluster_id: ClusterID, image_params: ImageParams, mocked_node_ports: None, mocked_user_completed_cb: mock.AsyncMock, faker: Faker, ): _DASK_EVENT_NAME = faker.pystr() # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_sidecar_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], expected_annotations, ) -> TaskOutputData: # get the task data worker = get_worker() task = worker.tasks.get(worker.get_current_task()) assert task is not None assert task.annotations == expected_annotations assert command == ["run"] event = distributed.Event(_DASK_EVENT_NAME) event.wait(timeout=5) return TaskOutputData.parse_obj({"some_output_key": 123}) # NOTE: We pass another fct so it can run in our localy created dask cluster node_id_to_job_ids = await dask_client.send_computation_tasks( user_id=user_id, project_id=project_id, cluster_id=cluster_id, tasks=image_params.fake_tasks, callback=mocked_user_completed_cb, remote_fct=functools.partial( fake_sidecar_fct, expected_annotations=image_params.expected_annotations ), ) assert node_id_to_job_ids assert len(node_id_to_job_ids) == 1 node_id, job_id = node_id_to_job_ids[0] assert node_id in image_params.fake_tasks # check status goes to PENDING/STARTED await _assert_wait_for_task_status( job_id, dask_client, expected_status=RunningState.STARTED ) # using the event we let the remote fct continue event = distributed.Event(_DASK_EVENT_NAME) await event.set() # type: ignore await _assert_wait_for_cb_call( mocked_user_completed_cb, timeout=_ALLOW_TIME_FOR_GATEWAY_TO_CREATE_WORKERS ) # check the task status await _assert_wait_for_task_status( job_id, dask_client, expected_status=RunningState.SUCCESS ) # check the results task_result = await dask_client.get_task_result(job_id) assert isinstance(task_result, TaskOutputData) assert task_result.get("some_output_key") == 123 # now release the results await dask_client.release_task_result(job_id) # check the status now await _assert_wait_for_task_status( job_id, dask_client, expected_status=RunningState.UNKNOWN, timeout=60 ) with pytest.raises(ComputationalBackendTaskNotFoundError): await dask_client.get_task_result(job_id)