コード例 #1
0
    def _maybe_run(event_name: str, fn: Callable, *args: Any,
                   **kwargs: Any) -> Any:
        """Check if the task should run against a `distributed.Event` before
        starting the task. This offers stronger guarantees than distributed's
        current cancellation mechanism, which only cancels pending tasks."""
        import dask
        from distributed import Event, get_client

        try:
            # Explicitly pass in the timeout from dask's config_dict. Some versions of
            # distributed hardcode this rather than using the value from the
            # config_dict.  Can be removed once we bump our min requirements for
            # distributed to >= 2.31.0.
            timeout = dask.config.get("distributed.comm.timeouts.connect")
            event = Event(event_name, client=get_client(timeout=timeout))
            should_run = event.is_set()
        except Exception:
            # Failure to create an event is usually due to connection errors. These
            # are either due to flaky behavior in distributed's comms under high
            # loads, or due to the scheduler shutting down. Either way, the safest
            # course here is to assume we *should* run the task still. If we guess
            # wrong, we're either doing a bit of unnecessary work, or the cluster
            # is shutting down and the task will be cancelled anyway.
            should_run = True

        if should_run:
            return fn(*args, **kwargs)
コード例 #2
0
    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        print(f"--> task {task=} started")
        cancel_event = Event(TaskCancelEventName.format(task.key))
        # tell the client we are started
        start_event = Event(_DASK_EVENT_NAME)
        start_event.set()
        # sleep a bit in case someone is aborting us
        print("--> waiting for task to be aborted...")
        cancel_event.wait(timeout=10)
        if cancel_event.is_set():
            # NOTE: asyncio.CancelledError is not propagated back to the client...
            print("--> raising cancellation error now")
            raise TaskCancelledError

        return TaskOutputData.parse_obj({"some_output_key": 123})
コード例 #3
0
ファイル: dask.py プロジェクト: tank0226/prefect
    def _pre_start_yield(self) -> None:
        from distributed import Event

        is_inproc = self.client.scheduler.address.startswith("inproc")  # type: ignore
        if self.address is not None or is_inproc:
            self._futures = weakref.WeakSet()
            self._should_run_event = Event(
                f"prefect-{uuid.uuid4().hex}", client=self.client
            )
            self._should_run_event.set()

        self._watch_dask_events_task = asyncio.run_coroutine_threadsafe(
            self._watch_dask_events(), self.client.loop.asyncio_loop  # type: ignore
        )
コード例 #4
0
 def fake_remote_fct(
     docker_auth: DockerBasicAuth,
     service_key: str,
     service_version: str,
     input_data: TaskInputData,
     output_data_keys: TaskOutputDataSchema,
     log_file_url: AnyUrl,
     command: List[str],
 ) -> TaskOutputData:
     # wait here until the client allows us to continue
     start_event = Event(_DASK_EVENT_NAME)
     start_event.wait(timeout=5)
     if fail_remote_fct:
         raise ValueError("We fail because we're told to!")
     return TaskOutputData.parse_obj({"some_output_key": 123})
コード例 #5
0
async def test_pause_executor_with_memory_monitor(c, s, a):
    assert memory_monitor_running(a)
    mocked_rss = 0
    a.monitor.get_process_memory = lambda: mocked_rss

    # Task that is running when the worker pauses
    ev_x = Event()

    def f(ev):
        ev.wait()
        return 1

    # Task that is running on the worker when the worker pauses
    x = c.submit(f, ev_x, key="x")
    while a.executing_count != 1:
        await asyncio.sleep(0.01)

    with captured_logger(
            logging.getLogger("distributed.worker_memory")) as logger:
        # Task that is queued on the worker when the worker pauses
        y = c.submit(inc, 1, key="y")
        while "y" not in a.tasks:
            await asyncio.sleep(0.01)

        # Hog the worker with 900MB unmanaged memory
        mocked_rss = 900_000_000
        while s.workers[a.address].status != Status.paused:
            await asyncio.sleep(0.01)

        assert "Pausing worker" in logger.getvalue()

        # Task that is queued on the scheduler when the worker pauses.
        # It is not sent to the worker.
        z = c.submit(inc, 2, key="z")
        while "z" not in s.tasks or s.tasks["z"].state != "no-worker":
            await asyncio.sleep(0.01)
        assert s.unrunnable == {s.tasks["z"]}

        # Test that a task that already started when the worker paused can complete
        # and its output can be retrieved. Also test that the now free slot won't be
        # used by other tasks.
        await ev_x.set()
        assert await x == 1
        await asyncio.sleep(0.05)

        assert a.executing_count == 0
        assert len(a.ready) == 1
        assert a.tasks["y"].state == "ready"
        assert "z" not in a.tasks

        # Release the memory. Tasks that were queued on the worker are executed.
        # Tasks that were stuck on the scheduler are sent to the worker and executed.
        mocked_rss = 0
        assert await y == 2
        assert await z == 3

        assert a.status == Status.running
        assert "Resuming worker" in logger.getvalue()
