コード例 #1
0
    def test_remove_input(self):

        var0 = MyVariable("var0")
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = MyVariable("var3")
        var4 = MyVariable("var4")

        op1_out = op1(var1, var0)
        out0 = op2(op1_out, var2)
        out1 = op1(var3, var4)
        out1.name = "out1"
        out2 = op1(out1, var0)
        out2.name = "out2"
        out3 = out1

        fg = FunctionGraph(
            [var0, var1, var2, var3, var4],
            [out0, out1, out2, out3],
            clone=False,
        )

        fg.remove_input(4)
        fg.check_integrity()

        assert fg.inputs == [var0, var1, var2, var3]
        assert fg.outputs == [out0]
コード例 #2
0
    def test_remove_output_3(self):

        var0 = MyVariable("var0")
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = MyVariable("var3")
        var4 = MyVariable("var4")
        var5 = MyVariable("var5")
        var6 = MyVariable("var6")
        op1_out = op1(var1, var0)
        out0 = op2(op1_out, var2)
        out1 = op1(var3, var4)
        out1.name = "out1"
        out2 = op1(op1_out, var5)
        out2.name = "out2"
        out3 = op1(var3, var6)
        out3.name = "out3"
        out4 = op1_out
        out5 = var3
        fg = FunctionGraph(
            [var0, var1, var2, var3, var4, var5, var6],
            [out0, out1, out2, out3, out4, out5],
            clone=False,
        )

        fg.remove_output(1)
        fg.check_integrity()

        assert fg.inputs == [var0, var1, var2, var3, var4, var5, var6]
        assert fg.outputs == [out0, out2, out3, out4, out5]
        assert out1 not in fg.clients
コード例 #3
0
    def test_init(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var1)
        var4 = op2(var3, var2)
        fg = FunctionGraph([var1, var2], [var3, var4], clone=False)
        assert fg.inputs == [var1, var2]
        assert fg.outputs == [var3, var4]
        assert fg.apply_nodes == {var3.owner, var4.owner}
        assert fg.update_mapping is None
        assert fg.check_integrity() is None
        assert fg.variables == {var1, var2, var3, var4}
        assert fg.get_clients(var1) == [(var3.owner, 0)]
        assert fg.get_clients(var2) == [(var4.owner, 1)]
        assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)]
        assert fg.get_clients(var4) == [("output", 1)]

        varC = MyConstant("varC")
        var5 = op1(var1, varC)
        fg = FunctionGraph(outputs=[var3, var4, var5], clone=False)
        assert fg.inputs == [var1, var2]

        memo = {}
        fg = FunctionGraph(outputs=[var3, var4], clone=True, memo=memo)

        assert memo[var1].type == var1.type
        assert memo[var1].name == var1.name
        assert memo[var2].type == var2.type
        assert memo[var2].name == var2.name
        assert var3 in memo
        assert var4 in memo
コード例 #4
0
 def test_constant(self):
     x = Constant(MyType(), 2, name="x")
     y = MyVariable("y")
     z = Constant(MyType(), 2, name="z")
     e = op1(op1(x, y), y)
     g = FunctionGraph([y], [e])
     PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g)
     assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
コード例 #5
0
 def test_nested_even(self):
     # regardless of the order in which we optimize, this
     # should work
     x, y, z = inputs()
     e = op1(op1(op1(op1(x))))
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op1, (op1, "1")), "1").optimize(g)
     assert str(g) == "FunctionGraph(x)"
コード例 #6
0
 def test_ambiguous(self):
     # this test should always work with TopoOptimizer and the
     # ignore_newtrees flag set to False. Behavior with ignore_newtrees
     # = True or with other NavigatorOptimizers may differ.
     x, y, z = inputs()
     e = op1(op1(op1(op1(op1(x)))))
     g = FunctionGraph([x, y, z], [e])
     TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"),
                          ign=False).optimize(g)
     assert str(g) == "FunctionGraph(Op1(x))"
コード例 #7
0
ファイル: test_fg.py プロジェクト: geofiber/aesara
    def test_validate_inputs(self):
        var1 = op1()
        var2 = op2()

        with pytest.raises(TypeError):
            FunctionGraph(var1, [var2])

        with pytest.raises(TypeError):
            FunctionGraph([var1], var2)

        with pytest.raises(ValueError):
            var3 = op1(var1)
            FunctionGraph([var3], [var2], clone=False)
