def test_ops(): x, y, z, w = tn.meshgrid([32] * 4) t = x + y + z + w + 1 assert tn.relative_error(1 / t.torch(), 1 / t) < 1e-4 assert tn.relative_error(torch.cos(t.torch()), tn.cos(t)) < 1e-4 assert tn.relative_error(torch.exp(t.torch()), tn.exp(t)) < 1e-4
def check(t1, t2): x1 = t1.torch() x2 = t2.torch() assert tn.relative_error(t1 + t2, x1 + x2) <= 1e-7 assert tn.relative_error(t1 - t2, x1 - x2) <= 1e-7 assert tn.relative_error(t1 * t2, x1 * x2) <= 1e-7 assert tn.relative_error(-t1 + t2, -x1 + x2) <= 1e-7
def test_orthogonalization(): for i in range(100): gt = tn.rand(np.random.randint(1, 8, np.random.randint(2, 6))) t = gt.clone() assert tn.relative_error(gt, t) <= 1e-7 t.left_orthogonalize(0) assert tn.relative_error(gt, t) <= 1e-7 t.right_orthogonalize(t.dim() - 1) assert tn.relative_error(gt, t) <= 1e-7 t.orthogonalize(np.random.randint(t.dim())) assert tn.relative_error(gt, t) <= 1e-7
def store_metrics(self, full): t = self.tn_weight full = full.view(t.torch().size()) self.compression = (full.numel(), t.numel(), full.numel() / t.numel()) self.relative_error = tn.relative_error(full, t) self.rmse = tn.rmse(full, t) self.r_squared = tn.r_squared(full, t)
def test_round_tucker(): for i in range(100): eps = np.random.rand()**2 gt = tn.rand([32] * 4, ranks_tt=8, ranks_tucker=8) t = gt.clone() t.round_tucker(eps=eps) assert tn.relative_error(gt, t) <= eps
def test_tensors(): for i in range(100): t = random_format([10] * 6) t2 = tn.cross(function=lambda x: x, tensors=t, ranks_tt=15, verbose=False) assert tn.relative_error(t, t2) < 1e-6
def test_tensors(): for i in range(100): t = random_format([10] * 6) t2 = tn.cross(function=lambda x: x, tensors=t, ranks_tt=15, verbose=False) assert tn.relative_error(t, t2) < 1e-6 t = tn.rand([10] * 6, ranks_tt=10) _, info = tn.cross(function=lambda x: x, tensors=[t], ranks_tt=15, verbose=False, return_info=True) t2 = tn.cross_forward(info, function=lambda x: x, tensors=t) assert tn.relative_error(t, t2) < 1e-6
def test_round_tt_eig(): for i in range(100): gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)), ranks_tt=np.random.randint(1, 10)) gt.round_tt(1e-8, algorithm='eig') t = gt + gt t.round_tt(1e-8, algorithm='eig') assert tn.relative_error(gt, t / 2) <= 1e-7
def test_domain(): def function(Xs): return 1. / torch.sum(Xs, dim=1) domain = [torch.linspace(1, 10, 10) for n in range(3)] t = tn.cross(function=function, domain=domain, ranks_tt=3, function_arg='matrix') gt = torch.meshgrid(domain) gt = 1. / sum(gt) assert tn.relative_error(gt, t) < 5e-2
def round(self, eps=1e-14, **kwargs): """ General recompression. Attempts to reduce TT ranks first; then does Tucker rounding with the remaining error budget. :param eps: this relative error will not be exceeded :param kwargs: passed to `round_tt()` and `round_tucker()` """ copy = self.clone() self.round_tt(eps, **kwargs) reached = tn.relative_error(copy, self) if reached < eps: self.round_tucker((1+eps) / (1+reached) - 1, **kwargs)
def test_als_completion(): I = 8 train_x = torch.arange(I)[:, None].repeat(1, 2) train_y = torch.ones(I) t = tn.als_completion(train_x, train_y, ranks_tt=3) assert tn.relative_error(train_y, t[train_x]) < 1e-5
def check(): x1 = t1.torch() x2 = t2.torch() gt = torch.dot(x1.flatten(), x2.flatten()) assert tn.relative_error(tn.dot(t1, t2), gt) <= 1e-7
def check(): x = t.torch() assert tn.relative_error(tn.mean(t), torch.mean(x)) <= 1e-3 assert tn.relative_error(tn.var(t), torch.var(x)) <= 1e-3 assert tn.relative_error(tn.norm(t), torch.norm(x)) <= 1e-3