コード例 #1
0
 def test_replace_output(self):
     # replacing the whole graph
     x, y, z = inputs()
     e = op1(op2(x, y), z)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op1, (op2, "1", "2"), "3"),
                      (op4, "3", "2")).optimize(g)
     assert str(g) == "FunctionGraph(Op4(z, y))"
コード例 #2
0
ファイル: test_opt.py プロジェクト: luke14free/Theano-PyMC
 def test_2(self):
     x, y, z = map(MyVariable, "xyz")
     e = op1(op1(op3(x, y)))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, (op2, "x", "y")), (op4, "x", "y")),
             PatternSub((op3, "x", "y"), (op4, "x", "y")),
             PatternSub((op4, "x", "y"), (op5, "x", "y")),
             PatternSub((op5, "x", "y"), (op6, "x", "y")),
             PatternSub((op6, "x", "y"), (op2, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     assert str(g) == "FunctionGraph(Op2(x, y))"
コード例 #3
0
ファイル: test_opt.py プロジェクト: luke14free/Theano-PyMC
 def test_nested_out_pattern(self):
     x, y, z = inputs()
     e = op1(x, y)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer(
         (op1, "1", "2"),
         (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))).optimize(g)
     assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
コード例 #4
0
 def test_no_recurse(self):
     # if the out pattern is an acceptable in pattern
     # and that the ignore_newtrees flag is True,
     # it should do the replacement and stop
     x, y, z = inputs()
     e = op1(op2(x, y), z)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).optimize(g)
     assert str(g) == "[Op1(Op2(y, x), z)]"
コード例 #5
0
 def test_unification_1(self):
     x, y, z = inputs()
     e = op1(op2(x, x), z)  # the arguments to op2 are the same
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer(
         (op1, (op2, "1", "1"), "2"),  # they are the same in the pattern
         (op4, "2", "1"),
     ).optimize(g)
     # So the replacement should occur
     assert str(g) == "[Op4(z, x)]"
コード例 #6
0
 def test_identical_constant_args(self):
     x = MyVariable("x")
     y = Constant(MyType(), 2, name="y")
     z = Constant(MyType(), 2, name="z")
     with config.change_flags(compute_test_value="off"):
         e1 = op1(y, z)
     g = FunctionGraph([x, y, z], [e1])
     MergeOptimizer().optimize(g)
     strg = str(g)
     assert strg == "FunctionGraph(Op1(y, y))" or strg == "FunctionGraph(Op1(z, z))"
コード例 #7
0
 def test_unification_2(self):
     x, y, z = inputs()
     e = op1(op2(x, y), z)  # the arguments to op2 are different
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer(
         (op1, (op2, "1", "1"), "2"),  # they are the same in the pattern
         (op4, "2", "1"),
     ).optimize(g)
     # The replacement should NOT occur
     assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))"
コード例 #8
0
 def test_constant_merging(self):
     x = MyVariable("x")
     y = Constant(MyType(), 2, name="y")
     z = Constant(MyType(), 2, name="z")
     e = op1(op2(x, y), op2(x, y), op2(x, z))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     strg = str(g)
     assert (strg == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, *1))"
             or strg == "FunctionGraph(Op1(*1 -> Op2(x, z), *1, *1))")
コード例 #9
0
    def test_pickle(self):
        var1 = op1()
        var2 = op2()
        var3 = op1(var1)
        var4 = op2(var3, var2)
        func = FunctionGraph([var1, var2], [var4])

        s = pickle.dumps(func)
        new_func = pickle.loads(s)

        assert all(
            type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs))
        assert all(
            type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs))
        assert all(
            type(a.op) is type(b.op)  # noqa: E721
            for a, b in zip(func.apply_nodes, new_func.apply_nodes))
        assert all(a.type == b.type
                   for a, b in zip(func.variables, new_func.variables))
コード例 #10
0
def test_pre_greedy_local_optimizer():

    empty_fgraph = FunctionGraph([], [])

    x = MyVariable("x")
    y = MyVariable("y")
    c1 = Constant(MyType(), 1, "c1")
    c2 = Constant(MyType(), 2, "c2")
    o1 = op2(c1, c2)
    o3 = op1(c1, y)
    o2 = op1(o1, c2, x, o3, o1)

    assert o2.owner.inputs[0].owner is not None
    assert o2.owner.inputs[4].owner is not None

    # This should fold `o1`, because it has only `Constant` arguments, and
    # replace it with the `Constant` result
    cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], o2)

    assert cst.owner.inputs[0].owner is None
    assert cst.owner.inputs[1] is c2
    assert cst.owner.inputs[2] is x
    assert cst.owner.inputs[3] is o3
    assert cst.owner.inputs[4] is cst.owner.inputs[0]

    # We're going to do it again, except this time `o1` is
    # in the `fgraph`, so it shouldn't be folded
    fg = FunctionGraph([], [o1], clone=False)
    o2 = op1(o1, c2, x, o3, o1)

    cst = pre_greedy_local_optimizer(fg, [constant_folding], o2)

    assert cst.owner.inputs[0] is o1
    assert cst.owner.inputs[4] is cst.owner.inputs[0]

    # What exactly is this supposed to test?
    ms = MakeSlice()(1)
    cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms)

    assert isinstance(cst, SliceConstant)

    # Make sure constant of slice signature is hashable.
    assert isinstance(hash(cst.signature()), int)
