示例#1
0
def dask_to_tfrecords(df,
                      folder,
                      compression_type="GZIP",
                      compression_level=9):
    """Store Dask.dataframe to TFRecord files."""
    makedirs(folder, exist_ok=True)
    compression_ext = get_compression_ext(compression_type)
    filenames = [
        get_part_filename(i, compression_ext) for i in range(df.npartitions)
    ]

    # Also write a meta data file
    write_meta(df, folder, compression_type)

    dsk = {}
    name = "to-tfrecord-" + tokenize(df, folder)
    part_tasks = []
    kwargs = {}

    for d, filename in enumerate(filenames):
        dsk[(name, d)] = (apply, pandas_df_to_tfrecords, [
            (df._name, d),
            os.path.join(folder, filename), compression_type, compression_level
        ], kwargs)
        part_tasks.append((name, d))

    dsk[name] = (lambda x: None, part_tasks)

    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[df])
    out = Delayed(name, graph)
    out = out.compute()
    return out
示例#2
0
def test_persist_delayed_custom_key(key):
    d = Delayed(key, {key: "b", "b": 1})
    assert d.compute() == 1
    dp = d.persist()
    assert dp.compute() == 1
    assert dp.key == key
    assert dict(dp.dask) == {key: 1}
示例#3
0
def test_persist_delayed_rename(key, rename, new_key):
    d = Delayed(key, {key: 1})
    assert d.compute() == 1
    rebuild, args = d.__dask_postpersist__()
    dp = rebuild({new_key: 2}, *args, rename=rename)
    assert dp.compute() == 2
    assert dp.key == new_key
    assert dict(dp.dask) == {new_key: 2}
示例#4
0
def _chunked_array_copy(spec: CopySpec) -> Delayed:
    """Chunked copy between arrays."""
    if spec.intermediate.array is None:
        target_store_delayed = _direct_array_copy(
            spec.read.array,
            spec.write.array,
            spec.read.chunks,
        )

        # fuse
        target_dsk = dask.utils.ensure_dict(target_store_delayed.dask)
        dsk_fused, _ = fuse(target_dsk)

        return Delayed(target_store_delayed.key, dsk_fused)

    else:
        # do intermediate store
        int_store_delayed = _direct_array_copy(
            spec.read.array,
            spec.intermediate.array,
            spec.read.chunks,
        )
        target_store_delayed = _direct_array_copy(
            spec.intermediate.array,
            spec.write.array,
            spec.write.chunks,
        )

        # now do some hacking to chain these together into a single graph.
        # get the two graphs as dicts
        int_dsk = dask.utils.ensure_dict(int_store_delayed.dask)
        target_dsk = dask.utils.ensure_dict(target_store_delayed.dask)

        # find the root store key representing the read
        root_keys = []
        for key in target_dsk:
            if isinstance(key, str):
                if key.startswith("from-zarr"):
                    root_keys.append(key)
        assert len(root_keys) == 1
        root_key = root_keys[0]

        # now rewrite the graph
        target_dsk[root_key] = (
            lambda a, *b: a,
            target_dsk[root_key],
            *int_dsk[int_store_delayed.key],
        )
        target_dsk.update(int_dsk)

        # fuse
        dsk_fused, _ = fuse(target_dsk)
        return Delayed(target_store_delayed.key, dsk_fused)
示例#5
0
def comp(dag, blocker_list):
    from copy import deepcopy
    dag = deepcopy(dag)
    _b = convert_ldicts_to_sdict(blocker_list)

    last_node = get_lastnode(dict(_b))
    if last_node != dag.key:
        dag = Delayed(last_node, _b)
    else:
        dag.dask = _b
    x = dag.compute()
    return x
示例#6
0
    def dask_finalise(sink: Delayed,
                      *deps,
                      extract=False,
                      strict=False,
                      return_value=_UNSET) -> Delayed:
        """

        When extract=True --> returns bytes (doubles memory requirements!!!)
        When extract=False -> returns return_value if supplied, or sink after completing everything
        """
        tk = tokenize(sink, extract, strict)
        delayed_close = dask.delayed(lambda sink, idx, *deps: sink.close(idx))
        parts = [
            delayed_close(sink,
                          idx,
                          *deps,
                          dask_key_name=(f"cog_close-{tk}", idx))
            for idx in range(8)
        ]

        def _copy_cog(sink, extract, strict, return_value, *parts):
            bb = sink._copy_cog(extract=extract, strict=strict)
            if return_value == _UNSET:
                return bb if extract else sink
            else:
                return return_value

        return dask.delayed(_copy_cog)(sink,
                                       extract,
                                       strict,
                                       return_value,
                                       *parts,
                                       dask_key_name=f"cog_copy-{tk}")
