def test_get_collection_names(): class DummyCollection: def __init__(self, dsk, keys): self.dask = dsk self.keys = keys def __dask_graph__(self): return self.dask def __dask_keys__(self): return self.keys with pytest.raises(TypeError): get_collection_names(object()) # Keys must either be a string or a tuple where the first element is a string with pytest.raises(TypeError): get_collection_names(DummyCollection({1: 2}, [1])) with pytest.raises(TypeError): get_collection_names(DummyCollection({(): 1}, [()])) with pytest.raises(TypeError): get_collection_names(DummyCollection({(1,): 1}, [(1,)])) assert get_collection_names(DummyCollection({}, [])) == set() # Arbitrary hashables h1 = object() h2 = object() # __dask_keys__() returns a nested list assert get_collection_names( DummyCollection( {("a-1", h1): 1, ("a-1", h2): 2, "b-2": 3, "c": 4}, [[[("a-1", h1), ("a-1", h2), "b-2", "c"]]], ) ) == {"a-1", "b-2", "c"}
def _checkpoint_one(collection, split_every) -> Delayed: tok = tokenize(collection) name = "checkpoint-" + tok keys_iter = flatten(collection.__dask_keys__()) try: next(keys_iter) next(keys_iter) except StopIteration: # Collection has 0 or 1 keys; no need for a map step layer = {name: (chunks.checkpoint, collection.__dask_keys__())} dsk = HighLevelGraph.from_collections(name, layer, dependencies=(collection, )) return Delayed(name, dsk) # Collection has 2+ keys; apply a two-step map->reduce algorithm so that we # transfer over the network and store in RAM only a handful of None's instead of # the full computed collection's contents dsks = [] map_names = set() map_keys = [] for prev_name in get_collection_names(collection): map_name = "checkpoint_map-" + tokenize(prev_name, tok) map_names.add(map_name) map_layer = _build_map_layer(chunks.checkpoint, prev_name, map_name, collection) map_keys += list(map_layer.get_output_keys()) dsks.append( HighLevelGraph.from_collections(map_name, map_layer, dependencies=(collection, ))) # recursive aggregation reduce_layer: dict = {} while split_every and len(map_keys) > split_every: k = (name, len(reduce_layer)) reduce_layer[k] = (chunks.checkpoint, map_keys[:split_every]) map_keys = map_keys[split_every:] + [k] reduce_layer[name] = (chunks.checkpoint, map_keys) dsks.append( HighLevelGraph({name: reduce_layer}, dependencies={name: map_names})) dsk = HighLevelGraph.merge(*dsks) return Delayed(name, dsk)
def block_one(coll): tok = tokenize(coll, blocker) dsks = [] rename = {} for prev_name in get_collection_names(coll): new_name = "wait_on-" + tokenize(prev_name, tok) rename[prev_name] = new_name layer = _build_map_layer( chunks.bind, prev_name, new_name, coll, dependencies=(blocker,) ) dsks.append( HighLevelGraph.from_collections( new_name, layer, dependencies=(coll, blocker) ) ) dsk = HighLevelGraph.merge(*dsks) rebuild, args = coll.__dask_postpersist__() return rebuild(dsk, *args, rename=rename)
def test_custom_collection(): # Arbitrary hashables h1 = object() h2 = object() dsk = {("x", h1): 1, ("x", h2): 2} dsk2 = { ("y", h1): (add, ("x", h1), ("x", h2)), ("y", h2): (add, ("y", h1), 1) } dsk2.update(dsk) dsk3 = {"z": (add, ("y", h1), ("y", h2))} dsk3.update(dsk2) w = Tuple({}, []) # A collection can have no keys at all x = Tuple(dsk, [("x", h1), ("x", h2)]) y = Tuple(dsk2, [("y", h1), ("y", h2)]) z = Tuple(dsk3, ["z"]) # Collection with multiple names t = w + x + y + z # __slots__ defined on base mixin class propagates with pytest.raises(AttributeError): x.foo = 1 # is_dask_collection assert is_dask_collection(w) assert is_dask_collection(x) assert is_dask_collection(y) assert is_dask_collection(z) assert is_dask_collection(t) # tokenize assert tokenize(w) == tokenize(w) assert tokenize(x) == tokenize(x) assert tokenize(y) == tokenize(y) assert tokenize(z) == tokenize(z) assert tokenize(t) == tokenize(t) # All tokens are unique assert len({tokenize(coll) for coll in (w, x, y, z, t)}) == 5 # get_collection_names assert get_collection_names(w) == set() assert get_collection_names(x) == {"x"} assert get_collection_names(y) == {"y"} assert get_collection_names(z) == {"z"} assert get_collection_names(t) == {"x", "y", "z"} # compute assert w.compute() == () assert x.compute() == (1, 2) assert y.compute() == (3, 4) assert z.compute() == (7, ) assert dask.compute(w, [{ "x": x }, y, z]) == ((), [{ "x": (1, 2) }, (3, 4), (7, )]) assert t.compute() == (1, 2, 3, 4, 7) # persist t2 = t.persist() assert isinstance(t2, Tuple) assert t2._keys == t._keys assert sorted(t2._dask.values()) == [1, 2, 3, 4, 7] assert t2.compute() == (1, 2, 3, 4, 7) w2, x2, y2, z2 = dask.persist(w, x, y, z) assert y2._keys == y._keys assert y2._dask == {("y", h1): 3, ("y", h2): 4} assert y2.compute() == (3, 4) t3 = x2 + y2 + z2 assert t3.compute() == (1, 2, 3, 4, 7) # __dask_postpersist__ with name change rebuild, args = w.__dask_postpersist__() w3 = rebuild({}, *args, rename={"w": "w3"}) assert w3.compute() == () rebuild, args = x.__dask_postpersist__() x3 = rebuild({("x3", h1): 10, ("x3", h2): 20}, *args, rename={"x": "x3"}) assert x3.compute() == (10, 20) rebuild, args = z.__dask_postpersist__() z3 = rebuild({"z3": 70}, *args, rename={"z": "z3"}) assert z3.compute() == (70, )
def __dask_layers__(self): return tuple(get_collection_names(self))
def _bind_one( child: T, blocker: Delayed | None, omit_layers: set[str], omit_keys: set[Hashable], seed: Hashable, ) -> T: prev_coll_names = get_collection_names(child) if not prev_coll_names: # Collection with no keys; this is a legitimate use case but, at the moment of # writing, can only happen with third-party collections return child dsk = child.__dask_graph__() # type: ignore new_layers: dict[str, Layer] = {} new_deps: dict[str, Set[Any]] = {} if isinstance(dsk, HighLevelGraph): try: layers_to_clone = set(child.__dask_layers__()) # type: ignore except AttributeError: layers_to_clone = prev_coll_names.copy() else: if len(prev_coll_names) == 1: hlg_name = next(iter(prev_coll_names)) else: hlg_name = tokenize(*prev_coll_names) dsk = HighLevelGraph.from_collections(hlg_name, dsk) layers_to_clone = {hlg_name} clone_keys = dsk.get_all_external_keys() - omit_keys for layer_name in omit_layers: try: layer = dsk.layers[layer_name] except KeyError: continue clone_keys -= layer.get_output_keys() # Note: when assume_layers=True, clone_keys can contain keys of the omit collections # that are not top-level. This is OK, as they will never be encountered inside the # values of their dependent layers. if blocker is not None: blocker_key = blocker.key blocker_dsk = blocker.__dask_graph__() assert isinstance(blocker_dsk, HighLevelGraph) new_layers.update(blocker_dsk.layers) new_deps.update(blocker_dsk.dependencies) else: blocker_key = None layers_to_copy_verbatim = set() while layers_to_clone: prev_layer_name = layers_to_clone.pop() new_layer_name = clone_key(prev_layer_name, seed=seed) if new_layer_name in new_layers: continue layer = dsk.layers[prev_layer_name] layer_deps = dsk.dependencies[prev_layer_name] layer_deps_to_clone = layer_deps - omit_layers layer_deps_to_omit = layer_deps & omit_layers layers_to_clone |= layer_deps_to_clone layers_to_copy_verbatim |= layer_deps_to_omit new_layers[new_layer_name], is_bound = layer.clone(keys=clone_keys, seed=seed, bind_to=blocker_key) new_dep = {clone_key(dep, seed=seed) for dep in layer_deps_to_clone} | layer_deps_to_omit if is_bound: new_dep.add(blocker_key) new_deps[new_layer_name] = new_dep # Add the layers of the collections from omit from child.dsk. Note that, when # assume_layers=False, it would be unsafe to simply do HighLevelGraph.merge(dsk, # omit[i].dsk). Also, collections in omit may or may not be parents of this specific # child, or of any children at all. while layers_to_copy_verbatim: layer_name = layers_to_copy_verbatim.pop() if layer_name in new_layers: continue layer_deps = dsk.dependencies[layer_name] layers_to_copy_verbatim |= layer_deps new_deps[layer_name] = layer_deps new_layers[layer_name] = dsk.layers[layer_name] rebuild, args = child.__dask_postpersist__() # type: ignore return rebuild( HighLevelGraph(new_layers, new_deps), *args, rename={ prev_name: clone_key(prev_name, seed) for prev_name in prev_coll_names }, )