def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() b_val = np.asarray(rng.random(b_shape), dtype=config.floatX) b = aet.as_tensor_variable(b_val).type() self._compile_and_check( [A, b], [solve_triangular(A, b)], [ np.asarray(rng.random((5, 5)), dtype=config.floatX), b_val, ], SolveTriangular, warn=False, )
def test_jax_SolveTriangular(trans, lower, check_finite): x = matrix("x") b = vector("b") out = at_slinalg.solve_triangular( x, b, trans=trans, lower=lower, check_finite=check_finite, ) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], )
def test_correctness(self, lower): rng = np.random.default_rng(utt.fetch_seed()) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_val = np.dot(A_val.transpose(), A_val) C_val = scipy.linalg.cholesky(A_val, lower=lower) A = matrix() b = matrix() cholesky = Cholesky(lower=lower) C = cholesky(A) y_lower = solve_triangular(C, b, lower=lower) lower_solve_func = aesara.function([C, b], y_lower) assert np.allclose( scipy.linalg.solve_triangular(C_val, b_val, lower=lower), lower_solve_func(C_val, b_val), )