示例#1
0
def test_ops():

    for i in range(100):
        t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)),
                     ranks_tt=3,
                     ranks_tucker=2)
        t2 = tn.rand(t1.shape)
        check(t1, t2)

    shape = [8] * 4

    t1 = tn.rand(shape,
                 ranks_tt=[3, None, None],
                 ranks_cp=[None, None, 2, 2],
                 ranks_tucker=5)
    t2 = tn.rand(shape, ranks_tt=[None, 2, None], ranks_cp=[4, None, None, 3])
    check(t1, t2)

    t2 = t1 * 2
    check(t1, t2)

    for i in range(100):
        t1 = random_format(shape)
        t2 = random_format(shape)
        check(t1, t2)
示例#2
0
def test_dot():

    def check():
        x1 = t1.torch()
        x2 = t2.torch()
        gt = torch.dot(x1.flatten(), x2.flatten())
        assert tn.relative_error(tn.dot(t1, t2), gt) <= 1e-7

    t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=None)
    t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=None)
    check()

    t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=4)
    t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=None)
    check()

    t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=None)
    t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=4)
    check()

    t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=3)
    t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=4)
    check()

    t1 = tn.rand([32] * 4, ranks_tt=[3, None, None], ranks_cp=[None, None, 10, 10], ranks_tucker=5)
    t2 = tn.rand([32]*4, ranks_tt=[None, 2, None], ranks_cp=[4, None, None, 5])
    check()

    shape = [8]*4
    for i in range(100):
        t1 = random_format(shape)
        t2 = random_format(shape)
        check()
示例#3
0
def test_round_tucker():
    a = tn.rand((10, 5, 6), ranks_tucker=3)
    b = a.clone()
    a.round_tucker(eps=1e-8)
    assert torch.norm(b.torch() - a.torch()) < 1e-8

    a = tn.rand((10, 5, 6), ranks_tucker=3, batch=True)
    b = a.clone()
    a.round_tucker(eps=1e-8)
    assert torch.norm(b.torch() - a.torch()) < 1e-8
示例#4
0
def test_cat():

    for i in range(100):
        N = np.random.randint(1, 4)
        shape1 = np.random.randint(1, 10, N)
        mode = np.random.randint(N)
        shape2 = shape1.copy()
        shape2[mode] = np.random.randint(1, 10)
        t1 = tn.rand(shape1, ranks_tt=2, ranks_tucker=2)
        t2 = tn.rand(shape2, ranks_tt=2)
        gt = np.concatenate([t1.numpy(), t2.numpy()], mode)
        assert np.linalg.norm(gt - tn.cat([t1, t2], dim=mode).numpy()) <= 1e-7
示例#5
0
def test_round_tucker():
    for i in range(100):
        eps = np.random.rand()**2
        gt = tn.rand([32] * 4, ranks_tt=8, ranks_tucker=8)
        t = gt.clone()
        t.round_tucker(eps=eps)
        assert tn.relative_error(gt, t) <= eps
示例#6
0
def test_batch():
    def check_one_tensor(t):
        x = t.numpy()

        idxs = []
        idxs.append(([0, 0, 0], None, None, 3))
        idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
        idxs.append((0, [0]))
        idxs.append(([0], None, None, None, 0, 1))
        idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
        idxs.append((slice(None), slice(None), slice(None), 0))
        idxs.append((slice(None), slice(None), [0, 1], 0))
        idxs.append((0, np.array([0]), None, 0))
        idxs.append((slice(None), slice(None), slice(None), slice(None), None))

        for idx in idxs:
            check(x, t, idx)

    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tucker=3, batch=True))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_cp=3, batch=True))

    with raises(ValueError) as exc_info:
        tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[None, ...]

    assert exc_info.type is ValueError

    with raises(ValueError) as exc_info2:
        tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[[0], [0]]

    assert exc_info2.type is ValueError
示例#7
0
def test_round_tt_eig():

    for i in range(100):
        gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)),
                     ranks_tt=np.random.randint(1, 10))
        gt.round_tt(1e-8, algorithm='eig')
        t = gt + gt
        t.round_tt(1e-8, algorithm='eig')
        assert tn.relative_error(gt, t / 2) <= 1e-7
示例#8
0
def test_slicing():

    t = tn.rand([1, 3, 1, 2, 1], ranks_tt=3, ranks_tucker=2)
    x = t.numpy()
    idx = slice(None)
    check(x, t, idx)
    idx = (slice(None), slice(1, None))
    check(x, t, idx)
    idx = (slice(None), slice(0, 2, None), slice(0, 1))
    check(x, t, idx)
