def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond): """Tests qdwh with upper triangular input of all ones.""" a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE) u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) a = (u * s) @ v is_hermitian = _check_symmetry(a) max_iterations = 10 actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian, max_iterations=max_iterations) expected_u, expected_h = osp_linalg.polar(a) # Sets the test tolerance. rtol = 1E6 * _QDWH_TEST_EPS with self.subTest('Test u.'): relative_diff_u = _compute_relative_diff(actual_u, expected_u) np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5) with self.subTest('Test h.'): relative_diff_h = _compute_relative_diff(actual_h, expected_h) np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) with self.subTest('Test u.dot(h).'): a_round_trip = _dot(actual_u, actual_h) relative_diff_a = _compute_relative_diff(a_round_trip, a) np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) with self.subTest('Test orthogonality.'): actual_results = _dot(actual_u.T, actual_u) expected_results = np.eye(n) self.assertAllClose( actual_results, expected_results, rtol=rtol, atol=1E-5)
def testQdwhWithOnRankDeficientInput(self, m, n, log_cond): """Tests qdwh with rank-deficient input.""" a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE) # Generates a rank-deficient input. u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1)) a = (u * s) @ v is_hermitian = _check_symmetry(a) max_iterations = 15 actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian, max_iterations=max_iterations) _, expected_h = osp_linalg.polar(a) # Sets the test tolerance. rtol = 1E4 * _QDWH_TEST_EPS # For rank-deficient matrix, `u` is not unique. with self.subTest('Test h.'): relative_diff_h = _compute_relative_diff(actual_h, expected_h) np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) with self.subTest('Test u.dot(h).'): a_round_trip = _dot(actual_u, actual_h) relative_diff_a = _compute_relative_diff(a_round_trip, a) np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) with self.subTest('Test orthogonality.'): actual_results = _dot(actual_u.T.conj(), actual_u) expected_results = np.eye(n) self.assertAllClose( actual_results, expected_results, rtol=rtol, atol=1E-6)
def lsp_linalg_fn(a): if padding is not None: pm, pn = padding a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan) u, h, _, _ = qdwh.qdwh( a, is_hermitian=is_hermitian, max_iterations=max_iterations, dynamic_shape=(m, n) if padding else None) if padding is not None: u = u[:m, :n] h = h[:n, :n] return u, h
def testQdwhUnconvergedAfterMaxNumberIterations(self, m, n, log_cond): """Tests unconvergence after maximum number of iterations.""" a = jnp.triu(jnp.ones((m, n))) u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) a = (u * s) @ v is_symmetric = _check_symmetry(a) max_iterations = 2 _, _, actual_num_iterations, is_converged = qdwh.qdwh( a, is_symmetric, max_iterations) with self.subTest('Number of iterations.'): self.assertEqual(max_iterations, actual_num_iterations) with self.subTest('Converged.'): self.assertFalse(is_converged)
def split_spectrum(H, n, split_point, V0=None): """ The Hermitian matrix `H` is split into two matrices `H_minus` `H_plus`, respectively sharing its eigenspaces beneath and above its `split_point`th eigenvalue. Returns, in addition, `V_minus` and `V_plus`, isometries such that `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are returned instead; this allows the overall isometries mapping from an initial input matrix to progressively smaller blocks to be formed. Args: H: The Hermitian matrix to split. split_point: The eigenvalue to split along. V0: Matrix of isometries to be updated. Returns: H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath `split_point`. V_minus: An isometry from the input space of `V0` to `H_minus`. H_plus: A Hermitian matrix sharing the eigenvalues of `H` above `split_point`. V_plus: An isometry from the input space of `V0` to `H_plus`. rank: The dynamic size of the m subblock. """ N, _ = H.shape H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype( H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n))) rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32) V_minus, V_plus = _projector_subspace(P, H, n, rank) H_minus = (V_minus.conj().T @ H) @ V_minus H_plus = (V_plus.conj().T @ H) @ V_plus if V0 is not None: V_minus = jnp.dot(V0, V_minus) V_plus = jnp.dot(V0, V_plus) return H_minus, V_minus, H_plus, V_plus, rank
def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None): r"""Computes the polar decomposition. Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that :math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or :math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`), where :math:`p` is positive semidefinite. If :math:`a` is nonsingular, :math:`p` is positive definite and the decomposition is unique. :math:`u` has orthonormal columns unless :math:`n > m`, in which case it has orthonormal rows. Writing the SVD of :math:`a` as :math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary factor :math:`u` can be constructed as the application of the sign function to the singular values of :math:`a`; or, if :math:`a` is Hermitian, the eigenvalues. Several methods exist to compute the polar decomposition. Currently two are supported: * ``method="svd"``: Computes the SVD of :math:`a` and then forms :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. * ``method="qdwh"``: Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm. Args: a: The :math:`m \times n` input matrix. side: Determines whether a right or left polar decomposition is computed. If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"`` then :math:`a = pu`. The default is ``"right"``. method: Determines the algorithm used, as described above. precision: :class:`~jax.lax.Precision` object specifying the matmul precision. eps: The final result will satisfy :math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`, where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not ``"qdwh"``. max_iterations: Iterations will terminate after this many steps even if the above is unsatisfied. Ignored if ``method`` is not ``"qdwh"``. Returns: A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor (:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor. ``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on whether ``side`` is ``"right"`` or ``"left"``, respectively. .. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999 """ a = jnp.asarray(a) if a.ndim != 2: raise ValueError("The input `a` must be a 2-D array.") if side not in ["right", "left"]: raise ValueError( "The argument `side` must be either 'right' or 'left'.") m, n = a.shape if method == "qdwh": # TODO(phawkins): return info also if the user opts in? if m >= n and side == "right": unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps) elif m < n and side == "left": a = a.T.conj() unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps) posdef = posdef.T.conj() unitary = unitary.T.conj() else: raise NotImplementedError( "method='qdwh' only supports mxn matrices " "where m < n where side='right' and m >= n " f"side='left', got {a.shape} with side={side}") elif method == "svd": u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False) s_svd = s_svd.astype(u_svd.dtype) unitary = u_svd @ vh_svd if side == "right": # a = u * p posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd else: # a = p * u posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj()) else: raise ValueError(f"Unknown polar decomposition method {method}.") return unitary, posdef
def lsp_linalg_fn(a): u, h, _, _ = qdwh.qdwh(a, is_symmetric=is_symmetric, max_iterations=max_iterations) return u, h
def lsp_linalg_fn(a): u, h, _, _ = qdwh.qdwh( a, is_hermitian=is_hermitian, max_iterations=max_iterations) return u, h