Exemple #1
0
 def test_gpu_matrix_inverse_inplace_opt(self):
     A = fmatrix("A")
     fn = aesara.function([A], matrix_inverse(A), mode=mode_with_gpu)
     assert any([
         node.op.inplace for node in fn.maker.fgraph.toposort()
         if isinstance(node.op, GpuMagmaMatrixInverse)
     ])
Exemple #2
0
def test_inverse_singular():
    singular = np.array([[1, 0, 0]] + [[0, 1, 0]] * 2,
                        dtype=aesara.config.floatX)
    a = tensor.matrix()
    f = function([a], matrix_inverse(a))
    with pytest.raises(np.linalg.LinAlgError):
        f(singular)
Exemple #3
0
def test_transinv_to_invtrans():
    X = matrix("X")
    Y = matrix_inverse(X)
    Z = Y.transpose()
    f = aesara.function([X], Z)
    if config.mode != "FAST_COMPILE":
        for node in f.maker.fgraph.toposort():
            if isinstance(node.op, MatrixInverse):
                assert isinstance(node.inputs[0].owner.op, DimShuffle)
            if isinstance(node.op, DimShuffle):
                assert node.inputs[0].name == "X"
Exemple #4
0
def test_rop_lop():
    mx = matrix("mx")
    mv = matrix("mv")
    v = vector("v")
    y = matrix_inverse(mx).sum(axis=0)

    yv = aesara.gradient.Rop(y, mx, mv)
    rop_f = function([mx, mv], yv)

    sy, _ = aesara.scan(
        lambda i, y, x, v: (aesara.gradient.grad(y[i], x) * v).sum(),
        sequences=aet.arange(y.shape[0]),
        non_sequences=[y, mx, mv],
    )
    scan_f = function([mx, mv], sy)

    rng = np.random.default_rng(utt.fetch_seed())
    vx = np.asarray(rng.standard_normal((4, 4)), aesara.config.floatX)
    vv = np.asarray(rng.standard_normal((4, 4)), aesara.config.floatX)

    v1 = rop_f(vx, vv)
    v2 = scan_f(vx, vv)

    assert _allclose(v1, v2), f"ROP mismatch: {v1} {v2}"

    raised = False
    try:
        aesara.gradient.Rop(aesara.clone_replace(y, replace={mx: break_op(mx)}), mx, mv)
    except ValueError:
        raised = True
    if not raised:
        raise Exception(
            "Op did not raised an error even though the function"
            " is not differentiable"
        )

    vv = np.asarray(rng.uniform(size=(4,)), aesara.config.floatX)
    yv = aesara.gradient.Lop(y, mx, v)
    lop_f = function([mx, v], yv)

    sy = aesara.gradient.grad((v * y).sum(), mx)
    scan_f = function([mx, v], sy)

    v1 = lop_f(vx, vv)
    v2 = scan_f(vx, vv)
    assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}"
Exemple #5
0
 def grad(self, inputs, g_outputs):
     [gz] = g_outputs
     [x] = inputs
     return [gz * matrix_inverse(x).T]
Exemple #6
0
def test_matrix_inverse_solve():
    A = dmatrix("A")
    b = dmatrix("b")
    node = matrix_inverse(A).dot(b).owner
    [out] = inv_as_solve.transform(None, node)
    assert isinstance(out.owner.op, Solve)
Exemple #7
0
def test_jax_basic():
    rng = np.random.default_rng(28494)

    x = matrix("x")
    y = matrix("y")
    b = vector("b")

    # `ScalarOp`
    z = cosh(x**2 + y / 3.0)

    # `[Inc]Subtensor`
    out = aet_subtensor.set_subtensor(z[0], -10.0)
    out = aet_subtensor.inc_subtensor(out[0, 1], 2.0)
    out = out[:5, :3]

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        np.tile(np.arange(10), (10, 1)).astype(config.floatX),
        np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
    ]
    (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals)

    # Confirm that the `Subtensor` slice operations are correct
    assert jax_res.shape == (5, 3)

    # Confirm that the `IncSubtensor` operations are correct
    assert jax_res[0, 0] == -10.0
    assert jax_res[0, 1] == -8.0

    out = clip(x, y, 5)
    out_fg = FunctionGraph([x, y], [out])
    compare_jax_and_py(out_fg, test_input_vals)

    out = aet.diagonal(x, 0)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_slinalg.cholesky(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    # not sure why this isn't working yet with lower=False
    out = aet_slinalg.Cholesky(lower=False)(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    out = aet_slinalg.solve(x, b)
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )

    out = aet.diag(b)
    out_fg = FunctionGraph([b], [out])
    compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])

    out = aet_nlinalg.det(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_nlinalg.matrix_inverse(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )