Esempio n. 1
0
    def test_c_views(self):
        x_at = vector()
        thunk, inputs, outputs = (CLinker().accept(
            FunctionGraph([x_at], [x_at[None]])).make_thunk())

        # This is a little hackish, but we're hoping that--by running this more than
        # a few times--we're more likely to run into random memory that isn't the same
        # as the broadcasted value; that way, we'll be able to tell that we're getting
        # junk data from a poorly constructed array view.
        x_val = np.broadcast_to(2039, (5000, ))
        for i in range(1000):
            inputs[0].storage[0] = x_val
            thunk()
            # Make sure it's a view of the original data
            assert np.shares_memory(x_val, outputs[0].storage[0])
            # Confirm the broadcasted value in the output
            assert np.array_equiv(outputs[0].storage[0], 2039)
Esempio n. 2
0
 def test_2(self):
     x, y, z = map(MyVariable, "xyz")
     e = op1(op1(op3(x, y)))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, (op2, "x", "y")), (op4, "x", "y")),
             PatternSub((op3, "x", "y"), (op4, "x", "y")),
             PatternSub((op4, "x", "y"), (op5, "x", "y")),
             PatternSub((op5, "x", "y"), (op6, "x", "y")),
             PatternSub((op6, "x", "y"), (op2, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     assert str(g) == "FunctionGraph(Op2(x, y))"
Esempio n. 3
0
    def test_pickle(self):
        var1 = op1()
        var2 = op2()
        var3 = op1(var1)
        var4 = op2(var3, var2)
        func = FunctionGraph([var1, var2], [var4])

        s = pickle.dumps(func)
        new_func = pickle.loads(s)

        assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs))
        assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs))
        assert all(
            type(a.op) is type(b.op)  # noqa: E721
            for a, b in zip(func.apply_nodes, new_func.apply_nodes)
        )
        assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
Esempio n. 4
0
    def test_one_assert_merge(self):
        """Merge two nodes, one has assert, the other not."""
        x1 = matrix("x1")
        x2 = matrix("x2")
        e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
        g = FunctionGraph([x1, x2], [e], clone=False)
        MergeOptimizer().optimize(g)

        assert g.outputs[0].owner.op == add
        add_inputs = g.outputs[0].owner.inputs
        assert isinstance(add_inputs[0].owner.op, Dot)
        # Confirm that the `Assert`s are correct
        assert_var = add_inputs[0].owner.inputs[0]
        assert_ref = assert_op(x1, (x1 > x2).all())
        assert equal_computations([assert_var], [assert_ref])
        # Confirm the merge
        assert add_inputs[0] is add_inputs[1]
Esempio n. 5
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = at_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    c = at.as_tensor(5)

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(
        fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
    )
Esempio n. 6
0
def test_KanrenRelationSub_dot():
    """Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached."""
    x_at = at.vector("x")
    c_at = at.vector("c")
    d_at = at.vector("d")
    A_at = at.matrix("A")
    B_at = at.matrix("B")

    Z_at = A_at.dot(x_at + B_at.dot(c_at + d_at))

    fgraph = FunctionGraph(outputs=[Z_at], clone=False)

    assert isinstance(fgraph.outputs[0].owner.op, Dot)

    def distributes(in_lv, out_lv):
        return lall(
            # lhs == A * (x + b)
            eq(
                etuple(_dot, var("A"), etuple(at.add, var("x"), var("b"))),
                in_lv,
            ),
            # rhs == A * x + A * b
            eq(
                etuple(
                    at.add,
                    etuple(_dot, var("A"), var("x")),
                    etuple(_dot, var("A"), var("b")),
                ),
                out_lv,
            ),
        )

    distribute_opt = EquilibriumOptimizer([KanrenRelationSub(distributes)],
                                          max_use_ratio=10)

    fgraph_opt = optimize_graph(fgraph, custom_opt=distribute_opt)
    (expr_opt, ) = fgraph_opt.outputs

    assert expr_opt.owner.op == at.add
    assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot)
    assert fgraph_opt.inputs[0] is A_at
    assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A"
    assert expr_opt.owner.inputs[1].owner.op == at.add
    assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
    assert isinstance(expr_opt.owner.inputs[1].owner.inputs[1].owner.op, Dot)
