Beispiel #1
0
    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        print(f"--> task {task=} started")
        cancel_event = Event(TaskCancelEventName.format(task.key))
        # tell the client we are started
        start_event = Event(_DASK_EVENT_NAME)
        start_event.set()
        # sleep a bit in case someone is aborting us
        print("--> waiting for task to be aborted...")
        cancel_event.wait(timeout=10)
        if cancel_event.is_set():
            # NOTE: asyncio.CancelledError is not propagated back to the client...
            print("--> raising cancellation error now")
            raise TaskCancelledError

        return TaskOutputData.parse_obj({"some_output_key": 123})
Beispiel #2
0
    def _maybe_run(event_name: str, fn: Callable, *args: Any,
                   **kwargs: Any) -> Any:
        """Check if the task should run against a `distributed.Event` before
        starting the task. This offers stronger guarantees than distributed's
        current cancellation mechanism, which only cancels pending tasks."""
        import dask
        from distributed import Event, get_client

        try:
            # Explicitly pass in the timeout from dask's config_dict. Some versions of
            # distributed hardcode this rather than using the value from the
            # config_dict.  Can be removed once we bump our min requirements for
            # distributed to >= 2.31.0.
            timeout = dask.config.get("distributed.comm.timeouts.connect")
            event = Event(event_name, client=get_client(timeout=timeout))
            should_run = event.is_set()
        except Exception:
            # Failure to create an event is usually due to connection errors. These
            # are either due to flaky behavior in distributed's comms under high
            # loads, or due to the scheduler shutting down. Either way, the safest
            # course here is to assume we *should* run the task still. If we guess
            # wrong, we're either doing a bit of unnecessary work, or the cluster
            # is shutting down and the task will be cancelled anyway.
            should_run = True

        if should_run:
            return fn(*args, **kwargs)
async def test_pause_executor_with_memory_monitor(c, s, a):
    assert memory_monitor_running(a)
    mocked_rss = 0
    a.monitor.get_process_memory = lambda: mocked_rss

    # Task that is running when the worker pauses
    ev_x = Event()

    def f(ev):
        ev.wait()
        return 1

    # Task that is running on the worker when the worker pauses
    x = c.submit(f, ev_x, key="x")
    while a.executing_count != 1:
        await asyncio.sleep(0.01)

    with captured_logger(
            logging.getLogger("distributed.worker_memory")) as logger:
        # Task that is queued on the worker when the worker pauses
        y = c.submit(inc, 1, key="y")
        while "y" not in a.tasks:
            await asyncio.sleep(0.01)

        # Hog the worker with 900MB unmanaged memory
        mocked_rss = 900_000_000
        while s.workers[a.address].status != Status.paused:
            await asyncio.sleep(0.01)

        assert "Pausing worker" in logger.getvalue()

        # Task that is queued on the scheduler when the worker pauses.
        # It is not sent to the worker.
        z = c.submit(inc, 2, key="z")
        while "z" not in s.tasks or s.tasks["z"].state != "no-worker":
            await asyncio.sleep(0.01)
        assert s.unrunnable == {s.tasks["z"]}

        # Test that a task that already started when the worker paused can complete
        # and its output can be retrieved. Also test that the now free slot won't be
        # used by other tasks.
        await ev_x.set()
        assert await x == 1
        await asyncio.sleep(0.05)

        assert a.executing_count == 0
        assert len(a.ready) == 1
        assert a.tasks["y"].state == "ready"
        assert "z" not in a.tasks

        # Release the memory. Tasks that were queued on the worker are executed.
        # Tasks that were stuck on the scheduler are sent to the worker and executed.
        mocked_rss = 0
        assert await y == 2
        assert await z == 3

        assert a.status == Status.running
        assert "Resuming worker" in logger.getvalue()