示例#9
0
def test_cumsum():

    for i in range(100):
        N = np.random.randint(1, 4)
        howmany = 1
        modes = np.random.choice(N, howmany, replace=False)
        shape = np.random.randint(1, 10, N)
        t = tn.rand(shape, ranks_tt=2, ranks_tucker=2)
        assert np.linalg.norm(
            tn.cumsum(t, modes).numpy() - np.cumsum(t.numpy(), *modes)) <= 1e-7
示例#10
0
def test_orthogonalization():

    for i in range(100):
        gt = tn.rand(np.random.randint(1, 8, np.random.randint(2, 6)))
        t = gt.clone()
        assert tn.relative_error(gt, t) <= 1e-7
        t.left_orthogonalize(0)
        assert tn.relative_error(gt, t) <= 1e-7
        t.right_orthogonalize(t.dim() - 1)
        assert tn.relative_error(gt, t) <= 1e-7
        t.orthogonalize(np.random.randint(t.dim()))
        assert tn.relative_error(gt, t) <= 1e-7
示例#11
0
def test_mixed():
    def check_one_tensor(t):

        x = t.numpy()

        idxs = []
        idxs.append(([0, 0, 0], None, None, 3))
        idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
        idxs.append((0, [0]))
        idxs.append(([0], [0]))
        idxs.append(([0], None, None, None, 0, 1))
        idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
        idxs.append(([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
        idxs.append((slice(None), slice(None), slice(None), 0))
        idxs.append((slice(None), slice(None), [0, 1], 0))
        idxs.append((0, np.array([0]), None, 0))
        idxs.append((slice(None), slice(None), slice(None), slice(None), None))
        idxs.append(
            (None, slice(None), slice(None), slice(None), slice(None), None))
        idxs.append((None, slice(None), slice(None), slice(None), slice(None)))

        for idx in idxs:
            check(x, t, idx)

    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, ranks_tucker=2))
    check_one_tensor(
        tn.rand([6, 7, 8, 9], ranks_tt=None, ranks_tucker=2, ranks_cp=3))
    check_one_tensor(
        tn.rand([6, 7, 8, 9],
                ranks_tt=[4, None, None],
                ranks_tucker=2,
                ranks_cp=[None, None, 3, 3]))
    check_one_tensor(
        tn.rand([6, 7, 8, 9],
                ranks_tt=[4, None, None],
                ranks_tucker=[2, None, 2, None],
                ranks_cp=[None, None, 3, 3]))
    check_one_tensor(
        tn.rand([6, 7, 8, 9],
                ranks_tt=[None, 4, 4],
                ranks_tucker=2,
                ranks_cp=[3, None, None, None]))

    for i in range(100):
        check_one_tensor(random_format([6, 7, 8, 9]))

    t = tn.rand([6, 7, 8, 9], ranks_cp=[3, 3, 3, 3])
    t.cores[-1] = t.cores[-1].permute(1, 0)[:, :, None]
    check_one_tensor(t)

    t = tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)
    check(t.numpy(), t, 0)
    check(t.numpy(), t, [0, 1])
示例#12
0
def test_divergence():

    t = tn.rand([10] * 3 + [3], ranks_tt=3)
    d = tn.divergence([t[..., 0], t[..., 1], t[..., 2]])
    x = t.numpy()

    def partial(x, mode):
        return np.concatenate([np.diff(x, axis=mode),
                               np.zeros([sh for sh in x.shape[:mode]] + [1] + [sh for sh in x.shape[mode+1:]])],
                              axis=mode)

    gt = partial(x[..., 0], 0)
    gt += partial(x[..., 1], 1)
    gt += partial(x[..., 2], 2)
    assert np.linalg.norm(d.numpy() - gt) / np.linalg.norm(gt) <= 1e-7
示例#13
0
def test_tensors():

    for i in range(100):
        t = random_format([10] * 6)
        t2 = tn.cross(function=lambda x: x,
                      tensors=t,
                      ranks_tt=15,
                      verbose=False)
        assert tn.relative_error(t, t2) < 1e-6

    t = tn.rand([10] * 6, ranks_tt=10)
    _, info = tn.cross(function=lambda x: x,
                       tensors=[t],
                       ranks_tt=15,
                       verbose=False,
                       return_info=True)
    t2 = tn.cross_forward(info, function=lambda x: x, tensors=t)
    assert tn.relative_error(t, t2) < 1e-6
