コード例 #1
0
ファイル: qdwh_test.py プロジェクト: xueeinstein/jax
  def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
    """Tests qdwh with upper triangular input of all ones."""
    a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE)
    u, s, v = jnp.linalg.svd(a, full_matrices=False)
    cond = 10**log_cond
    s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
    a = (u * s) @ v
    is_hermitian = _check_symmetry(a)
    max_iterations = 10

    actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
                                         max_iterations=max_iterations)
    expected_u, expected_h = osp_linalg.polar(a)

    # Sets the test tolerance.
    rtol = 1E6 * _QDWH_TEST_EPS

    with self.subTest('Test u.'):
      relative_diff_u = _compute_relative_diff(actual_u, expected_u)
      np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5)

    with self.subTest('Test h.'):
      relative_diff_h = _compute_relative_diff(actual_h, expected_h)
      np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)

    with self.subTest('Test u.dot(h).'):
      a_round_trip = _dot(actual_u, actual_h)
      relative_diff_a = _compute_relative_diff(a_round_trip, a)
      np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)

    with self.subTest('Test orthogonality.'):
      actual_results = _dot(actual_u.T, actual_u)
      expected_results = np.eye(n)
      self.assertAllClose(
          actual_results, expected_results, rtol=rtol, atol=1E-5)
コード例 #2
0
ファイル: qdwh_test.py プロジェクト: xueeinstein/jax
  def testQdwhWithOnRankDeficientInput(self, m, n, log_cond):
    """Tests qdwh with rank-deficient input."""
    a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE)

    # Generates a rank-deficient input.
    u, s, v = jnp.linalg.svd(a, full_matrices=False)
    cond = 10**log_cond
    s = jnp.linspace(cond, 1, min(m, n))
    s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1))
    a = (u * s) @ v

    is_hermitian = _check_symmetry(a)
    max_iterations = 15
    actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
                                         max_iterations=max_iterations)
    _, expected_h = osp_linalg.polar(a)

    # Sets the test tolerance.
    rtol = 1E4 * _QDWH_TEST_EPS

    # For rank-deficient matrix, `u` is not unique.
    with self.subTest('Test h.'):
      relative_diff_h = _compute_relative_diff(actual_h, expected_h)
      np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)

    with self.subTest('Test u.dot(h).'):
      a_round_trip = _dot(actual_u, actual_h)
      relative_diff_a = _compute_relative_diff(a_round_trip, a)
      np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)

    with self.subTest('Test orthogonality.'):
      actual_results = _dot(actual_u.T.conj(), actual_u)
      expected_results = np.eye(n)
      self.assertAllClose(
          actual_results, expected_results, rtol=rtol, atol=1E-6)
コード例 #3
0
ファイル: qdwh_test.py プロジェクト: xueeinstein/jax
 def lsp_linalg_fn(a):
   if padding is not None:
     pm, pn = padding
     a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan)
   u, h, _, _ = qdwh.qdwh(
       a, is_hermitian=is_hermitian, max_iterations=max_iterations,
       dynamic_shape=(m, n) if padding else None)
   if padding is not None:
     u = u[:m, :n]
     h = h[:n, :n]
   return u, h
コード例 #4
0
    def testQdwhUnconvergedAfterMaxNumberIterations(self, m, n, log_cond):
        """Tests unconvergence after maximum number of iterations."""
        a = jnp.triu(jnp.ones((m, n)))
        u, s, v = jnp.linalg.svd(a, full_matrices=False)
        cond = 10**log_cond
        s = jnp.linspace(cond, 1, min(m, n))
        a = (u * s) @ v
        is_symmetric = _check_symmetry(a)
        max_iterations = 2

        _, _, actual_num_iterations, is_converged = qdwh.qdwh(
            a, is_symmetric, max_iterations)

        with self.subTest('Number of iterations.'):
            self.assertEqual(max_iterations, actual_num_iterations)

        with self.subTest('Converged.'):
            self.assertFalse(is_converged)
