示例#1
0
    def test_remove_in_and_out(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var2, var1)
        op2_out = op2(op1_out, var2)
        op3_out = op3(op2_out, var2, var2)
        fg = FunctionGraph([var1, var2], [op1_out, op3_out], clone=False)

        # Remove an output
        fg.remove_output(1)
        fg.check_integrity()

        assert fg.outputs == [op1_out]
        assert op3_out not in fg.clients
        assert not any(op3_out.owner in clients
                       for clients in sum(fg.clients.values(), []))

        # Remove an input
        fg.remove_input(0)
        fg.check_integrity()

        assert var1 not in fg.variables
        assert fg.inputs == [var2]
        assert fg.outputs == []
        assert not any(op1_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
示例#2
0
    def test_remove_input(self):

        var0 = MyVariable("var0")
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = MyVariable("var3")
        var4 = MyVariable("var4")

        op1_out = op1(var1, var0)
        out0 = op2(op1_out, var2)
        out1 = op1(var3, var4)
        out1.name = "out1"
        out2 = op1(out1, var0)
        out2.name = "out2"
        out3 = out1

        fg = FunctionGraph(
            [var0, var1, var2, var3, var4],
            [out0, out1, out2, out3],
            clone=False,
        )

        fg.remove_input(4)
        fg.check_integrity()

        assert fg.inputs == [var0, var1, var2, var3]
        assert fg.outputs == [out0]
示例#3
0
    def test_remove_output_3(self):

        var0 = MyVariable("var0")
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = MyVariable("var3")
        var4 = MyVariable("var4")
        var5 = MyVariable("var5")
        var6 = MyVariable("var6")
        op1_out = op1(var1, var0)
        out0 = op2(op1_out, var2)
        out1 = op1(var3, var4)
        out1.name = "out1"
        out2 = op1(op1_out, var5)
        out2.name = "out2"
        out3 = op1(var3, var6)
        out3.name = "out3"
        out4 = op1_out
        out5 = var3
        fg = FunctionGraph(
            [var0, var1, var2, var3, var4, var5, var6],
            [out0, out1, out2, out3, out4, out5],
            clone=False,
        )

        fg.remove_output(1)
        fg.check_integrity()

        assert fg.inputs == [var0, var1, var2, var3, var4, var5, var6]
        assert fg.outputs == [out0, out2, out3, out4, out5]
        assert out1 not in fg.clients
示例#4
0
    def test_empty(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        fg = FunctionGraph([var1, var2], [], clone=False)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
        assert not fg.variables
        assert not fg.apply_nodes
        assert fg.clients == {var1: [], var2: []}
示例#5
0
    def test_init(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var1)
        var4 = op2(var3, var2)
        fg = FunctionGraph([var1, var2], [var3, var4], clone=False)
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [var3, var4]
        assert fg.apply_nodes == {var3.owner, var4.owner}
        assert fg.update_mapping is None
        assert fg.check_integrity() is None
        assert fg.variables == {var1, var2, var3, var4}
        assert fg.get_clients(var1) == [(var3.owner, 0)]
        assert fg.get_clients(var2) == [(var4.owner, 1)]
        assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)]
        assert fg.get_clients(var4) == [("output", 1)]

        fg = FunctionGraph(outputs=[var3, var4], clone=False)
        assert fg.inputs == [var1, var2]

        memo = {}
        fg = FunctionGraph(outputs=[var3, var4], clone=True, memo=memo)

        assert memo[var1].type == var1.type
        assert memo[var1].name == var1.name
        assert memo[var2].type == var2.type
        assert memo[var2].name == var2.name
        assert var3 in memo
        assert var4 in memo
