Beispiel #1
0
 def test_infer_shape(self, b_shape):
     rng = np.random.default_rng(utt.fetch_seed())
     A = matrix()
     b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
     b = aet.as_tensor_variable(b_val).type()
     self._compile_and_check(
         [A, b],
         [solve_triangular(A, b)],
         [
             np.asarray(rng.random((5, 5)), dtype=config.floatX),
             b_val,
         ],
         SolveTriangular,
         warn=False,
     )
Beispiel #2
0
def test_jax_SolveTriangular(trans, lower, check_finite):
    x = matrix("x")
    b = vector("b")

    out = at_slinalg.solve_triangular(
        x,
        b,
        trans=trans,
        lower=lower,
        check_finite=check_finite,
    )
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )
Beispiel #3
0
    def test_correctness(self, lower):
        rng = np.random.default_rng(utt.fetch_seed())

        b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)

        A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
        A_val = np.dot(A_val.transpose(), A_val)

        C_val = scipy.linalg.cholesky(A_val, lower=lower)

        A = matrix()
        b = matrix()

        cholesky = Cholesky(lower=lower)
        C = cholesky(A)
        y_lower = solve_triangular(C, b, lower=lower)
        lower_solve_func = aesara.function([C, b], y_lower)

        assert np.allclose(
            scipy.linalg.solve_triangular(C_val, b_val, lower=lower),
            lower_solve_func(C_val, b_val),
        )