示例#14
0
def test_set_item():
    a = tn.rand((10, 5, 6), ranks_tt=3)
    b = a.torch()
    a[5, 2, 3] = 6
    b[5, 2, 3] = 6

    assert a[5, 2, 3] == b[5, 2, 3] and b[5, 2, 3] == 6

    a[5, 2, :] = 7
    b[5, 2, :] = 7

    assert torch.allclose(a[5, 2, :].torch(), b[5, 2, :]) and b[5, 2, 0] == 7

    a[..., :] = 8
    b[..., :] = 8

    assert torch.allclose(a[..., :].torch(), b[..., :]) and b[5, 2, 0] == 8

    a = tn.rand((10, 5, 6), ranks_tt=3)
    c = torch.zeros_like(b[:, 2, 0])
    i = torch.rand(c.shape)
    a[:, 2, 0] = i
    b[:, 2, 0] = i

    assert torch.allclose(a[:, 2, 0].torch(), b[:, 2, 0])

    c = torch.zeros_like(b[:, :, 0])
    add = torch.rand(c.shape)
    a[:, :, 0] = add
    b[:, :, 0] = add

    assert torch.allclose(a[:, :, 0].torch(), b[:, :, 0])

    c = torch.zeros_like(b[..., 3:5])
    add = torch.rand(c.shape)
    a[..., 3:5] = add
    b[..., 3:5] = add

    assert torch.allclose(a[..., 3:5].torch(), b[..., 3:5])

    a = tn.rand((10, 5, 6), ranks_tt=3)
    c = torch.zeros_like(b[2, :, 3:5])
    i = torch.rand(c.shape)
    a[2, :, 3:5] = i
    b[2, :, 3:5] = i

    assert torch.allclose(a[2, :, 3:5].torch(), b[2, :, 3:5])

    # batch
    a = tn.rand((10, 5, 6), ranks_tt=3, batch=True)
    b = a.torch()
    a[5] = 6
    b[5] = 6

    assert torch.allclose(a[5].torch(), b[5]) and a[5, 0, 0] == 6

    a = tn.rand((10, 5, 6), ranks_tt=3, batch=True)
    b = a.torch()
    a[5, 2, 3] = 6
    b[5, 2, 3] = 6

    assert a[5, 2, 3] == b[5, 2, 3] and b[5, 2, 3] == 6

    a[5, 2, :] = 7
    b[5, 2, :] = 7

    assert torch.allclose(a[5, 2, :].torch(), b[5, 2, :]) and b[5, 2, 0] == 7

    a[..., :] = 8
    b[..., :] = 8

    assert torch.allclose(a[..., :].torch(), b[..., :]) and b[5, 2, 0] == 8

    a = tn.rand((10, 5, 6), ranks_tt=3, batch=True)
    c = torch.zeros_like(b[:, 2, 0])
    i = torch.rand(c.shape)
    a[:, 2, 0] = i
    b[:, 2, 0] = i

    assert torch.allclose(a[:, 2, 0], b[:, 2, 0])

    c = torch.zeros_like(b[:, :, 0])
    add = torch.rand(c.shape)
    a[:, :, 0] = add
    b[:, :, 0] = add

    assert torch.allclose(a[:, :, 0].torch(), b[:, :, 0])

    c = torch.zeros_like(b[..., 3:5])
    add = torch.rand(c.shape)
    a[..., 3:5] = add
    b[..., 3:5] = add

    assert torch.allclose(a[..., 3:5].torch(), b[..., 3:5])

    a = tn.rand((10, 5, 6), ranks_tt=3, batch=True)
    c = torch.zeros_like(b[2, :, 3:5])
    i = torch.rand(c.shape)
    a[2, :, 3:5] = i
    b[2, :, 3:5] = i

    assert torch.allclose(a[2, :, 3:5].torch(), b[2, :, 3:5])
示例#15
0
    def closed_index(self, variables):
        return self._get_index(variables, self.cst)

    def total_index(self, variables):
        return self._get_index(variables, self.tst)

    # def directional_covariance(self, variables):
    #     return self._get_index(variables, self.dircov)


if __name__ == '__main__':  # For example and testing purposes
    torch.manual_seed(0)
    I = 32
    N = 4
    R = 5
    t = tn.rand([I] * N, ranks_tt=R)
    ind = AllIndices(t)

    s = tn.symbols(N)

    print('Variance component example:')
    print(ind.variance_component([0, 1]))
    print(tn.sobol(t, mask=tn.only(s[0] & s[1])).item())
    print()

    print('Superset index example:')
    print(ind.superset_index([0, 1]))
    print(tn.sobol(t, mask=s[0] & s[1]).item())
    print()

    print('Closed index example:')