示例#7
0
def to_csv(df,
           filename,
           name_function=None,
           compression=None,
           get=None,
           compute=True,
           **kwargs):
    if compression:
        raise NotImplementedError("Writing compressed csv files not supported")
    name = 'to-csv-' + uuid.uuid1().hex

    kwargs2 = kwargs.copy()

    if name_function is None:
        name_function = build_name_function(df.npartitions - 1)

    if '*' in filename:
        if filename.count('*') > 1:
            raise ValueError(
                "A maximum of one asterisk is accepted in filename")

        if 'mode' in kwargs and kwargs['mode'] != 'w':
            raise ValueError(
                "to_csv does not support writing to multiple files in append mode, "
                "please specify mode='w'")

        formatted_names = [name_function(i) for i in range(df.npartitions)]
        if formatted_names != sorted(formatted_names):
            warn("To preserve order between partitions name_function "
                 "must preserve the order of its input")

        single_file = False
    else:
        kwargs2.update({'mode': 'a', 'header': False})
        single_file = True

    dsk = dict()
    dsk[(name,
         0)] = (lambda df, fn, kwargs: df.to_csv(fn, **kwargs), (df._name, 0),
                filename.replace('*', name_function(0)), kwargs)

    for i in range(1, df.npartitions):
        filename_i = filename.replace('*', name_function(i))

        task = (lambda df, fn, kwargs: df.to_csv(fn, **kwargs), (df._name, i),
                filename_i, kwargs2)
        if single_file:
            task = (_link, (name, i - 1), task)
        dsk[(name, i)] = task

    dsk = merge(dsk, df.dask)
    if single_file:
        keys = [(name, df.npartitions - 1)]
    else:
        keys = [(name, i) for i in range(df.npartitions)]

    if compute:
        return DataFrame._get(dsk, keys, get=get)
    else:
        return delayed([Delayed(key, [dsk]) for key in keys])
示例#8
0
文件: dag.py 项目: jakirkham/persist
def dask_to_collections(dask):
    funcs = dict()
    for key in dask.keys():
        # dsk, _ = cull(dask, key)
        # funcs[key] = Delayed(key, dsk)
        funcs[key] = Delayed(key, dask)
    return funcs
示例#9
0
def da_yxbt_sink(
    bands: Tuple[da.Array, ...], chunks: Tuple[int, ...], name="yxbt"
) -> da.Array:
    """
    each band is in <t,y,x>
    output is <y,x,b,t>

    eval(bands) |> transpose(YXBT) |> Store(RAM) |> DaskArray(RAM, chunks)
    """
    tk = tokenize(*bands)

    b = bands[0]
    dtype = b.dtype
    nt, ny, nx = b.shape
    nb = len(bands)
    shape = (ny, nx, nb, nt)

    token = Cache.dask_new(shape, dtype, f"{name}_alloc")

    sinks = [dask.delayed(_YXBTSink)(token, idx) for idx in range(nb)]
    fut = da.store(bands, sinks, lock=False, compute=False)
    sink_name = f"{name}_collect-{tk}"
    dsk = dict(fut.dask)
    dsk[sink_name] = (lambda *x: x[0], token.key, *fut.dask[fut.key])
    dsk = HighLevelGraph.from_collections(sink_name, dsk, dependencies=sinks)
    token_done = Delayed(sink_name, dsk)

    return _da_from_mem(token_done, shape=shape, dtype=dtype, chunks=chunks, name=name)
示例#10
0
文件: io.py 项目: gdmcbain/dask
def to_castra(df, fn=None, categories=None, sorted_index_column=None,
              compute=True, get=get_sync):
    """ Write DataFrame to Castra on-disk store

    See https://github.com/blosc/castra for details

    See Also
    --------
    Castra.to_dask
    """
    from castra import Castra
    if isinstance(categories, list):
        categories = (list, categories)

    name = 'to-castra-' + uuid.uuid1().hex

    if sorted_index_column:
        func = lambda part: (M.set_index, part, sorted_index_column)
    else:
        func = lambda part: part

    dsk = dict()
    dsk[(name, -1)] = (Castra, fn, func((df._name, 0)), categories)
    for i in range(0, df.npartitions):
        dsk[(name, i)] = (_link, (name, i - 1),
                          (Castra.extend, (name, -1), func((df._name, i))))

    dsk = merge(dsk, df.dask)
    keys = [(name, -1), (name, df.npartitions - 1)]
    if compute:
        return DataFrame._get(dsk, keys, get=get)[0]
    else:
        return delayed([Delayed(key, [dsk]) for key in keys])[0]