コード例 #11
0
def test_pre_constant_merge():

    empty_fgraph = FunctionGraph([], [])

    x = MyVariable("x")
    y = MyVariable("y")
    c1 = Constant(MyType(), 1, "c1")
    c2 = Constant(MyType(), 1, "c1")
    o1 = op2(c1, x)
    o2 = op1(o1, y, c2)

    assert c1 is not c2

    res = pre_constant_merge(empty_fgraph, [o2])

    assert [o2] == res
    assert o2.owner.inputs[2] is c1

    o2 = op1(o1, y, c2)
    fg = FunctionGraph([x, y], [o2], clone=False)

    assert o2.owner in fg.apply_nodes

    res = pre_constant_merge(fg, [o2])

    assert res == [o2]
    assert o2.owner.inputs[2] is c2

    # What is this supposed to test?
    ms = MakeSlice()(1)
    res = pre_constant_merge(empty_fgraph, [ms])

    assert res == [ms]

    const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2))

    assert isinstance(const_slice, Constant)

    adv = AdvancedSubtensor()(tt.matrix(), [2, 3], const_slice)

    res = pre_constant_merge(empty_fgraph, adv)
    assert res == [adv]
コード例 #12
0
 def test_identical_constant_args(self):
     x = MyVariable("x")
     y = Constant(MyType(), 2, name="y")
     z = Constant(MyType(), 2, name="z")
     ctv_backup = config.compute_test_value
     config.compute_test_value = "off"
     try:
         e1 = op1(y, z)
     finally:
         config.compute_test_value = ctv_backup
     g = FunctionGraph([x, y, z], [e1])
     MergeOptimizer().optimize(g)
     strg = str(g)
     assert strg == "FunctionGraph(Op1(y, y))" or strg == "FunctionGraph(Op1(z, z))"