Beispiel #4
0
async def test_timeout(c, s, a, b):
    # The event should not be set and the timeout should happen
    event = Event("x")
    assert not await Event("x").wait(timeout=0.1)

    await event.set()
    assert await Event("x").wait(timeout=0.1)

    await event.clear()
    assert not await Event("x").wait(timeout=0.1)
Beispiel #5
0
async def test_default_event(c, s, a, b):
    # The default flag for events should be false
    event = Event("x")
    assert not await event.is_set()

    await event.clear()

    # Cleanup should have happened
    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
Beispiel #6
0
async def test_event_types(c, s, a, b):
    # Event names could be strings, numbers or tuples
    for name in [1, ("a", 1), ["a", 1], b"123", "123"]:
        event = Event(name)
        assert event.name == name

        await event.set()
        await event.clear()
        result = await event.is_set()
        assert not result

    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
Beispiel #7
0
async def test_serializable(c, s, a, b):
    # Pickling an event should work properly
    def f(x, event=None):
        assert event.name == "x"
        return x + 1

    event = Event("x")
    futures = c.map(f, range(10), event=event)
    await c.gather(futures)

    event2 = pickle.loads(pickle.dumps(event))
    assert event2.name == event.name
    assert event2.client is event.client
Beispiel #8
0
    def _pre_start_yield(self) -> None:
        from distributed import Event

        is_inproc = self.client.scheduler.address.startswith("inproc")  # type: ignore
        if self.address is not None or is_inproc:
            self._futures = weakref.WeakSet()
            self._should_run_event = Event(
                f"prefect-{uuid.uuid4().hex}", client=self.client
            )
            self._should_run_event.set()

        self._watch_dask_events_task = asyncio.run_coroutine_threadsafe(
            self._watch_dask_events(), self.client.loop.asyncio_loop  # type: ignore
        )
async def test_pause_executor_manual(c, s, a):
    assert not memory_monitor_running(a)

    # Task that is running when the worker pauses
    ev_x = Event()

    def f(ev):
        ev.wait()
        return 1

    # Task that is running on the worker when the worker pauses
    x = c.submit(f, ev_x, key="x")
    while a.executing_count != 1:
        await asyncio.sleep(0.01)

    # Task that is queued on the worker when the worker pauses
    y = c.submit(inc, 1, key="y")
    while "y" not in a.tasks:
        await asyncio.sleep(0.01)

    a.status = Status.paused
    # Wait for sync to scheduler
    while s.workers[a.address].status != Status.paused:
        await asyncio.sleep(0.01)

    # Task that is queued on the scheduler when the worker pauses.
    # It is not sent to the worker.
    z = c.submit(inc, 2, key="z")
    while "z" not in s.tasks or s.tasks["z"].state != "no-worker":
        await asyncio.sleep(0.01)
    assert s.unrunnable == {s.tasks["z"]}

    # Test that a task that already started when the worker paused can complete
    # and its output can be retrieved. Also test that the now free slot won't be
    # used by other tasks.
    await ev_x.set()
    assert await x == 1
    await asyncio.sleep(0.05)

    assert a.executing_count == 0
    assert len(a.ready) == 1
    assert a.tasks["y"].state == "ready"
    assert "z" not in a.tasks

    # Unpause. Tasks that were queued on the worker are executed.
    # Tasks that were stuck on the scheduler are sent to the worker and executed.
    a.status = Status.running
    assert await y == 2
    assert await z == 3
Beispiel #10
0
 def fake_remote_fct(
     docker_auth: DockerBasicAuth,
     service_key: str,
     service_version: str,
     input_data: TaskInputData,
     output_data_keys: TaskOutputDataSchema,
     log_file_url: AnyUrl,
     command: List[str],
 ) -> TaskOutputData:
     # wait here until the client allows us to continue
     start_event = Event(_DASK_EVENT_NAME)
     start_event.wait(timeout=5)
     if fail_remote_fct:
         raise ValueError("We fail because we're told to!")
     return TaskOutputData.parse_obj({"some_output_key": 123})