コード例 #6
0
    def _pre_start_yield(self) -> None:
        from distributed import Event

        is_inproc = self.client.scheduler.address.startswith(
            "inproc")  # type: ignore
        if (self.address is not None
                or is_inproc) and not self.disable_cancellation_event:
            self._futures = weakref.WeakSet()
            self._should_run_event = Event(f"prefect-{uuid.uuid4().hex}",
                                           client=self.client)
            self._should_run_event.set()

        if self.watch_worker_status is True or (
                self.watch_worker_status is None and not self.adapt_kwargs):
            self._watch_dask_events_task = asyncio.run_coroutine_threadsafe(
                self._watch_dask_events(),
                self.client.loop.asyncio_loop  # type: ignore
            )
コード例 #7
0
async def test_default_event(c, s, a, b):
    # The default flag for events should be false
    event = Event("x")
    assert not await event.is_set()

    await event.clear()

    # Cleanup should have happened
    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
コード例 #8
0
async def test_timeout(c, s, a, b):
    # The event should not be set and the timeout should happen
    event = Event("x")
    assert not await Event("x").wait(timeout=0.1)

    await event.set()
    assert await Event("x").wait(timeout=0.1)

    await event.clear()
    assert not await Event("x").wait(timeout=0.1)
コード例 #9
0
    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:

        state_pub = distributed.Pub(TaskStateEvent.topic_name())
        progress_pub = distributed.Pub(TaskProgressEvent.topic_name())
        logs_pub = distributed.Pub(TaskLogEvent.topic_name())
        state_pub.put("my name is state")
        progress_pub.put("my name is progress")
        logs_pub.put("my name is logs")
        # tell the client we are done
        published_event = Event(name=_DASK_START_EVENT)
        published_event.set()

        return TaskOutputData.parse_obj({"some_output_key": 123})
コード例 #10
0
async def test_serializable(c, s, a, b):
    # Pickling an event should work properly
    def f(x, event=None):
        assert event.name == "x"
        return x + 1

    event = Event("x")
    futures = c.map(f, range(10), event=event)
    await c.gather(futures)

    event2 = pickle.loads(pickle.dumps(event))
    assert event2.name == event.name
    assert event2.client is event.client
コード例 #11
0
async def test_event_types(c, s, a, b):
    # Event names could be strings, numbers or tuples
    for name in [1, ("a", 1), ["a", 1], b"123", "123"]:
        event = Event(name)
        assert event.name == name

        await event.set()
        await event.clear()
        result = await event.is_set()
        assert not result

    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
コード例 #12
0
async def test_pause_executor_manual(c, s, a):
    assert not memory_monitor_running(a)

    # Task that is running when the worker pauses
    ev_x = Event()

    def f(ev):
        ev.wait()
        return 1

    # Task that is running on the worker when the worker pauses
    x = c.submit(f, ev_x, key="x")
    while a.executing_count != 1:
        await asyncio.sleep(0.01)

    # Task that is queued on the worker when the worker pauses
    y = c.submit(inc, 1, key="y")
    while "y" not in a.tasks:
        await asyncio.sleep(0.01)

    a.status = Status.paused
    # Wait for sync to scheduler
    while s.workers[a.address].status != Status.paused:
        await asyncio.sleep(0.01)

    # Task that is queued on the scheduler when the worker pauses.
    # It is not sent to the worker.
    z = c.submit(inc, 2, key="z")
    while "z" not in s.tasks or s.tasks["z"].state != "no-worker":
        await asyncio.sleep(0.01)
    assert s.unrunnable == {s.tasks["z"]}

    # Test that a task that already started when the worker paused can complete
    # and its output can be retrieved. Also test that the now free slot won't be
    # used by other tasks.
    await ev_x.set()
    assert await x == 1
    await asyncio.sleep(0.05)

    assert a.executing_count == 0
    assert len(a.ready) == 1
    assert a.tasks["y"].state == "ready"
    assert "z" not in a.tasks

    # Unpause. Tasks that were queued on the worker are executed.
    # Tasks that were stuck on the scheduler are sent to the worker and executed.
    a.status = Status.running
    assert await y == 2
    assert await z == 3
コード例 #13
0
async def test_set_not_set(c, s, a, b):
    # Set and unset the event and check if the flag is
    # propagated correctly
    event = Event("x")

    await event.clear()
    assert not await event.is_set()

    await event.set()
    assert await event.is_set()

    await event.set()
    assert await event.is_set()

    await event.clear()
    assert not await event.is_set()

    # Cleanup should have happened
    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
コード例 #14
0
async def test_set_not_set_many_events(c, s, a, b):
    # Set and unset the event and check if the flag is
    # propagated correctly with many events
    events = [Event(name) for name in range(100)]

    for event in events:
        await event.clear()
        assert not await event.is_set()

    for i, event in enumerate(events):
        if i % 2 == 0:
            await event.set()
            assert await event.is_set()
        else:
            assert not await event.is_set()

    for event in events:
        await event.clear()
        assert not await event.is_set()

    # Cleanup should have happened
    assert not s.extensions["events"]._events
    assert not s.extensions["events"]._waiter_count
