Example #1
0
def _use_cholesky(u, params):
    """Uses Cholesky decomposition."""
    a, b, c = params
    _, n = u.shape
    x = c * u.T.conj() @ u + jnp.eye(n)

    # `y` is lower triangular.
    y = lax_linalg.cholesky(x, symmetrize_input=False)

    z = lax_linalg.triangular_solve(y,
                                    u.T,
                                    left_side=True,
                                    lower=True,
                                    conjugate_a=True).conj()

    z = lax_linalg.triangular_solve(y,
                                    z,
                                    left_side=True,
                                    lower=True,
                                    transpose_a=True,
                                    conjugate_a=True).T.conj()

    e = b / c
    u = e * u + (a - e) * z
    return u
Example #2
0
def _use_cholesky(u, m, n, params):
    """QDWH iteration using Cholesky decomposition.

  Args:
  u: a matrix, with static (padded) shape M x N
  m, n: the dynamic shape of the matrix, where m <= M and n <= N.
  params: the QDWH parameters.
  """
    a, b, c = params
    _, N = u.shape
    x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u))
    # Pads the lower-right corner with the identity matrix to prevent the Cholesky
    # decomposition from failing due to the matrix not being PSD if padded with
    # zeros.
    x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype))

    # `y` is lower triangular.
    y = lax_linalg.cholesky(x, symmetrize_input=False)

    z = lax_linalg.triangular_solve(y,
                                    u.T,
                                    left_side=True,
                                    lower=True,
                                    conjugate_a=True).conj()

    z = lax_linalg.triangular_solve(y,
                                    z,
                                    left_side=True,
                                    lower=True,
                                    transpose_a=True,
                                    conjugate_a=True).T.conj()

    e = b / c
    u = e * u + (a - e) * z
    return u
Example #3
0
def cholesky(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    return lax_linalg.cholesky(a)
Example #4
0
def _cholesky(a, lower):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)),
                            symmetrize_input=False)
    return l if lower else jnp.conj(_T(l))
Example #5
0
def cholesky(a):
  a, = _promote_dtypes_inexact(jnp.asarray(a))
  return lax_linalg.cholesky(a)