def test_bind(layers):
    dsk1 = {("a-1", h1): 1, ("a-1", h2): 2}
    dsk2 = {"b-1": (add, ("a-1", h1), ("a-1", h2))}
    dsk3 = {"c-1": "b-1"}
    cnt = NodeCounter()
    dsk4 = {("d-1", h1): (cnt.f, 1), ("d-1", h2): (cnt.f, 2)}
    dsk4b = {"e": (cnt.f, 3)}

    if layers:
        dsk1 = HighLevelGraph.from_collections("a-1", dsk1)
        dsk2 = HighLevelGraph({
            "a-1": dsk1,
            "b-1": dsk2
        },
                              dependencies={
                                  "a-1": set(),
                                  "b-1": {"a-1"}
                              })
        dsk3 = HighLevelGraph(
            {
                "a-1": dsk1,
                "b-1": dsk2,
                "c-1": dsk3
            },
            dependencies={
                "a-1": set(),
                "b-1": {"a-1"},
                "c-1": {"b-1"}
            },
        )
        dsk4 = HighLevelGraph({"d-1": dsk4, "e": dsk4b}, {})
    else:
        dsk2.update(dsk1)
        dsk3.update(dsk2)
        dsk4.update(dsk4b)

    # t1 = Tuple(dsk1, [("a", h1), ("a", h2)])
    t2 = Tuple(dsk2, ["b-1"])
    t3 = Tuple(dsk3, ["c-1"])
    t4 = Tuple(dsk4, [("d-1", h1), ("d-1", h2), "e"])  # Multiple names

    bound1 = bind(t3, t4, seed=1, assume_layers=layers)
    cloned_a_name = clone_key("a-1", seed=1)
    assert bound1.__dask_graph__()[cloned_a_name, h1][0] is chunks.bind
    assert bound1.__dask_graph__()[cloned_a_name, h2][0] is chunks.bind
    assert bound1.compute() == (3, )
    assert cnt.n == 3

    bound2 = bind(t3, t4, omit=t2, seed=1, assume_layers=layers)
    cloned_c_name = clone_key("c-1", seed=1)
    assert bound2.__dask_graph__()[cloned_c_name][0] is chunks.bind
    assert bound2.compute() == (3, )
    assert cnt.n == 6
def test_split_every(split_every, nkeys):
    dsk = {("a", i): i for i in range(100)}
    t1 = Tuple(dsk, list(dsk))
    c = checkpoint(t1, split_every=split_every)
    assert len(c.__dask_graph__()) == nkeys
    assert c.compute(scheduler="sync") is None

    t2 = wait_on(t1, split_every=split_every)
    assert len(t2.__dask_graph__()) == nkeys + 100
    assert t2.compute(scheduler="sync") == tuple(range(100))

    dsk3 = {"b": 1, "c": 2}
    t3 = Tuple(dsk3, list(dsk3))
    t4 = bind(t3, t1, split_every=split_every, assume_layers=False)
    assert len(t4.__dask_graph__()) == nkeys + 2
    assert t4.compute(scheduler="sync") == (1, 2)
def test_bind_clone_collections(func):
    @delayed
    def double(x):
        return x * 2

    # dask.delayed
    d1 = double(2)
    d2 = double(d1)
    # dask.array
    a1 = da.ones((10, 10), chunks=5)
    a2 = a1 + 1
    a3 = a2.T
    # dask.bag
    b1 = db.from_sequence([1, 2], npartitions=2)
    # b1's tasks are not callable, so we need an extra step to properly test bind
    b2 = b1.map(lambda x: x * 2)
    b3 = b2.map(lambda x: x + 1)
    b4 = b3.min()
    # dask.dataframe
    df = pd.DataFrame({"x": list(range(10))})
    ddf1 = dd.from_pandas(df, npartitions=2)
    # ddf1's tasks are not callable, so we need an extra step to properly test bind
    ddf2 = ddf1.map_partitions(lambda x: x * 2)
    ddf3 = ddf2.map_partitions(lambda x: x + 1)
    ddf4 = ddf3["x"]  # dd.Series
    ddf5 = ddf4.min()  # dd.Scalar

    cnt = NodeCounter()
    if func is bind:
        parent = da.ones((10, 10), chunks=5).map_blocks(cnt.f)
        cnt.n = 0
        d2c, a3c, b3c, b4c, ddf3c, ddf4c, ddf5c = bind(
            children=(d2, a3, b3, b4, ddf3, ddf4, ddf5),
            parents=parent,
            omit=(d1, a1, b2, ddf2),
            seed=0,
        )
    else:
        d2c, a3c, b3c, b4c, ddf3c, ddf4c, ddf5c = clone(
            d2,
            a3,
            b3,
            b4,
            ddf3,
            ddf4,
            ddf5,
            omit=(d1, a1, b2, ddf2),
            seed=0,
        )

    assert_did_not_materialize(d2c, d2)
    assert_did_not_materialize(a3c, a3)
    assert_did_not_materialize(b3c, b3)
    assert_did_not_materialize(b4c, b4)
    assert_did_not_materialize(ddf3c, ddf3)
    assert_did_not_materialize(ddf4c, ddf4)
    assert_did_not_materialize(ddf5c, ddf5)

    assert_no_common_keys(d2c, d2, omit=d1, layers=True)
    assert_no_common_keys(a3c, a3, omit=a1, layers=True)
    assert_no_common_keys(b3c, b3, omit=b2, layers=True)
    assert_no_common_keys(ddf3c, ddf3, omit=ddf2, layers=True)
    assert_no_common_keys(ddf4c, ddf4, omit=ddf2, layers=True)
    assert_no_common_keys(ddf5c, ddf5, omit=ddf2, layers=True)

    assert d2.compute() == d2c.compute()
    assert cnt.n == 4 or func is clone
    da.utils.assert_eq(a3c, a3)
    assert cnt.n == 8 or func is clone
    db.utils.assert_eq(b3c, b3)
    assert cnt.n == 12 or func is clone
    db.utils.assert_eq(b4c, b4)
    assert cnt.n == 16 or func is clone
    dd.utils.assert_eq(ddf3c, ddf3)
    assert cnt.n == 24 or func is clone  # dd.utils.assert_eq calls compute() twice
    dd.utils.assert_eq(ddf4c, ddf4)
    assert cnt.n == 32 or func is clone  # dd.utils.assert_eq calls compute() twice
    dd.utils.assert_eq(ddf5c, ddf5)
    assert cnt.n == 36 or func is clone