Пример #1
0
def test_ops():

    x, y, z, w = tn.meshgrid([32] * 4)
    t = x + y + z + w + 1
    assert tn.relative_error(1 / t.torch(), 1 / t) < 1e-4
    assert tn.relative_error(torch.cos(t.torch()), tn.cos(t)) < 1e-4
    assert tn.relative_error(torch.exp(t.torch()), tn.exp(t)) < 1e-4
Пример #2
0
def check(t1, t2):
    x1 = t1.torch()
    x2 = t2.torch()
    assert tn.relative_error(t1 + t2, x1 + x2) <= 1e-7
    assert tn.relative_error(t1 - t2, x1 - x2) <= 1e-7
    assert tn.relative_error(t1 * t2, x1 * x2) <= 1e-7
    assert tn.relative_error(-t1 + t2, -x1 + x2) <= 1e-7
Пример #3
0
def test_orthogonalization():

    for i in range(100):
        gt = tn.rand(np.random.randint(1, 8, np.random.randint(2, 6)))
        t = gt.clone()
        assert tn.relative_error(gt, t) <= 1e-7
        t.left_orthogonalize(0)
        assert tn.relative_error(gt, t) <= 1e-7
        t.right_orthogonalize(t.dim() - 1)
        assert tn.relative_error(gt, t) <= 1e-7
        t.orthogonalize(np.random.randint(t.dim()))
        assert tn.relative_error(gt, t) <= 1e-7
Пример #4
0
 def store_metrics(self, full):
     t = self.tn_weight
     full = full.view(t.torch().size())
     self.compression = (full.numel(), t.numel(), full.numel() / t.numel())
     self.relative_error = tn.relative_error(full, t)
     self.rmse = tn.rmse(full, t)
     self.r_squared = tn.r_squared(full, t)
Пример #5
0
def test_round_tucker():
    for i in range(100):
        eps = np.random.rand()**2
        gt = tn.rand([32] * 4, ranks_tt=8, ranks_tucker=8)
        t = gt.clone()
        t.round_tucker(eps=eps)
        assert tn.relative_error(gt, t) <= eps
Пример #6
0
def test_tensors():

    for i in range(100):
        t = random_format([10] * 6)
        t2 = tn.cross(function=lambda x: x,
                      tensors=t,
                      ranks_tt=15,
                      verbose=False)
        assert tn.relative_error(t, t2) < 1e-6
Пример #7
0
def test_tensors():

    for i in range(100):
        t = random_format([10] * 6)
        t2 = tn.cross(function=lambda x: x,
                      tensors=t,
                      ranks_tt=15,
                      verbose=False)
        assert tn.relative_error(t, t2) < 1e-6

    t = tn.rand([10] * 6, ranks_tt=10)
    _, info = tn.cross(function=lambda x: x,
                       tensors=[t],
                       ranks_tt=15,
                       verbose=False,
                       return_info=True)
    t2 = tn.cross_forward(info, function=lambda x: x, tensors=t)
    assert tn.relative_error(t, t2) < 1e-6
Пример #8
0
def test_round_tt_eig():

    for i in range(100):
        gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)),
                     ranks_tt=np.random.randint(1, 10))
        gt.round_tt(1e-8, algorithm='eig')
        t = gt + gt
        t.round_tt(1e-8, algorithm='eig')
        assert tn.relative_error(gt, t / 2) <= 1e-7
Пример #9
0
def test_domain():
    def function(Xs):
        return 1. / torch.sum(Xs, dim=1)

    domain = [torch.linspace(1, 10, 10) for n in range(3)]
    t = tn.cross(function=function,
                 domain=domain,
                 ranks_tt=3,
                 function_arg='matrix')
    gt = torch.meshgrid(domain)
    gt = 1. / sum(gt)

    assert tn.relative_error(gt, t) < 5e-2
Пример #10
0
    def round(self, eps=1e-14, **kwargs):
        """
        General recompression. Attempts to reduce TT ranks first; then does Tucker rounding with the remaining error
        budget.

        :param eps: this relative error will not be exceeded
        :param kwargs: passed to `round_tt()` and `round_tucker()`
        """

        copy = self.clone()
        self.round_tt(eps, **kwargs)
        reached = tn.relative_error(copy, self)
        if reached < eps:
            self.round_tucker((1+eps) / (1+reached) - 1, **kwargs)
Пример #11
0
def test_als_completion():
    I = 8
    train_x = torch.arange(I)[:, None].repeat(1, 2)
    train_y = torch.ones(I)
    t = tn.als_completion(train_x, train_y, ranks_tt=3)
    assert tn.relative_error(train_y, t[train_x]) < 1e-5
Пример #12
0
 def check():
     x1 = t1.torch()
     x2 = t2.torch()
     gt = torch.dot(x1.flatten(), x2.flatten())
     assert tn.relative_error(tn.dot(t1, t2), gt) <= 1e-7
Пример #13
0
 def check():
     x = t.torch()
     assert tn.relative_error(tn.mean(t), torch.mean(x)) <= 1e-3
     assert tn.relative_error(tn.var(t), torch.var(x)) <= 1e-3
     assert tn.relative_error(tn.norm(t), torch.norm(x)) <= 1e-3