Esempio n. 7
0
def test_fgraph_to_python_multiline_str():
    """Make sure that multiline `__str__` values are supported by `fgraph_to_python`."""

    x = vector("x")
    y = vector("y")

    class TestOp(Op):
        def __init__(self):
            super().__init__()

        def make_node(self, *args):
            return Apply(self, list(args), [x.type() for x in args])

        def perform(self, inputs, outputs):
            for i, inp in enumerate(inputs):
                outputs[i][0] = inp[0]

        def __str__(self):
            return "Test\nOp()"

    @to_python.register(TestOp)
    def to_python_TestOp(op, **kwargs):
        def func(*args, op=op):
            return list(args)

        return func

    op1 = TestOp()
    op2 = TestOp()

    q, r = op1(x, y)
    outs = op2(q + r, q + r)

    out_fg = FunctionGraph([x, y], outs, clone=False)
    assert len(out_fg.outputs) == 2

    out_py = fgraph_to_python(out_fg, to_python)

    out_py_src = inspect.getsource(out_py)

    assert ("""
    # Elemwise{add,no_inplace}(Test
    # Op().0, Test
    # Op().1)
    """ in out_py_src)
Esempio n. 8
0
 def test_1(self):
     x, y, z = map(MyVariable, "xyz")
     # TODO FIXME: These `Op`s don't have matching/consistent `__prop__`s
     # and `__init__`s, so they can't be `etuplized` correctly
     e = op3(op4(x, y))
     g = FunctionGraph([x, y, z], [e])
     # print g
     opt = EquilibriumOptimizer(
         [
             PatternSub((op1, "x", "y"), (op2, "x", "y")),
             PatternSub((op4, "x", "y"), (op1, "x", "y")),
             PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")),
         ],
         max_use_ratio=10,
     )
     opt.optimize(g)
     # print g
     assert str(g) == "FunctionGraph(Op2(x, y))"
Esempio n. 9
0
    def test_remove_output_empty(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var1)
        op3_out = op3(op1_out, var2)
        fg = FunctionGraph([var1, var2], [op3_out], clone=False)

        fg.remove_output(0)
        fg.check_integrity()

        assert fg.inputs == [var1, var2]
        assert not fg.apply_nodes
        assert op1_out not in fg.clients
        assert not any(op1_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
        assert not any(op3_out.owner in clients
                       for clients in sum(fg.clients.values(), []))
Esempio n. 10
0
def test_local_dimshuffle_alloc():

    reshape_dimshuffle = out2in(local_dimshuffle_alloc)

    x = vector("x")

    out = aet.alloc(x, 3, 2).dimshuffle("x", "x", 0, 1)

    g = FunctionGraph([x], [out])
    reshape_dimshuffle(g)

    l = PerformLinker()
    l.accept(g)
    f = l.make_function()

    assert f([3, 4]).ndim == 4

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
Esempio n. 11
0
def test_jax_logp():

    mu = vector("mu")
    mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX)
    tau = vector("tau")
    tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX)
    sigma = vector("sigma")
    sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX)
    value = vector("value")
    value.tag.test_value = np.r_[0.1, -10].astype(config.floatX)

    logp = (-tau * (value - mu)**2 + log(tau / np.pi / 2.0)) / 2.0
    conditions = [sigma > 0]
    alltrue = aet_all([aet_all(1 * val) for val in conditions])
    normal_logp = aet.switch(alltrue, logp, -np.inf)

    fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp])

    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 12
