def test_expected_shapes(): val = torch.zeros((2, 3, 4, 5)) u, s, vh, _ = decompositions.svd(torch, val, 2) assert u.shape == (2, 3, 6) assert s.shape == (6, ) np.testing.assert_allclose(s, np.zeros(6)) assert vh.shape == (6, 4, 5)
def test_max_truncation_error_relative(): absolute = np.diag([2.0, 1.0, 0.2, 0.1]) relative = np.diag([2.0, 1.0, 0.2, 0.1]) max_truncation_err = 0.2 _, _, _, trunc_sv_absolute = decompositions.svd( torch, torch.Tensor(absolute), 1, max_truncation_error=max_truncation_err, relative=False) _, _, _, trunc_sv_relative = decompositions.svd( torch, torch.Tensor(relative), 1, max_truncation_error=max_truncation_err, relative=True) np.testing.assert_almost_equal(trunc_sv_absolute, [0.1]) np.testing.assert_almost_equal(trunc_sv_relative, [0.2, 0.1])
def svd( self, tensor: Tensor, pivot_axis: int = -1, max_singular_values: Optional[int] = None, max_truncation_error: Optional[float] = None, relative: Optional[bool] = False ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return decompositions.svd(torchlib, tensor, pivot_axis, max_singular_values, max_truncation_error, relative=relative)
def test_max_truncation_error(): np.random.seed(2019) random_matrix = np.random.rand(10, 10) unitary1, _, unitary2 = np.linalg.svd(random_matrix) singular_values = np.array(range(10)) val = unitary1.dot(np.diag(singular_values).dot(unitary2.T)) u, s, vh, trun = decompositions.svd(torch, torch.Tensor(val), 1, max_truncation_error=math.sqrt(5.1)) assert u.shape == (10, 7) assert s.shape == (7, ) np.testing.assert_array_almost_equal(s, np.arange(9, 2, -1), decimal=5) assert vh.shape == (7, 10) np.testing.assert_array_almost_equal(trun, np.arange(2, -1, -1))