Esempio n. 1
0
    def test_match_same_illegal(self):
        x, y, z = inputs()
        e = op2(op1(x, x), op1(x, y))
        g = FunctionGraph([x, y, z], [e])

        def constraint(r):
            # Only replacing if the input is an instance of Op2
            return r.owner.inputs[0] is not r.owner.inputs[1]

        PatternOptimizer({
            "pattern": (op1, "x", "y"),
            "constraint": constraint
        }, (op3, "x", "y")).optimize(g)
        assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
Esempio n. 2
0
def test_jax_shape_ops():
    x_np = np.zeros((20, 3))
    x = Shape()(aet.as_tensor_variable(x_np))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [], must_be_device_array=False)

    x = Shape_i(1)(aet.as_tensor_variable(x_np))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [], must_be_device_array=False)

    x = SpecifyShape()(aet.as_tensor_variable(x_np), (20, 3))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [])

    with config.change_flags(compute_test_value="off"):
        x = SpecifyShape()(aet.as_tensor_variable(x_np), (2, 3))
        x_fg = FunctionGraph([], [x])

        with pytest.raises(AssertionError):
            compare_jax_and_py(x_fg, [])
Esempio n. 3
0
    def test_constraints(self):
        x, y, z = inputs()
        e = op4(op1(op2(x, y)), op1(op1(x, y)))
        g = FunctionGraph([x, y, z], [e])

        def constraint(r):
            # Only replacing if the input is an instance of Op2
            return r.owner.op == op2

        PatternOptimizer((op1, {
            "pattern": "1",
            "constraint": constraint
        }), (op3, "1")).optimize(g)
        assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
Esempio n. 4
0
def test_nnet():
    x = vector("x")
    x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)

    out = sigmoid(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_nnet.ultra_fast_sigmoid(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = softplus(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_nnet.softmax(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_nnet.logsoftmax(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 5
0
def test_extra_ops_omni():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    # This function also cannot take symbolic input.
    c = aet.as_tensor(5)
    out = aet_extra_ops.bartlett(c)
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
    out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(
        fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
    )

    # The inputs are "concrete", yet it still has problems?
    out = aet_extra_ops.Unique()(
        aet.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2)))
    )
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [])
Esempio n. 6
0
    def test_contains(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        assert var1 in fg
        assert var3 in fg
        assert var3.owner in fg
        assert var5 in fg
        assert var5.owner in fg
Esempio n. 7
0
def test_jax_Join():
    a = matrix("a")
    b = matrix("b")

    x = aet.join(0, a, b)
    x_fg = FunctionGraph([a, b], [x])
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
            np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
        ],
    )
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
            np.c_[[4.0, 5.0]].astype(config.floatX),
        ],
    )

    x = aet.join(1, a, b)
    x_fg = FunctionGraph([a, b], [x])
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
            np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
        ],
    )
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
            np.c_[[5.0, 6.0]].astype(config.floatX),
        ],
    )
Esempio n. 8
0
def test_jax_FunctionGraph_once():
    """Make sure that an output is only computed once when it's referenced multiple times."""
    from aesara.link.jax.dispatch import jax_funcify

    x = vector("x")
    y = vector("y")

    class TestOp(Op):
        def __init__(self):
            self.called = 0

        def make_node(self, *args):
            return Apply(self, list(args), [x.type() for x in args])

        def perform(self, inputs, outputs):
            for i, inp in enumerate(inputs):
                outputs[i][0] = inp[0]

    @jax_funcify.register(TestOp)
    def jax_funcify_TestOp(op, **kwargs):
        def func(*args, op=op):
            op.called += 1
            return list(args)

        return func

    op1 = TestOp()
    op2 = TestOp()

    q, r = op1(x, y)
    outs = op2(q + r, q + r)

    out_fg = FunctionGraph([x, y], outs, clone=False)
    assert len(out_fg.outputs) == 2

    out_jx = jax_funcify(out_fg)

    x_val = np.r_[1, 2].astype(config.floatX)
    y_val = np.r_[2, 3].astype(config.floatX)

    res = out_jx(x_val, y_val)
    assert len(res) == 2
    assert op1.called == 1
    assert op2.called == 1

    res = out_jx(x_val, y_val)
    assert len(res) == 2
    assert op1.called == 2
    assert op2.called == 2