コード例 #15
0
    def wait_for_it_ok(x):
        event = Event("x")

        # Event is set in another task
        assert event.wait(timeout=0.5)
        assert event.is_set()
コード例 #16
0
 def event_is_set(event_name):
     assert Event(event_name).wait(timeout=0.5)
コード例 #17
0
 def set_it():
     event = Event("x")
     event.set()
コード例 #18
0
 def clear_it():
     event = Event("x")
     event.clear()
コード例 #19
0
async def test_get_tasks_status(
    dask_client: DaskClient,
    user_id: UserID,
    project_id: ProjectID,
    cluster_id: ClusterID,
    cpu_image: ImageParams,
    mocked_node_ports: None,
    mocked_user_completed_cb: mock.AsyncMock,
    faker: Faker,
    fail_remote_fct: bool,
):
    # NOTE: this must be inlined so that the test works,
    # the dask-worker must be able to import the function
    _DASK_EVENT_NAME = faker.pystr()

    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:
        # wait here until the client allows us to continue
        start_event = Event(_DASK_EVENT_NAME)
        start_event.wait(timeout=5)
        if fail_remote_fct:
            raise ValueError("We fail because we're told to!")
        return TaskOutputData.parse_obj({"some_output_key": 123})

    node_id_to_job_ids = await dask_client.send_computation_tasks(
        user_id=user_id,
        project_id=project_id,
        cluster_id=cluster_id,
        tasks=cpu_image.fake_tasks,
        callback=mocked_user_completed_cb,
        remote_fct=fake_remote_fct,
    )
    assert node_id_to_job_ids
    assert len(node_id_to_job_ids) == 1
    node_id, job_id = node_id_to_job_ids[0]
    assert node_id in cpu_image.fake_tasks
    # let's get a dask future for the task here so dask will not remove the task from the scheduler at the end
    computation_future = distributed.Future(key=job_id)
    assert computation_future

    await _assert_wait_for_task_status(job_id, dask_client, RunningState.STARTED)

    # let the remote fct run through now
    start_event = Event(_DASK_EVENT_NAME, dask_client.dask_subsystem.client)
    await start_event.set()  # type: ignore
    # it will become successful hopefuly
    await _assert_wait_for_task_status(
        job_id,
        dask_client,
        RunningState.FAILED if fail_remote_fct else RunningState.SUCCESS,
    )
    # release the task results
    await dask_client.release_task_result(job_id)
    # the task is still present since we hold a future here
    await _assert_wait_for_task_status(
        job_id,
        dask_client,
        RunningState.FAILED if fail_remote_fct else RunningState.SUCCESS,
    )

    # removing the future will let dask eventually delete the task from its memory, so its status becomes undefined
    del computation_future
    await _assert_wait_for_task_status(
        job_id, dask_client, RunningState.UNKNOWN, timeout=60
    )
コード例 #20
0
    def wait_for_it_failing(x):
        event = Event("x")

        # Event is not set in another task so far
        assert not event.wait(timeout=0.05)
        assert not event.is_set()
コード例 #21
0
async def test_abort_computation_tasks(
    dask_client: DaskClient,
    user_id: UserID,
    project_id: ProjectID,
    cluster_id: ClusterID,
    image_params: ImageParams,
    mocked_node_ports: None,
    mocked_user_completed_cb: mock.AsyncMock,
    faker: Faker,
):
    _DASK_EVENT_NAME = faker.pystr()
    # NOTE: this must be inlined so that the test works,
    # the dask-worker must be able to import the function
    def fake_remote_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        print(f"--> task {task=} started")
        cancel_event = Event(TaskCancelEventName.format(task.key))
        # tell the client we are started
        start_event = Event(_DASK_EVENT_NAME)
        start_event.set()
        # sleep a bit in case someone is aborting us
        print("--> waiting for task to be aborted...")
        cancel_event.wait(timeout=10)
        if cancel_event.is_set():
            # NOTE: asyncio.CancelledError is not propagated back to the client...
            print("--> raising cancellation error now")
            raise TaskCancelledError

        return TaskOutputData.parse_obj({"some_output_key": 123})

    node_id_to_job_ids = await dask_client.send_computation_tasks(
        user_id=user_id,
        project_id=project_id,
        cluster_id=cluster_id,
        tasks=image_params.fake_tasks,
        callback=mocked_user_completed_cb,
        remote_fct=fake_remote_fct,
    )
    assert node_id_to_job_ids
    assert len(node_id_to_job_ids) == 1
    node_id, job_id = node_id_to_job_ids[0]
    assert node_id in image_params.fake_tasks
    await _assert_wait_for_task_status(job_id, dask_client, RunningState.STARTED)

    # we wait to be sure the remote fct is started
    start_event = Event(_DASK_EVENT_NAME)
    await start_event.wait(timeout=10)  # type: ignore

    # now let's abort the computation
    await dask_client.abort_computation_task(job_id)
    await _assert_wait_for_cb_call(mocked_user_completed_cb)
    await _assert_wait_for_task_status(job_id, dask_client, RunningState.ABORTED)

    # getting the results should throw the cancellation error
    with pytest.raises(TaskCancelledError):
        await dask_client.get_task_result(job_id)

    # after releasing the results, the task shall be UNKNOWN
    await dask_client.release_task_result(job_id)
    await _assert_wait_for_task_status(
        job_id, dask_client, RunningState.UNKNOWN, timeout=120
    )