0
def test_jax_SolveTriangular(trans, lower, check_finite):
    x = matrix("x")
    b = vector("b")

    out = at_slinalg.solve_triangular(
        x,
        b,
        trans=trans,
        lower=lower,
        check_finite=check_finite,
    )
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )
Esempio n. 13
0
    def test_both_assert_merge_1(self):
        # Merge two nodes, both have assert on the same node
        # with different conditions.
        x1 = matrix("x1")
        x2 = matrix("x2")
        x3 = matrix("x3")
        e = dot(assert_op(x1, (x1 > x3).all()), x2) + dot(
            assert_op(x1, (x1 > x2).all()), x2)
        g = FunctionGraph([x1, x2, x3], [e])
        MergeOptimizer().optimize(g)
        strg = aesara.printing.debugprint(g, file="str")
        strref1 = """Elemwise{add,no_inplace} [id A] ''   6
 |dot [id B] ''   5
 | |Assert{msg='Aesara Assert failed!'} [id C] ''   4
 | | |x1 [id D]
 | | |All [id E] ''   3
 | | | |Elemwise{gt,no_inplace} [id F] ''   1
 | | |   |x1 [id D]
 | | |   |x3 [id G]
 | | |All [id H] ''   2
 | |   |Elemwise{gt,no_inplace} [id I] ''   0
 | |     |x1 [id D]
 | |     |x2 [id J]
 | |x2 [id J]
 |dot [id B] ''   5
"""
        strref2 = """Elemwise{add,no_inplace} [id A] ''   6
 |dot [id B] ''   5
 | |Assert{msg='Aesara Assert failed!'} [id C] ''   4
 | | |x1 [id D]
 | | |All [id E] ''   3
 | | | |Elemwise{gt,no_inplace} [id F] ''   1
 | | |   |x1 [id D]
 | | |   |x2 [id G]
 | | |All [id H] ''   2
 | |   |Elemwise{gt,no_inplace} [id I] ''   0
 | |     |x1 [id D]
 | |     |x3 [id J]
 | |x2 [id G]
 |dot [id B] ''   5
"""
        # print(strg)
        assert strg == strref1 or strg == strref2, (strg, strref1, strref2)
Esempio n. 14
0
def test_normal_ShapeFeature():
    M_aet = iscalar("M")
    M_aet.tag.test_value = 3
    sd_aet = scalar("sd")
    sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)

    d_rv = normal(aet.ones((M_aet, )), sd_aet, size=(2, M_aet))
    d_rv.tag.test_value

    fg = FunctionGraph(
        [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
        [d_rv],
        clone=False,
        features=[ShapeFeature()],
    )
    s1, s2 = fg.shape_feature.shape_of[d_rv]

    assert get_test_value(s1) == get_test_value(d_rv).shape[0]
    assert get_test_value(s2) == get_test_value(d_rv).shape[1]
Esempio n. 15
0
    def test_replace_test_value(self):

        var1 = MyVariable("var1")
        var1.tag.test_value = 1
        var2 = MyVariable("var2")
        var2.tag.test_value = 2
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var4.tag.test_value = np.array([1, 2])
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var6 = op3()
        var6.tag.test_value = np.array(0)

        assert var6.tag.test_value.shape != var4.tag.test_value.shape

        with pytest.raises(AssertionError, match="The replacement.*"):
            fg.replace(var4, var6)
Esempio n. 16
0
def test_fgraph_to_python_names():
    import inspect

    x = scalar("1x")
    y = scalar("_")
    z = scalar()
    q = scalar("def")
    r = NoneConst

    out_fg = FunctionGraph([x, y, z, q, r], [x, y, z, q, r], clone=False)
    out_jx = fgraph_to_python(out_fg, to_python)

    sig = inspect.signature(out_jx)
    assert (x.auto_name, "_", z.auto_name, q.auto_name,
            r.name) == tuple(sig.parameters.keys())
    assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)

    obj = object()
    assert get_name_for_object(obj) == type(obj).__name__