Esempio n. 9
0
def test_local_optimizer():

    with pytest.raises(ValueError):

        @local_optimizer([])
        def local_bad_1(fgraph, node):
            return node.outputs

    with pytest.raises(TypeError):

        @local_optimizer([None])
        def local_bad_2(fgraph, node):
            return node.outputs

    x = MyVariable("x")
    y = MyVariable("y")

    o1 = op1(x, y)

    class MyNewOp(MyOp):
        pass

    o2 = MyNewOp("MyNewOp")(x, y)

    class MyNewOp2(MyOp):
        pass

    o3 = MyNewOp2("MyNewOp2")(x, y)

    fgraph = FunctionGraph([x, y], [o1, o2, o3], clone=False)

    hits = [0]

    @local_optimizer([op1, MyNewOp])
    def local_opt_1(fgraph, node, hits=hits):
        hits[0] += 1
        return node.outputs

    # This is allowed by the `op1` in `tracks`
    local_opt_1.transform(fgraph, fgraph.outputs[0].owner)
    assert hits[0] == 1

    # This is allowed by the `MyOp` in `tracks`
    local_opt_1.transform(fgraph, fgraph.outputs[1].owner)
    assert hits[0] == 2

    # This is not allowed by `tracks`
    local_opt_1.transform(fgraph, fgraph.outputs[2].owner)
    assert hits[0] == 2
Esempio n. 10
0
def test_random_stats(at_dist, dist_params, rng, size):
    # The RNG states are not 1:1, so the best we can do is check some summary
    # statistics of the samples
    out = normal(*dist_params, rng=rng, size=size)
    fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)

    def assert_fn(x, y):
        (x,) = x
        (y,) = y
        assert x.dtype.kind == y.dtype.kind

        d = 2 if config.floatX == "float64" else 1
        np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d)

    compare_jax_and_py(fgraph, [], assert_fn=assert_fn)
Esempio n. 11
0
 def test_fill(self):
     for linker, op, t, rval in zip(
         self.linkers,
         [self.op, self.cop],
         [self.type, self.ctype],
         [self.rand_val, self.rand_cval],
     ):
         x = t(aesara.config.floatX, (False, False))("x")
         y = t(aesara.config.floatX, (True, True))("y")
         e = op(aes.Second(aes.transfer_type(0)), {0: 0})(x, y)
         f = make_function(linker().accept(FunctionGraph([x, y], [e])))
         xv = rval((5, 5))
         yv = rval((1, 1))
         f(xv, yv)
         assert (xv == yv).all()
Esempio n. 12
0
def test_jax_Dimshuffle():
    a_aet = matrix("a")

    x = a_aet.T
    x_fg = FunctionGraph([a_aet], [x])
    compare_jax_and_py(x_fg,
                       [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])

    x = a_aet.dimshuffle([0, 1, "x"])
    x_fg = FunctionGraph([a_aet], [x])
    compare_jax_and_py(x_fg,
                       [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])

    a_aet = tensor(dtype=config.floatX, broadcastable=[False, True])
    x = a_aet.dimshuffle((0, ))
    x_fg = FunctionGraph([a_aet], [x])
    compare_jax_and_py(x_fg,
                       [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])

    a_aet = tensor(dtype=config.floatX, broadcastable=[False, True])
    x = aet_elemwise.DimShuffle([False, True], (0, ))(a_aet)
    x_fg = FunctionGraph([a_aet], [x])
    compare_jax_and_py(x_fg,
                       [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
Esempio n. 13
0
def test_jax_FunctionGraph_names():
    import inspect

    from aesara.link.jax.dispatch import jax_funcify

    x = scalar("1x")
    y = scalar("_")
    z = scalar()
    q = scalar("def")

    out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False)
    out_jx = jax_funcify(out_fg)
    sig = inspect.signature(out_jx)
    assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys())
    assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
