def debug_triangular_solve(): from jax.lax_linalg import triangular_solve from scipy.linalg import solve_triangular from jax.scipy.linalg import solve_triangular as solve_triangular_jax import jax.numpy as jnp ndims = 2 A = jnp.diag(jnp.ones(ndims)) A = jnp.where(A == 0., 0.95, A) b = jnp.ones(ndims) L = jnp.linalg.cholesky(A) assert jnp.all(L @ L.T == A) x = jnp.linalg.solve(L, b) print("Solving L.x = b with scipy") print("x should be {}".format(x)) scipy_x = solve_triangular(L, b, lower=True) assert jnp.all(scipy_x == x) print("Works as expected!") print("Now note JAX's solution to L^T.x = b corresponds to scipy's L.x = b") jax_x = triangular_solve(L, b, lower=True, transpose_a=True) assert jnp.all(jax_x==scipy_x) print("Likewise, JAX's solution to L.x = b corresponds to scipy's L^T.x = b") assert jnp.all(triangular_solve(L, b, lower=True) == solve_triangular(L, b, lower=True, trans=1)) print("Note, I have not tested for the L^H.x=b case.") jax_x = solve_triangular_jax(L, b, lower=True) assert jnp.all(scipy_x == jax_x)
def _cho_solve(c, b, lower): c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b)) np_linalg._check_solve_shapes(c, b) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=not lower, conjugate_a=not lower) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=lower, conjugate_a=lower) return b
def _solve_triangular(a, b, trans, lower, unit_diagonal): if trans == 0 or trans == "N": transpose_a, conjugate_a = False, False elif trans == 1 or trans == "T": transpose_a, conjugate_a = True, False elif trans == 2 or trans == "C": transpose_a, conjugate_a = True, True else: raise ValueError("Invalid 'trans' value {}".format(trans)) a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) # lax_linalg.triangular_solve only supports matrix 'b's at the moment. b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1 if b_is_vector: b = b[..., None] out = lax_linalg.triangular_solve(a, b, left_side=True, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) if b_is_vector: return out[..., 0] else: return out
def testLaxLinalgTriangularSolve(self): a = onp.random.RandomState(0).randn(4, 10, 4).astype(onp.float32) a += onp.eye(4, dtype=np.float32)[:, None, :] b = onp.random.RandomState(0).randn(5, 4, 10).astype(onp.float32) ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b) expected = onp.stack( [lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b) expected = onp.stack( [lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0]) expected = onp.stack( [lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True)
def logpdf(x, mean, cov): x, mean, cov = _promote_dtypes_inexact(x, mean, cov) if not mean.shape: return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: n = mean.shape[-1] if not np.shape(cov): y = x - mean return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: if cov.ndim < 2 or cov.shape[-2:] != (n, n): raise ValueError( "multivariate_normal.logpdf got incompatible shapes") L = cholesky(cov) y = triangular_solve(L, x - mean, lower=True, transpose_a=True) return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) - n / 2 * np.log(2 * np.pi) - jnp.log(L.diagonal()).sum())
def multivariate_gaussian_logp(x, mu, cov): """ compute the log probability of a multivariate gaussian given x, mu, and a covariance matrix arguments x : jnp.array(N) position in latent space mu : jnp.array(N) mean of Gaussian cov : jnp.array(N,N) covariance of gaussian returns logp : float log probability of N(x | mu, cov) """ from jax.lax_linalg import cholesky, triangular_solve n = mu.shape[0] L = cholesky(cov) y = triangular_solve(L, x - mu, lower=True, transpose_a=True) return -1. / 2. * jnp.einsum('...i,...i->...', y, y) - n / 2. * jnp.log( 2 * jnp.pi) - jnp.log(L.diagonal()).sum()
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], )) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate( (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x