def fit(self, X, y=None): """Fit a gradient boosting classifier Parameters ---------- X : array-like [n_samples, n_features] Feature Matrix. May be a dask.array or dask.dataframe y : array-like Labels Returns ------- self : XGBClassifier Notes ----- This differs from the XGBoost version in three ways 1. The ``sample_weight``, ``eval_set``, ``eval_metric``, ``early_stopping_rounds`` and ``verbose`` fit kwargs are not supported. 2. The labels are not automatically label-encoded 3. The ``classes_`` and ``n_classes_`` attributes are not learned """ client = default_client() xgb_options = self.get_xgb_params() self._Booster = train(client, xgb_options, X, y, num_boost_round=self.n_estimators) return self
def __init__(self, child, n, loop=None): client = default_client() self.queue = dask.distributed.Queue(maxsize=n, client=client) core.Stream.__init__(self, child, loop=loop or client.loop) self.loop.add_callback(self.cb)
def fit(self, X, y=None): """Fit the gradient boosting model Parameters ---------- X : array-like [n_samples, n_features] y : array-like Returns ------- self : the fitted Regressor Notes ----- This differs from the XGBoost version not supporting the ``eval_set``, ``eval_metric``, ``early_stopping_rounds`` and ``verbose`` fit kwargs. """ client = default_client() xgb_options = self.get_xgb_params() self._Booster = train(client, xgb_options, X, y, num_boost_round=self.n_estimators) return self
def test_from_partitions(axis, index, columns, row_lengths, column_widths): num_rows = 2**16 num_cols = 2**8 data = np.random.randint(0, 100, size=(num_rows, num_cols)) df1, df2 = pandas.DataFrame(data), pandas.DataFrame(data) expected_df = pandas.concat([df1, df2], axis=1 if axis is None else axis) index = expected_df.index if index == "index" else None columns = expected_df.columns if columns == "columns" else None row_lengths = (None if row_lengths is None else [num_rows, num_rows] if axis == 0 else [num_rows]) column_widths = (None if column_widths is None else [num_cols] if axis == 0 else [num_cols, num_cols]) if Engine.get() == "Ray": if axis is None: futures = [[ray.put(df1), ray.put(df2)]] else: futures = [ray.put(df1), ray.put(df2)] if Engine.get() == "Dask": client = default_client() if axis is None: futures = [client.scatter([df1, df2], hash=False)] else: futures = client.scatter([df1, df2], hash=False) actual_df = from_partitions( futures, axis, index=index, columns=columns, row_lengths=row_lengths, column_widths=column_widths, ) df_equals(expected_df, actual_df)
def deploy(cls, func, *args, num_returns=1, pure=None, **kwargs): """ Deploy a function in a worker process. Parameters ---------- func : callable Function to be deployed in a worker process. *args : list Additional positional arguments to be passed in `func`. num_returns : int, default: 1 The number of returned objects. pure : bool, optional Whether or not `func` is pure. See `Client.submit` for details. **kwargs : dict Additional keyword arguments to be passed in ``func``. Returns ------- list The result of ``func`` splitted into parts in accordance with ``num_returns``. """ client = default_client() remote_task_future = client.submit(func, *args, pure=pure, **kwargs) if num_returns != 1: return [ client.submit(lambda l, i: l[i], remote_task_future, i) for i in range(num_returns) ] return remote_task_future
def wait(collections, client=None, return_when=FIRST_COMPLETED): """Calculates collections on client. Parameters ---------- collections : list of dask tasks client : None or distributed.client.Client if None uses distributed.client.default_client() return_when : int if == FIRST_COMPLETED returns when first task completes if == ALL_COMPLETED returns when all tasks completed Currently supports only FIRST_COMPLETED. Returns ------- tuple : (result, index, unfinished_futures) """ if return_when not in (FIRST_COMPLETED, ALL_COMPLETED): raise ValueError( "Unknown value for 'return_when'." + "Expected {} or {}.".format(FIRST_COMPLETED, ALL_COMPLETED) + "Received {}.".format(return_when)) if return_when == ALL_COMPLETED: raise NotImplementedError("Support for ALL_COMPLETED not implemented.") client = client or dc.default_client() futures = client.compute(collections) f = dc.as_completed(futures).__next__() i = futures.index(f) del futures[i] res = f.result() del f return res, i, futures
def mask(self, row_labels, col_labels): """ Lazily create a mask that extracts the indices provided. Parameters ---------- row_labels : list-like, slice or label The row labels for the rows to extract. col_labels : list-like, slice or label The column labels for the columns to extract. Returns ------- PandasOnDaskDataframePartition A new ``PandasOnDaskDataframePartition`` object. """ new_obj = super().mask(row_labels, col_labels) client = default_client() if isinstance(row_labels, slice) and isinstance(self._length_cache, Future): new_obj._length_cache = client.submit( compute_sliced_len, row_labels, self._length_cache ) if isinstance(col_labels, slice) and isinstance(self._width_cache, Future): new_obj._width_cache = client.submit( compute_sliced_len, col_labels, self._width_cache ) return new_obj
def initialize_dask(): """Initialize Dask environment.""" from distributed.client import default_client try: client = default_client() except ValueError: from distributed import Client # The indentation here is intentional, we want the code to be indented. ErrorMessage.not_initialized( "Dask", """ from distributed import Client client = Client() """, ) num_cpus = CpuCount.get() memory_limit = Memory.get() worker_memory_limit = memory_limit // num_cpus if memory_limit else "auto" client = Client(n_workers=num_cpus, memory_limit=worker_memory_limit) num_cpus = len(client.ncores()) NPartitions._put(num_cpus)
def __init__(self, child, limit=10, client=None): self.client = client or default_client() self.queue = Queue(maxsize=limit) self.condition = Condition() Stream.__init__(self, child) self.client.loop.add_callback(self.cb)
def predict(self, X): client = default_client() class_probs = predict(client, self._Booster, X) if class_probs.ndim > 1: cidx = da.argmax(class_probs, axis=1) else: cidx = (class_probs > 0).astype(np.int64) return cidx
def from_kafka_batched(topic, consumer_params, poll_interval="1s", npartitions=1, start=False, dask=False, **kwargs): """ Get messages from Kafka in batches Uses the confluent-kafka library, https://docs.confluent.io/current/clients/confluent-kafka-python/ This source will emit lists of messages for each partition of a single given topic per time interval, if there is new data. If using dask, one future will be produced per partition per time-step, if there is data. Parameters ---------- topic: str Kafka topic to consume from consumer_params: dict Settings to set up the stream, see https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md Examples: bootstrap.servers: Connection string(s) (host:port) by which to reach Kafka group.id: Identity of the consumer. If multiple sources share the same group, each message will be passed to only one of them. poll_interval: number Seconds that elapse between polling Kafka for new messages npartitions: int Number of partitions in the topic start: bool (False) Whether to start polling upon instantiation Example ------- >>> source = Stream.from_kafka_batched('mytopic', ... {'bootstrap.servers': 'localhost:9092', ... 'group.id': 'rapidz'}, npartitions=4) # doctest: +SKIP """ if dask: from distributed.client import default_client kwargs["loop"] = default_client().loop source = FromKafkaBatched(topic, consumer_params, poll_interval=poll_interval, npartitions=npartitions, **kwargs) if dask: source = source.scatter() if start: source.start() return source.starmap(get_message_batch)
def update(self, x, who=None, metadata=None): client = default_client() self._retain_refs(metadata) future = yield client.scatter(x, asynchronous=True) f = yield self._emit(future, metadata=metadata) self._release_refs(metadata) raise gen.Return(f)
def update(self, x, who=None, metadata=None): client = default_client() self._retain_refs(metadata) result = yield client.gather(x, asynchronous=True) result2 = yield self._emit(result, metadata=metadata) self._release_refs(metadata) raise gen.Return(result2)
def update(self, x, emit_id=None, who=None): client = default_client() result = client.submit( self.func, x, *self.args, **self.kwargs, pure=False, key=f'{self.func.__name__}--{emit_id}--{str(uuid.uuid4())}') return self._emit(result, emit_id=emit_id)
def update(self, x, emit_id=None, who=None): try: client = default_client() future_as_list = yield client.scatter([x], asynchronous=True, hash=False) future = future_as_list[0] f = yield self._emit(future, emit_id=emit_id) raise gen.Return(f) except Exception as e: self.error = True raise
def update(self, x, sleep_between_gather=0.1, emit_id=None, who=None): try: client = default_client() result = yield client.gather(x, asynchronous=True) result2 = yield self._emit(result, emit_id=emit_id) raise gen.Return(result2) # result1 = yield get_stream(self).wait(sleep_between_gather=sleep_between_gather) # raise gen.Return(result1) except Exception as e: self.error = True raise
def update(self, x, who=None): if self.state is core.no_default: self.state = x return self._emit(self.state) else: client = default_client() result = client.submit(self.func, self.state, x, **self.kwargs) if self.returns_state: state = client.submit(getitem, result, 0) result = client.submit(getitem, result, 1) else: state = result self.state = state return self._emit(result)
def from_kafka_batched_scatter( topics: List[str], consumer_params: Dict, poll_interval: int = 5, batch_size: int = 1000, dask: bool = False, db=None, **kwargs, ): """ Parameters ---------- topics: List[str] Labels of Kafka topics to consume from. consumer_params: Dict Settings to set up the stream, see https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md Examples: bootstrap.servers, Connection string(s) (host:port) by which to reach Kafka; group.id, Identity of the consumer. If multiple sources share the same group, each message will be passed to only one of them. batch_size: int, optional (default=1000) batch size of polling. poll_interval: float, optional (default=5.0) Seconds that elapse between polling Kafka for new messages. dask: bool, optional (default=False) If True, will poll events from each partitions distributed among Dask workers. db: Callable, optional (default=None) persistent layer to check kafka offset to provide once-semantics. If None, will initiate waterhealer.db.expiringdict.Database. """ if dask: from distributed.client import default_client kwargs['loop'] = default_client().loop source = FromKafkaBatched( topics=topics, consumer_params=consumer_params, poll_interval=poll_interval, batch_size=batch_size, db=db, **kwargs, ) if dask: source = source.scatter() return source.starmap(get_message_batch)
def deploy_func_between_two_axis_partitions(cls, axis, func, num_splits, len_of_left, other_shape, kwargs, *partitions): """ Deploy a function along a full axis between two data sets. Parameters ---------- axis : {0, 1} The axis to perform the function along. func : callable The function to perform. num_splits : int The number of splits to return (see `split_result_of_axis_func_pandas`). len_of_left : int The number of values in `partitions` that belong to the left data set. other_shape : np.ndarray The shape of right frame in terms of partitions, i.e. (other_shape[i-1], other_shape[i]) will indicate slice to restore i-1 axis partition. kwargs : dict Additional keywords arguments to be passed in `func`. *partitions : iterable All partitions that make up the full axis (row or column) for both data sets. Returns ------- list A list of distributed.Future. """ client = default_client() axis_result = client.submit( deploy_dask_func, PandasDataframeAxisPartition. deploy_func_between_two_axis_partitions, axis, func, num_splits, len_of_left, other_shape, kwargs, *partitions, pure=False, ) # We have to do this to split it back up. It is already split, but we need to # get futures for each. return [ client.submit(lambda l: l[i], axis_result, pure=False) for i in range(num_splits * 4) ]
def preprocess_func(cls, func): """ Preprocess a function before an ``apply`` call. Parameters ---------- func : callable The function to preprocess. Returns ------- callable An object that can be accepted by ``apply``. """ return default_client().scatter(func, hash=False, broadcast=True)
def update(self, x, who=None, metadata=None): client = default_client() self._retain_refs(metadata) # We need to make sure that x is treated as it is by dask # However, client.scatter works internally different for # lists and dicts. So we always use a list here to be sure # we know the format exactly. We do not use a key to avoid # issues like https://github.com/python-streamz/streams/issues/397. future_as_list = yield client.scatter([x], asynchronous=True, hash=False) future = future_as_list[0] f = yield self._emit(future, metadata=metadata) self._release_refs(metadata) raise gen.Return(f)
def get_client(return_exception=False): error = 'no error' try: from distributed.client import default_client client = default_client() except Exception as e: error = str(e) logger.error(e) client = None if return_exception: return client, error else: return client
def deploy_axis_func(cls, axis, func, num_splits, kwargs, maintain_partitioning, *partitions): """ Deploy a function along a full axis. Parameters ---------- axis : {0, 1} The axis to perform the function along. func : callable The function to perform. num_splits : int The number of splits to return (see `split_result_of_axis_func_pandas`). kwargs : dict Additional keywords arguments to be passed in `func`. maintain_partitioning : bool If True, keep the old partitioning if possible. If False, create a new partition layout. *partitions : iterable All partitions that make up the full axis (row or column). Returns ------- list A list of distributed.Future. """ client = default_client() axis_result = client.submit( deploy_dask_func, PandasDataframeAxisPartition.deploy_axis_func, axis, func, num_splits, kwargs, maintain_partitioning, *partitions, pure=False, ) lengths = kwargs.get("_lengths", None) result_num_splits = len(lengths) if lengths else num_splits # We have to do this to split it back up. It is already split, but we need to # get futures for each. return [ client.submit(lambda l: l[i], axis_result, pure=False) for i in range(result_num_splits * 4) ]
def _column_widths(self): """ Compute the column partitions widths if they are not cached. Returns ------- list A list of column partitions widths. """ client = default_client() if self._column_widths_cache is None: self._column_widths_cache = client.gather([ obj.apply(lambda df: len(df.columns)).future for obj in self._partitions[0] ]) return self._column_widths_cache
def materialize(cls, future): """ Materialize data matching `future` object. Parameters ---------- future : distributed.Future or list Future object of list of future objects whereby data needs to be materialized. Returns ------- Any An object(s) from the distributed memory. """ client = default_client() return client.gather(future)
def _get_files_from_location(input_item, file_format: str = None, **kwargs): if file_format == "memory": client = default_client() return client.get_dataset(input_item, **kwargs) if not file_format: _, extension = os.path.splitext(input_item) file_format = extension.lstrip(".") try: read_function = getattr(dd, f"read_{file_format}") except AttributeError: raise AttributeError(f"Can not read files of format {file_format}") return read_function(input_item, **kwargs)
def _row_lengths(self): """ Compute the row partitions lengths if they are not cached. Returns ------- list A list of row partitions lengths. """ client = default_client() if self._row_lengths_cache is None: self._row_lengths_cache = client.gather([ obj.apply(lambda df: len(df)).future for obj in self._partitions.T[0] ]) return self._row_lengths_cache
def update(self, x, who=None, metadata=None): client = default_client() self._retain_refs(metadata) # We need to make sure that x is treated as it is by dask # However, client.scatter works internally different for # lists and dicts. So we always use a dict here to be sure # we know the format exactly. The key will be taken as the # dask identifier of the data. tokenized_x = f"{type(x).__name__}-{tokenize(x)}" future_as_dict = yield client.scatter({tokenized_x: x}, asynchronous=True) future = future_as_dict[tokenized_x] f = yield self._emit(future, metadata=metadata) self._release_refs(metadata) raise gen.Return(f)
def put(cls, obj): """ Put an object into distributed memory and wrap it with partition object. Parameters ---------- obj : any An object to be put. Returns ------- PandasOnDaskDataframePartition A new ``PandasOnDaskDataframePartition`` object. """ client = default_client() return cls(client.scatter(obj, hash=False))
def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs): if format == "memory": client = default_client() return client.get_dataset(input_item, **kwargs) if not format: _, extension = os.path.splitext(input_item) format = extension.lstrip(".") try: read_function = getattr(dd, f"read_{format}") except AttributeError: raise AttributeError(f"Can not read files of format {format}") return read_function(input_item, **kwargs)
def read_bytes(urlpath, delimiter=None, not_zero=False, blocksize=2**27, sample=True, compression=None, **kwargs): """ Convert path to a list of delayed values The path may be a filename like ``'2015-01-01.csv'`` or a globstring like ``'2015-*-*.csv'``. The path may be preceded by a protocol, like ``s3://`` or ``hdfs://`` if those libraries are installed. This cleanly breaks data by a delimiter if given, so that block boundaries start directly after a delimiter and end on the delimiter. Parameters ---------- urlpath: string Absolute or relative filepath, URL (may include protocols like ``s3://``), or globstring pointing to data. delimiter: bytes An optional delimiter, like ``b'\\n'`` on which to split blocks of bytes. not_zero: bool Force seek of start-of-file delimiter, discarding header. blocksize: int (=128MB) Chunk size compression: string or None String like 'gzip' or 'xz'. Must support efficient random access. sample: bool or int Whether or not to return a header sample. If an integer is given it is used as sample size, otherwise the default sample size is 10kB. **kwargs: dict Extra options that make sense to a particular storage connection, e.g. host, port, username, password, etc. Examples -------- >>> sample, blocks = read_bytes('2015-*-*.csv', delimiter=b'\\n') # doctest: +SKIP >>> sample, blocks = read_bytes('s3://bucket/2015-*-*.csv', delimiter=b'\\n') # doctest: +SKIP Returns ------- A sample header and list of ``dask.Delayed`` objects or list of lists of delayed objects if ``fn`` is a globstring. """ fs, paths, myopen = get_fs_paths_myopen(urlpath, compression, 'rb', None, **kwargs) client = None if len(paths) == 0: raise IOError("%s resolved to no files" % urlpath) blocks, lengths, machines = fs.get_block_locations(paths) if blocks: offsets = blocks elif blocksize is None: offsets = [[0]] * len(paths) lengths = [[None]] * len(offsets) machines = [[None]] * len(offsets) else: offsets = [] lengths = [] for path in paths: try: size = fs.logical_size(path, compression) except KeyError: raise ValueError('Cannot read compressed files (%s) in byte chunks,' 'use blocksize=None' % infer_compression(urlpath)) off = list(range(0, size, blocksize)) length = [blocksize] * len(off) if not_zero: off[0] = 1 length[0] -= 1 offsets.append(off) lengths.append(length) machines = [[None]] * len(offsets) out = [] for path, offset, length, machine in zip(paths, offsets, lengths, machines): ukey = fs.ukey(path) keys = ['read-block-%s-%s' % (o, tokenize(path, compression, offset, ukey, kwargs, delimiter)) for o in offset] L = [delayed(read_block_from_file)(myopen(path, mode='rb'), o, l, delimiter, dask_key_name=key) for (o, key, l) in zip(offset, keys, length)] out.append(L) if machine is not None: # blocks are in preferred locations if client is None: try: from distributed.client import default_client client = default_client() except (ImportError, ValueError): # no distributed client client = False if client: restrictions = {key: w for key, w in zip(keys, machine)} client._send_to_scheduler({'op': 'update-graph', 'tasks': {}, 'dependencies': [], 'keys': [], 'restrictions': restrictions, 'loose_restrictions': list(restrictions), 'client': client.id}) if sample is not True: nbytes = sample else: nbytes = 10000 if sample: # myopen = OpenFileCreator(urlpath, compression) with myopen(paths[0], 'rb') as f: sample = read_block(f, 0, nbytes, delimiter) return sample, out
def persist(*args, **kwargs): """ Persist multiple Dask collections into memory This turns lazy Dask collections into Dask collections with the same metadata, but now with their results fully computed or actively computing in the background. For example a lazy dask.array built up from many lazy calls will now be a dask.array of the same shape, dtype, chunks, etc., but now with all of those previously lazy tasks either computed in memory as many small :class:`numpy.array` (in the single-machine case) or asynchronously running in the background on a cluster (in the distributed case). This function operates differently if a ``dask.distributed.Client`` exists and is connected to a distributed scheduler. In this case this function will return as soon as the task graph has been submitted to the cluster, but before the computations have completed. Computations will continue asynchronously in the background. When using this function with the single machine scheduler it blocks until the computations have finished. When using Dask on a single machine you should ensure that the dataset fits entirely within memory. Examples -------- >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP >>> df = df[df.name == 'Alice'] # doctest: +SKIP >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP >>> df = df.persist() # triggers computation # doctest: +SKIP >>> df.value().min() # future computations are now fast # doctest: +SKIP -10 >>> df.value().max() # doctest: +SKIP 100 >>> from dask import persist # use persist function on multiple collections >>> a, b = persist(a, b) # doctest: +SKIP Parameters ---------- *args: Dask collections get : callable, optional A scheduler ``get`` function to use. If not provided, the default is to check the global settings first, and then fall back to the collection defaults. optimize_graph : bool, optional If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging. **kwargs Extra keywords to forward to the scheduler ``get`` function. Returns ------- New dask collections backed by in-memory data """ collections = [a for a in args if is_dask_collection(a)] if not collections: return args get = kwargs.pop('get', None) or _globals['get'] if get is None and getattr(thread_state, 'key', False): from distributed.worker import get_worker get = get_worker().client.get if inspect.ismethod(get): try: from distributed.client import default_client except ImportError: pass else: try: client = default_client() except ValueError: pass else: if client.get == _globals['get']: collections = client.persist(collections, **kwargs) if isinstance(collections, list): # distributed is inconsistent here collections = tuple(collections) else: collections = (collections,) results_iter = iter(collections) return tuple(a if not is_dask_collection(a) else next(results_iter) for a in args) optimize_graph = kwargs.pop('optimize_graph', True) if not get: get = collections[0].__dask_scheduler__ if not all(a.__dask_scheduler__ == get for a in collections): raise ValueError("Compute called on multiple collections with " "differing default schedulers. Please specify a " "scheduler `get` function using either " "the `get` kwarg or globally with `set_options`.") dsk = collections_to_dsk(collections, optimize_graph, **kwargs) keys, postpersists = [], [] for a in args: if is_dask_collection(a): a_keys = list(flatten(a.__dask_keys__())) rebuild, state = a.__dask_postpersist__() keys.extend(a_keys) postpersists.append((rebuild, a_keys, state)) else: postpersists.append((None, None, a)) results = get(dsk, keys, **kwargs) d = dict(zip(keys, results)) return tuple(s if r is None else r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists)
def persist(*args, **kwargs): """ Persist multiple Dask collections into memory This turns lazy Dask collections into Dask collections with the same metadata, but now with their results fully computed or actively computing in the background. For example a lazy dask.array built up from many lazy calls will now be a dask.array of the same shape, dtype, chunks, etc., but now with all of those previously lazy tasks either computed in memory as many small :class:`numpy.array` (in the single-machine case) or asynchronously running in the background on a cluster (in the distributed case). This function operates differently if a ``dask.distributed.Client`` exists and is connected to a distributed scheduler. In this case this function will return as soon as the task graph has been submitted to the cluster, but before the computations have completed. Computations will continue asynchronously in the background. When using this function with the single machine scheduler it blocks until the computations have finished. When using Dask on a single machine you should ensure that the dataset fits entirely within memory. Examples -------- >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP >>> df = df[df.name == 'Alice'] # doctest: +SKIP >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP >>> df = df.persist() # triggers computation # doctest: +SKIP >>> df.value().min() # future computations are now fast # doctest: +SKIP -10 >>> df.value().max() # doctest: +SKIP 100 >>> from dask import persist # use persist function on multiple collections >>> a, b = persist(a, b) # doctest: +SKIP Parameters ---------- *args: Dask collections scheduler : string, optional Which scheduler to use like "threads", "synchronous" or "processes". If not provided, the default is to check the global settings first, and then fall back to the collection defaults. traverse : bool, optional By default dask traverses builtin python collections looking for dask objects passed to ``persist``. For large collections this can be expensive. If none of the arguments contain any dask objects, set ``traverse=False`` to avoid doing this traversal. optimize_graph : bool, optional If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging. **kwargs Extra keywords to forward to the scheduler function. Returns ------- New dask collections backed by in-memory data """ traverse = kwargs.pop('traverse', True) optimize_graph = kwargs.pop('optimize_graph', True) collections, repack = unpack_collections(*args, traverse=traverse) if not collections: return args schedule = get_scheduler(get=kwargs.pop('get', None), scheduler=kwargs.pop('scheduler', None), collections=collections) if inspect.ismethod(schedule): try: from distributed.client import default_client except ImportError: pass else: try: client = default_client() except ValueError: pass else: if client.get == schedule: results = client.persist(collections, optimize_graph=optimize_graph, **kwargs) return repack(results) dsk = collections_to_dsk(collections, optimize_graph, **kwargs) keys, postpersists = [], [] for a in collections: a_keys = list(flatten(a.__dask_keys__())) rebuild, state = a.__dask_postpersist__() keys.extend(a_keys) postpersists.append((rebuild, a_keys, state)) results = schedule(dsk, keys, **kwargs) d = dict(zip(keys, results)) results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists] return repack(results2)