def test_unpack_collections(): a = delayed(1) + 5 b = a + 1 c = a + 2 def build(a, b, c, iterator): t = ( a, b, # Top-level collections { "a": a, # dict a: b, # collections as keys "b": [1, 2, [b]], # list "c": 10, # other builtins pass through unchanged "d": (c, 2), # tuple "e": {a, 2, 3}, # set "f": OrderedDict([("a", a)]), }, # OrderedDict iterator, ) # Iterator if dataclasses is not None: t[2]["f"] = ADataClass(a=a) t[2]["g"] = (ADataClass, a) return t args = build(a, b, c, (i for i in [a, b, c])) collections, repack = unpack_collections(*args) assert len(collections) == 3 # Replace collections with `'~a'` strings result = repack(["~a", "~b", "~c"]) sol = build("~a", "~b", "~c", ["~a", "~b", "~c"]) assert result == sol # traverse=False collections, repack = unpack_collections(*args, traverse=False) assert len(collections) == 2 # just a and b assert repack(collections) == args # No collections collections, repack = unpack_collections(1, 2, {"a": 3}) assert not collections assert repack(collections) == (1, 2, {"a": 3}) # Result that looks like a task def fail(*args): raise ValueError("Shouldn't have been called") # pragma: nocover collections, repack = unpack_collections( a, (fail, 1), [(fail, 2, 3)], traverse=False ) repack(collections) # Smoketest task literals repack([(fail, 1)]) # Smoketest results that look like tasks
def test_unpack_collections(): a = delayed(1) + 5 b = a + 1 c = a + 2 def build(a, b, c, iterator): t = ( a, b, # Top-level collections { 'a': a, # dict a: b, # collections as keys 'b': [1, 2, [b]], # list 'c': 10, # other builtins pass through unchanged 'd': (c, 2), # tuple 'e': {a, 2, 3}, # set 'f': OrderedDict([('a', a)]) }, # OrderedDict iterator) # Iterator if dataclasses is not None: t[2]['f'] = ADataClass(a=a) return t args = build(a, b, c, (i for i in [a, b, c])) collections, repack = unpack_collections(*args) assert len(collections) == 3 # Replace collections with `'~a'` strings result = repack(['~a', '~b', '~c']) sol = build('~a', '~b', '~c', ['~a', '~b', '~c']) assert result == sol # traverse=False collections, repack = unpack_collections(*args, traverse=False) assert len(collections) == 2 # just a and b assert repack(collections) == args # No collections collections, repack = unpack_collections(1, 2, {'a': 3}) assert not collections assert repack(collections) == (1, 2, {'a': 3}) # Result that looks like a task def fail(*args): raise ValueError("Shouldn't have been called") collections, repack = unpack_collections(a, (fail, 1), [(fail, 2, 3)], traverse=False) repack(collections) # Smoketest task literals repack([(fail, 1)]) # Smoketest results that look like tasks
def test_unpack_collections(): a = delayed(1) + 5 b = a + 1 c = a + 2 def build(a, b, c, iterator): t = (a, b, # Top-level collections {'a': a, # dict a: b, # collections as keys 'b': [1, 2, [b]], # list 'c': 10, # other builtins pass through unchanged 'd': (c, 2), # tuple 'e': {a, 2, 3}}, # set iterator) # Iterator if dataclasses is not None: t[2]['f'] = ADataClass(a=a) return t args = build(a, b, c, (i for i in [a, b, c])) collections, repack = unpack_collections(*args) assert len(collections) == 3 # Replace collections with `'~a'` strings result = repack(['~a', '~b', '~c']) sol = build('~a', '~b', '~c', ['~a', '~b', '~c']) assert result == sol # traverse=False collections, repack = unpack_collections(*args, traverse=False) assert len(collections) == 2 # just a and b assert repack(collections) == args # No collections collections, repack = unpack_collections(1, 2, {'a': 3}) assert not collections assert repack(collections) == (1, 2, {'a': 3}) # Result that looks like a task def fail(*args): raise ValueError("Shouldn't have been called") collections, repack = unpack_collections(a, (fail, 1), [(fail, 2, 3)], traverse=False) repack(collections) # Smoketest task literals repack([(fail, 1)]) # Smoketest results that look like tasks
def checkpoint( *collections, split_every: float | Literal[False] | None = None, ) -> Delayed: """Build a :doc:`delayed` which waits until all chunks of the input collection(s) have been computed before returning None. Parameters ---------- collections Zero or more Dask collections or nested data structures containing zero or more collections split_every: int >= 2 or False, optional Determines the depth of the recursive aggregation. If greater than the number of input keys, the aggregation will be performed in multiple steps; the depth of the aggregation graph will be :math:`log_{split_every}(input keys)`. Setting to a low value can reduce cache size and network transfers, at the cost of more CPU and a larger dask graph. Set to False to disable. Defaults to 8. Returns ------- :doc:`delayed` yielding None """ if split_every is None: # FIXME https://github.com/python/typeshed/issues/5074 split_every = 8 # type: ignore elif split_every is not False: split_every = int(split_every) # type: ignore if split_every < 2: # type: ignore raise ValueError("split_every must be False, None, or >= 2") collections, _ = unpack_collections(*collections) if len(collections) == 1: return _checkpoint_one(collections[0], split_every) else: return delayed(chunks.checkpoint)(*(_checkpoint_one(c, split_every) for c in collections))
def wait_on( *collections, split_every: float | Literal[False] | None = None, ): """Ensure that all chunks of all input collections have been computed before computing the dependents of any of the chunks. The following example creates a dask array ``u`` that, when used in a computation, will only proceed when all chunks of the array ``x`` have been computed, but otherwise matches ``x``: >>> import dask.array as da >>> x = da.ones(10, chunks=5) >>> u = wait_on(x) The following example will create two arrays ``u`` and ``v`` that, when used in a computation, will only proceed when all chunks of the arrays ``x`` and ``y`` have been computed but otherwise match ``x`` and ``y``: >>> x = da.ones(10, chunks=5) >>> y = da.zeros(10, chunks=5) >>> u, v = wait_on(x, y) Parameters ---------- collections Zero or more Dask collections or nested structures of Dask collections split_every See :func:`checkpoint` Returns ------- Same as ``collections`` Dask collection of the same type as the input, which computes to the same value, or a nested structure equivalent to the input where the original collections have been replaced. """ blocker = checkpoint(*collections, split_every=split_every) def block_one(coll): tok = tokenize(coll, blocker) dsks = [] rename = {} for prev_name in get_collection_names(coll): new_name = "wait_on-" + tokenize(prev_name, tok) rename[prev_name] = new_name layer = _build_map_layer(chunks.bind, prev_name, new_name, coll, dependencies=(blocker, )) dsks.append( HighLevelGraph.from_collections(new_name, layer, dependencies=(coll, blocker))) dsk = HighLevelGraph.merge(*dsks) rebuild, args = coll.__dask_postpersist__() return rebuild(dsk, *args, rename=rename) unpacked, repack = unpack_collections(*collections) out = repack([block_one(coll) for coll in unpacked]) return out[0] if len(collections) == 1 else out
def bind( children: T, parents, *, omit=None, seed: Hashable = None, assume_layers: bool = True, split_every: float | Literal[False] | None = None, ) -> T: """ Make ``children`` collection(s), optionally omitting sub-collections, dependent on ``parents`` collection(s). Two examples follow. The first example creates an array ``b2`` whose computation first computes an array ``a`` completely and then computes ``b`` completely, recomputing ``a`` in the process: >>> import dask >>> import dask.array as da >>> a = da.ones(4, chunks=2) >>> b = a + 1 >>> b2 = bind(b, a) >>> len(b2.dask) 9 >>> b2.compute() array([2., 2., 2., 2.]) The second example creates arrays ``b3`` and ``c3``, whose computation first computes an array ``a`` and then computes the additions, this time not recomputing ``a`` in the process: >>> c = a + 2 >>> b3, c3 = bind((b, c), a, omit=a) >>> len(b3.dask), len(c3.dask) (7, 7) >>> dask.compute(b3, c3) (array([2., 2., 2., 2.]), array([3., 3., 3., 3.])) Parameters ---------- children Dask collection or nested structure of Dask collections parents Dask collection or nested structure of Dask collections omit Dask collection or nested structure of Dask collections seed Hashable used to seed the key regeneration. Omit to default to a random number that will produce different keys at every call. assume_layers True Use a fast algorithm that works at layer level, which assumes that all collections in ``children`` and ``omit`` #. use :class:`~dask.highlevelgraph.HighLevelGraph`, #. define the ``__dask_layers__()`` method, and #. never had their graphs squashed and rebuilt between the creation of the ``omit`` collections and the ``children`` collections; in other words if the keys of the ``omit`` collections can be found among the keys of the ``children`` collections, then the same must also hold true for the layers. False Use a slower algorithm that works at keys level, which makes none of the above assumptions. split_every See :func:`checkpoint` Returns ------- Same as ``children`` Dask collection or structure of dask collection equivalent to ``children``, which compute to the same values. All keys of ``children`` will be regenerated, up to and excluding the keys of ``omit``. Nodes immediately above ``omit``, or the leaf nodes if the collections in ``omit`` are not found, are prevented from computing until all collections in ``parents`` have been fully computed. """ if seed is None: seed = uuid.uuid4().bytes # parents=None is a special case invoked by the one-liner wrapper clone() below blocker = (checkpoint(parents, split_every=split_every) if parents is not None else None) omit, _ = unpack_collections(omit) if assume_layers: # Set of all the top-level layers of the collections in omit omit_layers = { layer for coll in omit for layer in coll.__dask_layers__() } omit_keys = set() else: omit_layers = set() # Set of *all* the keys, not just the top-level ones, of the collections in omit omit_keys = {key for coll in omit for key in coll.__dask_graph__()} unpacked_children, repack = unpack_collections(children) return repack([ _bind_one(child, blocker, omit_layers, omit_keys, seed) for child in unpacked_children ])[0]
def _get_dsk(node): d = node.todelayed() collections, repack = unpack_collections(d, traverse=False) return collections_to_dsk(collections, True)