コード例 #22
0
class DaskExecutor(Executor):
    """
    An executor that runs all functions using the `dask.distributed` scheduler.

    Check https://docs.dask.org/en/latest/setup.html for all kinds of cluster types.

    By default a temporary `distributed.LocalCluster` is created (and
    subsequently torn down) within the `start_loop()` contextmanager. To use a
    different cluster class (e.g.
    [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can
    specify `cluster_class`/`cluster_kwargs`.

    Alternatively, if you already have a dask cluster _running, you can provide
    the address of the scheduler via the `address` kwarg.

    Note that if you have tasks with tags of the form `"dask-accum_resource:KEY=NUM"`
    they will be parsed and passed as
    [Worker Resources](https://distributed.dask.org/en/latest/resources.html)
    of the form `{"KEY": float(NUM)}` to the Dask TaskScheduler.

    Args:
        - address (string, optional): address of a currently _running dask
            scheduler; if one is not provided, a temporary cluster will be
            created in `executor.start_loop()`.  Defaults to `None`.
        - cluster_class (string or callable, optional): the cluster class to use
            when creating a temporary dask cluster. Can be either the full
            class name (e.g. `"distributed.LocalCluster"`), or the class itself.
        - cluster_kwargs (dict, optional): addtional kwargs to pass to the
           `cluster_class` when creating a temporary dask cluster.
        - adapt_kwargs (dict, optional): additional kwargs to pass to `cluster.adapt`
            when creating a temporary dask cluster. Note that adaptive scaling
            is only enabled if `adapt_kwargs` are provided.
        - client_kwargs (dict, optional): additional kwargs to use when creating a
            [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client).
        - debug (bool, optional): When _running with a local cluster, setting
            `debug=True` will increase dask's logging level, providing
            potentially useful debug info. Defaults to the `debug` value in
            your Prefect configuration.

    Examples:

    Using a temporary local dask cluster:

    ```python
    executor = DaskExecutor()
    ```

    Using a temporary cluster _running elsewhere. Any Dask cluster class should
    work, here we use [dask-cloudprovider](https://cloudprovider.dask.org):

    ```python
    executor = DaskExecutor(
        cluster_class="dask_cloudprovider.FargateCluster",
        cluster_kwargs={
            "image": "prefecthq/prefect:latest",
            "n_workers": 5,
            ...
        },
    )
    ```

    Connecting to an existing dask cluster

    ```python
    executor = DaskExecutor(address="192.0.2.255:8786")
    ```
    """
    def __init__(self,
                 address: str = None,
                 cluster_class: Union[str, Callable] = None,
                 cluster_kwargs: dict = None,
                 adapt_kwargs: dict = None,
                 client_kwargs: dict = None,
                 debug: bool = False,
                 **kwargs):
        super().__init__(**kwargs)
        if address is not None and (cluster_class is not None
                                    or cluster_kwargs is not None):
            raise ValueError(
                "Cannot specify both `address` and `cluster_class`/`cluster_kwargs`"
            )
        from distributed import Client
        from distributed.deploy.local import LocalCluster
        if isinstance(cluster_class, str):
            cluster_class = import_object(cluster_class)
        elif not cluster_class:
            cluster_class = LocalCluster

        self.cluster_class = cluster_class

        self.cluster_kwargs = {} if not cluster_kwargs else cluster_kwargs.copy(
        )
        if cluster_class == LocalCluster:
            self.cluster_kwargs.setdefault(
                'silence_logs',
                logging.CRITICAL if not debug else logging.WARNING)

        self.adapt_kwargs = {} if not adapt_kwargs else adapt_kwargs.copy()
        self.client_kwargs = {} if not client_kwargs else client_kwargs.copy()
        self.client_kwargs.setdefault('set_as_default', False)
        self.client: Optional[Client] = None
        self.cluster: Optional[str, LocalCluster] = None

        self.address = address
        self._futures = None
        self._should_run_event = None
        self._watch_dask_events_task = None

    async def __aenter__(self):
        from distributed import Client
        if self.address is None:
            self.cluster = self.cluster_class(**self.cluster_kwargs,
                                              asynchronous=True)
            await self.cluster.__aenter__()
            if self.adapt_kwargs:
                self.cluster.adapt(**self.adapt_kwargs)
        else:
            self.cluster = self.address
        self.client = Client(self.cluster,
                             **self.client_kwargs,
                             asynchronous=True)
        await self.client.__aenter__()
        self._pre_start_yield()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self._post_start_yield()
        if not isinstance(self.cluster, str):
            await self.cluster.__aexit__(exc_type, exc_val, exc_tb)
        await self.client.__aexit__(exc_type, exc_val, exc_tb)

    async def run(self,
                  fn: Callable,
                  *args: Any,
                  extra_context: dict = None,
                  **kwargs: Any) -> Future:
        """
        Submit a function to the executor for execution. Returns a Future object.

        Args:
            - fn (Callable): function that is being submitted for execution
            - *args (Any): arguments to be passed to `fn`
            - extra_context (dict, optional): an optional dictionary with extra information
                about the submitted task
            - **kwargs (Any): keyword arguments to be passed to `fn`

        Returns:
            - Future: a Future-like object that represents the computation of `fn(*args, **kwargs)`
        """
        if self.client is None:
            raise ValueError("This executor has not been started.")

        kwargs.update(self._prep_dask_kwargs(extra_context))
        if self._should_run_event is None:
            fut = self.client.submit(fn, *args, **kwargs)
        else:
            fut = self.client.submit(self._maybe_run,
                                     self._should_run_event.name, fn, *args,
                                     **kwargs)
        res = await fut
        return res

    async def _watch_dask_events(self) -> None:
        scheduler_comm = None
        comm = None
        from distributed.core import rpc

        try:
            scheduler_comm = rpc(
                self.client.scheduler.address,  # type: ignore
                connection_args=self.client.security.get_connection_args(
                    "client"),  # type: ignore
            )
            # due to a bug in distributed's inproc comms, letting cancellation
            # bubble up here will kill the listener. wrap with a shield to
            # prevent that.
            comm = await asyncio.shield(scheduler_comm.live_comm())
            await comm.write({"op": "subscribe_worker_status"})
            _ = await comm.read()
            while True:
                try:
                    msgs = await comm.read()
                except OSError:
                    break
                for op, msg in msgs:
                    if op == "add":
                        for worker in msg.get("workers", ()):
                            flowsaber.context.logger.debug(
                                "Worker %s added", worker)
                    elif op == "remove":
                        flowsaber.context.logger.debug("Worker %s removed",
                                                       msg)
        except asyncio.CancelledError:
            pass
        except Exception:
            flowsaber.context.logger.debug(
                "Failure while watching dask worker events", exc_info=True)
        finally:
            if comm is not None:
                try:
                    await comm.close()
                except Exception:
                    pass
            if scheduler_comm is not None:
                scheduler_comm.close_rpc()

    def _pre_start_yield(self) -> None:
        from distributed import Event

        is_inproc = self.client.scheduler.address.startswith(
            "inproc")  # type: ignore
        if self.address is not None or is_inproc:
            self._futures = weakref.WeakSet()
            self._should_run_event = Event(f"flowsaber-{uuid.uuid4().hex}",
                                           client=self.client)
            self._should_run_event.set()

        self._watch_dask_events_task = asyncio.run_coroutine_threadsafe(
            self._watch_dask_events(),
            self.client.loop.asyncio_loop  # type: ignore
        )

    def _post_start_yield(self) -> None:
        from distributed import wait

        if self._watch_dask_events_task is not None:
            try:
                self._watch_dask_events_task.cancel()
            except Exception:
                pass
            self._watch_dask_events_task = None

        if self._should_run_event is not None:
            # Multipart cleanup, ignoring exceptions in each stage
            # 1.) Stop pending tasks from starting
            try:
                self._should_run_event.clear()
            except Exception:
                pass
            # 2.) Wait for all _running tasks to complete
            try:
                futures = [f for f in list(self._futures)
                           if not f.done()]  # type: ignore
                if futures:
                    flowsaber.context.logger.info(
                        "Stopping executor, waiting for %d active tasks to complete",
                        len(futures),
                    )
                    wait(futures)
            except Exception:
                pass
        self._should_run_event = None
        self._futures = None

    def _prep_dask_kwargs(self, extra_context: dict = None) -> dict:
        if extra_context is None:
            extra_context = {}

        dask_kwargs = {"pure": False}  # type: dict

        # set a key for the dask scheduler UI
        key = self._make_task_key(**extra_context)
        if key is not None:
            dask_kwargs["key"] = key

        # infer from context if dask resources are being utilized
        task_tags = extra_context.get("task_tags", [])
        dask_resource_tags = [
            tag for tag in task_tags
            if tag.lower().startswith("dask-accum_resource")
        ]
        if dask_resource_tags:
            resources = {}
            for tag in dask_resource_tags:
                prefix, val = tag.split("=")
                resources.update({prefix.split(":")[1]: float(val)})
            dask_kwargs.update(resources=resources)

        return dask_kwargs

    def __getstate__(self) -> dict:
        state = self.__dict__.copy()
        state.update({
            k: None
            for k in [
                "client", "_futures", "_should_run_event",
                "_watch_dask_events_task"
            ]
        })
        return state

    def __setstate__(self, state: dict) -> None:
        self.__dict__.update(state)

    @staticmethod
    def _make_task_key(task_name: str = "",
                       task_index: int = None,
                       **kwargs: Any) -> Optional[str]:
        """A helper for generating a dask task key from field set in `extra_context`"""
        if task_name:
            suffix = uuid.uuid4().hex
            if task_index is not None:
                return f"{task_name}-{task_index}-{suffix}"
            return f"{task_name}-{suffix}"
        return None

    @staticmethod
    def _maybe_run(event_name: str, fn: Callable, *args: Any,
                   **kwargs: Any) -> Any:
        """Check if the task should run against a `distributed.Event` before
        starting the task. This offers stronger guarantees than distributed's
        current cancellation mechanism, which only cancels pending tasks."""
        import dask
        from distributed import Event, get_client

        try:
            # Explicitly pass in the timeout from dask's config_dict. Some versions of
            # distributed hardcode this rather than using the value from the
            # config_dict.  Can be removed once we bump our min requirements for
            # distributed to >= 2.31.0.
            timeout = dask.config.get("distributed.comm.timeouts.connect")
            event = Event(event_name, client=get_client(timeout=timeout))
            should_run = event.is_set()
        except Exception:
            # Failure to create an event is usually due to connection errors. These
            # are either due to flaky behavior in distributed's comms under high
            # loads, or due to the scheduler shutting down. Either way, the safest
            # course here is to assume we *should* run the task still. If we guess
            # wrong, we're either doing a bit of unnecessary work, or the cluster
            # is shutting down and the task will be cancelled anyway.
            should_run = True

        if should_run:
            return fn(*args, **kwargs)