Beispiel #11
0
    def _pre_start_yield(self) -> None:
        from distributed import Event

        is_inproc = self.client.scheduler.address.startswith(
            "inproc")  # type: ignore
        if (self.address is not None
                or is_inproc) and not self.disable_cancellation_event:
            self._futures = weakref.WeakSet()
            self._should_run_event = Event(f"prefect-{uuid.uuid4().hex}",
                                           client=self.client)
            self._should_run_event.set()

        if self.watch_worker_status is True or (
                self.watch_worker_status is None and not self.adapt_kwargs):
            self._watch_dask_events_task = asyncio.run_coroutine_threadsafe(
                self._watch_dask_events(),
                self.client.loop.asyncio_loop  # type: ignore
            )
Beispiel #12
0
async def test_set_not_set(c, s, a, b):
    # Set and unset the event and check if the flag is
    # propagated correctly
    event = Event("x")

    await event.clear()
    assert not await event.is_set()

    await event.set()
    assert await event.is_set()

    await event.set()
    assert await event.is_set()

    await event.clear()
    assert not await event.is_set()

    # Cleanup should have happened
    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
Beispiel #13
0
    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:

        state_pub = distributed.Pub(TaskStateEvent.topic_name())
        progress_pub = distributed.Pub(TaskProgressEvent.topic_name())
        logs_pub = distributed.Pub(TaskLogEvent.topic_name())
        state_pub.put("my name is state")
        progress_pub.put("my name is progress")
        logs_pub.put("my name is logs")
        # tell the client we are done
        published_event = Event(name=_DASK_START_EVENT)
        published_event.set()

        return TaskOutputData.parse_obj({"some_output_key": 123})
Beispiel #14
0
async def test_set_not_set_many_events(c, s, a, b):
    # Set and unset the event and check if the flag is
    # propagated correctly with many events
    events = [Event(name) for name in range(100)]

    for event in events:
        await event.clear()
        assert not await event.is_set()

    for i, event in enumerate(events):
        if i % 2 == 0:
            await event.set()
            assert await event.is_set()
        else:
            assert not await event.is_set()

    for event in events:
        await event.clear()
        assert not await event.is_set()

    # Cleanup should have happened
    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
Beispiel #15
0
 def event_not_set(event_name):
     assert not Event(event_name).wait(timeout=0.05)
Beispiel #16
0
 def event_is_set(event_name):
     assert Event(event_name).wait(timeout=0.5)
Beispiel #17
0
async def test_get_tasks_status(
    dask_client: DaskClient,
    user_id: UserID,
    project_id: ProjectID,
    cluster_id: ClusterID,
    cpu_image: ImageParams,
    mocked_node_ports: None,
    mocked_user_completed_cb: mock.AsyncMock,
    faker: Faker,
    fail_remote_fct: bool,
):
    # NOTE: this must be inlined so that the test works,
    # the dask-worker must be able to import the function
    _DASK_EVENT_NAME = faker.pystr()

    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:
        # wait here until the client allows us to continue
        start_event = Event(_DASK_EVENT_NAME)
        start_event.wait(timeout=5)
        if fail_remote_fct:
            raise ValueError("We fail because we're told to!")
        return TaskOutputData.parse_obj({"some_output_key": 123})

    node_id_to_job_ids = await dask_client.send_computation_tasks(
        user_id=user_id,
        project_id=project_id,
        cluster_id=cluster_id,
        tasks=cpu_image.fake_tasks,
        callback=mocked_user_completed_cb,
        remote_fct=fake_remote_fct,
    )
    assert node_id_to_job_ids
    assert len(node_id_to_job_ids) == 1
    node_id, job_id = node_id_to_job_ids[0]
    assert node_id in cpu_image.fake_tasks
    # let's get a dask future for the task here so dask will not remove the task from the scheduler at the end
    computation_future = distributed.Future(key=job_id)
    assert computation_future

    await _assert_wait_for_task_status(job_id, dask_client, RunningState.STARTED)

    # let the remote fct run through now
    start_event = Event(_DASK_EVENT_NAME, dask_client.dask_subsystem.client)
    await start_event.set()  # type: ignore
    # it will become successful hopefuly
    await _assert_wait_for_task_status(
        job_id,
        dask_client,
        RunningState.FAILED if fail_remote_fct else RunningState.SUCCESS,
    )
    # release the task results
    await dask_client.release_task_result(job_id)
    # the task is still present since we hold a future here
    await _assert_wait_for_task_status(
        job_id,
        dask_client,
        RunningState.FAILED if fail_remote_fct else RunningState.SUCCESS,
    )

    # removing the future will let dask eventually delete the task from its memory, so its status becomes undefined
    del computation_future
    await _assert_wait_for_task_status(
        job_id, dask_client, RunningState.UNKNOWN, timeout=60
    )