示例#6
0
    def test_remove_output_empty(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var1)
        op3_out = op3(op1_out, var2)
        fg = FunctionGraph([var1, var2], [op3_out], clone=False)

        fg.remove_output(0)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert not fg.apply_nodes
        assert op1_out not in fg.clients
        assert not any(op1_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
        assert not any(op3_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
示例#7
0
    def test_remove_duplicates(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var2, var1)
        op2_out = op2(op1_out, var2)
        op3_out = op3(op2_out, var2, var2)
        fg = FunctionGraph([var1, var1, var2], [op1_out, op3_out, op3_out],
                           clone=False)

        fg.remove_output(2)
        fg.check_integrity()

        assert fg.outputs == [op1_out, op3_out]

        fg.remove_input(0)
        fg.check_integrity()

        assert var1 not in fg.variables
        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
示例#8
0
    def test_remove_node(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        node1_out = op1(var1)
        node2_out = op2(var2, node1_out)
        node3_out = op3(node2_out)
        fg = FunctionGraph([var1, var2], [node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert fg.apply_nodes == {node1_out.owner, node2_out.owner}

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node2_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes
示例#9
0
    def test_remove_output(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        node1_out = op1(var1)
        node2_out = op2(var2, node1_out)
        node3_out = op3(node2_out)

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_output(0)
        fg.check_integrity()

        assert fg.apply_nodes == {
            node1_out.owner, node2_out.owner, node3_out.owner
        }
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [node3_out]

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_output(1)
        fg.check_integrity()

        assert fg.apply_nodes == {node1_out.owner, node2_out.owner}
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [node2_out]

        fg = FunctionGraph([var1, var2], [node2_out, node3_out, var1],
                           clone=False)

        fg.remove_output(2)
        fg.check_integrity()

        assert fg.apply_nodes == {
            node1_out.owner, node2_out.owner, node3_out.owner
        }
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [node2_out, node3_out]

        fg = FunctionGraph([var1, var2], [var1], clone=False)

        fg.remove_output(0)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
示例#10
0
 def test_init(self):
     var1 = MyVariable("var1")
     var2 = MyVariable("var2")
     var3 = op1(var1)
     var4 = op2(var3, var2)
     fg = FunctionGraph([var1, var2], [var3, var4], clone=False)
     assert fg.inputs == [var1, var2]
     assert fg.outputs == [var3, var4]
     assert fg.apply_nodes == {var3.owner, var4.owner}
     assert fg.update_mapping is None
     assert fg.check_integrity() is None
     assert fg.variables == {var1, var2, var3, var4}
     assert fg.get_clients(var1) == [(var3.owner, 0)]
     assert fg.get_clients(var2) == [(var4.owner, 1)]
     assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)]
     assert fg.get_clients(var4) == [("output", 1)]
示例#11
0
    def test_remove_node_multi_out(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        multi_op = MyOp("mop", n_outs=2)
        op1_out = op1(var1)
        mop_out_1, mop_out_2 = multi_op(op1_out, var2)
        op3_out = op3(mop_out_2)

        fg = FunctionGraph([var1, var2], [mop_out_1, op3_out], clone=False)

        fg.remove_node(mop_out_1.owner)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
        assert mop_out_1 not in fg.clients
        assert mop_out_2 not in fg.clients
        assert mop_out_1 not in fg.variables
        assert mop_out_2 not in fg.variables

        mop1_out_1, mop1_out_2 = multi_op(var1)
        op2_out = op2(mop1_out_1)
        op3_out = op3(mop1_out_1, mop1_out_2)

        fg = FunctionGraph([var1], [op2_out, op3_out], clone=False)

        fg.remove_node(op3_out.owner)
        fg.check_integrity()

        assert fg.inputs == [var1]
        assert fg.outputs == [op2_out]
        # If we only want to track "active" variables in the graphs, the
        # following would need to be true, as well
        # assert mop1_out_2 not in fg.clients
        # assert mop1_out_2 not in fg.variables

        fg = FunctionGraph([var1], [op2_out, op3_out, mop1_out_2], clone=False)

        fg.remove_node(op3_out.owner)
        fg.check_integrity()

        assert fg.inputs == [var1]
        assert fg.outputs == [op2_out, mop1_out_2]
        assert mop1_out_2 in fg.clients
        assert mop1_out_2 in fg.variables
        assert mop1_out_2 in fg.outputs
示例#12
0
    def test_check_integrity(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(Exception, match="The nodes are .*"):
            fg.apply_nodes.remove(var5.owner)

            fg.check_integrity()

        with pytest.raises(Exception, match="Inconsistent clients.*"):
            fg.apply_nodes.add(var5.owner)
            fg.remove_client(var2, (var5.owner, 1))

            fg.check_integrity()

        fg.add_client(var2, (var5.owner, 1))

        with pytest.raises(Exception, match="The variables are.*"):
            fg.variables.remove(var4)

            fg.check_integrity()

        fg.variables.add(var4)

        with pytest.raises(Exception, match="Undeclared input.*"):
            var6 = MyVariable2("var6")
            fg.clients[var6] = [(var5.owner, 3)]
            fg.variables.add(var6)
            var5.owner.inputs.append(var6)

            fg.check_integrity()

        fg.variables.remove(var6)
        var5.owner.inputs.remove(var6)

        # TODO: What if the index value is greater than 1?  It will throw an
        # `IndexError`, but that doesn't sound like anything we'd want.
        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            fg.add_client(var4, ("output", 1))

            fg.check_integrity()

        fg.remove_client(var4, ("output", 1))

        with pytest.raises(Exception, match="Client not in FunctionGraph.*"):
            fg.add_client(var4, (var6.owner, 0))

            fg.check_integrity()

        fg.remove_client(var4, (var6.owner, 0))

        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            fg.add_client(var4, (var3.owner, 0))

            fg.check_integrity()