Ejemplo n.º 1
0
def svd(a,
        full_matrices: bool = True,
        compute_uv: bool = True,
        hermitian: bool = False):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if hermitian:
        w, v = lax_linalg.eigh(a)
        s = lax.abs(v)
        if compute_uv:
            sign = lax.sign(v)
            idxs = lax.broadcasted_iota(np.int64,
                                        s.shape,
                                        dimension=s.ndim - 1)
            s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
            s = lax.rev(s, dimensions=[s.ndim - 1])
            idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
            sign = lax.rev(sign, dimensions=[s.ndim - 1])
            u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
            vh = _H(u * sign[..., None, :])
            return u, s, vh
        else:
            return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1])

    return lax_linalg.svd(a,
                          full_matrices=full_matrices,
                          compute_uv=compute_uv)
Ejemplo n.º 2
0
def svd(a,
        full_matrices=True,
        compute_uv=True,
        overwrite_a=False,
        check_finite=True,
        lapack_driver='gesdd'):
    del overwrite_a, check_finite, lapack_driver
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    return lax_linalg.svd(a, full_matrices, compute_uv)
Ejemplo n.º 3
0
def _svd(a, *, full_matrices, compute_uv):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    return lax_linalg.svd(a,
                          full_matrices=full_matrices,
                          compute_uv=compute_uv)
Ejemplo n.º 4
0
def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None):
    r"""Computes the polar decomposition.

  Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
  decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that
  :math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or
  :math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`),
  where :math:`p` is positive semidefinite.  If :math:`a` is nonsingular,
  :math:`p` is positive definite and the
  decomposition is unique. :math:`u` has orthonormal columns unless
  :math:`n > m`, in which case it has orthonormal rows.

  Writing the SVD of :math:`a` as
  :math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we
  have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary
  factor :math:`u` can be constructed as the application of the sign function to
  the singular values of :math:`a`; or, if :math:`a` is Hermitian, the
  eigenvalues.

  Several methods exist to compute the polar decomposition. Currently two
  are supported:

  * ``method="svd"``:

    Computes the SVD of :math:`a` and then forms
    :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`.

  * ``method="qdwh"``:

    Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm.

  Args:
    a: The :math:`m \times n` input matrix.
    side: Determines whether a right or left polar decomposition is computed.
      If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"``
      then :math:`a = pu`. The default is ``"right"``.
    method: Determines the algorithm used, as described above.
    precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
    eps: The final result will satisfy
      :math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`,
      where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not
      ``"qdwh"``.
    max_iterations: Iterations will terminate after this many steps even if the
      above is unsatisfied.  Ignored if ``method`` is not ``"qdwh"``.

  Returns:
    A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor
    (:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor.
    ``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on
    whether ``side`` is ``"right"`` or ``"left"``, respectively.

  .. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999
  """
    a = jnp.asarray(a)
    if a.ndim != 2:
        raise ValueError("The input `a` must be a 2-D array.")

    if side not in ["right", "left"]:
        raise ValueError(
            "The argument `side` must be either 'right' or 'left'.")

    m, n = a.shape
    if method == "qdwh":
        # TODO(phawkins): return info also if the user opts in?
        if m >= n and side == "right":
            unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
        elif m < n and side == "left":
            a = a.T.conj()
            unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
            posdef = posdef.T.conj()
            unitary = unitary.T.conj()
        else:
            raise NotImplementedError(
                "method='qdwh' only supports mxn matrices "
                "where m < n where side='right' and m >= n "
                f"side='left', got {a.shape} with side={side}")
    elif method == "svd":
        u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False)
        s_svd = s_svd.astype(u_svd.dtype)
        unitary = u_svd @ vh_svd
        if side == "right":
            # a = u * p
            posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd
        else:
            # a = p * u
            posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj())
    else:
        raise ValueError(f"Unknown polar decomposition method {method}.")

    return unitary, posdef
Ejemplo n.º 5
0
def svd(a, full_matrices=True, compute_uv=True):
    a = _promote_arg_dtypes(jnp.asarray(a))
    return lax_linalg.svd(a, full_matrices, compute_uv)
Ejemplo n.º 6
0
def _svd(a, *, full_matrices, compute_uv):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    return lax_linalg.svd(a, full_matrices, compute_uv)