示例#11
0
    def fit(self, col_selector: ColumnSelector, ddf: dd.DataFrame):
        for group in col_selector.subgroups:
            if len(group.names) > 1:
                name = nvt_cat._make_name(*group.names, sep=self.name_sep)
                for col in group.names:
                    self.storage_name[col] = name

        # Check metadata type to reset on_host and cat_cache if the
        # underlying ddf is already a pandas-backed collection
        if isinstance(ddf._meta, pd.DataFrame):
            self.on_host = False
            # Cannot use "device" caching if the data is pandas-backed
            self.cat_cache = "host" if self.cat_cache == "device" else self.cat_cache

        dsk, key = nvt_cat._category_stats(
            ddf,
            nvt_cat.FitOptions(
                col_selector,
                self.cont_names,
                self.stats,
                self.out_path,
                0,
                self.tree_width,
                self.on_host,
                concat_groups=False,
                name_sep=self.name_sep,
            ),
        )
        return Delayed(key, dsk)
示例#12
0
    def dask_new(shape: ShapeLike, dtype: DtypeLike, name: str = "") -> Delayed:
        if name == "":
            name = f"mem_array_{str(dtype)}"

        name = name + "-" + tokenize(name, shape, dtype)
        dsk = {name: (Cache.new, shape, dtype)}
        return Delayed(name, dsk)
示例#13
0
文件: io.py 项目: zmyer/dask
def to_castra(df, fn=None, categories=None, sorted_index_column=None,
              compute=True, get=get_sync):
    """ Write DataFrame to Castra on-disk store

    The Castra project has been deprecated.  We recommend using Parquet instead.

    See Also
    --------
    Castra.to_dask
    """
    from castra import Castra

    name = 'to-castra-' + uuid.uuid1().hex

    if sorted_index_column:
        func = lambda part: (M.set_index, part, sorted_index_column)
    else:
        func = lambda part: part

    dsk = dict()
    dsk[(name, -1)] = (Castra, fn, func((df._name, 0)), categories)
    for i in range(0, df.npartitions):
        dsk[(name, i)] = (_link, (name, i - 1),
                          (Castra.extend, (name, -1), func((df._name, i))))

    dsk = merge(dsk, df.dask)
    keys = [(name, -1), (name, df.npartitions - 1)]
    if compute:
        return DataFrame._get(dsk, keys, get=get)[0]
    else:
        return delayed([Delayed(key, dsk) for key in keys])[0]
示例#14
0
def test_delayed_optimize():
    x = Delayed('b', {'a': 1,
                      'b': (inc, 'a'),
                      'c': (inc, 'b')})
    (x2,) = dask.optimize(x)
    # Delayed's __dask_optimize__ culls out 'c'
    assert sorted(x2.dask.keys()) == ['a', 'b']
示例#15
0
def test_annotations_survive_optimization():
    with dask.annotate(foo="bar"):
        graph = HighLevelGraph.from_collections(
            "b",
            {
                "a": 1,
                "b": (inc, "a"),
                "c": (inc, "b")
            },
            [],
        )
        d = Delayed("b", graph)

    assert type(d.dask) is HighLevelGraph
    assert len(d.dask.layers) == 1
    assert len(d.dask.layers["b"]) == 3
    assert d.dask.layers["b"].annotations == {"foo": "bar"}

    # Ensure optimizing a Delayed object returns a HighLevelGraph
    # and doesn't loose annotations
    (d_opt, ) = dask.optimize(d)
    assert type(d_opt.dask) is HighLevelGraph
    assert len(d_opt.dask.layers) == 1
    assert len(d_opt.dask.layers["b"]) == 2  # c is culled
    assert d_opt.dask.layers["b"].annotations == {"foo": "bar"}