Esempio n. 17
0
    def test_both_assert_merge_identical(self):
        # Merge two nodes, both have assert on the same node
        # with the same conditions.
        x1 = matrix("x1")
        x2 = matrix("x2")
        e = dot(assert_op(x1, (x1 > x2).all()), x2) + dot(
            assert_op(x1, (x1 > x2).all()), x2)
        g = FunctionGraph([x1, x2], [e], clone=False)
        MergeOptimizer().optimize(g)

        assert g.outputs[0].owner.op == add
        add_inputs = g.outputs[0].owner.inputs
        assert isinstance(add_inputs[0].owner.op, Dot)
        # Confirm that the `Assert`s are correct
        assert_var = add_inputs[0].owner.inputs[0]
        assert_ref = assert_op(x1, (x1 > x2).all())
        assert equal_computations([assert_var], [assert_ref])
        # Confirm the merge
        assert add_inputs[0] is add_inputs[1]
Esempio n. 18
0
def test_dirichlet_ShapeFeature():
    """Make sure `RandomVariable.infer_shape` works with `ShapeFeature`."""
    M_at = iscalar("M")
    M_at.tag.test_value = 2
    N_at = iscalar("N")
    N_at.tag.test_value = 3

    d_rv = dirichlet(at.ones((M_at, N_at)), name="Gamma")

    fg = FunctionGraph(
        outputs=[d_rv],
        clone=False,
        features=[ShapeFeature()],
    )

    s1, s2 = fg.shape_feature.shape_of[d_rv]

    assert M_at in graph_inputs([s1])
    assert N_at in graph_inputs([s2])
Esempio n. 19
0
    def test_remove_client(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        assert fg.variables == {var1, var2, var3, var4, var5}
        assert fg.get_clients(var2) == [
            (var3.owner, 0),
            (var4.owner, 1),
            (var5.owner, 1),
            (var5.owner, 2),
        ]

        fg.remove_client(var2, (var4.owner, 1))

        assert fg.get_clients(var2) == [
            (var3.owner, 0),
            (var5.owner, 1),
            (var5.owner, 2),
        ]

        fg.remove_client(var1, (var3.owner, 1))

        assert fg.get_clients(var1) == []

        assert var4.owner in fg.apply_nodes

        # This next `remove_client` should trigger a complete removal of `var4`'s
        # variables and `Apply` node from the `FunctionGraph`.
        #
        # Also, notice that we already removed `var4` from `var2`'s client list
        # above, so, when we completely remove `var4`, `fg.remove_client` will
        # attempt to remove `(var4.owner, 1)` from `var2`'s client list again.
        # This attempt would previously raise a `ValueError` exception, because
        # the entry was not in the list.
        fg.remove_client(var4, (var5.owner, 0), reason="testing")

        assert var4.owner not in fg.apply_nodes
        assert var4.owner.tag.removed_by == ["testing"]
        assert not any(o in fg.variables for o in var4.owner.outputs)
Esempio n. 20
0
    def test_replace(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(TypeError):
            var0 = MyVariable2("var0")
            # The types don't match and one cannot be converted to the other
            fg.replace(var3, var0)

        # Test a basic replacement
        fg.replace_all([(var3, var1)])
        assert var3 not in fg.variables
        assert fg.apply_nodes == {var4.owner, var5.owner}
        assert var4.owner.inputs == [var1, var2]
Esempio n. 21
0
def test_Stack_updates():

    a = scalar("a")
    a_plus_1 = a + 1
    fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)

    nodes = fg.toposort()
    input_storage, output_storage, storage_map = map_storage(
        fg, nodes, None, None, None)

    compute_map = {}
    for k in storage_map:
        compute_map[k] = [k.owner is None]

    thunks = [
        node.op.make_thunk(node, storage_map, compute_map, [])
        for node in nodes
    ]

    assert a in storage_map

    update_vars = {a: a_plus_1}

    stack_vm = Stack(
        fg,
        fg.apply_nodes,
        thunks,
        [],
        storage_map,
        input_storage,
        output_storage,
        update_vars,
        compute_map,
        False,
    )

    storage_map[a][0] = np.array(1.0, dtype=config.floatX)

    res = stack_vm()

    assert res == [np.array(1.0), np.array(2.0)]
    assert storage_map[a][0] == np.array(2.0)
Esempio n. 22
0
    def test_one_assert_merge(self):
        # Merge two nodes, one has assert, the other not.
        x1 = matrix("x1")
        x2 = matrix("x2")
        e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
        g = FunctionGraph([x1, x2], [e])
        MergeOptimizer().optimize(g)
        strg = aesara.printing.debugprint(g, file="str")
        strref = """Elemwise{add,no_inplace} [id A] ''   4
 |dot [id B] ''   3
 | |Assert{msg='Aesara Assert failed!'} [id C] ''   2
 | | |x1 [id D]
 | | |All [id E] ''   1
 | |   |Elemwise{gt,no_inplace} [id F] ''   0
 | |     |x1 [id D]
 | |     |x2 [id G]
 | |x2 [id G]
 |dot [id B] ''   3
"""
        assert strg == strref, (strg, strref)
Esempio n. 23
0
def test_random_unimplemented():
    class NonExistentRV(RandomVariable):
        name = "non-existent"
        ndim_supp = 0
        ndims_params = []
        dtype = "floatX"

        def __call__(self, size=None, **kwargs):
            return super().__call__(size=size, **kwargs)

        def rng_fn(cls, rng, size):
            return 0

    nonexistentrv = NonExistentRV()
    rng = shared(np.random.RandomState(123))
    out = nonexistentrv(rng=rng)
    fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)

    with pytest.raises(NotImplementedError):
        compare_jax_and_py(fgraph, [])