Esempio n. 14
0
def test_jax_Composite(x, y, x_val, y_val):
    x_s = aes.float64("x")
    y_s = aes.float64("y")

    comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))

    out = comp_op(x, y)

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        x_val.astype(config.floatX),
        y_val.astype(config.floatX),
    ]
    _ = compare_jax_and_py(out_fg, test_input_vals)
Esempio n. 15
0
def test_patternsub_invalid_dtype(out_pattern):
    # PatternSub would wrongly return output of different dtype as the original node
    x = MyVariable("x")
    e = op_cast_type2(x)
    fg = FunctionGraph([x], [e])

    opt = EquilibriumOptimizer(
        [PatternSub(
            (op_cast_type2, "x"),
            out_pattern,
        )],
        max_use_ratio=1,
    )
    opt.optimize(fg)
    assert fg.apply_nodes.pop().op == op_cast_type2
Esempio n. 16
0
 def test_weird_strides(self):
     for linker, op, t, rval in zip(
         self.linkers,
         [self.op, self.cop],
         [self.type, self.ctype],
         [self.rand_val, self.rand_cval],
     ):
         x = t(aesara.config.floatX, (False,) * 5)("x")
         y = t(aesara.config.floatX, (False,) * 5)("y")
         e = op(aes.add)(x, y)
         f = make_function(linker().accept(FunctionGraph([x, y], [e])))
         xv = rval((2, 2, 2, 2, 2))
         yv = rval((2, 2, 2, 2, 2)).transpose(4, 0, 3, 1, 2)
         zv = xv + yv
         assert (f(xv, yv) == zv).all()
Esempio n. 17
0
    def schedule(self, fgraph: FunctionGraph) -> typing.List[Apply]:
        """Runs the scheduler (if set) or the toposort on the FunctionGraph.

        Parameters
        ----------
        fgraph : FunctionGraph
            A graph to compute the schedule for.

        Returns
        -------
        nodes : list of Apply nodes
            The result of the scheduling or toposort operation.
        """
        if callable(self._scheduler):
            return self._scheduler(fgraph)
        return fgraph.toposort()
Esempio n. 18
0
def test_patternsub_different_output_lengths():
    # Test that PatternSub won't replace nodes with different numbers of outputs
    ps = PatternSub(
        (op1, "x"),
        ("x"),
        name="ps",
    )
    opt = in2out(ps)

    x = MyVariable("x")
    e1, e2 = op_multiple_outputs(x)
    o = op1(e1)

    fgraph = FunctionGraph(inputs=[x], outputs=[o])
    opt.optimize(fgraph)
    assert fgraph.outputs[0].owner.op == op1
Esempio n. 19
0
    def schedule(self, fgraph: FunctionGraph) -> List[Apply]:
        """Runs the scheduler (if set) or the toposort on the FunctionGraph.

        Parameters
        ----------
        fgraph : :py:class:`aerasa.graph.fg.FunctionGraph`
            A graph to compute the schedule for.

        Returns
        -------
        nodes : list of :py:class:`aesara.graph.basic.Apply` nodes
            The result of the scheduling or toposort operation.
        """
        if callable(self._scheduler):
            return self._scheduler(fgraph)
        return fgraph.toposort()
Esempio n. 20
0
 def test_multiple_merges(self):
     x, y, z = inputs()
     e1 = op1(x, y)
     e2 = op2(op3(x), y, z)
     e = op1(e1, op4(e2, e1), op1(e2))
     g = FunctionGraph([x, y, z], [e])
     MergeOptimizer().optimize(g)
     strg = str(g)
     # note: graph.as_string can only produce the following two possibilities, but if
     # the implementation was to change there are 6 other acceptable answers.
     assert (
         strg ==
         "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))"
         or strg ==
         "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))"
     )
