Esempio n. 1
0
    def testSvdWithRectangularInput(self, m, n, log_cond, full_matrices):
        """Tests SVD with rectangular input."""
        with jax.default_matmul_precision('float32'):
            a = np.random.uniform(low=0.3, high=0.9,
                                  size=(m, n)).astype(_SVD_TEST_DTYPE)
            u, s, v = osp_linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = jnp.linspace(cond, 1, min(m, n))
            a = (u * s) @ v
            a = a.astype(complex) * (1 + 1j)

            osp_linalg_fn = functools.partial(osp_linalg.svd,
                                              full_matrices=full_matrices)
            actual_u, actual_s, actual_v = svd.svd(a,
                                                   full_matrices=full_matrices)

            k = min(m, n)
            if m > n:
                unitary_u = jnp.real(actual_u.T.conj() @ actual_u)
                unitary_v = jnp.real(actual_v.T.conj() @ actual_v)
                unitary_u_size = m if full_matrices else k
                unitary_v_size = k
            else:
                unitary_u = jnp.real(actual_u @ actual_u.T.conj())
                unitary_v = jnp.real(actual_v @ actual_v.T.conj())
                unitary_u_size = k
                unitary_v_size = n if full_matrices else k

            _, expected_s, _ = osp_linalg_fn(a)

            svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices)
            args_maker = lambda: [a]

            with self.subTest('Test JIT compatibility'):
                self._CompileAndCheck(svd_fn, args_maker)

            with self.subTest('Test unitary u.'):
                self.assertAllClose(np.eye(unitary_u_size),
                                    unitary_u,
                                    rtol=_SVD_RTOL,
                                    atol=2E-3)

            with self.subTest('Test unitary v.'):
                self.assertAllClose(np.eye(unitary_v_size),
                                    unitary_v,
                                    rtol=_SVD_RTOL,
                                    atol=2E-3)

            with self.subTest('Test s.'):
                self.assertAllClose(expected_s,
                                    jnp.real(actual_s),
                                    rtol=_SVD_RTOL,
                                    atol=1E-6)
Esempio n. 2
0
    def testSvdWithSkinnyTallInput(self, m, n):
        """Tests SVD with skinny and tall input."""
        # Generates a skinny and tall input
        with jax.default_matmul_precision('float32'):
            np.random.seed(1235)
            a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE)
            u, s, v = svd.svd(a, full_matrices=False, hermitian=False)

            relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a)

            np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6)
Esempio n. 3
0
    def testSingularValues(self, m, n, log_cond, full_matrices):
        """Tests singular values."""
        with jax.default_matmul_precision('float32'):
            a = np.random.uniform(low=0.3, high=0.9,
                                  size=(m, n)).astype(_SVD_TEST_DTYPE)
            u, s, v = osp_linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = np.linspace(cond, 1, min(m, n))
            a = (u * s) @ v
            a = a + 1j * a

            # Only computes singular values.
            compute_uv = False

            osp_linalg_fn = functools.partial(osp_linalg.svd,
                                              full_matrices=full_matrices,
                                              compute_uv=compute_uv)
            actual_s = svd.svd(a,
                               full_matrices=full_matrices,
                               compute_uv=compute_uv)

            expected_s = osp_linalg_fn(a)

            svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices)
            args_maker = lambda: [a]

            with self.subTest('Test JIT compatibility'):
                self._CompileAndCheck(svd_fn, args_maker)

            with self.subTest('Test s.'):
                self.assertAllClose(expected_s,
                                    actual_s,
                                    rtol=_SVD_RTOL,
                                    atol=1E-6)

            with self.subTest('Test non-increasing order.'):
                # Computes `actual_diff[i] = s[i+1] - s[i]`.
                actual_diff = jnp.diff(actual_s, append=0)
                np.testing.assert_array_less(actual_diff,
                                             np.zeros_like(actual_diff))
Esempio n. 4
0
    def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
        """Tests SVD with rank-deficient input."""
        with jax.default_matmul_precision('float32'):
            a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE)

            # Generates a rank-deficient input.
            u, s, v = jnp.linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = jnp.linspace(cond, 1, m)
            s = s.at[r:m].set(jnp.zeros((m - r, )))
            a = (u * s) @ v

            with jax.default_matmul_precision('float32'):
                u, s, v = svd.svd(a, full_matrices=False, hermitian=False)
            diff = np.linalg.norm(a - (u * s) @ v)

            np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
Esempio n. 5
0
 def lax_fun(a):
     return svd.svd(a,
                    full_matrices=False,
                    compute_uv=False,
                    hermitian=False)