def testSvdWithRectangularInput(self, m, n, log_cond, full_matrices): """Tests SVD with rectangular input.""" with jax.default_matmul_precision('float32'): a = np.random.uniform(low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE) u, s, v = osp_linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) a = (u * s) @ v a = a.astype(complex) * (1 + 1j) osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=full_matrices) actual_u, actual_s, actual_v = svd.svd(a, full_matrices=full_matrices) k = min(m, n) if m > n: unitary_u = jnp.real(actual_u.T.conj() @ actual_u) unitary_v = jnp.real(actual_v.T.conj() @ actual_v) unitary_u_size = m if full_matrices else k unitary_v_size = k else: unitary_u = jnp.real(actual_u @ actual_u.T.conj()) unitary_v = jnp.real(actual_v @ actual_v.T.conj()) unitary_u_size = k unitary_v_size = n if full_matrices else k _, expected_s, _ = osp_linalg_fn(a) svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices) args_maker = lambda: [a] with self.subTest('Test JIT compatibility'): self._CompileAndCheck(svd_fn, args_maker) with self.subTest('Test unitary u.'): self.assertAllClose(np.eye(unitary_u_size), unitary_u, rtol=_SVD_RTOL, atol=2E-3) with self.subTest('Test unitary v.'): self.assertAllClose(np.eye(unitary_v_size), unitary_v, rtol=_SVD_RTOL, atol=2E-3) with self.subTest('Test s.'): self.assertAllClose(expected_s, jnp.real(actual_s), rtol=_SVD_RTOL, atol=1E-6)
def testSvdWithSkinnyTallInput(self, m, n): """Tests SVD with skinny and tall input.""" # Generates a skinny and tall input with jax.default_matmul_precision('float32'): np.random.seed(1235) a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE) u, s, v = svd.svd(a, full_matrices=False, hermitian=False) relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a) np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6)
def testSingularValues(self, m, n, log_cond, full_matrices): """Tests singular values.""" with jax.default_matmul_precision('float32'): a = np.random.uniform(low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE) u, s, v = osp_linalg.svd(a, full_matrices=False) cond = 10**log_cond s = np.linspace(cond, 1, min(m, n)) a = (u * s) @ v a = a + 1j * a # Only computes singular values. compute_uv = False osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv) actual_s = svd.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) expected_s = osp_linalg_fn(a) svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices) args_maker = lambda: [a] with self.subTest('Test JIT compatibility'): self._CompileAndCheck(svd_fn, args_maker) with self.subTest('Test s.'): self.assertAllClose(expected_s, actual_s, rtol=_SVD_RTOL, atol=1E-6) with self.subTest('Test non-increasing order.'): # Computes `actual_diff[i] = s[i+1] - s[i]`. actual_diff = jnp.diff(actual_s, append=0) np.testing.assert_array_less(actual_diff, np.zeros_like(actual_diff))
def testSvdWithOnRankDeficientInput(self, m, r, log_cond): """Tests SVD with rank-deficient input.""" with jax.default_matmul_precision('float32'): a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE) # Generates a rank-deficient input. u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, m) s = s.at[r:m].set(jnp.zeros((m - r, ))) a = (u * s) @ v with jax.default_matmul_precision('float32'): u, s, v = svd.svd(a, full_matrices=False, hermitian=False) diff = np.linalg.norm(a - (u * s) @ v) np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
def lax_fun(a): return svd.svd(a, full_matrices=False, compute_uv=False, hermitian=False)