Esempio n. 21
0
 def test_1(self):
     x, y, z = map(MyVariable, "xyz")
     e = op3(op4(x, y))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, "x", "y"), (op2, "x", "y")),
             PatternSub((op4, "x", "y"), (op1, "x", "y")),
             PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     # print g
     assert str(g) == "FunctionGraph(Op2(x, y))"
Esempio n. 22
0
def test_optimize_graph():

    x, y = vectors("xy")

    @optimizer
    def custom_opt(fgraph):
        fgraph.replace(x, y, import_missing=True)

    x_opt = optimize_graph(x, custom_opt=custom_opt)

    assert x_opt is y

    x_opt = optimize_graph(FunctionGraph(outputs=[x], clone=False),
                           custom_opt=custom_opt)

    assert x_opt.outputs[0] is y
Esempio n. 23
0
def test_KanrenRelationSub_filters():
    x_at = at.vector("x")
    y_at = at.vector("y")
    z_at = at.vector("z")
    A_at = at.matrix("A")

    fact(commutative, _dot)
    fact(commutative, at.add)
    fact(associative, at.add)

    Z_at = A_at.dot((x_at + y_at) + z_at)

    fgraph = FunctionGraph(outputs=[Z_at], clone=False)

    def distributes(in_lv, out_lv):
        A_lv, x_lv, y_lv, z_lv = vars(4)
        return lall(
            # lhs == A * (x + y + z)
            eq_assoccomm(
                etuple(_dot, A_lv,
                       etuple(at.add, x_lv, etuple(at.add, y_lv, z_lv))),
                in_lv,
            ),
            # This relation does nothing but provide us with a means of
            # generating associative-commutative matches in the `kanren`
            # output.
            eq((A_lv, x_lv, y_lv, z_lv), out_lv),
        )

    def results_filter(results):
        _results = [eval_if_etuple(v) for v in results]

        # Make sure that at least a couple permutations are present
        assert (A_at, x_at, y_at, z_at) in _results
        assert (A_at, y_at, x_at, z_at) in _results
        assert (A_at, z_at, x_at, y_at) in _results

        return None

    _ = KanrenRelationSub(distributes,
                          results_filter=results_filter).transform(
                              fgraph, fgraph.outputs[0].owner)

    res = KanrenRelationSub(distributes,
                            node_filter=lambda x: False).transform(
                                fgraph, fgraph.outputs[0].owner)
    assert res is False
Esempio n. 24
0
    def test_one_assert_merge(self):
        # Merge two nodes, one has assert, the other not.
        x1 = matrix("x1")
        x2 = matrix("x2")
        e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
        g = FunctionGraph([x1, x2], [e], clone=False)
        MergeOptimizer().optimize(g)

        assert g.outputs[0].owner.op == add
        add_inputs = g.outputs[0].owner.inputs
        assert isinstance(add_inputs[0].owner.op, Dot)
        # Confirm that the `Assert`s are correct
        assert_var = add_inputs[0].owner.inputs[0]
        assert_ref = assert_op(x1, (x1 > x2).all())
        assert equal_computations([assert_var], [assert_ref])
        # Confirm the merge
        assert add_inputs[0] is add_inputs[1]
Esempio n. 25
0
    def test_c_views(self):
        x_at = vector()
        thunk, inputs, outputs = (CLinker().accept(
            FunctionGraph([x_at], [x_at[None]])).make_thunk())

        # This is a little hackish, but we're hoping that--by running this more than
        # a few times--we're more likely to run into random memory that isn't the same
        # as the broadcasted value; that way, we'll be able to tell that we're getting
        # junk data from a poorly constructed array view.
        x_val = np.broadcast_to(2039, (5000, ))
        for i in range(1000):
            inputs[0].storage[0] = x_val
            thunk()
            # Make sure it's a view of the original data
            assert np.shares_memory(x_val, outputs[0].storage[0])
            # Confirm the broadcasted value in the output
            assert np.array_equiv(outputs[0].storage[0], 2039)