コード例 #8
0
    def test_match_same_illegal(self):
        x, y, z = inputs()
        e = op2(op1(x, x), op1(x, y))
        g = FunctionGraph([x, y, z], [e])

        def constraint(r):
            # Only replacing if the input is an instance of Op2
            return r.owner.inputs[0] is not r.owner.inputs[1]

        PatternOptimizer({
            "pattern": (op1, "x", "y"),
            "constraint": constraint
        }, (op3, "x", "y")).optimize(g)
        assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
コード例 #9
0
    def test_constraints(self):
        x, y, z = inputs()
        e = op4(op1(op2(x, y)), op1(op1(x, y)))
        g = FunctionGraph([x, y, z], [e])

        def constraint(r):
            # Only replacing if the input is an instance of Op2
            return r.owner.op == op2

        PatternOptimizer((op1, {
            "pattern": "1",
            "constraint": constraint
        }), (op3, "1")).optimize(g)
        assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
コード例 #10
0
 def test_replace_subgraph(self):
     # replacing inside the graph
     x, y, z = inputs()
     e = op1(op2(x, y), z)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g)
     assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
コード例 #11
0
 def test_multiple(self):
     # it should replace all occurrences of the pattern
     x, y, z = inputs()
     e = op1(op2(x, y), op2(x, y), op2(y, z))
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g)
     assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
コード例 #12
0
ファイル: test_opt.py プロジェクト: ricardoV94/aesara
 def test_multiple_merges(self):
     x, y, z = inputs()
     e1 = op1(x, y)
     e2 = op2(op3(x), y, z)
     e = op1(e1, op4(e2, e1), op1(e2))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     strg = str(g)
     # note: graph.as_string can only produce the following two possibilities, but if
     # the implementation was to change there are 6 other acceptable answers.
     assert (
         strg ==
         "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))"
         or strg ==
         "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))"
     )