示例#16
0
def test_indexing():
    a = tn.rand((10, 5, 6), ranks_tt=3)
    b = a.torch()

    assert torch.allclose(a[None].torch(), b[None])
    assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])

    a = tn.rand((10, 5, 6), ranks_cp=3)
    b = a.torch()

    assert torch.allclose(a[None].torch(), b[None])
    assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])

    a = tn.rand((10, 5, 6), ranks_tucker=3)
    b = a.torch()

    assert torch.allclose(a[None].torch(), b[None])
    assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])

    a = tn.rand((10, 5, 6), ranks_tt=3, ranks_tucker=3)
    b = a.torch()

    assert torch.allclose(a[None].torch(), b[None])
    assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])

    a = tn.rand((10, 5, 6), ranks_cp=3, ranks_tucker=3)
    b = a.torch()

    assert torch.allclose(a[None].torch(), b[None])
    assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])
    assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1])

    a = tn.rand((10, 5, 6), ranks_tt=3, batch=True)
    b = a.torch()

    with pytest.raises(ValueError) as exc_info:
        a[None].torch(), b[None]
    assert exc_info.value.args[0] == 'Cannot change batch dimension'

    assert torch.allclose(a[..., None].torch(), b[..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[..., 1].torch(), b[..., 1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])

    a = tn.rand((10, 5, 6), ranks_tucker=3, batch=True)
    b = a.torch()

    with pytest.raises(ValueError) as exc_info:
        a[None].torch(), b[None]
    assert exc_info.value.args[0] == 'Cannot change batch dimension'

    assert torch.allclose(a[..., None].torch(), b[..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[..., 1].torch(), b[..., 1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])

    a = tn.rand((10, 5, 6), ranks_cp=3, batch=True)
    b = a.torch()

    with pytest.raises(ValueError) as exc_info:
        a[None].torch(), b[None]
    assert exc_info.value.args[0] == 'Cannot change batch dimension'

    assert torch.allclose(a[..., None].torch(), b[..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[..., 1].torch(), b[..., 1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])

    a = tn.rand((10, 5, 6), ranks_cp=3, ranks_tucker=3, batch=True)
    b = a.torch()

    with pytest.raises(ValueError) as exc_info:
        a[None].torch(), b[None]
    assert exc_info.value.args[0] == 'Cannot change batch dimension'

    assert torch.allclose(a[..., None].torch(), b[..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[..., 1].torch(), b[..., 1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])

    a = tn.rand((10, 5, 6), ranks_tt=3, ranks_tucker=3, batch=True)
    b = a.torch()

    with pytest.raises(ValueError) as exc_info:
        a[None].torch(), b[None]
    assert exc_info.value.args[0] == 'Cannot change batch dimension'

    assert torch.allclose(a[..., None].torch(), b[..., None])
    assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1])
    assert torch.allclose(a[..., 1].torch(), b[..., 1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])
    assert torch.allclose(a[..., -1].torch(), b[..., -1])
示例#17
0
def test_mul():
    a = tn.rand((10, 5, 6), ranks_tt=3)
    b = tn.rand((10, 5, 6), ranks_tt=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_cp=3)
    b = tn.rand((10, 5, 6), ranks_cp=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tucker=3)
    b = tn.rand((10, 5, 6), ranks_tucker=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tucker=3, ranks_cp=3)
    b = tn.rand((10, 5, 6), ranks_tucker=3, ranks_cp=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tucker=3, ranks_tt=3)
    b = tn.rand((10, 5, 6), ranks_tucker=3, ranks_tt=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tt=3)
    b = tn.rand((10, 5, 6), ranks_tucker=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tt=3)
    b = tn.rand((10, 5, 6), ranks_cp=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tucker=3)
    b = tn.rand((10, 5, 6), ranks_cp=3)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    with pytest.raises(ValueError) as exc_info:
        a = tn.rand((10, 5, 6), ranks_cp=3, ranks_tt=3)
    assert exc_info.value.args[
        0] == 'The ranks_tt and ranks_cp provided are incompatible'

    a = tn.rand((10, 5, 6), ranks_tt=3, batch=True)
    b = tn.rand((10, 5, 6), ranks_tt=3, batch=True)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_cp=3, batch=True)
    b = tn.rand((10, 5, 6), ranks_cp=3, batch=True)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tucker=3, batch=True)
    b = tn.rand((10, 5, 6), ranks_tucker=3, batch=True)

    assert torch.allclose((a * b).torch(), a.torch() * b.torch())

    a = tn.rand((10, 5, 6), ranks_tt=3, batch=True)
    b = 5
    assert torch.allclose((a + b).torch(), a.torch() + b)
