def test_cholesky_indef(): x = matrix() mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) cholesky = Cholesky(lower=True, on_error="raise") chol_f = function([x], cholesky(x)) with pytest.raises(scipy.linalg.LinAlgError): chol_f(mat) cholesky = Cholesky(lower=True, on_error="nan") chol_f = function([x], cholesky(x)) assert np.all(np.isnan(chol_f(mat)))
def test_cholesky_grad_indef(): scipy = pytest.importorskip("scipy") x = matrix() mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) cholesky = Cholesky(lower=True, on_error="raise") chol_f = function([x], grad(cholesky(x).sum(), [x])) with pytest.raises(scipy.linalg.LinAlgError): chol_f(mat) cholesky = Cholesky(lower=True, on_error="nan") chol_f = function([x], grad(cholesky(x).sum(), [x])) assert np.all(np.isnan(chol_f(mat)))
def test_cholesky_grad_indef(): x = aesara.tensor.matrix() matrix = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) cholesky = GpuCholesky(lower=True) chol_f = aesara.function([x], aesara.tensor.grad(cholesky(x).sum(), [x])) with pytest.raises(LinAlgError): chol_f(matrix)
def test_cholesky_grad(): rng = np.random.default_rng(utt.fetch_seed()) r = rng.standard_normal((5, 5)).astype(config.floatX) # The dots are inside the graph since Cholesky needs separable matrices # Check the default. utt.verify_grad(lambda r: cholesky(r.dot(r.T)), [r], 3, rng) # Explicit lower-triangular. utt.verify_grad( lambda r: Cholesky(lower=True)(r.dot(r.T)), [r], 3, rng, abs_tol=0.05, rel_tol=0.05, ) # Explicit upper-triangular. utt.verify_grad( lambda r: Cholesky(lower=False)(r.dot(r.T)), [r], 3, rng, abs_tol=0.05, rel_tol=0.05, )
def test_gpu_cholesky_opt(self): A = matrix("A", dtype="float64") fn = aesara.function([A], cholesky(A), mode=mode_with_gpu) assert any([ isinstance(node.op, GpuCholesky) for node in fn.maker.fgraph.toposort() ])
def test_gpu_cholesky_opt(self): A = aesara.tensor.matrix("A", dtype="float32") fn = aesara.function([A], cholesky(A), mode=mode_with_gpu.excluding("cusolver")) assert any( [ isinstance(node.op, GpuMagmaCholesky) for node in fn.maker.fgraph.toposort() ] )
def psd_solve_with_chol(fgraph, node): if node.op == solve: A, b = node.inputs # result is solution Ax=b if is_psd(A): L = cholesky(A) # N.B. this can be further reduced to a yet-unwritten cho_solve Op # __if__ no other Op makes use of the the L matrix during the # stabilization Li_b = Solve("lower_triangular")(L, b) x = Solve("upper_triangular")(L.T, Li_b) return [x]
def psd_solve_with_chol(fgraph, node): """ This utilizes a boolean `psd` tag on matrices. """ if isinstance(node.op, Solve): A, b = node.inputs # result is solution Ax=b if getattr(A.tag, "psd", None) is True: L = cholesky(A) # N.B. this can be further reduced to a yet-unwritten cho_solve Op # __if__ no other Op makes use of the the L matrix during the # stabilization Li_b = Solve(assume_a="sym", lower=True)(L, b) x = Solve(assume_a="sym", lower=False)(L.T, Li_b) return [x]
def test_cholesky_grad(): pytest.importorskip("scipy") rng = np.random.RandomState(utt.fetch_seed()) r = rng.randn(5, 5).astype(config.floatX) # The dots are inside the graph since Cholesky needs separable matrices # Check the default. utt.verify_grad(lambda r: cholesky(r.dot(r.T)), [r], 3, rng) # Explicit lower-triangular. utt.verify_grad(lambda r: Cholesky(lower=True)(r.dot(r.T)), [r], 3, rng) # Explicit upper-triangular. utt.verify_grad(lambda r: Cholesky(lower=False)(r.dot(r.T)), [r], 3, rng)
def test_cholesky_and_cholesky_grad_shape(): rng = np.random.default_rng(utt.fetch_seed()) x = matrix() for l in (cholesky(x), Cholesky(lower=True)(x), Cholesky(lower=False)(x)): f_chol = aesara.function([x], l.shape) g = aesara.gradient.grad(l.sum(), x) f_cholgrad = aesara.function([x], g.shape) topo_chol = f_chol.maker.fgraph.toposort() topo_cholgrad = f_cholgrad.maker.fgraph.toposort() if config.mode != "FAST_COMPILE": assert sum([node.op.__class__ == Cholesky for node in topo_chol]) == 0 assert ( sum([node.op.__class__ == CholeskyGrad for node in topo_cholgrad]) == 0 ) for shp in [2, 3, 5]: m = np.cov(rng.standard_normal((shp, shp + 10))).astype(config.floatX) np.testing.assert_equal(f_chol(m), (shp, shp)) np.testing.assert_equal(f_cholgrad(m), (shp, shp))
def test_cholesky(): rng = np.random.default_rng(utt.fetch_seed()) r = rng.standard_normal((5, 5)).astype(config.floatX) pd = np.dot(r, r.T) x = matrix() chol = cholesky(x) # Check the default. ch_f = function([x], chol) check_lower_triangular(pd, ch_f) # Explicit lower-triangular. chol = Cholesky(lower=True)(x) ch_f = function([x], chol) check_lower_triangular(pd, ch_f) # Explicit upper-triangular. chol = Cholesky(lower=False)(x) ch_f = function([x], chol) check_upper_triangular(pd, ch_f) chol = Cholesky(lower=False, on_error="nan")(x) ch_f = function([x], chol) check_upper_triangular(pd, ch_f)
def test_cholesky_and_cholesky_grad_shape(): pytest.importorskip("scipy") rng = np.random.RandomState(utt.fetch_seed()) x = tensor.matrix() for l in (cholesky(x), Cholesky(lower=True)(x), Cholesky(lower=False)(x)): f_chol = aesara.function([x], l.shape) g = tensor.grad(l.sum(), x) f_cholgrad = aesara.function([x], g.shape) topo_chol = f_chol.maker.fgraph.toposort() topo_cholgrad = f_cholgrad.maker.fgraph.toposort() if config.mode != "FAST_COMPILE": assert sum([node.op.__class__ == Cholesky for node in topo_chol]) == 0 assert ( sum([node.op.__class__ == CholeskyGrad for node in topo_cholgrad]) == 0 ) for shp in [2, 3, 5]: m = np.cov(rng.randn(shp, shp + 10)).astype(config.floatX) np.testing.assert_equal(f_chol(m), (shp, shp)) np.testing.assert_equal(f_cholgrad(m), (shp, shp))
def test_cholesky(): pytest.importorskip("scipy") rng = np.random.RandomState(utt.fetch_seed()) r = rng.randn(5, 5).astype(config.floatX) pd = np.dot(r, r.T) x = matrix() chol = cholesky(x) # Check the default. ch_f = function([x], chol) check_lower_triangular(pd, ch_f) # Explicit lower-triangular. chol = Cholesky(lower=True)(x) ch_f = function([x], chol) check_lower_triangular(pd, ch_f) # Explicit upper-triangular. chol = Cholesky(lower=False)(x) ch_f = function([x], chol) check_upper_triangular(pd, ch_f) chol = Cholesky(lower=False, on_error="nan")(x) ch_f = function([x], chol) check_upper_triangular(pd, ch_f)
def test_correctness(self, lower): rng = np.random.default_rng(utt.fetch_seed()) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_val = np.dot(A_val.transpose(), A_val) C_val = scipy.linalg.cholesky(A_val, lower=lower) A = matrix() b = matrix() cholesky = Cholesky(lower=lower) C = cholesky(A) y_lower = solve_triangular(C, b, lower=lower) lower_solve_func = aesara.function([C, b], y_lower) assert np.allclose( scipy.linalg.solve_triangular(C_val, b_val, lower=lower), lower_solve_func(C_val, b_val), )
def test_jax_basic(): rng = np.random.default_rng(28494) x = matrix("x") y = matrix("y") b = vector("b") # `ScalarOp` z = cosh(x**2 + y / 3.0) # `[Inc]Subtensor` out = aet_subtensor.set_subtensor(z[0], -10.0) out = aet_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) # Confirm that the `IncSubtensor` operations are correct assert jax_res[0, 0] == -10.0 assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py(out_fg, test_input_vals) out = aet.diagonal(x, 0) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) # not sure why this isn't working yet with lower=False out = aet_slinalg.Cholesky(lower=False)(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) out = aet_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], ) out = aet.diag(b) out_fg = FunctionGraph([b], [out]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) out = aet_nlinalg.det(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_nlinalg.matrix_inverse(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], )