Esempio n. 24
0
    def test_remove_duplicates(self):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        op1_out = op1(var2, var1)
        op2_out = op2(op1_out, var2)
        op3_out = op3(op2_out, var2, var2)
        fg = FunctionGraph([var1, var1, var2], [op1_out, op3_out, op3_out],
                           clone=False)

        fg.remove_output(2)
        fg.check_integrity()

        assert fg.outputs == [op1_out, op3_out]

        fg.remove_input(0)
        fg.check_integrity()

        assert var1 not in fg.variables
        assert fg.inputs == [var1, var2]
        assert fg.outputs == []
Esempio n. 25
0
def test_dirichlet_ShapeFeature():
    """Make sure `RandomVariable.infer_shape` works with `ShapeFeature`."""
    M_tt = iscalar("M")
    M_tt.tag.test_value = 2
    N_tt = iscalar("N")
    N_tt.tag.test_value = 3

    d_rv = dirichlet(aet.ones((M_tt, N_tt)), name="Gamma")

    fg = FunctionGraph(
        [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
        [d_rv],
        clone=False,
        features=[ShapeFeature()],
    )

    s1, s2 = fg.shape_feature.shape_of[d_rv]

    assert M_tt in graph_inputs([s1])
    assert N_tt in graph_inputs([s2])
Esempio n. 26
0
    def test_composite_printing(self):
        x, y, z = floats("xyz")
        e0 = x + y + z
        e1 = x + y * z
        e2 = x / y
        e3 = x // 5
        e4 = -x
        e5 = x - y
        e6 = x**y + (-z)
        e7 = x % 3
        C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7])
        c = C.make_node(x, y, z)
        g = FunctionGraph([x, y, z], c.outputs)
        DualLinker().accept(g).make_function()

        assert str(g) == ("FunctionGraph(*1 -> Composite{((i0 + i1) + i2),"
                          " (i0 + (i1 * i2)), (i0 / i1), "
                          "(i0 // ScalarConstant{5}), "
                          "(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),"
                          " (i0 % ScalarConstant{3})}(x, y, z), "
                          "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)")
