示例#1
0
async def test_delete_unset_variable(c, s, a, b):
    x = Variable()
    assert x.client is c
    with captured_logger(logging.getLogger("distributed.utils")) as logger:
        x.delete()
        await c.close()
    text = logger.getvalue()
    assert "KeyError" not in text
示例#2
0
async def test_variables_do_not_leak_client(c, s, a, b):
    # https://github.com/dask/distributed/issues/3899
    clients_pre = set(s.clients)

    # setup variable with future
    x = Variable("x")
    future = c.submit(inc, 1)
    await x.set(future)

    # complete teardown
    x.delete()

    start = time()
    while set(s.clients) != clients_pre:
        await asyncio.sleep(0.01)
        assert time() < start + 5
示例#3
0
def test_variable(c, s, a, b):
    x = Variable("x")
    xx = Variable("x")
    assert x.client is c

    future = c.submit(inc, 1)

    yield x.set(future)
    future2 = yield xx.get()
    assert future.key == future2.key

    del future, future2

    yield gen.sleep(0.1)
    assert s.tasks  # future still present

    x.delete()

    start = time()
    while s.tasks:
        yield gen.sleep(0.01)
        assert time() < start + 5
示例#4
0
def test_variable(c, s, a, b):
    x = Variable('x')
    xx = Variable('x')
    assert x.client is c

    future = c.submit(inc, 1)

    yield x.set(future)
    future2 = yield xx.get()
    assert future.key == future2.key

    del future, future2

    yield gen.sleep(0.1)
    assert s.tasks  # future still present

    x.delete()

    start = time()
    while s.tasks:
        yield gen.sleep(0.01)
        assert time() < start + 5