示例#16
0
def _checkpoint_one(collection, split_every) -> Delayed:
    tok = tokenize(collection)
    name = "checkpoint-" + tok

    keys_iter = flatten(collection.__dask_keys__())
    try:
        next(keys_iter)
        next(keys_iter)
    except StopIteration:
        # Collection has 0 or 1 keys; no need for a map step
        layer = {name: (chunks.checkpoint, collection.__dask_keys__())}
        dsk = HighLevelGraph.from_collections(name,
                                              layer,
                                              dependencies=(collection, ))
        return Delayed(name, dsk)

    # Collection has 2+ keys; apply a two-step map->reduce algorithm so that we
    # transfer over the network and store in RAM only a handful of None's instead of
    # the full computed collection's contents
    dsks = []
    map_names = set()
    map_keys = []

    for prev_name in get_collection_names(collection):
        map_name = "checkpoint_map-" + tokenize(prev_name, tok)
        map_names.add(map_name)
        map_layer = _build_map_layer(chunks.checkpoint, prev_name, map_name,
                                     collection)
        map_keys += list(map_layer.get_output_keys())
        dsks.append(
            HighLevelGraph.from_collections(map_name,
                                            map_layer,
                                            dependencies=(collection, )))

    # recursive aggregation
    reduce_layer: dict = {}
    while split_every and len(map_keys) > split_every:
        k = (name, len(reduce_layer))
        reduce_layer[k] = (chunks.checkpoint, map_keys[:split_every])
        map_keys = map_keys[split_every:] + [k]
    reduce_layer[name] = (chunks.checkpoint, map_keys)

    dsks.append(
        HighLevelGraph({name: reduce_layer}, dependencies={name: map_names}))
    dsk = HighLevelGraph.merge(*dsks)

    return Delayed(name, dsk)
示例#17
0
def fuse_delayed(tasks: dask.delayed) -> dask.delayed:
    """
    Apply task fusion optimization to tasks. Useful (or even required)
    because dask.delayed optimization doesn't do this step.
    """
    dsk_fused, deps = fuse(dask.utils.ensure_dict(tasks.dask))
    fused = Delayed(tasks._key, dsk_fused)
    return fused
示例#18
0
def test_dask_layers():
    d1 = delayed(1)
    assert d1.dask.layers.keys() == {d1.key}
    assert d1.dask.dependencies == {d1.key: set()}
    assert d1.__dask_layers__() == (d1.key,)
    d2 = modlevel_delayed1(d1)
    assert d2.dask.layers.keys() == {d1.key, d2.key}
    assert d2.dask.dependencies == {d1.key: set(), d2.key: {d1.key}}
    assert d2.__dask_layers__() == (d2.key,)

    hlg = HighLevelGraph.from_collections("foo", {"alias": d2.key}, dependencies=[d2])
    with pytest.raises(ValueError, match="not in"):
        Delayed("alias", hlg)

    explicit = Delayed("alias", hlg, layer="foo")
    assert explicit.__dask_layers__() == ("foo",)
    explicit.dask.validate()
示例#19
0
文件: base.py 项目: biglyan/dask
def redict_collection(c, dsk):
    from dask.delayed import Delayed
    if isinstance(c, Delayed):
        return Delayed(c.key, [dsk])
    else:
        cc = copy.copy(c)
        cc.dask = dsk
        return cc
示例#20
0
文件: core.py 项目: tym1062/dask-cudf
def to_delayed(df):
    """ Create Dask Delayed objects from a dask_cudf Dataframe
    Returns a list of delayed values, one value per partition.
    """
    from dask.delayed import Delayed

    keys = df.__dask_keys__()
    dsk = df.__dask_optimize__(df.dask, keys)
    return [Delayed(k, dsk) for k in keys]
示例#21
0
def decorate_delayed(delayed_func, dump, decorate_mode=None):
    key = delayed_func._key
    dsk = dict(delayed_func.dask)
    if decorate_mode is None:
        # replace the function by decorated ones (with standard mechanism)
        task = list(dsk[key])
        task[0] = dump_result(dump, task[0], key)
        dsk[key] = tuple(task)
        return Delayed(key, dsk)
示例#22
0
 def _fit_transform(self, X, y, **kwargs):
     token = tokenize(self, X, y, kwargs)
     fit_tr_name = 'fit-transform-' + token
     fit_name = 'fit-' + token
     tr_name = 'tr-' + token
     X, y, dsk = unpack_arguments(X, y)
     dsk[fit_tr_name] = (_fit_transform, self._name, X, y, kwargs)
     dsk1 = merge({fit_name: (getitem, fit_tr_name, 0)}, dsk, self.dask)
     dsk2 = merge({tr_name: (getitem, fit_tr_name, 1)}, dsk, self.dask)
     return Wrapped(self._est, dsk1, fit_name), Delayed(tr_name, [dsk2])