コード例 #13
0
ファイル: test_fg.py プロジェクト: kyleabeauchamp/Theano-PyMC
    def test_contains(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        assert var1 in fg
        assert var3 in fg
        assert var3.owner in fg
        assert var5 in fg
        assert var5.owner in fg
コード例 #14
0
ファイル: test_fg.py プロジェクト: kyleabeauchamp/Theano-PyMC
    def test_replace_bad_state(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(MissingInputError):
            var0 = MyVariable("var0")

            # FIXME TODO XXX: This breaks the state of the `FunctionGraph`,
            # because it doesn't check for validity of the replacement *first*.
            fg.replace(var1, var0, verbose=True)
コード例 #15
0
    def test_replace_circular(self):
        """`FunctionGraph` allows cycles--for better or worse."""

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        fg.replace_all([(var3, var4)])

        # The following works (and is kind of gross), because `var4` has been
        # mutated in-place
        assert fg.apply_nodes == {var4.owner, var5.owner}
        assert var4.owner.inputs == [var4, var2]
コード例 #16
0
 def test_init(self):
     var1 = MyVariable("var1")
     var2 = MyVariable("var2")
     var3 = op1(var1)
     var4 = op2(var3, var2)
     fg = FunctionGraph([var1, var2], [var3, var4], clone=False)
     assert fg.inputs == [var1, var2]
     assert fg.outputs == [var3, var4]
     assert fg.apply_nodes == {var3.owner, var4.owner}
     assert fg.update_mapping is None
     assert fg.check_integrity() is None
     assert fg.variables == {var1, var2, var3, var4}
     assert fg.clients(var1) == [(var3.owner, 0)]
     assert fg.clients(var2) == [(var4.owner, 1)]
     assert fg.clients(var3) == [(var4.owner, 0), ("output", 0)]
     assert fg.clients(var4) == [("output", 1)]
コード例 #17
0
    def test_remove_client(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        assert fg.variables == {var1, var2, var3, var4, var5}
        assert fg.clients(var2) == [
            (var3.owner, 0),
            (var4.owner, 1),
            (var5.owner, 1),
            (var5.owner, 2),
        ]

        fg.remove_client(var2, (var4.owner, 1))

        assert fg.clients(var2) == [
            (var3.owner, 0),
            (var5.owner, 1),
            (var5.owner, 2),
        ]

        fg.remove_client(var1, (var3.owner, 1))

        assert fg.clients(var1) == []

        assert var4.owner in fg.apply_nodes

        # This next `remove_client` should trigger a complete removal of `var4`'s
        # variables and `Apply` node from the `FunctionGraph`.
        #
        # Also, notice that we already removed `var4` from `var2`'s client list
        # above, so, when we completely remove `var4`, `fg.remove_client` will
        # attempt to remove `(var4.owner, 1)` from `var2`'s client list again.
        # This attempt would previously raise a `ValueError` exception, because
        # the entry was not in the list.
        fg.remove_client(var4, (var5.owner, 0), reason="testing")

        assert var4.owner not in fg.apply_nodes
        assert var4.owner.tag.removed_by == ["testing"]
        assert not any(o in fg.variables for o in var4.owner.outputs)
コード例 #18
0
    def test_replace_test_value(self):

        var1 = MyVariable("var1")
        var1.tag.test_value = 1
        var2 = MyVariable("var2")
        var2.tag.test_value = 2
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var4.tag.test_value = np.array([1, 2])
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var6 = op3()
        var6.tag.test_value = np.array(0)

        assert var6.tag.test_value.shape != var4.tag.test_value.shape

        with pytest.raises(AssertionError, match="The replacement.*"):
            fg.replace(var4, var6)
コード例 #19
0
ファイル: test_fg.py プロジェクト: kyleabeauchamp/Theano-PyMC
    def test_replace(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(BadOptimization):
            var0 = MyVariable2("var0")
            # The types don't match and one cannot be converted to the other
            fg.replace(var3, var0)

        # Test a basic replacement
        fg.replace_all([(var3, var1)])
        assert var3 not in fg.variables
        assert fg.apply_nodes == {var4.owner, var5.owner}
        assert var4.owner.inputs == [var1, var2]
コード例 #20
0
    def test_import_node(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var5 = MyVariable("var5")
        var6 = op2(var5)

        with pytest.raises(MissingInputError):
            fg.import_node(var6.owner)

        var6 = op2(var2)
        assert not hasattr(var6.owner.tag, "imported_by")
        fg.import_node(var6.owner)

        assert hasattr(var6.owner.tag, "imported_by")
        assert var6 in fg.variables
        assert var6.owner in fg.apply_nodes
        assert (var6.owner, 0) in var2.clients
コード例 #21
0
 def test_match_same(self):
     x, y, z = inputs()
     e = op1(x, x)
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g)
     assert str(g) == "FunctionGraph(Op3(x, x))"
コード例 #22
0
    def test_check_integrity(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(Exception, match="The nodes are .*"):
            fg.apply_nodes.remove(var5.owner)

            fg.check_integrity()

        with pytest.raises(Exception, match="Inconsistent clients.*"):
            fg.apply_nodes.add(var5.owner)
            var2.clients.remove((var5.owner, 1))

            fg.check_integrity()

        var2.clients.append((var5.owner, 1))

        with pytest.raises(Exception, match="The variables are.*"):
            fg.variables.remove(var4)

            fg.check_integrity()

        fg.variables.add(var4)

        with pytest.raises(Exception, match="Undeclared input.*"):
            var6 = MyVariable2("var6")
            var6.fgraph = fg
            var6.clients = [(var5.owner, 3)]
            fg.variables.add(var6)
            var5.owner.inputs.append(var6)

            fg.check_integrity()

        fg.variables.remove(var6)
        var5.owner.inputs.remove(var6)

        # TODO: What if the index value is greater than 1?  It will throw an
        # `IndexError`, but that doesn't sound like anything we'd want.
        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            var4.clients.append(("output", 1))

            fg.check_integrity()

        var4.clients.remove(("output", 1))

        with pytest.raises(Exception, match="Client not in FunctionGraph.*"):
            var4.clients.append((var6.owner, 0))

            fg.check_integrity()

        var4.clients.remove((var6.owner, 0))

        with pytest.raises(Exception, match="Inconsistent clients list.*"):
            var4.clients.append((var3.owner, 0))

            fg.check_integrity()
コード例 #23
0
 def test_no_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y)), op3(op2(y, x)))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     assert str(g) == "FunctionGraph(Op1(Op3(Op2(x, y)), Op3(Op2(y, x))))"
コード例 #24
0
 def test_deep_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     assert str(g) == "FunctionGraph(Op1(*1 -> Op3(Op2(x, y), z), Op4(*1)))"
コード例 #25
0
 def test_straightforward(self):
     x, y, z = inputs()
     e = op1(op2(x, y), op2(x, y), op2(x, z))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     assert str(g) == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, Op2(x, z)))"
コード例 #26
0
 def test_straightforward_2(self):
     x, y, z = inputs()
     e = op1(op2(x), op3(y), op4(z))
     g = FunctionGraph([x, y, z], [e])
     OpSubOptimizer(op3, op4).optimize(g)
     assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
コード例 #27
0
 def test_straightforward(self):
     x, y, z = inputs()
     e = op1(op1(op1(op1(op1(x)))))
     g = FunctionGraph([x, y, z], [e])
     OpSubOptimizer(op1, op2).optimize(g)
     assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
コード例 #28
0
 def test_expand(self):
     x, y, z = inputs()
     e = op1(op1(op1(x)))
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op1, "1"), (op2, (op1, "1")), ign=True).optimize(g)
     assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
コード例 #29
0
 def test_nested_odd(self):
     x, y, z = inputs()
     e = op1(op1(op1(op1(op1(x)))))
     g = FunctionGraph([x, y, z], [e])
     PatternOptimizer((op1, (op1, "1")), "1").optimize(g)
     assert str(g) == "FunctionGraph(Op1(x))"