Ejemplo n.º 1
0
def eigh(H,
         *,
         precision="float32",
         termination_size=256,
         n=None,
         sort_eigenvalues=True):
    """ Computes the eigendecomposition of the symmetric/Hermitian matrix H.

  Args:
    H: The `n x n` Hermitian input, padded to `N x N`.
    precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
    termination_size: Recursion ends once the blocks reach this linear size.
    n: the true (dynamic) size of the matrix.
    sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
      highest.
  Returns:
    vals: The `n` eigenvalues of `H`.
    vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
      of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
      to numerical error.
  """
    M, N = H.shape
    if M != N:
        raise TypeError(f"Input H of shape {H.shape} must be square.")

    if N <= termination_size:
        if n is not None:
            H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype))
        return lax_linalg.eigh_jacobi(H, sort_eigenvalues=sort_eigenvalues)

    # TODO(phawkins): consider rounding N up to a larger size to maximize reuse
    # between matrices.

    n = N if n is None else n
    with jax.default_matmul_precision(precision):
        eig_vals, eig_vecs = _eigh_work(H,
                                        n,
                                        termination_size=termination_size)
    eig_vals = _mask(jnp.real(eig_vals), (n, ), jnp.nan)
    if sort_eigenvalues:
        sort_idxs = jnp.argsort(eig_vals)
        eig_vals = eig_vals[sort_idxs]
        eig_vecs = eig_vecs[:, sort_idxs]
    return eig_vals, eig_vecs
Ejemplo n.º 2
0
def _qdwh_svd(a: Any,
              full_matrices: bool,
              compute_uv: bool = True,
              hermitian: bool = False,
              max_iterations: int = 10) -> Union[Any, Sequence[Any]]:
    """Singular value decomposition.

  Args:
    a: A matrix of shape `m x n`.
    full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
      respectively. If False, the shapes are `m x k` and `k x n`, respectively,
      where `k = min(m, n)`.
    compute_uv: Whether to compute also `u` and `v` in addition to `s`.
    hermitian: True if `a` is Hermitian.
    max_iterations: The predefined maximum number of iterations of QDWH.

  Returns:
    A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
    `s` is vector of length `k` containing the singular values in the
    non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh`
    depend on the value of `full_matrices`. For `compute_uv=False`,
    only `s` is returned.
  """
    m, n = a.shape

    is_flip = False
    if m < n:
        a = a.T.conj()
        m, n = a.shape
        is_flip = True

    reduce_to_square = False
    if full_matrices:
        q_full, a_full = lax.linalg.qr(a, full_matrices=True)
        q = q_full[:, :n]
        u_out_null = q_full[:, n:]
        a = a_full[:n, :]
        reduce_to_square = True
    else:
        # The constant `1.15` comes from Yuji Nakatsukasa's implementation
        # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav
        if m > 1.15 * n:
            q, a = lax.linalg.qr(a, full_matrices=False)
            reduce_to_square = True

    if not compute_uv:
        with jax.default_matmul_precision('float32'):
            return _svd_tall_and_square_input(a, hermitian, compute_uv,
                                              max_iterations)

    with jax.default_matmul_precision('float32'):
        u_out, s_out, v_out = _svd_tall_and_square_input(
            a, hermitian, compute_uv, max_iterations)
        if reduce_to_square:
            u_out = q @ u_out

    if full_matrices:
        u_out = jnp.hstack((u_out, u_out_null))

    if is_flip:
        return (v_out, s_out, u_out.T.conj())

    return (u_out, s_out, v_out.T.conj())