Esempio n. 1
0
    def test_remove_node(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        node1_out = op1(var1)
        node2_out = op2(var2, node1_out)
        node3_out = op3(node2_out)
        fg = FunctionGraph([var1, var2], [node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert fg.apply_nodes == {node1_out.owner, node2_out.owner}

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node2_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes
Esempio n. 2
0
    def test_remove_node_multi_out(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        multi_op = MyOp("mop", n_outs=2)
        op1_out = op1(var1)
        mop_out_1, mop_out_2 = multi_op(op1_out, var2)
        op3_out = op3(mop_out_2)

        fg = FunctionGraph([var1, var2], [mop_out_1, op3_out], clone=False)

        fg.remove_node(mop_out_1.owner)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
        assert mop_out_1 not in fg.clients
        assert mop_out_2 not in fg.clients
        assert mop_out_1 not in fg.variables
        assert mop_out_2 not in fg.variables

        mop1_out_1, mop1_out_2 = multi_op(var1)
        op2_out = op2(mop1_out_1)
        op3_out = op3(mop1_out_1, mop1_out_2)

        fg = FunctionGraph([var1], [op2_out, op3_out], clone=False)

        fg.remove_node(op3_out.owner)
        fg.check_integrity()

        assert fg.inputs == [var1]
        assert fg.outputs == [op2_out]
        # If we only want to track "active" variables in the graphs, the
        # following would need to be true, as well
        # assert mop1_out_2 not in fg.clients
        # assert mop1_out_2 not in fg.variables

        fg = FunctionGraph([var1], [op2_out, op3_out, mop1_out_2], clone=False)

        fg.remove_node(op3_out.owner)
        fg.check_integrity()

        assert fg.inputs == [var1]
        assert fg.outputs == [op2_out, mop1_out_2]
        assert mop1_out_2 in fg.clients
        assert mop1_out_2 in fg.variables
        assert mop1_out_2 in fg.outputs