def _lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots, permutation = lu_p.bind(a) a_shape = jnp.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) batch_dims = a_shape[:-2] iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) x = a_dot[iotas[:-1] + (permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = jnp._constant_like(lu, 0) l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding) l = l + jnp.eye(m, m, dtype=dtype) u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, unit_diagonal=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = jnp.matmul(l, jnp.tril(lau, -1)) u_dot = jnp.matmul(jnp.triu(lau), u) lu_dot = l_dot + u_dot return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots), ad_util.Zero.from_value(permutation))
def cholesky(x, symmetrize_input: bool = True): """Cholesky decomposition. Computes the Cholesky decomposition .. math:: A = L . L^H of square matrices, :math:`A`, such that :math:`L` is lower triangular. The matrices of :math:`A` must be positive-definite and either Hermitian, if complex, or symmetric, if real. Args: x: A batch of square Hermitian (symmetric if real) positive-definite matrices with shape ``[..., n, n]``. symmetrize_input: If ``True``, the matrix is symmetrized before Cholesky decomposition by computing :math:`\\frac{1}{2}(x + x^H)`. If ``False``, only the lower triangle of ``x`` is used; the upper triangle is ignored and not accessed. Returns: The Cholesky decomposition as a matrix with the same dtype as ``x`` and shape ``[..., n, n]``. If Cholesky decomposition fails, returns a matrix full of NaNs. The behavior on failure may change in the future. """ if symmetrize_input: x = symmetrize(x) return jnp.tril(cholesky_p.bind(x))
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): m, n = b.shape[-2:] k = 1 if unit_diagonal else 0 g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k) g_a = lax.neg(g_a) g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a g_a = jnp.conj(g_a) if conjugate_a else g_a dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) def a_inverse(rhs): return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a, unit_diagonal) # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs # for matrix/vector inputs). Order these operations in whichever order is # cheaper. if left_side: assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == ( m, n) if m > n: return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X) else: return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X else: assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == ( m, n) if m < n: return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1} else: return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
def cholesky_jvp_rule(primals, tangents): x, = primals sigma_dot, = tangents L = jnp.tril(cholesky_p.bind(x)) # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf def phi(X): l = jnp.tril(X) return l / (jnp._constant_like(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, conjugate_a=True, lower=True) L_dot = lax.batch_matmul(L, phi( triangular_solve(L, tmp, left_side=True, transpose_a=False, lower=True)), precision=lax.Precision.HIGHEST) return L, L_dot
def _lu(a, permute_l): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def _lu(a, permute_l): a, = _promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real( jnp.array(permutation[None, :] == jnp.arange( m, dtype=permutation.dtype)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def qr_jvp_rule(primals, tangents, full_matrices): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals dx, = tangents q, r = qr_p.bind(x, full_matrices=False) *_, m, n = x.shape if full_matrices or m < n: raise NotImplementedError( "Unimplemented case of QR decomposition derivative") dx_rinv = triangular_solve(r, dx) # Right side solve by default qt_dx_rinv = jnp.matmul(_H(q), dx_rinv) qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv)) dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv dr = jnp.matmul(qt_dx_rinv - do, r) return (q, r), (dq, dr)
def tril(m, k=0): return jnp.tril(m, k)
def phi(X): l = jnp.tril(X) return l / (jnp._constant_like(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype))
def cholesky(x, symmetrize_input=True): if symmetrize_input: x = symmetrize(x) return jnp.tril(cholesky_p.bind(x))