Esempio n. 27
0
def test_duallinker_mismatch():
    x, y, z = inputs()
    # bad_sub is correct in C but erroneous in Python
    e = bad_sub(mul(x, y), mul(y, z))
    g = FunctionGraph([x, y, z], [e])
    lnk = DualLinker(checker=_my_checker).accept(g)
    fn = make_function(lnk)

    # good
    assert make_function(CLinker().accept(g))(1.0, 2.0, 3.0) == -4.0
    # good
    assert make_function(OpWiseCLinker().accept(g))(1.0, 2.0, 3.0) == -4.0

    # (purposely) wrong
    assert make_function(PerformLinker().accept(g))(1.0, 2.0, 3.0) == -10.0

    with pytest.raises(MyExc):
        # this runs OpWiseCLinker and PerformLinker in parallel and feeds
        # variables of matching operations to _my_checker to verify that they
        # are the same.
        fn(1.0, 2.0, 3.0)
Esempio n. 28
0
    def test_both_assert_merge_2_reverse(self):
        # Test case "test_both_assert_merge_2" but in reverse order
        x1 = matrix("x1")
        x2 = matrix("x2")
        x3 = matrix("x3")
        e = dot(x1, assert_op(x2, (x2 > x3).all())) + dot(
            assert_op(x1, (x1 > x3).all()), x2)
        g = FunctionGraph([x1, x2, x3], [e], clone=False)
        MergeOptimizer().optimize(g)

        assert g.outputs[0].owner.op == add
        add_inputs = g.outputs[0].owner.inputs
        assert isinstance(add_inputs[0].owner.op, Dot)
        # Confirm that the `Assert`s are correct
        assert_var_1, assert_var_2 = add_inputs[0].owner.inputs
        assert_ref_1 = assert_op(x2, (x2 > x3).all())
        assert equal_computations([assert_var_1], [assert_ref_1])
        assert_ref_2 = assert_op(x1, (x1 > x3).all())
        assert equal_computations([assert_var_2], [assert_ref_2])
        # Confirm the merge
        assert add_inputs[0] is add_inputs[1]
Esempio n. 29
0
def get_jaxified_graph(
    inputs: Optional[List[TensorVariable]] = None,
    outputs: Optional[List[TensorVariable]] = None,
) -> List[TensorVariable]:
    """Compile an Aesara graph into an optimized JAX function"""

    graph = _replace_shared_variables(outputs)

    fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
    # We need to add a Supervisor to the fgraph to be able to run the
    # JAX sequential optimizer without warnings. We made sure there
    # are no mutable input variables, so we only need to check for
    # "destroyers". This should be automatically handled by Aesara
    # once https://github.com/aesara-devs/aesara/issues/637 is fixed.
    fgraph.attach_feature(
        Supervisor(input for input in fgraph.inputs if not (
            hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))))
    mode.JAX.optimizer.optimize(fgraph)

    # We now jaxify the optimized fgraph
    return jax_funcify(fgraph)
Esempio n. 30
0
def optimize_graph(
    fgraph: Union[Variable, FunctionGraph],
    include: Sequence[str] = ["canonicalize"],
    custom_opt=None,
    clone: bool = False,
    **kwargs
) -> Union[Variable, FunctionGraph]:
    """Easily optimize a graph.

    Parameters
    ==========
    fgraph:
        A ``FunctionGraph`` or ``Variable`` to be optimized.
    include:
        String names of the optimizations to be applied.  The default
        optimization is ``"canonicalization"``.
    custom_opt:
        A custom ``Optimization`` to also be applied.
    clone:
        Whether or not to clone the input graph before optimizing.
    **kwargs:
        Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object.
    """
    from aesara.compile import optdb

    return_only_out = False
    if not isinstance(fgraph, FunctionGraph):
        fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
        return_only_out = True

    canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs))
    _ = canonicalize_opt.optimize(fgraph)

    if custom_opt:
        custom_opt.optimize(fgraph)

    if return_only_out:
        return fgraph.outputs[0]
    else:
        return fgraph