Exemplo n.º 1
0
def test_expected_shapes():
    val = torch.zeros((2, 3, 4, 5))
    u, s, vh, _ = decompositions.svd_decomposition(torch, val, 2)
    assert u.shape == (2, 3, 6)
    assert s.shape == (6, )
    np.testing.assert_allclose(s, np.zeros(6))
    assert vh.shape == (6, 4, 5)
Exemplo n.º 2
0
def test_max_truncation_error_relative():
    absolute = np.diag([2.0, 1.0, 0.2, 0.1])
    relative = np.diag([2.0, 1.0, 0.2, 0.1])
    max_truncation_err = 0.2
    _, _, _, trunc_sv_absolute = decompositions.svd_decomposition(
        torch,
        torch.Tensor(absolute),
        1,
        max_truncation_error=max_truncation_err,
        relative=False)
    _, _, _, trunc_sv_relative = decompositions.svd_decomposition(
        torch,
        torch.Tensor(relative),
        1,
        max_truncation_error=max_truncation_err,
        relative=True)
    np.testing.assert_almost_equal(trunc_sv_absolute, [0.1])
    np.testing.assert_almost_equal(trunc_sv_relative, [0.2, 0.1])
Exemplo n.º 3
0
 def svd_decomposition(self,
                       tensor: Tensor,
                       split_axis: int,
                       max_singular_values: Optional[int] = None,
                       max_truncation_error: Optional[float] = None
                       ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
   return decompositions.svd_decomposition(self.torch, tensor, split_axis,
                                           max_singular_values,
                                           max_truncation_error)
Exemplo n.º 4
0
def test_max_truncation_error():
    np.random.seed(2019)
    random_matrix = np.random.rand(10, 10)
    unitary1, _, unitary2 = np.linalg.svd(random_matrix)
    singular_values = np.array(range(10))
    val = unitary1.dot(np.diag(singular_values).dot(unitary2.T))
    u, s, vh, trun = decompositions.svd_decomposition(
        torch, torch.Tensor(val), 1, max_truncation_error=math.sqrt(5.1))
    assert u.shape == (10, 7)
    assert s.shape == (7, )
    np.testing.assert_array_almost_equal(s, np.arange(9, 2, -1), decimal=5)
    assert vh.shape == (7, 10)
    np.testing.assert_array_almost_equal(trun, np.arange(2, -1, -1))
Exemplo n.º 5
0
 def svd_decomposition(
     self,
     tensor: Tensor,
     split_axis: int,
     max_singular_values: Optional[int] = None,
     max_truncation_error: Optional[float] = None,
     relative: Optional[bool] = False
 ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
     return decompositions.svd_decomposition(torchlib,
                                             tensor,
                                             split_axis,
                                             max_singular_values,
                                             max_truncation_error,
                                             relative=relative)