def test_init(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) fg = FunctionGraph([var1, var2], [var3, var4], clone=False) assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} assert fg.update_mapping is None assert fg.check_integrity() is None assert fg.variables == {var1, var2, var3, var4} assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var2) == [(var4.owner, 1)] assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.get_clients(var4) == [("output", 1)] fg = FunctionGraph(outputs=[var3, var4], clone=False) assert fg.inputs == [var1, var2] memo = {} fg = FunctionGraph(outputs=[var3, var4], clone=True, memo=memo) assert memo[var1].type == var1.type assert memo[var1].name == var1.name assert memo[var2].type == var2.type assert memo[var2].name == var2.name assert var3 in memo assert var4 in memo
def test_change_input(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var6 = MyVariable2("var6") with pytest.raises(TypeError): fg.change_input("output", 1, var6) with pytest.raises(TypeError): fg.change_input(var5.owner, 1, var6) old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(fg.get_clients(var5)) # We're replacing with the same variable, so nothing should happen fg.change_input(var5.owner, 1, var2) assert old_apply_nodes == fg.apply_nodes assert old_variables == fg.variables assert old_var5_clients == fg.get_clients(var5) # Perform a valid `Apply` node input change fg.change_input(var5.owner, 1, var1) assert var5.owner.inputs[1] is var1 assert (var5.owner, 1) not in fg.get_clients(var2)
def test_import_node(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var8 = MyVariable("var8") var6 = op2(var8) with pytest.raises(MissingInputError): fg.import_node(var6.owner) assert var8 not in fg.variables fg.import_node(var6.owner, import_missing=True) assert var8 in fg.inputs assert var6.owner in fg.apply_nodes var7 = op2(var2) assert not hasattr(var7.owner.tag, "imported_by") fg.import_node(var7.owner) assert hasattr(var7.owner.tag, "imported_by") assert var7 in fg.variables assert var7.owner in fg.apply_nodes assert (var7.owner, 0) in fg.get_clients(var2)
def test_init(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) fg = FunctionGraph([var1, var2], [var3, var4], clone=False) assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} assert fg.update_mapping is None assert fg.check_integrity() is None assert fg.variables == {var1, var2, var3, var4} assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var2) == [(var4.owner, 1)] assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.get_clients(var4) == [("output", 1)]
def test_remove_client(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) assert fg.variables == {var1, var2, var3, var4, var5} assert fg.get_clients(var2) == [ (var3.owner, 0), (var4.owner, 1), (var5.owner, 1), (var5.owner, 2), ] fg.remove_client(var2, (var4.owner, 1)) assert fg.get_clients(var2) == [ (var3.owner, 0), (var5.owner, 1), (var5.owner, 2), ] fg.remove_client(var1, (var3.owner, 1)) assert fg.get_clients(var1) == [] assert var4.owner in fg.apply_nodes # This next `remove_client` should trigger a complete removal of `var4`'s # variables and `Apply` node from the `FunctionGraph`. # # Also, notice that we already removed `var4` from `var2`'s client list # above, so, when we completely remove `var4`, `fg.remove_client` will # attempt to remove `(var4.owner, 1)` from `var2`'s client list again. # This attempt would previously raise a `ValueError` exception, because # the entry was not in the list. fg.remove_client(var4, (var5.owner, 0), reason="testing") assert var4.owner not in fg.apply_nodes assert var4.owner.tag.removed_by == ["testing"] assert not any(o in fg.variables for o in var4.owner.outputs)