コード例 #13
0
ファイル: test_fg.py プロジェクト: blueskysir/aesara
    def test_check_integrity(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(Exception, match="The nodes are .*"):
            fg.apply_nodes.remove(var5.owner)

            fg.check_integrity()

        with pytest.raises(Exception, match="Inconsistent clients.*"):
            fg.apply_nodes.add(var5.owner)
            fg.remove_client(var2, (var5.owner, 1))

            fg.check_integrity()

        fg.add_client(var2, (var5.owner, 1))

        with pytest.raises(Exception, match="The variables are.*"):
            fg.variables.remove(var4)

            fg.check_integrity()

        fg.variables.add(var4)

        with pytest.raises(Exception, match="Undeclared input.*"):
            var6 = MyVariable2("var6")
            fg.clients[var6] = [(var5.owner, 3)]
            fg.variables.add(var6)
            var5.owner.inputs.append(var6)

            fg.check_integrity()

        fg.variables.remove(var6)
        var5.owner.inputs.remove(var6)

        # TODO: What if the index value is greater than 1?  It will throw an
        # `IndexError`, but that doesn't sound like anything we'd want.
        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            fg.add_client(var4, ("output", 1))

            fg.check_integrity()

        fg.remove_client(var4, ("output", 1))

        with pytest.raises(Exception, match="Client not in FunctionGraph.*"):
            fg.add_client(var4, (var6.owner, 0))

            fg.check_integrity()

        fg.remove_client(var4, (var6.owner, 0))

        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            fg.add_client(var4, (var3.owner, 0))

            fg.check_integrity()
コード例 #14
0
ファイル: test_fg.py プロジェクト: blueskysir/aesara
    def test_change_input(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var6 = MyVariable2("var6")
        with pytest.raises(TypeError):
            fg.change_input("output", 1, var6)

        with pytest.raises(TypeError):
            fg.change_input(var5.owner, 1, var6)

        old_apply_nodes = set(fg.apply_nodes)
        old_variables = set(fg.variables)
        old_var5_clients = list(fg.get_clients(var5))

        # We're replacing with the same variable, so nothing should happen
        fg.change_input(var5.owner, 1, var2)

        assert old_apply_nodes == fg.apply_nodes
        assert old_variables == fg.variables
        assert old_var5_clients == fg.get_clients(var5)

        # Perform a valid `Apply` node input change
        fg.change_input(var5.owner, 1, var1)

        assert var5.owner.inputs[1] is var1
        assert (var5.owner, 1) not in fg.get_clients(var2)
コード例 #15
0
ファイル: test_fg.py プロジェクト: blueskysir/aesara
    def test_import_var(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var0 = MyVariable("var0")

        with pytest.raises(MissingInputError):
            # We can't import a new `FunctionGraph` input (i.e. something
            # without an owner), at least not without setting `import_missing`
            fg.import_var(var0, "testing")

        fg.import_var(var0, import_missing=True)

        assert var0 in fg.inputs

        var5 = op2()
        # We can import variables with owners
        fg.import_var(var5, "testing")
        assert var5 in fg.variables
        assert var5.owner in fg.apply_nodes

        with pytest.raises(TypeError, match="Computation graph contains.*"):
            from aesara.graph.null_type import NullType

            fg.import_var(NullType()(), "testing")
コード例 #16
0
ファイル: test_fg.py プロジェクト: blueskysir/aesara
    def test_import_node(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var8 = MyVariable("var8")
        var6 = op2(var8)

        with pytest.raises(MissingInputError):
            fg.import_node(var6.owner)

        assert var8 not in fg.variables

        fg.import_node(var6.owner, import_missing=True)
        assert var8 in fg.inputs
        assert var6.owner in fg.apply_nodes

        var7 = op2(var2)
        assert not hasattr(var7.owner.tag, "imported_by")
        fg.import_node(var7.owner)

        assert hasattr(var7.owner.tag, "imported_by")
        assert var7 in fg.variables
        assert var7.owner in fg.apply_nodes
        assert (var7.owner, 0) in fg.get_clients(var2)
コード例 #17
0
    def test_verbose(self, capsys):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        fg = FunctionGraph([var1, var2], [var3], clone=False)

        rv_feature = ReplaceValidate()
        fg.attach_feature(rv_feature)
        rv_feature.replace_all_validate(fg, [(var3, var1)],
                                        reason="test-reason",
                                        verbose=True)

        capres = capsys.readouterr()
        assert capres.err == ""
        assert "optimizer: rewrite test-reason replaces Op1.0 with var1" in capres.out

        class TestFeature(Feature):
            def validate(self, *args):
                raise Exception()

        fg.attach_feature(TestFeature())

        with pytest.raises(Exception):
            rv_feature.replace_all_validate(fg, [(var3, var1)],
                                            reason="test-reason",
                                            verbose=True)

        capres = capsys.readouterr()
        assert "optimizer: validate failed on node Op1.0" in capres.out
コード例 #18
0
    def test_remove_in_and_out(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var2, var1)
        op2_out = op2(op1_out, var2)
        op3_out = op3(op2_out, var2, var2)
        fg = FunctionGraph([var1, var2], [op1_out, op3_out], clone=False)

        # Remove an output
        fg.remove_output(1)
        fg.check_integrity()

        assert fg.outputs == [op1_out]
        assert op3_out not in fg.clients
        assert not any(op3_out.owner in clients
                       for clients in sum(fg.clients.values(), []))

        # Remove an input
        fg.remove_input(0)
        fg.check_integrity()

        assert var1 not in fg.variables
        assert fg.inputs == [var2]
        assert fg.outputs == []
        assert not any(op1_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
コード例 #19
0
    def test_remove_node(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        node1_out = op1(var1)
        node2_out = op2(var2, node1_out)
        node3_out = op3(node2_out)
        fg = FunctionGraph([var1, var2], [node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node3_out.owner)
        fg.check_integrity()

        assert fg.apply_nodes == {node1_out.owner, node2_out.owner}

        fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)

        fg.remove_node(node2_out.owner)
        fg.check_integrity()

        assert not fg.apply_nodes
コード例 #20
0
ファイル: test_opt.py プロジェクト: ricardoV94/aesara
 def test_multi(self):
     x, y, z = inputs()
     e0 = op1(x, y)
     e = op3(op4(e0), e0)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
     assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
コード例 #21
0
def test_patternsub_values_eq_approx(out_pattern, tracks):
    # PatternSub would fail when `values_eq_approx` and `get_nodes` were specified
    x = MyVariable("x")
    e = op1(x)
    fg = FunctionGraph([x], [e], clone=False)

    opt = EquilibriumOptimizer(
        [
            PatternSub(
                (op1, "x"),
                out_pattern,
                tracks=[op1] if tracks else (),
                get_nodes=(lambda fgraph, node: [node]) if tracks else None,
                values_eq_approx=values_eq_approx_always_true,
            )
        ],
        max_use_ratio=1,
    )
    opt.optimize(fg)
    output = fg.outputs[0]
    if isinstance(out_pattern, tuple):
        assert output.owner.op == op2
        assert output.tag.values_eq_approx is values_eq_approx_always_true
    elif out_pattern == "x":
        assert output is x
        assert output.tag.values_eq_approx is values_eq_approx_always_true
    else:
        # The replacement types do not match, so the substitution should've
        # failed
        assert output is e
コード例 #22
0
    def test_optimizer_verbose(self, capsys):

        x = MyVariable("x")
        y = MyVariable("y")
        o1 = op1(x, y)

        fgraph = FunctionGraph([x, y], [o1], clone=False)

        @local_optimizer(None)
        def local_opt_1(fgraph, node):
            if node.inputs[0] == x:
                res = op2(y, *node.inputs[1:])
                return [res]

        @local_optimizer(None)
        def local_opt_2(fgraph, node):
            if node.inputs[0] == y:
                res = op2(x, *node.inputs[1:])
                return [res]

        opt_group = LocalOptGroup(local_opt_1, local_opt_2)

        with config.change_flags(optimizer_verbose=True):
            (new_res, ) = opt_group.transform(fgraph, o1.owner)
            _ = opt_group.transform(fgraph, new_res.owner)

        capres = capsys.readouterr()
        assert capres.err == ""
        assert (
            "optimizer: rewrite local_opt_1 replaces Op1(x, y) with [Op2.0]"
            in capres.out)
        assert (
            "optimizer: rewrite local_opt_2 replaces Op2(y, y) with [Op2.0]"
            in capres.out)
コード例 #23
0
 def test_nested_out_pattern(self):
     x, y, z = inputs()
     e = op1(x, y)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer(
         (op1, "1", "2"),
         (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))).optimize(g)
     assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
コード例 #24
0
 def test_replace_output(self):
     # replacing the whole graph
     x, y, z = inputs()
     e = op1(op2(x, y), z)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op1, (op2, "1", "2"), "3"),
                      (op4, "3", "2")).optimize(g)
     assert str(g) == "FunctionGraph(Op4(z, y))"
コード例 #25
0
 def test_2(self):
     x, y, z = map(MyVariable, "xyz")
     e = op1(op1(op3(x, y)))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, (op2, "x", "y")), (op4, "x", "y")),
             PatternSub((op3, "x", "y"), (op4, "x", "y")),
             PatternSub((op4, "x", "y"), (op5, "x", "y")),
             PatternSub((op5, "x", "y"), (op6, "x", "y")),
             PatternSub((op6, "x", "y"), (op2, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     assert str(g) == "FunctionGraph(Op2(x, y))"
コード例 #26
0
 def test_deep_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
     g = FunctionGraph([x, y, z], [e], clone=False)
     MergeOptimizer().optimize(g)
     out_var = g.outputs[0]
     var_1, var_2 = out_var.owner.inputs
     assert var_2.owner.inputs[0] is var_1
コード例 #27
0
 def test_allow_multiple_clients(self):
     x, y, z = inputs()
     e0 = op1(x, y)
     # `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
     e = op3(op4(e0), e0)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
     assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
コード例 #28
0
ファイル: test_fg.py プロジェクト: mrtommyb/aesara
    def test_pickle(self):
        var1 = op1()
        var2 = op2()
        var3 = op1(var1)
        var4 = op2(var3, var2)
        func = FunctionGraph([var1, var2], [var4])

        s = pickle.dumps(func)
        new_func = pickle.loads(s)

        assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs))
        assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs))
        assert all(
            type(a.op) is type(b.op)  # noqa: E721
            for a, b in zip(func.apply_nodes, new_func.apply_nodes)
        )
        assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
コード例 #29
0
 def test_straightforward(self):
     x, y, z = inputs()
     e = op1(op2(x, y), op2(x, y), op2(x, z))
     g = FunctionGraph([x, y, z], [e], clone=False)
     MergeOptimizer().optimize(g)
     out_var = g.outputs[0]
     var_1, var_2, var_3 = out_var.owner.inputs
     assert var_1 is var_2
     assert var_1 is not var_3
コード例 #30
0
 def test_eq(self):
     # replacing the whole graph
     x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
     e = op1(op_y(x, y), z)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op1, (op_z, "1", "2"), "3"),
                      (op4, "3", "2")).optimize(g)
     str_g = str(g)
     assert str_g == "FunctionGraph(Op4(z, y))"