示例#18
0
def als_completion(X,
                   y,
                   ranks_tt,
                   shape=None,
                   ws=None,
                   x0=None,
                   niter=10,
                   verbose=True):
    """
    Complete an N-dimensional TT from P samples using alternating least squares (ALS).
    We assume only low-rank structure, and no smoothness/spatial locality. Such assumption requires that there is at
    least one sample for each tensor hyperslice. This is usually the case for categorical variables, since it is
    meaningless to have a class or label for which no instances exist.

    Note that this method may not converge (or be extremely slow to do so) if the number of available samples is below
    or near a certain proportion of the overall tensor. Such proportion, unfortunately, depends on the data set and its
    true rank structure
    ("Riemannian optimization for high-dimensional tensor completion", M. Steinlechner, 2015)

    :param X: a P X N matrix of integers (tensor indices)
    :param y: a vector with P elements
    :param ranks_tt: an integer (or list). Ignored if x0 is given
    :param shape: list of N integers. If None, the smallest shape that accommodates `X` will be chosen
    :param ws: a vector with P elements, with the weight of each sample (if None, 1 is assumed)
    :param x0: initial solution (a TT tensor). If None, a random tensor will be used
    :param niter: number of ALS sweeps. Default is 10
    :param verbose:
    :return: a `tntorch.Tensor'
    """

    assert not X.dtype.is_floating_point
    assert X.dim() == 2
    assert y.dim() == 1
    if ws is None:
        ws = torch.ones(len(y))
    X = X.long()
    if shape is None:
        shape = [val.item() for val in torch.max(X, dim=0)[0] + 1]
    y = y.to(torch.get_default_dtype())
    P = X.shape[0]
    N = X.shape[1]
    if x0 is None:
        x0 = tn.rand(shape, ranks_tt=ranks_tt)
    # All tensor slices must contain at least one sample point
    for dim in range(N):
        if torch.unique(X[:, dim]).numel() != x0.shape[dim]:
            raise ValueError(
                'One groundtruth sample is needed for every tensor slice')

    if verbose:
        print('Completing a {}D tensor of size {} using {} samples...'.format(
            N, list(shape), P))

    normy = torch.norm(y)
    x0.orthogonalize(0)
    cores = x0.cores

    # Memoized product chains for all groundtruth points
    # lefts will be initialized on the go
    lefts = [torch.ones(1, P, x0.cores[n].shape[0]) for n in range(N)]
    # rights, however, needs to be initialized now
    rights = [None] * N
    rights[-1] = torch.ones(1, P, 1)
    for dim in range(N - 2, -1, -1):
        rights[dim] = torch.einsum(
            'ijk,kjl->ijl',
            (cores[dim + 1][:, X[:, dim + 1], :], rights[dim + 1]))

    def optimize_core(cores, mu, direction):
        sse = 0
        for index in range(cores[mu].shape[1]):
            idx = torch.where(X[:, mu] == index)[0]
            leftside = lefts[mu][0, idx, :]
            rightside = rights[mu][:, idx, 0]
            lhs = rightside.t()[:, :, None]
            rhs = leftside[:, None, :]
            A = torch.reshape(lhs * rhs, [len(idx), -1]) * ws[idx, None]
            b = y[idx] * ws[idx]
            sol = torch.lstsq(b, A)[0][:A.shape[1], :]
            residuals = torch.norm(A.matmul(sol)[:, 0] - b)**2
            cores[mu][:, index, :] = torch.reshape(
                sol, cores[mu][:, index, :].shape)  #.t()
            sse += residuals
        # Update product chains for next core
        if direction == 'right':
            x0.left_orthogonalize(mu)
            lefts[mu + 1] = torch.einsum(
                'ijk,kjl->ijl', (lefts[mu], cores[mu][:, X[:, mu], :]))
        else:
            x0.right_orthogonalize(mu)
            rights[mu - 1] = torch.einsum(
                'ijk,kjl->ijl', (cores[mu][:, X[:, mu], :], rights[mu]))
        return sse

    start = time.time()
    for swp in range(niter):

        # Sweep: left-to-right
        for mu in range(N - 1):
            optimize_core(cores, mu, direction="right")

        # Sweep: right-to-left
        for mu in range(N - 1, 0, -1):
            sse = optimize_core(cores, mu, direction="left")
        eps = torch.sqrt(sse) / normy

        if verbose:
            print('iter: {: <{}}'.format(swp,
                                         len('{}'.format(niter)) + 1),
                  end='')
            print('| eps: {:.3e}'.format(eps), end='')
            print(' | time: {:8.4f}'.format(time.time() - start))

    return x0