Esempio n. 1
0
 def test_merge_outputs(self):
     x, y, z = inputs()
     e1 = op3(op2(x, y))
     e2 = op3(op2(x, y))
     g = FunctionGraph([x, y, z], [e1, e2])
     MergeOptimizer().optimize(g)
     assert str(g) == "FunctionGraph(*1 -> Op3(Op2(x, y)), *1)"
Esempio n. 2
0
 def test_merge_outputs(self):
     x, y, z = inputs()
     e1 = op3(op2(x, y))
     e2 = op3(op2(x, y))
     g = FunctionGraph([x, y, z], [e1, e2], clone=False)
     MergeOptimizer().optimize(g)
     assert g.outputs[0] is g.outputs[1]
Esempio n. 3
0
 def test_deep_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
     g = FunctionGraph([x, y, z], [e], clone=False)
     MergeOptimizer().optimize(g)
     out_var = g.outputs[0]
     var_1, var_2 = out_var.owner.inputs
     assert var_2.owner.inputs[0] is var_1
Esempio n. 4
0
    def test_change_input(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var6 = MyVariable2("var6")
        with pytest.raises(TypeError):
            fg.change_input("output", 1, var6)

        with pytest.raises(TypeError):
            fg.change_input(var5.owner, 1, var6)

        old_apply_nodes = set(fg.apply_nodes)
        old_variables = set(fg.variables)
        old_var5_clients = list(fg.get_clients(var5))

        # We're replacing with the same variable, so nothing should happen
        fg.change_input(var5.owner, 1, var2)

        assert old_apply_nodes == fg.apply_nodes
        assert old_variables == fg.variables
        assert old_var5_clients == fg.get_clients(var5)

        # Perform a valid `Apply` node input change
        fg.change_input(var5.owner, 1, var1)

        assert var5.owner.inputs[1] is var1
        assert (var5.owner, 1) not in fg.get_clients(var2)
Esempio n. 5
0
 def test_multi(self):
     x, y, z = inputs()
     e0 = op1(x, y)
     e = op3(op4(e0), e0)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
     assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
Esempio n. 6
0
    def test_import_node(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var8 = MyVariable("var8")
        var6 = op2(var8)

        with pytest.raises(MissingInputError):
            fg.import_node(var6.owner)

        assert var8 not in fg.variables

        fg.import_node(var6.owner, import_missing=True)
        assert var8 in fg.inputs
        assert var6.owner in fg.apply_nodes

        var7 = op2(var2)
        assert not hasattr(var7.owner.tag, "imported_by")
        fg.import_node(var7.owner)

        assert hasattr(var7.owner.tag, "imported_by")
        assert var7 in fg.variables
        assert var7.owner in fg.apply_nodes
        assert (var7.owner, 0) in fg.get_clients(var2)
Esempio n. 7
0
    def test_import_var(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var0 = MyVariable("var0")

        with pytest.raises(MissingInputError):
            # We can't import a new `FunctionGraph` input (i.e. something
            # without an owner), at least not without setting `import_missing`
            fg.import_var(var0, "testing")

        fg.import_var(var0, import_missing=True)

        assert var0 in fg.inputs

        var5 = op2()
        # We can import variables with owners
        fg.import_var(var5, "testing")
        assert var5 in fg.variables
        assert var5.owner in fg.apply_nodes

        with pytest.raises(TypeError, match="Computation graph contains.*"):
            from aesara.graph.null_type import NullType

            fg.import_var(NullType()(), "testing")
Esempio n. 8
0
    def test_check_integrity(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(Exception, match="The nodes are .*"):
            fg.apply_nodes.remove(var5.owner)

            fg.check_integrity()

        with pytest.raises(Exception, match="Inconsistent clients.*"):
            fg.apply_nodes.add(var5.owner)
            fg.remove_client(var2, (var5.owner, 1))

            fg.check_integrity()

        fg.add_client(var2, (var5.owner, 1))

        with pytest.raises(Exception, match="The variables are.*"):
            fg.variables.remove(var4)

            fg.check_integrity()

        fg.variables.add(var4)

        with pytest.raises(Exception, match="Undeclared input.*"):
            var6 = MyVariable2("var6")
            fg.clients[var6] = [(var5.owner, 3)]
            fg.variables.add(var6)
            var5.owner.inputs.append(var6)

            fg.check_integrity()

        fg.variables.remove(var6)
        var5.owner.inputs.remove(var6)

        # TODO: What if the index value is greater than 1?  It will throw an
        # `IndexError`, but that doesn't sound like anything we'd want.
        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            fg.add_client(var4, ("output", 1))

            fg.check_integrity()

        fg.remove_client(var4, ("output", 1))

        with pytest.raises(Exception, match="Client not in FunctionGraph.*"):
            fg.add_client(var4, (var6.owner, 0))

            fg.check_integrity()

        fg.remove_client(var4, (var6.owner, 0))

        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            fg.add_client(var4, (var3.owner, 0))

            fg.check_integrity()
Esempio n. 9
0
    def test_remove_in_and_out(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var2, var1)
        op2_out = op2(op1_out, var2)
        op3_out = op3(op2_out, var2, var2)
        fg = FunctionGraph([var1, var2], [op1_out, op3_out], clone=False)

        # Remove an output
        fg.remove_output(1)
        fg.check_integrity()

        assert fg.outputs == [op1_out]
        assert op3_out not in fg.clients
        assert not any(op3_out.owner in clients
                       for clients in sum(fg.clients.values(), []))

        # Remove an input
        fg.remove_input(0)
        fg.check_integrity()

        assert var1 not in fg.variables
        assert fg.inputs == [var2]
        assert fg.outputs == []
        assert not any(op1_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
Esempio n. 10
0
    def test_remove_node(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        node1_out = op1(var1)
        node2_out = op2(var2, node1_out)
        node3_out = op3(node2_out)
        fg = FunctionGraph([var1, var2], [node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert fg.apply_nodes == {node1_out.owner, node2_out.owner}

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node2_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes
Esempio n. 11
0
    def test_remove_node_multi_out(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        multi_op = MyOp("mop", n_outs=2)
        op1_out = op1(var1)
        mop_out_1, mop_out_2 = multi_op(op1_out, var2)
        op3_out = op3(mop_out_2)

        fg = FunctionGraph([var1, var2], [mop_out_1, op3_out], clone=False)

        fg.remove_node(mop_out_1.owner)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
        assert mop_out_1 not in fg.clients
        assert mop_out_2 not in fg.clients
        assert mop_out_1 not in fg.variables
        assert mop_out_2 not in fg.variables

        mop1_out_1, mop1_out_2 = multi_op(var1)
        op2_out = op2(mop1_out_1)
        op3_out = op3(mop1_out_1, mop1_out_2)

        fg = FunctionGraph([var1], [op2_out, op3_out], clone=False)

        fg.remove_node(op3_out.owner)
        fg.check_integrity()

        assert fg.inputs == [var1]
        assert fg.outputs == [op2_out]
        # If we only want to track "active" variables in the graphs, the
        # following would need to be true, as well
        # assert mop1_out_2 not in fg.clients
        # assert mop1_out_2 not in fg.variables

        fg = FunctionGraph([var1], [op2_out, op3_out, mop1_out_2], clone=False)

        fg.remove_node(op3_out.owner)
        fg.check_integrity()

        assert fg.inputs == [var1]
        assert fg.outputs == [op2_out, mop1_out_2]
        assert mop1_out_2 in fg.clients
        assert mop1_out_2 in fg.variables
        assert mop1_out_2 in fg.outputs
Esempio n. 12
0
 def test_allow_multiple_clients(self):
     x, y, z = inputs()
     e0 = op1(x, y)
     # `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
     e = op3(op4(e0), e0)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
     assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
Esempio n. 13
0
    def test_replace_test_value(self):

        var1 = MyVariable("var1")
        var1.tag.test_value = 1
        var2 = MyVariable("var2")
        var2.tag.test_value = 2
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var4.tag.test_value = np.array([1, 2])
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var6 = op3()
        var6.tag.test_value = np.array(0)

        assert var6.tag.test_value.shape != var4.tag.test_value.shape

        with pytest.raises(AssertionError, match="The replacement.*"):
            fg.replace(var4, var6)
Esempio n. 14
0
    def test_contains(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        assert var1 in fg
        assert var3 in fg
        assert var3.owner in fg
        assert var5 in fg
        assert var5.owner in fg
Esempio n. 15
0
    def test_replace_bad_state(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(MissingInputError):
            var0 = MyVariable("var0")

            # FIXME TODO XXX: This breaks the state of the `FunctionGraph`,
            # because it doesn't check for validity of the replacement *first*.
            fg.replace(var1, var0, verbose=True)
Esempio n. 16
0
    def test_replace_circular(self):
        """`FunctionGraph` allows cycles--for better or worse."""

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        fg.replace_all([(var3, var4)])

        # The following works (and is kind of gross), because `var4` has been
        # mutated in-place
        assert fg.apply_nodes == {var4.owner, var5.owner}
        assert var4.owner.inputs == [var4, var2]
Esempio n. 17
0
 def test_1(self):
     x, y, z = map(MyVariable, "xyz")
     e = op3(op4(x, y))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, "x", "y"), (op2, "x", "y")),
             PatternSub((op4, "x", "y"), (op1, "x", "y")),
             PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     # print g
     assert str(g) == "FunctionGraph(Op2(x, y))"
Esempio n. 18
0
 def test_multiple_merges(self):
     x, y, z = inputs()
     e1 = op1(x, y)
     e2 = op2(op3(x), y, z)
     e = op1(e1, op4(e2, e1), op1(e2))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     strg = str(g)
     # note: graph.as_string can only produce the following two possibilities, but if
     # the implementation was to change there are 6 other acceptable answers.
     assert (
         strg ==
         "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))"
         or strg ==
         "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))"
     )
Esempio n. 19
0
    def test_remove_output(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        node1_out = op1(var1)
        node2_out = op2(var2, node1_out)
        node3_out = op3(node2_out)

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_output(0)
        fg.check_integrity()

        assert fg.apply_nodes == {
            node1_out.owner, node2_out.owner, node3_out.owner
        }
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [node3_out]

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_output(1)
        fg.check_integrity()

        assert fg.apply_nodes == {node1_out.owner, node2_out.owner}
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [node2_out]

        fg = FunctionGraph([var1, var2], [node2_out, node3_out, var1],
                           clone=False)

        fg.remove_output(2)
        fg.check_integrity()

        assert fg.apply_nodes == {
            node1_out.owner, node2_out.owner, node3_out.owner
        }
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [node2_out, node3_out]

        fg = FunctionGraph([var1, var2], [var1], clone=False)

        fg.remove_output(0)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
Esempio n. 20
0
 def test_1(self):
     x, y, z = map(MyVariable, "xyz")
     # TODO FIXME: These `Op`s don't have matching/consistent `__prop__`s
     # and `__init__`s, so they can't be `etuplized` correctly
     e = op3(op4(x, y))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, "x", "y"), (op2, "x", "y")),
             PatternSub((op4, "x", "y"), (op1, "x", "y")),
             PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     # print g
     assert str(g) == "FunctionGraph(Op2(x, y))"
Esempio n. 21
0
    def test_remove_output_empty(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var1)
        op3_out = op3(op1_out, var2)
        fg = FunctionGraph([var1, var2], [op3_out], clone=False)

        fg.remove_output(0)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert not fg.apply_nodes
        assert op1_out not in fg.clients
        assert not any(op1_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
        assert not any(op3_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
Esempio n. 22
0
    def test_remove_client(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        assert fg.variables == {var1, var2, var3, var4, var5}
        assert fg.get_clients(var2) == [
            (var3.owner, 0),
            (var4.owner, 1),
            (var5.owner, 1),
            (var5.owner, 2),
        ]

        fg.remove_client(var2, (var4.owner, 1))

        assert fg.get_clients(var2) == [
            (var3.owner, 0),
            (var5.owner, 1),
            (var5.owner, 2),
        ]

        fg.remove_client(var1, (var3.owner, 1))

        assert fg.get_clients(var1) == []

        assert var4.owner in fg.apply_nodes

        # This next `remove_client` should trigger a complete removal of `var4`'s
        # variables and `Apply` node from the `FunctionGraph`.
        #
        # Also, notice that we already removed `var4` from `var2`'s client list
        # above, so, when we completely remove `var4`, `fg.remove_client` will
        # attempt to remove `(var4.owner, 1)` from `var2`'s client list again.
        # This attempt would previously raise a `ValueError` exception, because
        # the entry was not in the list.
        fg.remove_client(var4, (var5.owner, 0), reason="testing")

        assert var4.owner not in fg.apply_nodes
        assert var4.owner.tag.removed_by == ["testing"]
        assert not any(o in fg.variables for o in var4.owner.outputs)
Esempio n. 23
0
    def test_replace(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(TypeError):
            var0 = MyVariable2("var0")
            # The types don't match and one cannot be converted to the other
            fg.replace(var3, var0)

        # Test a basic replacement
        fg.replace_all([(var3, var1)])
        assert var3 not in fg.variables
        assert fg.apply_nodes == {var4.owner, var5.owner}
        assert var4.owner.inputs == [var1, var2]
Esempio n. 24
0
    def test_remove_duplicates(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var2, var1)
        op2_out = op2(op1_out, var2)
        op3_out = op3(op2_out, var2, var2)
        fg = FunctionGraph([var1, var1, var2], [op1_out, op3_out, op3_out],
                           clone=False)

        fg.remove_output(2)
        fg.check_integrity()

        assert fg.outputs == [op1_out, op3_out]

        fg.remove_input(0)
        fg.check_integrity()

        assert var1 not in fg.variables
        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
Esempio n. 25
0
 def test_low_use_ratio(self):
     x, y, z = map(MyVariable, "xyz")
     e = op3(op4(x, y))
     g = FunctionGraph([x, y, z], [e])
     # print 'before', g
     # display pesky warnings along with stdout
     # also silence logger for 'aesara.graph.opt'
     _logger = logging.getLogger("aesara.graph.opt")
     oldlevel = _logger.level
     _logger.setLevel(logging.CRITICAL)
     try:
         opt = EquilibriumOptimizer(
             [
                 PatternSub((op1, "x", "y"), (op2, "x", "y")),
                 PatternSub((op4, "x", "y"), (op1, "x", "y")),
                 PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")),
             ],
             max_use_ratio=1.0 / len(g.apply_nodes),
         )  # each opt can only be applied once
         opt.optimize(g)
     finally:
         _logger.setLevel(oldlevel)
     # print 'after', g
     assert str(g) == "FunctionGraph(Op1(x, y))"
Esempio n. 26
0
 def test_no_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y)), op3(op2(y, x)))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     assert str(g) == "FunctionGraph(Op1(Op3(Op2(x, y)), Op3(Op2(y, x))))"
Esempio n. 27
0
 def test_no_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y)), op3(op2(y, x)))
     g = FunctionGraph([x, y, z], [e])
     g.attach_feature(AssertNoChanges())
     MergeOptimizer().optimize(g)
Esempio n. 28
0
 def test_straightforward_2(self):
     x, y, z = inputs()
     e = op1(op2(x), op3(y), op4(z))
     g = FunctionGraph([x, y, z], [e])
     OpSubOptimizer(op3, op4).optimize(g)
     assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
Esempio n. 29
0
 def test_deep_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     assert str(g) == "FunctionGraph(Op1(*1 -> Op3(Op2(x, y), z), Op4(*1)))"