def L_op(self, inputs, outputs, gradients): """ Cholesky decomposition reverse-mode gradient update. Symbolic expression for reverse-mode Cholesky gradient taken from [#]_ References ---------- .. [#] I. Murray, "Differentiation of the Cholesky decomposition", http://arxiv.org/abs/1602.07527 """ dz = gradients[0] chol_x = outputs[0] # Replace the cholesky decomposition with 1 if there are nans # or solve_upper_triangular will throw a ValueError. if self.on_error == "nan": ok = ~atm.any(atm.isnan(chol_x)) chol_x = at.switch(ok, chol_x, 1) dz = at.switch(ok, dz, 1) # deal with upper triangular by converting to lower triangular if not self.lower: chol_x = chol_x.T dz = dz.T def tril_and_halve_diagonal(mtx): """Extracts lower triangle of square matrix and halves diagonal.""" return at.tril(mtx) - at.diag(at.diagonal(mtx) / 2.0) def conjugate_solve_triangular(outer, inner): """Computes L^{-T} P L^{-1} for lower-triangular L.""" return solve_upper_triangular( outer.T, solve_upper_triangular(outer.T, inner.T).T ) s = conjugate_solve_triangular( chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz)) ) if self.lower: grad = at.tril(s + s.T) - at.diag(at.diagonal(s)) else: grad = at.triu(s + s.T) - at.diag(at.diagonal(s)) if self.on_error == "nan": return [at.switch(ok, grad, np.nan)] else: return [grad]
def L_op(self, inputs, outputs, gradients): # Modified from aesara/tensor/slinalg.py # No handling for on_error = 'nan' dz = gradients[0] chol_x = outputs[0] # this is for nan mode # # ok = ~tm.any(tm.isnan(chol_x)) # chol_x = aet.switch(ok, chol_x, 1) # dz = aet.switch(ok, dz, 1) # deal with upper triangular by converting to lower triangular if not self.lower: chol_x = chol_x.T dz = dz.T def tril_and_halve_diagonal(mtx): """Extracts lower triangle of square matrix and halves diagonal.""" return aet.tril(mtx) - aet.diag(aet.diagonal(mtx) / 2.0) def conjugate_solve_triangular(outer, inner): """Computes L^{-T} P L^{-1} for lower-triangular L.""" return gpu_solve_upper_triangular( outer.T, gpu_solve_upper_triangular(outer.T, inner.T).T) s = conjugate_solve_triangular( chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz))) if self.lower: grad = aet.tril(s + s.T) - aet.diag(aet.diagonal(s)) else: grad = aet.triu(s + s.T) - aet.diag(aet.diagonal(s)) return [grad]
def grad(self, inp, cost_grad): """ Notes ----- The gradient is currently implemented for matrices only. """ a, val = inp grad = cost_grad[0] if a.dtype.startswith("complex"): return [None, None] elif a.ndim > 2: raise NotImplementedError("%s: gradient is currently implemented" " for matrices only" % self.__class__.__name__) wr_a = fill_diagonal(grad, 0) # valid for any number of dimensions # diag is only valid for matrices wr_val = aet.diag(grad).sum() return [wr_a, wr_val]
def tril_and_halve_diagonal(mtx): """Extracts lower triangle of square matrix and halves diagonal.""" return aet.tril(mtx) - aet.diag(aet.diagonal(mtx) / 2.0)
def test_jax_basic(): rng = np.random.default_rng(28494) x = matrix("x") y = matrix("y") b = vector("b") # `ScalarOp` z = cosh(x**2 + y / 3.0) # `[Inc]Subtensor` out = aet_subtensor.set_subtensor(z[0], -10.0) out = aet_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) # Confirm that the `IncSubtensor` operations are correct assert jax_res[0, 0] == -10.0 assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py(out_fg, test_input_vals) out = aet.diagonal(x, 0) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) # not sure why this isn't working yet with lower=False out = aet_slinalg.Cholesky(lower=False)(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) out = aet_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], ) out = aet.diag(b) out_fg = FunctionGraph([b], [out]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) out = aet_nlinalg.det(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_nlinalg.matrix_inverse(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], )