def test_max_truncation_error_relative(self):
   absolute = np.diag([2.0, 1.0, 0.2, 0.1])
   relative = np.diag([2.0, 1.0, 0.2, 0.1])
   max_truncation_err = 0.2
   _, _, _, trunc_sv_absolute = decompositions.svd(
       tf,
       absolute,
       1,
       max_truncation_error=max_truncation_err,
       relative=False)
   _, _, _, trunc_sv_relative = decompositions.svd(
       tf, relative, 1, max_truncation_error=max_truncation_err, relative=True)
   np.testing.assert_almost_equal(trunc_sv_absolute, [0.1])
   np.testing.assert_almost_equal(trunc_sv_relative, [0.2, 0.1])
 def test_expected_shapes(self):
   val = tf.zeros((2, 3, 4, 5))
   u, s, vh, _ = decompositions.svd(tf, val, 2)
   self.assertEqual(u.shape, (2, 3, 6))
   self.assertEqual(s.shape, (6,))
   self.assertAllClose(s, np.zeros(6))
   self.assertEqual(vh.shape, (6, 4, 5))
 def test_max_truncation_error(self):
   random_matrix = np.random.rand(10, 10)
   unitary1, _, unitary2 = np.linalg.svd(random_matrix)
   singular_values = np.array(range(10))
   val = unitary1.dot(np.diag(singular_values).dot(unitary2.T))
   u, s, vh, trun = decompositions.svd(
       tf, val, 1, max_truncation_error=math.sqrt(5.1))
   self.assertEqual(u.shape, (10, 7))
   self.assertEqual(s.shape, (7,))
   self.assertAllClose(s, np.arange(9, 2, -1))
   self.assertEqual(vh.shape, (7, 10))
   self.assertAllClose(trun, np.arange(2, -1, -1))
Ejemplo n.º 4
0
 def svd(
     self,
     tensor: Tensor,
     pivot_axis: int = -1,
     max_singular_values: Optional[int] = None,
     max_truncation_error: Optional[float] = None,
     relative: Optional[bool] = False
 ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
     return decompositions.svd(tf,
                               tensor,
                               pivot_axis,
                               max_singular_values,
                               max_truncation_error,
                               relative=relative)