Пример #1
0
def test_get_collection_names():
    class DummyCollection:
        def __init__(self, dsk, keys):
            self.dask = dsk
            self.keys = keys

        def __dask_graph__(self):
            return self.dask

        def __dask_keys__(self):
            return self.keys

    with pytest.raises(TypeError):
        get_collection_names(object())
    # Keys must either be a string or a tuple where the first element is a string
    with pytest.raises(TypeError):
        get_collection_names(DummyCollection({1: 2}, [1]))
    with pytest.raises(TypeError):
        get_collection_names(DummyCollection({(): 1}, [()]))
    with pytest.raises(TypeError):
        get_collection_names(DummyCollection({(1,): 1}, [(1,)]))

    assert get_collection_names(DummyCollection({}, [])) == set()

    # Arbitrary hashables
    h1 = object()
    h2 = object()
    # __dask_keys__() returns a nested list
    assert get_collection_names(
        DummyCollection(
            {("a-1", h1): 1, ("a-1", h2): 2, "b-2": 3, "c": 4},
            [[[("a-1", h1), ("a-1", h2), "b-2", "c"]]],
        )
    ) == {"a-1", "b-2", "c"}
Пример #2
0
def _checkpoint_one(collection, split_every) -> Delayed:
    tok = tokenize(collection)
    name = "checkpoint-" + tok

    keys_iter = flatten(collection.__dask_keys__())
    try:
        next(keys_iter)
        next(keys_iter)
    except StopIteration:
        # Collection has 0 or 1 keys; no need for a map step
        layer = {name: (chunks.checkpoint, collection.__dask_keys__())}
        dsk = HighLevelGraph.from_collections(name,
                                              layer,
                                              dependencies=(collection, ))
        return Delayed(name, dsk)

    # Collection has 2+ keys; apply a two-step map->reduce algorithm so that we
    # transfer over the network and store in RAM only a handful of None's instead of
    # the full computed collection's contents
    dsks = []
    map_names = set()
    map_keys = []

    for prev_name in get_collection_names(collection):
        map_name = "checkpoint_map-" + tokenize(prev_name, tok)
        map_names.add(map_name)
        map_layer = _build_map_layer(chunks.checkpoint, prev_name, map_name,
                                     collection)
        map_keys += list(map_layer.get_output_keys())
        dsks.append(
            HighLevelGraph.from_collections(map_name,
                                            map_layer,
                                            dependencies=(collection, )))

    # recursive aggregation
    reduce_layer: dict = {}
    while split_every and len(map_keys) > split_every:
        k = (name, len(reduce_layer))
        reduce_layer[k] = (chunks.checkpoint, map_keys[:split_every])
        map_keys = map_keys[split_every:] + [k]
    reduce_layer[name] = (chunks.checkpoint, map_keys)

    dsks.append(
        HighLevelGraph({name: reduce_layer}, dependencies={name: map_names}))
    dsk = HighLevelGraph.merge(*dsks)

    return Delayed(name, dsk)
Пример #3
0
 def block_one(coll):
     tok = tokenize(coll, blocker)
     dsks = []
     rename = {}
     for prev_name in get_collection_names(coll):
         new_name = "wait_on-" + tokenize(prev_name, tok)
         rename[prev_name] = new_name
         layer = _build_map_layer(
             chunks.bind, prev_name, new_name, coll, dependencies=(blocker,)
         )
         dsks.append(
             HighLevelGraph.from_collections(
                 new_name, layer, dependencies=(coll, blocker)
             )
         )
     dsk = HighLevelGraph.merge(*dsks)
     rebuild, args = coll.__dask_postpersist__()
     return rebuild(dsk, *args, rename=rename)
Пример #4
0
def test_custom_collection():
    # Arbitrary hashables
    h1 = object()
    h2 = object()

    dsk = {("x", h1): 1, ("x", h2): 2}
    dsk2 = {
        ("y", h1): (add, ("x", h1), ("x", h2)),
        ("y", h2): (add, ("y", h1), 1)
    }
    dsk2.update(dsk)
    dsk3 = {"z": (add, ("y", h1), ("y", h2))}
    dsk3.update(dsk2)

    w = Tuple({}, [])  # A collection can have no keys at all
    x = Tuple(dsk, [("x", h1), ("x", h2)])
    y = Tuple(dsk2, [("y", h1), ("y", h2)])
    z = Tuple(dsk3, ["z"])
    # Collection with multiple names
    t = w + x + y + z

    # __slots__ defined on base mixin class propagates
    with pytest.raises(AttributeError):
        x.foo = 1

    # is_dask_collection
    assert is_dask_collection(w)
    assert is_dask_collection(x)
    assert is_dask_collection(y)
    assert is_dask_collection(z)
    assert is_dask_collection(t)

    # tokenize
    assert tokenize(w) == tokenize(w)
    assert tokenize(x) == tokenize(x)
    assert tokenize(y) == tokenize(y)
    assert tokenize(z) == tokenize(z)
    assert tokenize(t) == tokenize(t)
    # All tokens are unique
    assert len({tokenize(coll) for coll in (w, x, y, z, t)}) == 5

    # get_collection_names
    assert get_collection_names(w) == set()
    assert get_collection_names(x) == {"x"}
    assert get_collection_names(y) == {"y"}
    assert get_collection_names(z) == {"z"}
    assert get_collection_names(t) == {"x", "y", "z"}

    # compute
    assert w.compute() == ()
    assert x.compute() == (1, 2)
    assert y.compute() == (3, 4)
    assert z.compute() == (7, )
    assert dask.compute(w, [{
        "x": x
    }, y, z]) == ((), [{
        "x": (1, 2)
    }, (3, 4), (7, )])
    assert t.compute() == (1, 2, 3, 4, 7)

    # persist
    t2 = t.persist()
    assert isinstance(t2, Tuple)
    assert t2._keys == t._keys
    assert sorted(t2._dask.values()) == [1, 2, 3, 4, 7]
    assert t2.compute() == (1, 2, 3, 4, 7)

    w2, x2, y2, z2 = dask.persist(w, x, y, z)
    assert y2._keys == y._keys
    assert y2._dask == {("y", h1): 3, ("y", h2): 4}
    assert y2.compute() == (3, 4)

    t3 = x2 + y2 + z2
    assert t3.compute() == (1, 2, 3, 4, 7)

    # __dask_postpersist__ with name change
    rebuild, args = w.__dask_postpersist__()
    w3 = rebuild({}, *args, rename={"w": "w3"})
    assert w3.compute() == ()

    rebuild, args = x.__dask_postpersist__()
    x3 = rebuild({("x3", h1): 10, ("x3", h2): 20}, *args, rename={"x": "x3"})
    assert x3.compute() == (10, 20)

    rebuild, args = z.__dask_postpersist__()
    z3 = rebuild({"z3": 70}, *args, rename={"z": "z3"})
    assert z3.compute() == (70, )
