def test_SubgraphCallable():
    non_hashable = [1, 2, 3]

    dsk = {
        "a": (apply, add, ["in1", 2]),
        "b": (
            apply,
            partial_by_order,
            ["in2"],
            {
                "function": func_with_kwargs,
                "other": [(1, 20)],
                "c": 4
            },
        ),
        "c": (
            apply,
            partial_by_order,
            ["in2", "in1"],
            {
                "function": func_with_kwargs,
                "other": [(1, 20)]
            },
        ),
        "d": (inc, "a"),
        "e": (add, "c", "d"),
        "f": ["a", 2, "b", (add, "b", (sum, non_hashable))],
        "h": (add, (sum, "f"), (sum, ["a", "b"])),
    }

    f = SubgraphCallable(dsk, "h", ["in1", "in2"], name="test")
    assert f.name == "test"
    assert repr(f) == "test"

    f2 = SubgraphCallable(dsk, "h", ["in1", "in2"], name="test")
    assert f == f2

    f3 = SubgraphCallable(dsk, "g", ["in1", "in2"], name="test")
    assert f != f3

    assert dict(f=None)
    assert hash(SubgraphCallable(None, None, [None]))
    assert hash(f3) != hash(f2)

    dsk2 = dsk.copy()
    dsk2.update({"in1": 1, "in2": 2})
    assert f(1, 2) == get(cull(dsk2, ["h"])[0], ["h"])[0]
    assert f(1, 2) == f(1, 2)

    f2 = pickle.loads(pickle.dumps(f))
    assert f2(1, 2) == f(1, 2)
def test_fuse_subgraphs_linear_chains_of_duplicate_deps():
    dsk = {
        "x-1": 1,
        "add-1": (add, "x-1", "x-1"),
        "add-2": (add, "add-1", "add-1"),
        "add-3": (add, "add-2", "add-2"),
        "add-4": (add, "add-3", "add-3"),
        "add-5": (add, "add-4", "add-4"),
    }

    res = fuse(dsk, "add-5", fuse_subgraphs=True)
    sol = with_deps({
        "add-x-1": (SubgraphCallable(
            {
                "x-1": 1,
                "add-1": (add, "x-1", "x-1"),
                "add-2": (add, "add-1", "add-1"),
                "add-3": (add, "add-2", "add-2"),
                "add-4": (add, "add-3", "add-3"),
                "add-5": (add, "add-4", "add-4"),
            },
            "add-5",
            (),
        ), ),
        "add-5":
        "add-x-1",
    })
    assert res == sol
def test_SubgraphCallable_with_numpy():
    np = pytest.importorskip("numpy")

    # Testing support of numpy arrays in `dsk`, which uses elementwise equalities.
    dsk1 = {"a": np.arange(10)}
    f1 = SubgraphCallable(dsk1, "a", [None], name="test")
    f2 = SubgraphCallable(dsk1, "a", [None], name="test")
    assert f1 == f2

    # Notice, even though `dsk1` and `dsk2` are not equal they compare equal because
    # SubgraphCallable.__eq__() only checks name, outkeys, and inkeys.
    dsk2 = {"a": np.arange(10) + 1}
    f3 = SubgraphCallable(dsk2, "a", [None], name="test")
    assert f1 == f3

    f4 = SubgraphCallable(dsk1, "a", [None], name="test2")
    assert f1 != f4
def test_fuse_subgraphs():
    dsk = {
        "x-1": 1,
        "inc-1": (inc, "x-1"),
        "inc-2": (inc, "inc-1"),
        "add-1": (add, "x-1", "inc-2"),
        "inc-3": (inc, "add-1"),
        "inc-4": (inc, "inc-3"),
        "add-2": (add, "add-1", "inc-4"),
        "inc-5": (inc, "add-2"),
        "inc-6": (inc, "inc-5"),
    }

    res = fuse(dsk, "inc-6", fuse_subgraphs=True)
    sol = with_deps({
        "inc-6":
        "add-inc-x-1",
        "add-inc-x-1": (SubgraphCallable(
            {
                "x-1": 1,
                "add-1": (add, "x-1", (inc, (inc, "x-1"))),
                "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))),
            },
            "inc-6",
            (),
        ), ),
    })
    assert res == sol

    res = fuse(dsk, "inc-6", fuse_subgraphs=True, rename_keys=False)
    sol = with_deps({
        "inc-6": (SubgraphCallable(
            {
                "x-1": 1,
                "add-1": (add, "x-1", (inc, (inc, "x-1"))),
                "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))),
            },
            "inc-6",
            (),
        ), )
    })
    assert res == sol

    res = fuse(dsk, "add-2", fuse_subgraphs=True)
    sol = with_deps({
        "add-inc-x-1": (SubgraphCallable(
            {
                "x-1": 1,
                "add-1": (add, "x-1", (inc, (inc, "x-1"))),
                "add-2": (add, "add-1", (inc, (inc, "add-1"))),
            },
            "add-2",
            (),
        ), ),
        "add-2":
        "add-inc-x-1",
        "inc-6": (inc, (inc, "add-2")),
    })
    assert res == sol

    res = fuse(dsk, "inc-2", fuse_subgraphs=True)
    # ordering of arguments is unstable, check all permutations
    sols = []
    for inkeys in itertools.permutations(("x-1", "inc-2")):
        sols.append(
            with_deps({
                "x-1":
                1,
                "inc-2": (inc, (inc, "x-1")),
                "inc-6":
                "inc-add-1",
                "inc-add-1": (SubgraphCallable(
                    {
                        "add-1": (add, "x-1", "inc-2"),
                        "inc-6": (
                            inc,
                            (inc, (add, "add-1", (inc, (inc, "add-1")))),
                        ),
                    },
                    "inc-6",
                    inkeys,
                ), ) + inkeys,
            }))
    assert res in sols

    res = fuse(dsk, ["inc-2", "add-2"], fuse_subgraphs=True)
    # ordering of arguments is unstable, check all permutations
    sols = []
    for inkeys in itertools.permutations(("x-1", "inc-2")):
        sols.append(
            with_deps({
                "x-1":
                1,
                "inc-2": (inc, (inc, "x-1")),
                "inc-add-1": (SubgraphCallable(
                    {
                        "add-1": (add, "x-1", "inc-2"),
                        "add-2": (add, "add-1", (inc, (inc, "add-1"))),
                    },
                    "add-2",
                    inkeys,
                ), ) + inkeys,
                "add-2":
                "inc-add-1",
                "inc-6": (inc, (inc, "add-2")),
            }))
    assert res in sols