Beispiel #18
0
    def wait_for_it_failing(x):
        event = Event("x")

        # Event is not set in another task so far
        assert not event.wait(timeout=0.05)
        assert not event.is_set()
Beispiel #19
0
async def test_abort_computation_tasks(
    dask_client: DaskClient,
    user_id: UserID,
    project_id: ProjectID,
    cluster_id: ClusterID,
    image_params: ImageParams,
    mocked_node_ports: None,
    mocked_user_completed_cb: mock.AsyncMock,
    faker: Faker,
):
    _DASK_EVENT_NAME = faker.pystr()
    # NOTE: this must be inlined so that the test works,
    # the dask-worker must be able to import the function
    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        print(f"--> task {task=} started")
        cancel_event = Event(TaskCancelEventName.format(task.key))
        # tell the client we are started
        start_event = Event(_DASK_EVENT_NAME)
        start_event.set()
        # sleep a bit in case someone is aborting us
        print("--> waiting for task to be aborted...")
        cancel_event.wait(timeout=10)
        if cancel_event.is_set():
            # NOTE: asyncio.CancelledError is not propagated back to the client...
            print("--> raising cancellation error now")
            raise TaskCancelledError

        return TaskOutputData.parse_obj({"some_output_key": 123})

    node_id_to_job_ids = await dask_client.send_computation_tasks(
        user_id=user_id,
        project_id=project_id,
        cluster_id=cluster_id,
        tasks=image_params.fake_tasks,
        callback=mocked_user_completed_cb,
        remote_fct=fake_remote_fct,
    )
    assert node_id_to_job_ids
    assert len(node_id_to_job_ids) == 1
    node_id, job_id = node_id_to_job_ids[0]
    assert node_id in image_params.fake_tasks
    await _assert_wait_for_task_status(job_id, dask_client, RunningState.STARTED)

    # we wait to be sure the remote fct is started
    start_event = Event(_DASK_EVENT_NAME)
    await start_event.wait(timeout=10)  # type: ignore

    # now let's abort the computation
    await dask_client.abort_computation_task(job_id)
    await _assert_wait_for_cb_call(mocked_user_completed_cb)
    await _assert_wait_for_task_status(job_id, dask_client, RunningState.ABORTED)

    # getting the results should throw the cancellation error
    with pytest.raises(TaskCancelledError):
        await dask_client.get_task_result(job_id)

    # after releasing the results, the task shall be UNKNOWN
    await dask_client.release_task_result(job_id)
    await _assert_wait_for_task_status(
        job_id, dask_client, RunningState.UNKNOWN, timeout=120
    )
Beispiel #20
0
 def set_it():
     event = Event("x")
     event.set()
Beispiel #21
0
    def wait_for_it_ok(x):
        event = Event("x")

        # Event is set in another task
        assert event.wait(timeout=0.5)
        assert event.is_set()
Beispiel #22
0
 def clear_it():
     event = Event("x")
     event.clear()