示例#5
0
class DaskExecutor(Executor):
    """
    An executor that runs all functions using the `dask.distributed` scheduler.

    By default a temporary `distributed.LocalCluster` is created (and
    subsequently torn down) within the `start()` contextmanager. To use a
    different cluster class (e.g.
    [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can
    specify `cluster_class`/`cluster_kwargs`.

    Alternatively, if you already have a dask cluster running, you can provide
    the address of the scheduler via the `address` kwarg.

    Note that if you have tasks with tags of the form `"dask-resource:KEY=NUM"`
    they will be parsed and passed as
    [Worker Resources](https://distributed.dask.org/en/latest/resources.html)
    of the form `{"KEY": float(NUM)}` to the Dask Scheduler.

    Args:
        - address (string, optional): address of a currently running dask
            scheduler; if one is not provided, a temporary cluster will be
            created in `executor.start()`.  Defaults to `None`.
        - cluster_class (string or callable, optional): the cluster class to use
            when creating a temporary dask cluster. Can be either the full
            class name (e.g. `"distributed.LocalCluster"`), or the class itself.
        - cluster_kwargs (dict, optional): addtional kwargs to pass to the
           `cluster_class` when creating a temporary dask cluster.
        - adapt_kwargs (dict, optional): additional kwargs to pass to ``cluster.adapt`
            when creating a temporary dask cluster. Note that adaptive scaling
            is only enabled if `adapt_kwargs` are provided.
        - client_kwargs (dict, optional): additional kwargs to use when creating a
            [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client).
        - debug (bool, optional): When running with a local cluster, setting
            `debug=True` will increase dask's logging level, providing
            potentially useful debug info. Defaults to the `debug` value in
            your Prefect configuration.
        - **kwargs: DEPRECATED

    Using a temporary local dask cluster:

    ```python
    executor = DaskExecutor()
    ```

    Using a temporary cluster running elsewhere. Any Dask cluster class should
    work, here we use [dask-cloudprovider](https://cloudprovider.dask.org):

    ```python
    executor = DaskExecutor(
        cluster_class="dask_cloudprovider.FargateCluster",
        cluster_kwargs={
            "image": "prefecthq/prefect:latest",
            "n_workers": 5,
            ...
        },
    )
    ```

    Connecting to an existing dask cluster

    ```python
    executor = DaskExecutor(address="192.0.2.255:8786")
    ```
    """
    def __init__(
        self,
        address: str = None,
        cluster_class: Union[str, Callable] = None,
        cluster_kwargs: dict = None,
        adapt_kwargs: dict = None,
        client_kwargs: dict = None,
        debug: bool = None,
        **kwargs: Any,
    ):
        if address is None:
            address = context.config.engine.executor.dask.address or None
        # XXX: deprecated
        if address == "local":
            warnings.warn(
                "`address='local'` is deprecated. To use a local cluster, leave the "
                "`address` field empty.")
            address = None

        # XXX: deprecated
        local_processes = kwargs.pop("local_processes", None)
        if local_processes is None:
            local_processes = context.config.engine.executor.dask.get(
                "local_processes", None)
        if local_processes is not None:
            warnings.warn(
                "`local_processes` is deprecated, please use "
                "`cluster_kwargs={'processes': local_processes}`. The default is "
                "now `local_processes=True`.")

        if address is not None:
            if cluster_class is not None or cluster_kwargs is not None:
                raise ValueError(
                    "Cannot specify `address` and `cluster_class`/`cluster_kwargs`"
                )
        else:
            if cluster_class is None:
                cluster_class = context.config.engine.executor.dask.cluster_class
            if isinstance(cluster_class, str):
                cluster_class = import_object(cluster_class)
            if cluster_kwargs is None:
                cluster_kwargs = {}
            else:
                cluster_kwargs = cluster_kwargs.copy()

            from distributed.deploy.local import LocalCluster

            if cluster_class == LocalCluster:
                if debug is None:
                    debug = context.config.debug
                cluster_kwargs.setdefault(
                    "silence_logs",
                    logging.CRITICAL if not debug else logging.WARNING)
                if local_processes is not None:
                    cluster_kwargs.setdefault("processes", local_processes)
                for_cluster = set(kwargs).difference(_valid_client_kwargs)
                if for_cluster:
                    warnings.warn(
                        "Forwarding executor kwargs to `LocalCluster` is now handled by the "
                        "`cluster_kwargs` parameter, please update accordingly"
                    )
                    for k in for_cluster:
                        cluster_kwargs[k] = kwargs.pop(k)

            if adapt_kwargs is None:
                adapt_kwargs = {}

        if client_kwargs is None:
            client_kwargs = {}
        else:
            client_kwargs = client_kwargs.copy()
        if kwargs:
            warnings.warn(
                "Forwarding executor kwargs to `Client` is now handled by the "
                "`client_kwargs` parameter, please update accordingly")
            client_kwargs.update(kwargs)
        client_kwargs.setdefault("set_as_default", False)

        self.address = address
        self.cluster_class = cluster_class
        self.cluster_kwargs = cluster_kwargs
        self.adapt_kwargs = adapt_kwargs
        self.client_kwargs = client_kwargs
        # Runtime attributes
        self.client = None
        # These are coupled - they're either both None, or both non-None.
        # They're used in the case we can't forcibly kill all the dask workers,
        # and need to wait for all the dask tasks to cleanup before exiting.
        self._futures = None  # type: Optional[weakref.WeakSet[Future]]
        self._should_run_var = None  # type: Optional[Variable]

        super().__init__()

    @contextmanager
    def start(self) -> Iterator[None]:
        """
        Context manager for initializing execution.

        Creates a `dask.distributed.Client` and yields it.
        """
        from distributed import Client

        try:
            if self.address is not None:
                with Client(self.address, **self.client_kwargs) as client:
                    self.client = client
                    try:
                        self._pre_start_yield()
                        yield
                    finally:
                        self._post_start_yield()
            else:
                with self.cluster_class(
                        **self.cluster_kwargs) as cluster:  # type: ignore
                    if self.adapt_kwargs:
                        cluster.adapt(**self.adapt_kwargs)
                    with Client(cluster, **self.client_kwargs) as client:
                        self.client = client
                        try:
                            self._pre_start_yield()
                            yield
                        finally:
                            self._post_start_yield()
        finally:
            self.client = None

    def _pre_start_yield(self) -> None:
        from distributed import Variable

        is_inproc = self.client.scheduler.address.startswith(
            "inproc")  # type: ignore
        if self.address is not None or is_inproc:
            self._futures = weakref.WeakSet()
            self._should_run_var = Variable(f"prefect-{uuid.uuid4().hex}",
                                            client=self.client)
            self._should_run_var.set(True)

    def _post_start_yield(self) -> None:
        from distributed import wait

        if self._should_run_var is not None:
            # Multipart cleanup, ignoring exceptions in each stage
            # 1.) Stop pending tasks from starting
            try:
                self._should_run_var.set(False)
            except Exception:
                pass
            # 2.) Wait for all running tasks to complete
            try:
                futures = [f for f in list(self._futures)
                           if not f.done()]  # type: ignore
                if futures:
                    self.logger.info(
                        "Stopping executor, waiting for %d active tasks to complete",
                        len(futures),
                    )
                    wait(futures)
            except Exception:
                pass
            # 3.) Delete the distributed variable
            try:
                self._should_run_var.delete()
            except Exception:
                pass
        self._should_run_var = None
        self._futures = None

    def _prep_dask_kwargs(self, extra_context: dict = None) -> dict:
        if extra_context is None:
            extra_context = {}

        dask_kwargs = {"pure": False}  # type: dict

        # set a key for the dask scheduler UI
        key = _make_task_key(**extra_context)
        if key is not None:
            dask_kwargs["key"] = key

        # infer from context if dask resources are being utilized
        task_tags = extra_context.get("task_tags", [])
        dask_resource_tags = [
            tag for tag in task_tags if tag.lower().startswith("dask-resource")
        ]
        if dask_resource_tags:
            resources = {}
            for tag in dask_resource_tags:
                prefix, val = tag.split("=")
                resources.update({prefix.split(":")[1]: float(val)})
            dask_kwargs.update(resources=resources)

        return dask_kwargs

    def __getstate__(self) -> dict:
        state = self.__dict__.copy()
        state.update(
            {k: None
             for k in ["client", "_futures", "_should_run_var"]})
        return state

    def __setstate__(self, state: dict) -> None:
        self.__dict__.update(state)

    def submit(self,
               fn: Callable,
               *args: Any,
               extra_context: dict = None,
               **kwargs: Any) -> "Future":
        """
        Submit a function to the executor for execution. Returns a Future object.

        Args:
            - fn (Callable): function that is being submitted for execution
            - *args (Any): arguments to be passed to `fn`
            - extra_context (dict, optional): an optional dictionary with extra information
                about the submitted task
            - **kwargs (Any): keyword arguments to be passed to `fn`

        Returns:
            - Future: a Future-like object that represents the computation of `fn(*args, **kwargs)`
        """
        if self.client is None:
            raise ValueError("This executor has not been started.")

        kwargs.update(self._prep_dask_kwargs(extra_context))
        if self._should_run_var is None:
            fut = self.client.submit(fn, *args, **kwargs)
        else:
            fut = self.client.submit(_maybe_run, self._should_run_var.name, fn,
                                     *args, **kwargs)
            self._futures.add(fut)
        return fut

    def wait(self, futures: Any) -> Any:
        """
        Resolves the Future objects to their values. Blocks until the computation is complete.

        Args:
            - futures (Any): single or iterable of future-like objects to compute

        Returns:
            - Any: an iterable of resolved futures with similar shape to the input
        """
        if self.client is None:
            raise ValueError("This executor has not been started.")

        return self.client.gather(futures)