コード例 #5
0
def split_spectrum(H, n, split_point, V0=None):
    """ The Hermitian matrix `H` is split into two matrices `H_minus`
  `H_plus`, respectively sharing its eigenspaces beneath and above
  its `split_point`th eigenvalue.

  Returns, in addition, `V_minus` and `V_plus`, isometries such that
  `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
  returned instead; this allows the overall isometries mapping from
  an initial input matrix to progressively smaller blocks to be formed.

  Args:
    H: The Hermitian matrix to split.
    split_point: The eigenvalue to split along.
    V0: Matrix of isometries to be updated.
  Returns:
    H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath
      `split_point`.
    V_minus: An isometry from the input space of `V0` to `H_minus`.
    H_plus: A Hermitian matrix sharing the eigenvalues of `H` above
      `split_point`.
    V_plus: An isometry from the input space of `V0` to `H_plus`.
    rank: The dynamic size of the m subblock.
  """
    N, _ = H.shape
    H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(
        H.dtype)
    U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
    P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n)))
    rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32)

    V_minus, V_plus = _projector_subspace(P, H, n, rank)
    H_minus = (V_minus.conj().T @ H) @ V_minus
    H_plus = (V_plus.conj().T @ H) @ V_plus
    if V0 is not None:
        V_minus = jnp.dot(V0, V_minus)
        V_plus = jnp.dot(V0, V_plus)
    return H_minus, V_minus, H_plus, V_plus, rank
コード例 #6
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None):
    r"""Computes the polar decomposition.

  Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
  decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that
  :math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or
  :math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`),
  where :math:`p` is positive semidefinite.  If :math:`a` is nonsingular,
  :math:`p` is positive definite and the
  decomposition is unique. :math:`u` has orthonormal columns unless
  :math:`n > m`, in which case it has orthonormal rows.

  Writing the SVD of :math:`a` as
  :math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we
  have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary
  factor :math:`u` can be constructed as the application of the sign function to
  the singular values of :math:`a`; or, if :math:`a` is Hermitian, the
  eigenvalues.

  Several methods exist to compute the polar decomposition. Currently two
  are supported:

  * ``method="svd"``:

    Computes the SVD of :math:`a` and then forms
    :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`.

  * ``method="qdwh"``:

    Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm.

  Args:
    a: The :math:`m \times n` input matrix.
    side: Determines whether a right or left polar decomposition is computed.
      If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"``
      then :math:`a = pu`. The default is ``"right"``.
    method: Determines the algorithm used, as described above.
    precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
    eps: The final result will satisfy
      :math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`,
      where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not
      ``"qdwh"``.
    max_iterations: Iterations will terminate after this many steps even if the
      above is unsatisfied.  Ignored if ``method`` is not ``"qdwh"``.

  Returns:
    A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor
    (:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor.
    ``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on
    whether ``side`` is ``"right"`` or ``"left"``, respectively.

  .. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999
  """
    a = jnp.asarray(a)
    if a.ndim != 2:
        raise ValueError("The input `a` must be a 2-D array.")

    if side not in ["right", "left"]:
        raise ValueError(
            "The argument `side` must be either 'right' or 'left'.")

    m, n = a.shape
    if method == "qdwh":
        # TODO(phawkins): return info also if the user opts in?
        if m >= n and side == "right":
            unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
        elif m < n and side == "left":
            a = a.T.conj()
            unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
            posdef = posdef.T.conj()
            unitary = unitary.T.conj()
        else:
            raise NotImplementedError(
                "method='qdwh' only supports mxn matrices "
                "where m < n where side='right' and m >= n "
                f"side='left', got {a.shape} with side={side}")
    elif method == "svd":
        u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False)
        s_svd = s_svd.astype(u_svd.dtype)
        unitary = u_svd @ vh_svd
        if side == "right":
            # a = u * p
            posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd
        else:
            # a = p * u
            posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj())
    else:
        raise ValueError(f"Unknown polar decomposition method {method}.")

    return unitary, posdef
コード例 #7
0
 def lsp_linalg_fn(a):
     u, h, _, _ = qdwh.qdwh(a,
                            is_symmetric=is_symmetric,
                            max_iterations=max_iterations)
     return u, h
コード例 #8
0
ファイル: qdwh_test.py プロジェクト: xueeinstein/jax
 def lsp_linalg_fn(a):
   u, h, _, _ = qdwh.qdwh(
       a, is_hermitian=is_hermitian, max_iterations=max_iterations)
   return u, h