示例#23
0
    def fit(self, columns: ColumnNames, ddf: dd.DataFrame):
        # User passed in a list of column groups. We need to figure out
        # if this list contains any multi-column groups, and if there
        # are any (obvious) problems with these groups
        columns_uniq = list(set(flatten(columns, container=tuple)))
        columns_all = list(flatten(columns, container=tuple))
        if sorted(columns_all) != sorted(
                columns_uniq) and self.encode_type == "joint":
            # If we are doing "joint" encoding, there must be unique mapping
            # between input column names and column groups.  Otherwise, more
            # than one unique-value table could be used to encode the same
            # column.
            raise ValueError("Same column name included in multiple groups.")

        for group in columns:
            if isinstance(group, tuple) and len(group) > 1:
                # For multi-column groups, we concatenate column names
                # to get the "group" name.
                name = _make_name(*group, sep=self.name_sep)
                for col in group:
                    self.storage_name[col] = name

        # Check metadata type to reset on_host and cat_cache if the
        # underlying ddf is already a pandas-backed collection
        if isinstance(ddf._meta, pd.DataFrame):
            self.on_host = False
            # Cannot use "device" caching if the data is pandas-backed
            self.cat_cache = "host" if self.cat_cache == "device" else self.cat_cache
            if self.search_sorted:
                # Pandas' search_sorted only works with Series.
                # For now, it is safest to disallow this option.
                self.search_sorted = False
                warnings.warn(
                    "Cannot use `search_sorted=True` for pandas-backed data.")

        # convert tuples to lists
        columns = [list(c) if isinstance(c, tuple) else c for c in columns]
        dsk, key = _category_stats(
            ddf,
            columns,
            [],
            [],
            self.out_path,
            self.freq_threshold,
            self.tree_width,
            self.on_host,
            concat_groups=self.encode_type == "joint",
            name_sep=self.name_sep,
            max_size=self.max_size,
            num_buckets=self.num_buckets,
        )
        # TODO: we can't check the dtypes on the ddf here since they are incorrect
        # for cudf's list type. So, we're checking the partitions. fix.
        return Delayed(key,
                       dsk), ddf.map_partitions(lambda df: _is_list_dtype(df))
示例#24
0
 def _fit_transform(self, X, y, **kwargs):
     clsname = type(self._est).__name__
     token = tokenize(self, X, y, kwargs)
     fit_tr_name = 'fit-transform-%s-%s' % (clsname, token)
     fit_name = 'fit-%s-%s' % (clsname, token)
     tr_name = 'tr-%s-%s' % (clsname, token)
     X, y, dsk = unpack_arguments(X, y)
     dsk[fit_tr_name] = (_fit_transform, self._name, X, y, kwargs)
     dsk1 = merge({fit_name: (getitem, fit_tr_name, 0)}, dsk, self.dask)
     dsk2 = merge({tr_name: (getitem, fit_tr_name, 1)}, dsk, self.dask)
     return Wrapped(self._est, dsk1, fit_name), Delayed(tr_name, [dsk2])
示例#25
0
 def to_delayed(self, optimize_graph: bool = True) -> list[Delayed]:
     keys = self.__dask_keys__()
     graph = self.__dask_graph__()
     layer = self.__dask_layers__()[0]
     if optimize_graph:
         graph = self.__dask_optimize__(graph, keys)
         layer = f"delayed-{self.name}"
         graph = HighLevelGraph.from_collections(layer,
                                                 graph,
                                                 dependencies=())
     return [Delayed(k, graph, layer=layer) for k in keys]
示例#26
0
 def transform(self, raw_X, y=None):
     name = 'transform-' + tokenize(self, raw_X)
     sk_est = self._est
     if isinstance(raw_X, db.Bag):
         dsk = dict(((name, i), (_transform, sk_est, k))
                    for (i, k) in enumerate(raw_X._keys()))
         dsk.update(raw_X.dask)
         return dm.Matrix(dsk, name, raw_X.npartitions, dtype=self.dtype,
                          shape=(None, self.n_features))
     raw_X, dsk = unpack_arguments(raw_X)
     dsk[name] = (_transform, sk_est, raw_X)
     return Delayed(name, [dsk])