Пример #5
0
 def __dask_layers__(self):
     return tuple(get_collection_names(self))
Пример #6
0
def _bind_one(
    child: T,
    blocker: Delayed | None,
    omit_layers: set[str],
    omit_keys: set[Hashable],
    seed: Hashable,
) -> T:
    prev_coll_names = get_collection_names(child)
    if not prev_coll_names:
        # Collection with no keys; this is a legitimate use case but, at the moment of
        # writing, can only happen with third-party collections
        return child

    dsk = child.__dask_graph__()  # type: ignore
    new_layers: dict[str, Layer] = {}
    new_deps: dict[str, Set[Any]] = {}

    if isinstance(dsk, HighLevelGraph):
        try:
            layers_to_clone = set(child.__dask_layers__())  # type: ignore
        except AttributeError:
            layers_to_clone = prev_coll_names.copy()
    else:
        if len(prev_coll_names) == 1:
            hlg_name = next(iter(prev_coll_names))
        else:
            hlg_name = tokenize(*prev_coll_names)
        dsk = HighLevelGraph.from_collections(hlg_name, dsk)
        layers_to_clone = {hlg_name}

    clone_keys = dsk.get_all_external_keys() - omit_keys
    for layer_name in omit_layers:
        try:
            layer = dsk.layers[layer_name]
        except KeyError:
            continue
        clone_keys -= layer.get_output_keys()
    # Note: when assume_layers=True, clone_keys can contain keys of the omit collections
    # that are not top-level. This is OK, as they will never be encountered inside the
    # values of their dependent layers.

    if blocker is not None:
        blocker_key = blocker.key
        blocker_dsk = blocker.__dask_graph__()
        assert isinstance(blocker_dsk, HighLevelGraph)
        new_layers.update(blocker_dsk.layers)
        new_deps.update(blocker_dsk.dependencies)
    else:
        blocker_key = None

    layers_to_copy_verbatim = set()

    while layers_to_clone:
        prev_layer_name = layers_to_clone.pop()
        new_layer_name = clone_key(prev_layer_name, seed=seed)
        if new_layer_name in new_layers:
            continue

        layer = dsk.layers[prev_layer_name]
        layer_deps = dsk.dependencies[prev_layer_name]
        layer_deps_to_clone = layer_deps - omit_layers
        layer_deps_to_omit = layer_deps & omit_layers
        layers_to_clone |= layer_deps_to_clone
        layers_to_copy_verbatim |= layer_deps_to_omit

        new_layers[new_layer_name], is_bound = layer.clone(keys=clone_keys,
                                                           seed=seed,
                                                           bind_to=blocker_key)
        new_dep = {clone_key(dep, seed=seed)
                   for dep in layer_deps_to_clone} | layer_deps_to_omit
        if is_bound:
            new_dep.add(blocker_key)
        new_deps[new_layer_name] = new_dep

    # Add the layers of the collections from omit from child.dsk. Note that, when
    # assume_layers=False, it would be unsafe to simply do HighLevelGraph.merge(dsk,
    # omit[i].dsk). Also, collections in omit may or may not be parents of this specific
    # child, or of any children at all.
    while layers_to_copy_verbatim:
        layer_name = layers_to_copy_verbatim.pop()
        if layer_name in new_layers:
            continue
        layer_deps = dsk.dependencies[layer_name]
        layers_to_copy_verbatim |= layer_deps
        new_deps[layer_name] = layer_deps
        new_layers[layer_name] = dsk.layers[layer_name]

    rebuild, args = child.__dask_postpersist__()  # type: ignore
    return rebuild(
        HighLevelGraph(new_layers, new_deps),
        *args,
        rename={
            prev_name: clone_key(prev_name, seed)
            for prev_name in prev_coll_names
        },
    )