Esempio n. 26
0
    def test_pickle(self):
        var1 = op1()
        var2 = op2()
        var3 = op1(var1)
        var4 = op2(var3, var2)
        func = FunctionGraph([var1, var2], [var4])

        s = pickle.dumps(func)
        new_func = pickle.loads(s)

        assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs))
        assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs))
        assert all(
            type(a.op) is type(b.op)  # noqa: E721
            for a, b in zip(func.apply_nodes, new_func.apply_nodes)
        )
        assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
Esempio n. 27
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = at_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    c = at.as_tensor(5)

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(
        fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
    )
Esempio n. 28
0
def test_KanrenRelationSub_dot():
    """Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached."""
    x_at = at.vector("x")
    c_at = at.vector("c")
    d_at = at.vector("d")
    A_at = at.matrix("A")
    B_at = at.matrix("B")

    Z_at = A_at.dot(x_at + B_at.dot(c_at + d_at))

    fgraph = FunctionGraph(outputs=[Z_at], clone=False)

    assert isinstance(fgraph.outputs[0].owner.op, Dot)

    def distributes(in_lv, out_lv):
        return lall(
            # lhs == A * (x + b)
            eq(
                etuple(_dot, var("A"), etuple(at.add, var("x"), var("b"))),
                in_lv,
            ),
            # rhs == A * x + A * b
            eq(
                etuple(
                    at.add,
                    etuple(_dot, var("A"), var("x")),
                    etuple(_dot, var("A"), var("b")),
                ),
                out_lv,
            ),
        )

    distribute_opt = EquilibriumOptimizer([KanrenRelationSub(distributes)],
                                          max_use_ratio=10)

    fgraph_opt = optimize_graph(fgraph, custom_opt=distribute_opt)
    (expr_opt, ) = fgraph_opt.outputs

    assert expr_opt.owner.op == at.add
    assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot)
    assert fgraph_opt.inputs[0] is A_at
    assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A"
    assert expr_opt.owner.inputs[1].owner.op == at.add
    assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
    assert isinstance(expr_opt.owner.inputs[1].owner.inputs[1].owner.op, Dot)
Esempio n. 29
0
def test_fgraph_to_python_multiline_str():
    """Make sure that multiline `__str__` values are supported by `fgraph_to_python`."""

    x = vector("x")
    y = vector("y")

    class TestOp(Op):
        def __init__(self):
            super().__init__()

        def make_node(self, *args):
            return Apply(self, list(args), [x.type() for x in args])

        def perform(self, inputs, outputs):
            for i, inp in enumerate(inputs):
                outputs[i][0] = inp[0]

        def __str__(self):
            return "Test\nOp()"

    @to_python.register(TestOp)
    def to_python_TestOp(op, **kwargs):
        def func(*args, op=op):
            return list(args)

        return func

    op1 = TestOp()
    op2 = TestOp()

    q, r = op1(x, y)
    outs = op2(q + r, q + r)

    out_fg = FunctionGraph([x, y], outs, clone=False)
    assert len(out_fg.outputs) == 2

    out_py = fgraph_to_python(out_fg, to_python)

    out_py_src = inspect.getsource(out_py)

    assert ("""
    # Elemwise{add,no_inplace}(Test
    # Op().0, Test
    # Op().1)
    """ in out_py_src)
Esempio n. 30
0
 def test_1(self):
     x, y, z = map(MyVariable, "xyz")
     # TODO FIXME: These `Op`s don't have matching/consistent `__prop__`s
     # and `__init__`s, so they can't be `etuplized` correctly
     e = op3(op4(x, y))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, "x", "y"), (op2, "x", "y")),
             PatternSub((op4, "x", "y"), (op1, "x", "y")),
             PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     # print g
     assert str(g) == "FunctionGraph(Op2(x, y))"