def _use_cholesky(u, params): """Uses Cholesky decomposition.""" a, b, c = params _, n = u.shape x = c * u.T.conj() @ u + jnp.eye(n) # `y` is lower triangular. y = lax_linalg.cholesky(x, symmetrize_input=False) z = lax_linalg.triangular_solve(y, u.T, left_side=True, lower=True, conjugate_a=True).conj() z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True, transpose_a=True, conjugate_a=True).T.conj() e = b / c u = e * u + (a - e) * z return u
def _use_cholesky(u, m, n, params): """QDWH iteration using Cholesky decomposition. Args: u: a matrix, with static (padded) shape M x N m, n: the dynamic shape of the matrix, where m <= M and n <= N. params: the QDWH parameters. """ a, b, c = params _, N = u.shape x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u)) # Pads the lower-right corner with the identity matrix to prevent the Cholesky # decomposition from failing due to the matrix not being PSD if padded with # zeros. x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype)) # `y` is lower triangular. y = lax_linalg.cholesky(x, symmetrize_input=False) z = lax_linalg.triangular_solve(y, u.T, left_side=True, lower=True, conjugate_a=True).conj() z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True, transpose_a=True, conjugate_a=True).T.conj() e = b / c u = e * u + (a - e) * z return u
def cholesky(a): a = _promote_arg_dtypes(jnp.asarray(a)) return lax_linalg.cholesky(a)
def _cholesky(a, lower): a, = _promote_dtypes_inexact(jnp.asarray(a)) l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)), symmetrize_input=False) return l if lower else jnp.conj(_T(l))
def cholesky(a): a, = _promote_dtypes_inexact(jnp.asarray(a)) return lax_linalg.cholesky(a)