コード例 #23
0
 def event_not_set(event_name):
     assert not Event(event_name).wait(timeout=0.05)
コード例 #24
0
class DaskExecutor(Executor):
    """
    An executor that runs all functions using the `dask.distributed` scheduler.

    By default a temporary `distributed.LocalCluster` is created (and
    subsequently torn down) within the `start()` contextmanager. To use a
    different cluster class (e.g.
    [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can
    specify `cluster_class`/`cluster_kwargs`.

    Alternatively, if you already have a dask cluster running, you can provide
    the address of the scheduler via the `address` kwarg.

    Note that if you have tasks with tags of the form `"dask-resource:KEY=NUM"`
    they will be parsed and passed as
    [Worker Resources](https://distributed.dask.org/en/latest/resources.html)
    of the form `{"KEY": float(NUM)}` to the Dask Scheduler.

    Args:
        - address (string, optional): address of a currently running dask
            scheduler; if one is not provided, a temporary cluster will be
            created in `executor.start()`.  Defaults to `None`.
        - cluster_class (string or callable, optional): the cluster class to use
            when creating a temporary dask cluster. Can be either the full
            class name (e.g. `"distributed.LocalCluster"`), or the class itself.
        - cluster_kwargs (dict, optional): addtional kwargs to pass to the
           `cluster_class` when creating a temporary dask cluster.
        - adapt_kwargs (dict, optional): additional kwargs to pass to `cluster.adapt`
            when creating a temporary dask cluster. Note that adaptive scaling
            is only enabled if `adapt_kwargs` are provided.
        - client_kwargs (dict, optional): additional kwargs to use when creating a
            [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client).
        - debug (bool, optional): When running with a local cluster, setting
            `debug=True` will increase dask's logging level, providing
            potentially useful debug info. Defaults to the `debug` value in
            your Prefect configuration.
        - performance_report_path (str, optional): An optional path for the [dask performance
            report](https://distributed.dask.org/en/latest/api.html#distributed.performance_report).
        - disable_cancellation_event (bool, optional): By default, Prefect uses a
            Dask event to allow for better cancellation of task runs. Sometimes this
            can cause strain on the scheduler as each task needs to retrieve a client
            to check the status of the cancellation event. If set to `False`, we will
            skip this check.

    Examples:

    Using a temporary local dask cluster:

    ```python
    executor = DaskExecutor()
    ```

    Using a temporary cluster running elsewhere. Any Dask cluster class should
    work, here we use [dask-cloudprovider](https://cloudprovider.dask.org):

    ```python
    executor = DaskExecutor(
        cluster_class="dask_cloudprovider.FargateCluster",
        cluster_kwargs={
            "image": "prefecthq/prefect:latest",
            "n_workers": 5,
            ...
        },
    )
    ```

    Connecting to an existing dask cluster

    ```python
    executor = DaskExecutor(address="192.0.2.255:8786")
    ```
    """
    def __init__(
        self,
        address: str = None,
        cluster_class: Union[str, Callable] = None,
        cluster_kwargs: dict = None,
        adapt_kwargs: dict = None,
        client_kwargs: dict = None,
        debug: bool = None,
        performance_report_path: str = None,
        disable_cancellation_event: bool = False,
    ):
        if address is None:
            address = context.config.engine.executor.dask.address or None

        if address is not None:
            if cluster_class is not None or cluster_kwargs is not None:
                raise ValueError(
                    "Cannot specify `address` and `cluster_class`/`cluster_kwargs`"
                )
        else:
            if cluster_class is None:
                cluster_class = context.config.engine.executor.dask.cluster_class
            if isinstance(cluster_class, str):
                cluster_class = import_object(cluster_class)
            if cluster_kwargs is None:
                cluster_kwargs = {}
            else:
                cluster_kwargs = cluster_kwargs.copy()

            from distributed.deploy.local import LocalCluster

            if cluster_class == LocalCluster:
                if debug is None:
                    debug = context.config.debug
                cluster_kwargs.setdefault(
                    "silence_logs",
                    logging.CRITICAL if not debug else logging.WARNING)

            if adapt_kwargs is None:
                adapt_kwargs = {}

        if client_kwargs is None:
            client_kwargs = {}
        else:
            client_kwargs = client_kwargs.copy()
        client_kwargs.setdefault("set_as_default", False)

        self.address = address
        self.cluster_class = cluster_class
        self.cluster_kwargs = cluster_kwargs
        self.adapt_kwargs = adapt_kwargs
        self.client_kwargs = client_kwargs
        self.disable_cancellation_event = disable_cancellation_event
        # Runtime attributes
        self.client = None
        # These are coupled - they're either both None, or both non-None.
        # They're used in the case we can't forcibly kill all the dask workers,
        # and need to wait for all the dask tasks to cleanup before exiting.
        self._futures = None  # type: Optional[weakref.WeakSet[Future]]
        self._should_run_event = None  # type: Optional[Event]
        # A ref to a background task subscribing to dask cluster events
        self._watch_dask_events_task = None  # type: Optional[concurrent.futures.Future]

        self.performance_report_path = performance_report_path

        super().__init__()

    @contextmanager
    def start(self) -> Iterator[None]:
        """
        Context manager for initializing execution.

        Creates a `dask.distributed.Client` and yields it.
        """
        if sys.platform != "win32":
            # Fix for https://github.com/dask/distributed/issues/4168
            import multiprocessing.popen_spawn_posix  # noqa
        from distributed import Client, performance_report

        performance_report_context = (performance_report(
            self.performance_report_path) if self.performance_report_path else
                                      nullcontext())

        try:
            if self.address is not None:
                self.logger.info(
                    "Connecting to an existing Dask cluster at %s",
                    self.address)
                with Client(self.address, **self.client_kwargs) as client:

                    with performance_report_context:
                        self.client = client
                        try:
                            self._pre_start_yield()
                            yield
                        finally:
                            self._post_start_yield()
            else:
                assert callable(self.cluster_class)  # mypy
                assert isinstance(self.cluster_kwargs, dict)  # mypy
                self.logger.info(
                    "Creating a new Dask cluster with `%s.%s`...",
                    self.cluster_class.__module__,
                    self.cluster_class.__qualname__,
                )
                with self.cluster_class(**self.cluster_kwargs) as cluster:
                    if getattr(cluster, "dashboard_link", None):
                        self.logger.info(
                            "The Dask dashboard is available at %s",
                            cluster.dashboard_link,
                        )
                    if self.adapt_kwargs:
                        cluster.adapt(**self.adapt_kwargs)
                    with Client(cluster, **self.client_kwargs) as client:
                        with performance_report_context:
                            self.client = client
                            try:
                                self._pre_start_yield()
                                yield
                            finally:
                                self._post_start_yield()
        finally:
            self.client = None

    async def on_worker_status_changed(self, op: str, message: dict) -> None:
        """
        This method is triggered when a worker is added or removed from the cluster.

        Args:
            - op (str): Either "add" or "remove"
            - message (dict): Information about the event that the scheduler has sent
        """
        if op == "add":
            for worker in message.get("workers", ()):
                self.logger.debug("Worker %s added", worker)
        elif op == "remove":
            self.logger.debug("Worker %s removed", message)

    async def _watch_dask_events(self) -> None:
        scheduler_comm = None
        comm = None
        from distributed.core import rpc

        try:
            scheduler_comm = rpc(
                self.client.scheduler.address,  # type: ignore
                connection_args=self.client.security.get_connection_args(
                    "client"),  # type: ignore
            )
            # due to a bug in distributed's inproc comms, letting cancellation
            # bubble up here will kill the listener. wrap with a shield to
            # prevent that.
            comm = await asyncio.shield(scheduler_comm.live_comm())
            await comm.write({"op": "subscribe_worker_status"})
            _ = await comm.read()
            while True:
                try:
                    msgs = await comm.read()
                except OSError:
                    break
                for op, msg in msgs:
                    await self.on_worker_status_changed(op, msg)
        except asyncio.CancelledError:
            pass
        except Exception:
            self.logger.debug("Failure while watching dask worker events",
                              exc_info=True)
        finally:
            if comm is not None:
                try:
                    await comm.close()
                except Exception:
                    pass
            if scheduler_comm is not None:
                scheduler_comm.close_rpc()

    def _pre_start_yield(self) -> None:
        from distributed import Event

        is_inproc = self.client.scheduler.address.startswith(
            "inproc")  # type: ignore
        if (self.address is not None
                or is_inproc) and not self.disable_cancellation_event:
            self._futures = weakref.WeakSet()
            self._should_run_event = Event(f"prefect-{uuid.uuid4().hex}",
                                           client=self.client)
            self._should_run_event.set()

        self._watch_dask_events_task = asyncio.run_coroutine_threadsafe(
            self._watch_dask_events(),
            self.client.loop.asyncio_loop  # type: ignore
        )

    def _post_start_yield(self) -> None:
        from distributed import wait

        if self._watch_dask_events_task is not None:
            try:
                self._watch_dask_events_task.cancel()
            except Exception:
                pass
            self._watch_dask_events_task = None

        if self._should_run_event is not None:
            # Multipart cleanup, ignoring exceptions in each stage
            # 1.) Stop pending tasks from starting
            try:
                self._should_run_event.clear()
            except Exception:
                pass
            # 2.) Wait for all running tasks to complete
            try:
                futures = [f for f in list(self._futures)
                           if not f.done()]  # type: ignore
                if futures:
                    self.logger.info(
                        "Stopping executor, waiting for %d active tasks to complete",
                        len(futures),
                    )
                    wait(futures)
            except Exception:
                pass
        self._should_run_event = None
        self._futures = None

    def _prep_dask_kwargs(self, extra_context: dict = None) -> dict:
        if extra_context is None:
            extra_context = {}

        dask_kwargs = {"pure": False}  # type: dict

        # set a key for the dask scheduler UI
        key = _make_task_key(**extra_context)
        if key is not None:
            dask_kwargs["key"] = key

        # infer from context if dask resources are being utilized
        task_tags = extra_context.get("task_tags", [])
        dask_resource_tags = [
            tag for tag in task_tags if tag.lower().startswith("dask-resource")
        ]
        if dask_resource_tags:
            resources = {}
            for tag in dask_resource_tags:
                prefix, val = tag.split("=")
                resources.update({prefix.split(":")[1]: float(val)})
            dask_kwargs.update(resources=resources)

        return dask_kwargs

    def __getstate__(self) -> dict:
        state = self.__dict__.copy()
        state.update({
            k: None
            for k in [
                "client",
                "_futures",
                "_should_run_event",
                "_watch_dask_events_task",
            ]
        })
        return state

    def __setstate__(self, state: dict) -> None:
        self.__dict__.update(state)

    def submit(self,
               fn: Callable,
               *args: Any,
               extra_context: dict = None,
               **kwargs: Any) -> "Future":
        """
        Submit a function to the executor for execution. Returns a Future object.

        Args:
            - fn (Callable): function that is being submitted for execution
            - *args (Any): arguments to be passed to `fn`
            - extra_context (dict, optional): an optional dictionary with extra information
                about the submitted task
            - **kwargs (Any): keyword arguments to be passed to `fn`

        Returns:
            - Future: a Future-like object that represents the computation of `fn(*args, **kwargs)`
        """
        if self.client is None:
            raise ValueError("This executor has not been started.")

        kwargs.update(self._prep_dask_kwargs(extra_context))
        if self._should_run_event is None:
            fut = self.client.submit(fn, *args, **kwargs)
        else:
            fut = self.client.submit(_maybe_run, self._should_run_event.name,
                                     fn, *args, **kwargs)
            self._futures.add(fut)
        return fut

    def wait(self, futures: Any) -> Any:
        """
        Resolves the Future objects to their values. Blocks until the computation is complete.

        Args:
            - futures (Any): single or iterable of future-like objects to compute

        Returns:
            - Any: an iterable of resolved futures with similar shape to the input
        """
        if self.client is None:
            raise ValueError("This executor has not been started.")

        return self.client.gather(futures)

    @property
    def performance_report(self) -> str:
        """The performance report html string."""
        if self.performance_report_path is None:
            self.logger.warning(
                "Executor was not configured to generate performance report")
            return ""
        self.logger.debug(
            f"Retreiving dask performance report from {self.performance_report_path!r}"
        )
        try:
            with open(self.performance_report_path, "r",
                      encoding="utf-8") as f:
                report = f.read()
                return report
        except Exception as exc:
            self.logger.error(
                f"Failed to get dask performance report with exception {exc}")
            return ""