async def test_delete_unset_variable(c, s, a, b): x = Variable() assert x.client is c with captured_logger(logging.getLogger("distributed.utils")) as logger: x.delete() await c.close() text = logger.getvalue() assert "KeyError" not in text
async def test_variables_do_not_leak_client(c, s, a, b): # https://github.com/dask/distributed/issues/3899 clients_pre = set(s.clients) # setup variable with future x = Variable("x") future = c.submit(inc, 1) await x.set(future) # complete teardown x.delete() start = time() while set(s.clients) != clients_pre: await asyncio.sleep(0.01) assert time() < start + 5
def test_variable(c, s, a, b): x = Variable("x") xx = Variable("x") assert x.client is c future = c.submit(inc, 1) yield x.set(future) future2 = yield xx.get() assert future.key == future2.key del future, future2 yield gen.sleep(0.1) assert s.tasks # future still present x.delete() start = time() while s.tasks: yield gen.sleep(0.01) assert time() < start + 5
def test_variable(c, s, a, b): x = Variable('x') xx = Variable('x') assert x.client is c future = c.submit(inc, 1) yield x.set(future) future2 = yield xx.get() assert future.key == future2.key del future, future2 yield gen.sleep(0.1) assert s.tasks # future still present x.delete() start = time() while s.tasks: yield gen.sleep(0.01) assert time() < start + 5
class DaskExecutor(Executor): """ An executor that runs all functions using the `dask.distributed` scheduler. By default a temporary `distributed.LocalCluster` is created (and subsequently torn down) within the `start()` contextmanager. To use a different cluster class (e.g. [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can specify `cluster_class`/`cluster_kwargs`. Alternatively, if you already have a dask cluster running, you can provide the address of the scheduler via the `address` kwarg. Note that if you have tasks with tags of the form `"dask-resource:KEY=NUM"` they will be parsed and passed as [Worker Resources](https://distributed.dask.org/en/latest/resources.html) of the form `{"KEY": float(NUM)}` to the Dask Scheduler. Args: - address (string, optional): address of a currently running dask scheduler; if one is not provided, a temporary cluster will be created in `executor.start()`. Defaults to `None`. - cluster_class (string or callable, optional): the cluster class to use when creating a temporary dask cluster. Can be either the full class name (e.g. `"distributed.LocalCluster"`), or the class itself. - cluster_kwargs (dict, optional): addtional kwargs to pass to the `cluster_class` when creating a temporary dask cluster. - adapt_kwargs (dict, optional): additional kwargs to pass to ``cluster.adapt` when creating a temporary dask cluster. Note that adaptive scaling is only enabled if `adapt_kwargs` are provided. - client_kwargs (dict, optional): additional kwargs to use when creating a [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client). - debug (bool, optional): When running with a local cluster, setting `debug=True` will increase dask's logging level, providing potentially useful debug info. Defaults to the `debug` value in your Prefect configuration. - **kwargs: DEPRECATED Using a temporary local dask cluster: ```python executor = DaskExecutor() ``` Using a temporary cluster running elsewhere. Any Dask cluster class should work, here we use [dask-cloudprovider](https://cloudprovider.dask.org): ```python executor = DaskExecutor( cluster_class="dask_cloudprovider.FargateCluster", cluster_kwargs={ "image": "prefecthq/prefect:latest", "n_workers": 5, ... }, ) ``` Connecting to an existing dask cluster ```python executor = DaskExecutor(address="192.0.2.255:8786") ``` """ def __init__( self, address: str = None, cluster_class: Union[str, Callable] = None, cluster_kwargs: dict = None, adapt_kwargs: dict = None, client_kwargs: dict = None, debug: bool = None, **kwargs: Any, ): if address is None: address = context.config.engine.executor.dask.address or None # XXX: deprecated if address == "local": warnings.warn( "`address='local'` is deprecated. To use a local cluster, leave the " "`address` field empty.") address = None # XXX: deprecated local_processes = kwargs.pop("local_processes", None) if local_processes is None: local_processes = context.config.engine.executor.dask.get( "local_processes", None) if local_processes is not None: warnings.warn( "`local_processes` is deprecated, please use " "`cluster_kwargs={'processes': local_processes}`. The default is " "now `local_processes=True`.") if address is not None: if cluster_class is not None or cluster_kwargs is not None: raise ValueError( "Cannot specify `address` and `cluster_class`/`cluster_kwargs`" ) else: if cluster_class is None: cluster_class = context.config.engine.executor.dask.cluster_class if isinstance(cluster_class, str): cluster_class = import_object(cluster_class) if cluster_kwargs is None: cluster_kwargs = {} else: cluster_kwargs = cluster_kwargs.copy() from distributed.deploy.local import LocalCluster if cluster_class == LocalCluster: if debug is None: debug = context.config.debug cluster_kwargs.setdefault( "silence_logs", logging.CRITICAL if not debug else logging.WARNING) if local_processes is not None: cluster_kwargs.setdefault("processes", local_processes) for_cluster = set(kwargs).difference(_valid_client_kwargs) if for_cluster: warnings.warn( "Forwarding executor kwargs to `LocalCluster` is now handled by the " "`cluster_kwargs` parameter, please update accordingly" ) for k in for_cluster: cluster_kwargs[k] = kwargs.pop(k) if adapt_kwargs is None: adapt_kwargs = {} if client_kwargs is None: client_kwargs = {} else: client_kwargs = client_kwargs.copy() if kwargs: warnings.warn( "Forwarding executor kwargs to `Client` is now handled by the " "`client_kwargs` parameter, please update accordingly") client_kwargs.update(kwargs) client_kwargs.setdefault("set_as_default", False) self.address = address self.cluster_class = cluster_class self.cluster_kwargs = cluster_kwargs self.adapt_kwargs = adapt_kwargs self.client_kwargs = client_kwargs # Runtime attributes self.client = None # These are coupled - they're either both None, or both non-None. # They're used in the case we can't forcibly kill all the dask workers, # and need to wait for all the dask tasks to cleanup before exiting. self._futures = None # type: Optional[weakref.WeakSet[Future]] self._should_run_var = None # type: Optional[Variable] super().__init__() @contextmanager def start(self) -> Iterator[None]: """ Context manager for initializing execution. Creates a `dask.distributed.Client` and yields it. """ from distributed import Client try: if self.address is not None: with Client(self.address, **self.client_kwargs) as client: self.client = client try: self._pre_start_yield() yield finally: self._post_start_yield() else: with self.cluster_class( **self.cluster_kwargs) as cluster: # type: ignore if self.adapt_kwargs: cluster.adapt(**self.adapt_kwargs) with Client(cluster, **self.client_kwargs) as client: self.client = client try: self._pre_start_yield() yield finally: self._post_start_yield() finally: self.client = None def _pre_start_yield(self) -> None: from distributed import Variable is_inproc = self.client.scheduler.address.startswith( "inproc") # type: ignore if self.address is not None or is_inproc: self._futures = weakref.WeakSet() self._should_run_var = Variable(f"prefect-{uuid.uuid4().hex}", client=self.client) self._should_run_var.set(True) def _post_start_yield(self) -> None: from distributed import wait if self._should_run_var is not None: # Multipart cleanup, ignoring exceptions in each stage # 1.) Stop pending tasks from starting try: self._should_run_var.set(False) except Exception: pass # 2.) Wait for all running tasks to complete try: futures = [f for f in list(self._futures) if not f.done()] # type: ignore if futures: self.logger.info( "Stopping executor, waiting for %d active tasks to complete", len(futures), ) wait(futures) except Exception: pass # 3.) Delete the distributed variable try: self._should_run_var.delete() except Exception: pass self._should_run_var = None self._futures = None def _prep_dask_kwargs(self, extra_context: dict = None) -> dict: if extra_context is None: extra_context = {} dask_kwargs = {"pure": False} # type: dict # set a key for the dask scheduler UI key = _make_task_key(**extra_context) if key is not None: dask_kwargs["key"] = key # infer from context if dask resources are being utilized task_tags = extra_context.get("task_tags", []) dask_resource_tags = [ tag for tag in task_tags if tag.lower().startswith("dask-resource") ] if dask_resource_tags: resources = {} for tag in dask_resource_tags: prefix, val = tag.split("=") resources.update({prefix.split(":")[1]: float(val)}) dask_kwargs.update(resources=resources) return dask_kwargs def __getstate__(self) -> dict: state = self.__dict__.copy() state.update( {k: None for k in ["client", "_futures", "_should_run_var"]}) return state def __setstate__(self, state: dict) -> None: self.__dict__.update(state) def submit(self, fn: Callable, *args: Any, extra_context: dict = None, **kwargs: Any) -> "Future": """ Submit a function to the executor for execution. Returns a Future object. Args: - fn (Callable): function that is being submitted for execution - *args (Any): arguments to be passed to `fn` - extra_context (dict, optional): an optional dictionary with extra information about the submitted task - **kwargs (Any): keyword arguments to be passed to `fn` Returns: - Future: a Future-like object that represents the computation of `fn(*args, **kwargs)` """ if self.client is None: raise ValueError("This executor has not been started.") kwargs.update(self._prep_dask_kwargs(extra_context)) if self._should_run_var is None: fut = self.client.submit(fn, *args, **kwargs) else: fut = self.client.submit(_maybe_run, self._should_run_var.name, fn, *args, **kwargs) self._futures.add(fut) return fut def wait(self, futures: Any) -> Any: """ Resolves the Future objects to their values. Blocks until the computation is complete. Args: - futures (Any): single or iterable of future-like objects to compute Returns: - Any: an iterable of resolved futures with similar shape to the input """ if self.client is None: raise ValueError("This executor has not been started.") return self.client.gather(futures)