def test_as_dask_graph_multiple_links(self, graph, sum_op, square_op,
                                          negative_op):
        def my_func(x: int, y: int) -> (int, int):
            return y, x

        # Connect sum_op to square_op.
        # Connect sum_op to my_op's x, square_op to my_op's y.
        # Leave negative_op unconnected
        my_op = operation(my_func, output_names=("y", "x"))()
        graph.add_operations(sum_op, square_op, negative_op, my_op)
        graph.add_link(sum_op, square_op, "sum", "n")
        graph.add_link(sum_op, my_op, "sum", "x")
        graph.add_link(square_op, my_op, "square", "y")
        dask_graph, end_ids = graph.as_dask_graph()

        # Should look like:
        sum_wrapper = dask_graph["0"]
        square_wrapper = dask_graph["1"]
        negative_wrapper = dask_graph["2"]
        my_wrapper = dask_graph["3"]

        # sum_op has no dependent nodes (no ops connect into it)
        assert len(sum_wrapper) == 1
        assert sum_wrapper[0].node is sum_op

        # square_op has 1 dependent node, takes sum_op's output
        assert len(square_wrapper) == 2
        assert square_wrapper[0].node is square_op

        # negative_op has no dependent nodes; is unconnected
        assert len(negative_wrapper) == 1
        assert negative_wrapper[0].node is negative_op

        # my_op has two dependent nodes; sum_op and square_op connect to its inputs
        assert len(my_wrapper) == 3
        assert my_wrapper[0].node is my_op
        assert my_wrapper[1] == "0"  # sum_op
        assert my_wrapper[2] == "1"  # square_op

        # negative_op, and my_op should be end nodes
        assert sorted(end_ids) == sorted(["2", "3"])
    def test_as_dask_graph(self, graph, sum_op, square_op, negative_op):
        # Connect sum_op to square_op; don't connect negative_op
        graph.add_operations(sum_op, square_op, negative_op)
        graph.add_link(sum_op, square_op, "sum", "n")
        dask_graph, end_ids = graph.as_dask_graph()

        # Should look like:
        # { "0": (<sum_op>,), "1": (<square_op>, "0"), "2": (<negative_op>,) }
        sum_wrapper = dask_graph["0"]
        square_wrapper = dask_graph["1"]
        negative_wrapper = dask_graph["2"]
        assert len(sum_wrapper) == 1
        assert sum_wrapper[0].node is sum_op
        assert len(square_wrapper) == 2
        assert square_wrapper[0].node is square_op
        assert square_wrapper[1] == "0"
        assert len(negative_wrapper) == 1
        assert negative_wrapper[0].node is negative_op

        # Both square_op and negative_op should be end nodes
        assert sorted(end_ids) == sorted(["1", "2"])
 def test_as_dask_graph_no_links(self, graph, sum_op):
     graph.add_operation(sum_op)
     dask_graph, end_ids = graph.as_dask_graph()
     assert len(dask_graph["0"]) == 1
     assert dask_graph["0"][0].node is sum_op
     assert end_ids == ["0"]
 def test_as_dask_graph_empty(self, graph):
     # Empty graph, no end nodes
     assert graph.as_dask_graph() == ({}, [])