def test_remove_node(self): var1 = MyVariable("var1") var2 = MyVariable("var2") node1_out = op1(var1) node2_out = op2(var2, node1_out) node3_out = op3(node2_out) fg = FunctionGraph([var1, var2], [node3_out], clone=False) fg.remove_node(node3_out.owner) fg.check_integrity() assert not fg.apply_nodes fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_node(node3_out.owner) fg.check_integrity() assert fg.apply_nodes == {node1_out.owner, node2_out.owner} fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_node(node2_out.owner) fg.check_integrity() assert not fg.apply_nodes
def test_remove_node_multi_out(self): var1 = MyVariable("var1") var2 = MyVariable("var2") multi_op = MyOp("mop", n_outs=2) op1_out = op1(var1) mop_out_1, mop_out_2 = multi_op(op1_out, var2) op3_out = op3(mop_out_2) fg = FunctionGraph([var1, var2], [mop_out_1, op3_out], clone=False) fg.remove_node(mop_out_1.owner) fg.check_integrity() assert fg.inputs == [var1, var2] assert fg.outputs == [] assert mop_out_1 not in fg.clients assert mop_out_2 not in fg.clients assert mop_out_1 not in fg.variables assert mop_out_2 not in fg.variables mop1_out_1, mop1_out_2 = multi_op(var1) op2_out = op2(mop1_out_1) op3_out = op3(mop1_out_1, mop1_out_2) fg = FunctionGraph([var1], [op2_out, op3_out], clone=False) fg.remove_node(op3_out.owner) fg.check_integrity() assert fg.inputs == [var1] assert fg.outputs == [op2_out] # If we only want to track "active" variables in the graphs, the # following would need to be true, as well # assert mop1_out_2 not in fg.clients # assert mop1_out_2 not in fg.variables fg = FunctionGraph([var1], [op2_out, op3_out, mop1_out_2], clone=False) fg.remove_node(op3_out.owner) fg.check_integrity() assert fg.inputs == [var1] assert fg.outputs == [op2_out, mop1_out_2] assert mop1_out_2 in fg.clients assert mop1_out_2 in fg.variables assert mop1_out_2 in fg.outputs