示例#27
0
def _ddf_to_dataset(
    ddf,
    fs,
    output_path,
    shuffle,
    out_files_per_proc,
    cat_names,
    cont_names,
    label_names,
    output_format,
    client,
    num_threads,
    cpu,
):
    # Construct graph for Dask-based dataset write
    name = "write-processed"
    write_name = name + tokenize(
        ddf, shuffle, out_files_per_proc, cat_names, cont_names, label_names
    )
    # Check that the data is in the correct place
    assert isinstance(ddf._meta, pd.DataFrame) is cpu
    task_list = []
    dsk = {}
    for idx in range(ddf.npartitions):
        key = (write_name, idx)
        dsk[key] = (
            _write_output_partition,
            (ddf._name, idx),
            output_path,
            shuffle,
            out_files_per_proc,
            fs,
            cat_names,
            cont_names,
            label_names,
            output_format,
            num_threads,
            cpu,
        )
        task_list.append(key)
    dsk[name] = (lambda x: x, task_list)
    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[ddf])
    out = Delayed(name, graph)

    # Trigger write execution
    if client:
        out = client.compute(out).result()
    else:
        out = dask.compute(out, scheduler="synchronous")[0]

    # Follow-up Shuffling and _metadata creation
    _finish_dataset(client, ddf, output_path, fs, output_format, cpu)
示例#28
0
 def predict(self, X):
     name = 'predict-' + tokenize(self, X)
     if isinstance(X, (da.Array, dm.Matrix, db.Bag)):
         keys, dsk = unpack_as_lists_of_keys(X)
         dsk.update(((name, i), (_predict, self._name, k))
                    for (i, k) in enumerate(keys))
         dsk.update(self.dask)
         if isinstance(X, da.Array):
             return da.Array(dsk, name, chunks=(X.chunks[0], ))
         return dm.Matrix(dsk, name, npartitions=len(keys))
     X, dsk = unpack_arguments(X)
     dsk[name] = (_predict, self._name, X)
     return Delayed(name, [dsk, self.dask])
示例#29
0
def _make_pipeline(pipeline: Pipeline) -> Delayed:
    token = dask.base.tokenize(pipeline)

    # we are constructing a HighLevelGraph from scratch
    # https://docs.dask.org/en/latest/high-level-graphs.html
    layers = dict()  # type: Dict[str, Dict[Union[str, Tuple[str, int]], Any]]
    dependencies = dict()  # type: Dict[str, Set[str]]

    # start with just the config as a standalone layer
    # create a custom delayed object for the config
    config_key = append_token("config", token)
    layers[config_key] = {config_key: pipeline.config}
    dependencies[config_key] = set()

    prev_key: str = config_key
    for stage in pipeline.stages:
        if stage.mappable is None:
            stage_key = append_token(stage.name, token)
            func = wrap_standalone_task(stage.function)
            layers[stage_key] = {stage_key: (func, config_key, prev_key)}
            dependencies[stage_key] = {config_key, prev_key}
        else:
            func = wrap_map_task(stage.function)
            map_key = append_token(stage.name, token)
            layers[map_key] = map_layer = blockwise(
                func,
                map_key,
                "x",  # <-- dimension name doesn't matter
                BlockwiseDepDict({(i, ): x
                                  for i, x in enumerate(stage.mappable)}),
                # ^ this is extra annoying. `BlockwiseDepList` at least would be nice.
                "x",
                config_key,
                None,
                prev_key,
                None,
                numblocks={},
                # ^ also annoying; the default of None breaks Blockwise
            )
            dependencies[map_key] = {config_key, prev_key}

            stage_key = f"{stage.name}-checkpoint-{token}"
            layers[stage_key] = {
                stage_key: (checkpoint, *map_layer.get_output_keys())
            }
            dependencies[stage_key] = {map_key}
        prev_key = stage_key

    hlg = HighLevelGraph(layers, dependencies)
    delayed = Delayed(prev_key, hlg)
    return delayed
示例#30
0
def delayed_using_cache(delayed,
                        serializers=None,
                        cache=None,
                        *args,
                        **kwargs):
    """
    *args and **kwargs are passed to collections_to_dsk
    """
    key = delayed._key
    dsk = delayed.dask
    collections = dask_to_collections(dsk)
    collections = collections.values()
    persistent_dsk = persistent_collections_to_dsk(collections, key,
                                                   serializers, cache, *args,
                                                   **kwargs)
    return Delayed(key, persistent_dsk)