def test_remove_in_and_out(self): var1 = MyVariable("var1") var2 = MyVariable("var2") op1_out = op1(var2, var1) op2_out = op2(op1_out, var2) op3_out = op3(op2_out, var2, var2) fg = FunctionGraph([var1, var2], [op1_out, op3_out], clone=False) # Remove an output fg.remove_output(1) fg.check_integrity() assert fg.outputs == [op1_out] assert op3_out not in fg.clients assert not any(op3_out.owner in clients for clients in sum(fg.clients.values(), [])) # Remove an input fg.remove_input(0) fg.check_integrity() assert var1 not in fg.variables assert fg.inputs == [var2] assert fg.outputs == [] assert not any(op1_out.owner in clients for clients in sum(fg.clients.values(), []))
def test_remove_input(self): var0 = MyVariable("var0") var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = MyVariable("var3") var4 = MyVariable("var4") op1_out = op1(var1, var0) out0 = op2(op1_out, var2) out1 = op1(var3, var4) out1.name = "out1" out2 = op1(out1, var0) out2.name = "out2" out3 = out1 fg = FunctionGraph( [var0, var1, var2, var3, var4], [out0, out1, out2, out3], clone=False, ) fg.remove_input(4) fg.check_integrity() assert fg.inputs == [var0, var1, var2, var3] assert fg.outputs == [out0]
def test_remove_output_3(self): var0 = MyVariable("var0") var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = MyVariable("var3") var4 = MyVariable("var4") var5 = MyVariable("var5") var6 = MyVariable("var6") op1_out = op1(var1, var0) out0 = op2(op1_out, var2) out1 = op1(var3, var4) out1.name = "out1" out2 = op1(op1_out, var5) out2.name = "out2" out3 = op1(var3, var6) out3.name = "out3" out4 = op1_out out5 = var3 fg = FunctionGraph( [var0, var1, var2, var3, var4, var5, var6], [out0, out1, out2, out3, out4, out5], clone=False, ) fg.remove_output(1) fg.check_integrity() assert fg.inputs == [var0, var1, var2, var3, var4, var5, var6] assert fg.outputs == [out0, out2, out3, out4, out5] assert out1 not in fg.clients
def test_empty(self): var1 = MyVariable("var1") var2 = MyVariable("var2") fg = FunctionGraph([var1, var2], [], clone=False) fg.check_integrity() assert fg.inputs == [var1, var2] assert fg.outputs == [] assert not fg.variables assert not fg.apply_nodes assert fg.clients == {var1: [], var2: []}
def test_init(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) fg = FunctionGraph([var1, var2], [var3, var4], clone=False) assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} assert fg.update_mapping is None assert fg.check_integrity() is None assert fg.variables == {var1, var2, var3, var4} assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var2) == [(var4.owner, 1)] assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.get_clients(var4) == [("output", 1)] fg = FunctionGraph(outputs=[var3, var4], clone=False) assert fg.inputs == [var1, var2] memo = {} fg = FunctionGraph(outputs=[var3, var4], clone=True, memo=memo) assert memo[var1].type == var1.type assert memo[var1].name == var1.name assert memo[var2].type == var2.type assert memo[var2].name == var2.name assert var3 in memo assert var4 in memo
def test_remove_output_empty(self): var1 = MyVariable("var1") var2 = MyVariable("var2") op1_out = op1(var1) op3_out = op3(op1_out, var2) fg = FunctionGraph([var1, var2], [op3_out], clone=False) fg.remove_output(0) fg.check_integrity() assert fg.inputs == [var1, var2] assert not fg.apply_nodes assert op1_out not in fg.clients assert not any(op1_out.owner in clients for clients in sum(fg.clients.values(), [])) assert not any(op3_out.owner in clients for clients in sum(fg.clients.values(), []))
def test_remove_duplicates(self): var1 = MyVariable("var1") var2 = MyVariable("var2") op1_out = op1(var2, var1) op2_out = op2(op1_out, var2) op3_out = op3(op2_out, var2, var2) fg = FunctionGraph([var1, var1, var2], [op1_out, op3_out, op3_out], clone=False) fg.remove_output(2) fg.check_integrity() assert fg.outputs == [op1_out, op3_out] fg.remove_input(0) fg.check_integrity() assert var1 not in fg.variables assert fg.inputs == [var1, var2] assert fg.outputs == []
def test_remove_node(self): var1 = MyVariable("var1") var2 = MyVariable("var2") node1_out = op1(var1) node2_out = op2(var2, node1_out) node3_out = op3(node2_out) fg = FunctionGraph([var1, var2], [node3_out], clone=False) fg.remove_node(node3_out.owner) fg.check_integrity() assert not fg.apply_nodes fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_node(node3_out.owner) fg.check_integrity() assert fg.apply_nodes == {node1_out.owner, node2_out.owner} fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_node(node2_out.owner) fg.check_integrity() assert not fg.apply_nodes
def test_remove_output(self): var1 = MyVariable("var1") var2 = MyVariable("var2") node1_out = op1(var1) node2_out = op2(var2, node1_out) node3_out = op3(node2_out) fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_output(0) fg.check_integrity() assert fg.apply_nodes == { node1_out.owner, node2_out.owner, node3_out.owner } assert fg.inputs == [var1, var2] assert fg.outputs == [node3_out] fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_output(1) fg.check_integrity() assert fg.apply_nodes == {node1_out.owner, node2_out.owner} assert fg.inputs == [var1, var2] assert fg.outputs == [node2_out] fg = FunctionGraph([var1, var2], [node2_out, node3_out, var1], clone=False) fg.remove_output(2) fg.check_integrity() assert fg.apply_nodes == { node1_out.owner, node2_out.owner, node3_out.owner } assert fg.inputs == [var1, var2] assert fg.outputs == [node2_out, node3_out] fg = FunctionGraph([var1, var2], [var1], clone=False) fg.remove_output(0) fg.check_integrity() assert fg.inputs == [var1, var2] assert fg.outputs == []
def test_init(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) fg = FunctionGraph([var1, var2], [var3, var4], clone=False) assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} assert fg.update_mapping is None assert fg.check_integrity() is None assert fg.variables == {var1, var2, var3, var4} assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var2) == [(var4.owner, 1)] assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.get_clients(var4) == [("output", 1)]
def test_remove_node_multi_out(self): var1 = MyVariable("var1") var2 = MyVariable("var2") multi_op = MyOp("mop", n_outs=2) op1_out = op1(var1) mop_out_1, mop_out_2 = multi_op(op1_out, var2) op3_out = op3(mop_out_2) fg = FunctionGraph([var1, var2], [mop_out_1, op3_out], clone=False) fg.remove_node(mop_out_1.owner) fg.check_integrity() assert fg.inputs == [var1, var2] assert fg.outputs == [] assert mop_out_1 not in fg.clients assert mop_out_2 not in fg.clients assert mop_out_1 not in fg.variables assert mop_out_2 not in fg.variables mop1_out_1, mop1_out_2 = multi_op(var1) op2_out = op2(mop1_out_1) op3_out = op3(mop1_out_1, mop1_out_2) fg = FunctionGraph([var1], [op2_out, op3_out], clone=False) fg.remove_node(op3_out.owner) fg.check_integrity() assert fg.inputs == [var1] assert fg.outputs == [op2_out] # If we only want to track "active" variables in the graphs, the # following would need to be true, as well # assert mop1_out_2 not in fg.clients # assert mop1_out_2 not in fg.variables fg = FunctionGraph([var1], [op2_out, op3_out, mop1_out_2], clone=False) fg.remove_node(op3_out.owner) fg.check_integrity() assert fg.inputs == [var1] assert fg.outputs == [op2_out, mop1_out_2] assert mop1_out_2 in fg.clients assert mop1_out_2 in fg.variables assert mop1_out_2 in fg.outputs
def test_check_integrity(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="The nodes are .*"): fg.apply_nodes.remove(var5.owner) fg.check_integrity() with pytest.raises(Exception, match="Inconsistent clients.*"): fg.apply_nodes.add(var5.owner) fg.remove_client(var2, (var5.owner, 1)) fg.check_integrity() fg.add_client(var2, (var5.owner, 1)) with pytest.raises(Exception, match="The variables are.*"): fg.variables.remove(var4) fg.check_integrity() fg.variables.add(var4) with pytest.raises(Exception, match="Undeclared input.*"): var6 = MyVariable2("var6") fg.clients[var6] = [(var5.owner, 3)] fg.variables.add(var6) var5.owner.inputs.append(var6) fg.check_integrity() fg.variables.remove(var6) var5.owner.inputs.remove(var6) # TODO: What if the index value is greater than 1? It will throw an # `IndexError`, but that doesn't sound like anything we'd want. with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, ("output", 1)) fg.check_integrity() fg.remove_client(var4, ("output", 1)) with pytest.raises(Exception, match="Client not in FunctionGraph.*"): fg.add_client(var4, (var6.owner, 0)) fg.check_integrity() fg.remove_client(var4, (var6.owner, 0)) with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, (var3.owner, 0)) fg.check_integrity()