Пример #1
0
def test_expected_shapes():
    val = torch.zeros((2, 3, 4, 5))
    u, s, vh, _ = decompositions.svd(torch, val, 2)
    assert u.shape == (2, 3, 6)
    assert s.shape == (6, )
    np.testing.assert_allclose(s, np.zeros(6))
    assert vh.shape == (6, 4, 5)
Пример #2
0
def test_max_truncation_error_relative():
    absolute = np.diag([2.0, 1.0, 0.2, 0.1])
    relative = np.diag([2.0, 1.0, 0.2, 0.1])
    max_truncation_err = 0.2
    _, _, _, trunc_sv_absolute = decompositions.svd(
        torch,
        torch.Tensor(absolute),
        1,
        max_truncation_error=max_truncation_err,
        relative=False)
    _, _, _, trunc_sv_relative = decompositions.svd(
        torch,
        torch.Tensor(relative),
        1,
        max_truncation_error=max_truncation_err,
        relative=True)
    np.testing.assert_almost_equal(trunc_sv_absolute, [0.1])
    np.testing.assert_almost_equal(trunc_sv_relative, [0.2, 0.1])
Пример #3
0
 def svd(
     self,
     tensor: Tensor,
     pivot_axis: int = -1,
     max_singular_values: Optional[int] = None,
     max_truncation_error: Optional[float] = None,
     relative: Optional[bool] = False
 ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
     return decompositions.svd(torchlib,
                               tensor,
                               pivot_axis,
                               max_singular_values,
                               max_truncation_error,
                               relative=relative)
Пример #4
0
def test_max_truncation_error():
    np.random.seed(2019)
    random_matrix = np.random.rand(10, 10)
    unitary1, _, unitary2 = np.linalg.svd(random_matrix)
    singular_values = np.array(range(10))
    val = unitary1.dot(np.diag(singular_values).dot(unitary2.T))
    u, s, vh, trun = decompositions.svd(torch,
                                        torch.Tensor(val),
                                        1,
                                        max_truncation_error=math.sqrt(5.1))
    assert u.shape == (10, 7)
    assert s.shape == (7, )
    np.testing.assert_array_almost_equal(s, np.arange(9, 2, -1), decimal=5)
    assert vh.shape == (7, 10)
    np.testing.assert_array_almost_equal(trun, np.arange(2, -1, -1))