def test_ops(): for i in range(100): t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=3, ranks_tucker=2) t2 = tn.rand(t1.shape) check(t1, t2) shape = [8] * 4 t1 = tn.rand(shape, ranks_tt=[3, None, None], ranks_cp=[None, None, 2, 2], ranks_tucker=5) t2 = tn.rand(shape, ranks_tt=[None, 2, None], ranks_cp=[4, None, None, 3]) check(t1, t2) t2 = t1 * 2 check(t1, t2) for i in range(100): t1 = random_format(shape) t2 = random_format(shape) check(t1, t2)
def test_dot(): def check(): x1 = t1.torch() x2 = t2.torch() gt = torch.dot(x1.flatten(), x2.flatten()) assert tn.relative_error(tn.dot(t1, t2), gt) <= 1e-7 t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=None) t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=None) check() t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=4) t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=None) check() t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=None) t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=4) check() t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=3) t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=4) check() t1 = tn.rand([32] * 4, ranks_tt=[3, None, None], ranks_cp=[None, None, 10, 10], ranks_tucker=5) t2 = tn.rand([32]*4, ranks_tt=[None, 2, None], ranks_cp=[4, None, None, 5]) check() shape = [8]*4 for i in range(100): t1 = random_format(shape) t2 = random_format(shape) check()
def test_round_tucker(): a = tn.rand((10, 5, 6), ranks_tucker=3) b = a.clone() a.round_tucker(eps=1e-8) assert torch.norm(b.torch() - a.torch()) < 1e-8 a = tn.rand((10, 5, 6), ranks_tucker=3, batch=True) b = a.clone() a.round_tucker(eps=1e-8) assert torch.norm(b.torch() - a.torch()) < 1e-8
def test_cat(): for i in range(100): N = np.random.randint(1, 4) shape1 = np.random.randint(1, 10, N) mode = np.random.randint(N) shape2 = shape1.copy() shape2[mode] = np.random.randint(1, 10) t1 = tn.rand(shape1, ranks_tt=2, ranks_tucker=2) t2 = tn.rand(shape2, ranks_tt=2) gt = np.concatenate([t1.numpy(), t2.numpy()], mode) assert np.linalg.norm(gt - tn.cat([t1, t2], dim=mode).numpy()) <= 1e-7
def test_round_tucker(): for i in range(100): eps = np.random.rand()**2 gt = tn.rand([32] * 4, ranks_tt=8, ranks_tucker=8) t = gt.clone() t.round_tucker(eps=eps) assert tn.relative_error(gt, t) <= eps
def test_batch(): def check_one_tensor(t): x = t.numpy() idxs = [] idxs.append(([0, 0, 0], None, None, 3)) idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0)) idxs.append((0, [0])) idxs.append(([0], None, None, None, 0, 1)) idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])) idxs.append((slice(None), slice(None), slice(None), 0)) idxs.append((slice(None), slice(None), [0, 1], 0)) idxs.append((0, np.array([0]), None, 0)) idxs.append((slice(None), slice(None), slice(None), slice(None), None)) for idx in idxs: check(x, t, idx) check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)) check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tucker=3, batch=True)) check_one_tensor(tn.rand([6, 7, 8, 9], ranks_cp=3, batch=True)) with raises(ValueError) as exc_info: tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[None, ...] assert exc_info.type is ValueError with raises(ValueError) as exc_info2: tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[[0], [0]] assert exc_info2.type is ValueError
def test_round_tt_eig(): for i in range(100): gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)), ranks_tt=np.random.randint(1, 10)) gt.round_tt(1e-8, algorithm='eig') t = gt + gt t.round_tt(1e-8, algorithm='eig') assert tn.relative_error(gt, t / 2) <= 1e-7
def test_slicing(): t = tn.rand([1, 3, 1, 2, 1], ranks_tt=3, ranks_tucker=2) x = t.numpy() idx = slice(None) check(x, t, idx) idx = (slice(None), slice(1, None)) check(x, t, idx) idx = (slice(None), slice(0, 2, None), slice(0, 1)) check(x, t, idx)
def test_cumsum(): for i in range(100): N = np.random.randint(1, 4) howmany = 1 modes = np.random.choice(N, howmany, replace=False) shape = np.random.randint(1, 10, N) t = tn.rand(shape, ranks_tt=2, ranks_tucker=2) assert np.linalg.norm( tn.cumsum(t, modes).numpy() - np.cumsum(t.numpy(), *modes)) <= 1e-7
def test_orthogonalization(): for i in range(100): gt = tn.rand(np.random.randint(1, 8, np.random.randint(2, 6))) t = gt.clone() assert tn.relative_error(gt, t) <= 1e-7 t.left_orthogonalize(0) assert tn.relative_error(gt, t) <= 1e-7 t.right_orthogonalize(t.dim() - 1) assert tn.relative_error(gt, t) <= 1e-7 t.orthogonalize(np.random.randint(t.dim())) assert tn.relative_error(gt, t) <= 1e-7
def test_mixed(): def check_one_tensor(t): x = t.numpy() idxs = [] idxs.append(([0, 0, 0], None, None, 3)) idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0)) idxs.append((0, [0])) idxs.append(([0], [0])) idxs.append(([0], None, None, None, 0, 1)) idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])) idxs.append(([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])) idxs.append((slice(None), slice(None), slice(None), 0)) idxs.append((slice(None), slice(None), [0, 1], 0)) idxs.append((0, np.array([0]), None, 0)) idxs.append((slice(None), slice(None), slice(None), slice(None), None)) idxs.append( (None, slice(None), slice(None), slice(None), slice(None), None)) idxs.append((None, slice(None), slice(None), slice(None), slice(None))) for idx in idxs: check(x, t, idx) check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, ranks_tucker=2)) check_one_tensor( tn.rand([6, 7, 8, 9], ranks_tt=None, ranks_tucker=2, ranks_cp=3)) check_one_tensor( tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=2, ranks_cp=[None, None, 3, 3])) check_one_tensor( tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=[2, None, 2, None], ranks_cp=[None, None, 3, 3])) check_one_tensor( tn.rand([6, 7, 8, 9], ranks_tt=[None, 4, 4], ranks_tucker=2, ranks_cp=[3, None, None, None])) for i in range(100): check_one_tensor(random_format([6, 7, 8, 9])) t = tn.rand([6, 7, 8, 9], ranks_cp=[3, 3, 3, 3]) t.cores[-1] = t.cores[-1].permute(1, 0)[:, :, None] check_one_tensor(t) t = tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True) check(t.numpy(), t, 0) check(t.numpy(), t, [0, 1])
def test_divergence(): t = tn.rand([10] * 3 + [3], ranks_tt=3) d = tn.divergence([t[..., 0], t[..., 1], t[..., 2]]) x = t.numpy() def partial(x, mode): return np.concatenate([np.diff(x, axis=mode), np.zeros([sh for sh in x.shape[:mode]] + [1] + [sh for sh in x.shape[mode+1:]])], axis=mode) gt = partial(x[..., 0], 0) gt += partial(x[..., 1], 1) gt += partial(x[..., 2], 2) assert np.linalg.norm(d.numpy() - gt) / np.linalg.norm(gt) <= 1e-7
def test_tensors(): for i in range(100): t = random_format([10] * 6) t2 = tn.cross(function=lambda x: x, tensors=t, ranks_tt=15, verbose=False) assert tn.relative_error(t, t2) < 1e-6 t = tn.rand([10] * 6, ranks_tt=10) _, info = tn.cross(function=lambda x: x, tensors=[t], ranks_tt=15, verbose=False, return_info=True) t2 = tn.cross_forward(info, function=lambda x: x, tensors=t) assert tn.relative_error(t, t2) < 1e-6
def test_set_item(): a = tn.rand((10, 5, 6), ranks_tt=3) b = a.torch() a[5, 2, 3] = 6 b[5, 2, 3] = 6 assert a[5, 2, 3] == b[5, 2, 3] and b[5, 2, 3] == 6 a[5, 2, :] = 7 b[5, 2, :] = 7 assert torch.allclose(a[5, 2, :].torch(), b[5, 2, :]) and b[5, 2, 0] == 7 a[..., :] = 8 b[..., :] = 8 assert torch.allclose(a[..., :].torch(), b[..., :]) and b[5, 2, 0] == 8 a = tn.rand((10, 5, 6), ranks_tt=3) c = torch.zeros_like(b[:, 2, 0]) i = torch.rand(c.shape) a[:, 2, 0] = i b[:, 2, 0] = i assert torch.allclose(a[:, 2, 0].torch(), b[:, 2, 0]) c = torch.zeros_like(b[:, :, 0]) add = torch.rand(c.shape) a[:, :, 0] = add b[:, :, 0] = add assert torch.allclose(a[:, :, 0].torch(), b[:, :, 0]) c = torch.zeros_like(b[..., 3:5]) add = torch.rand(c.shape) a[..., 3:5] = add b[..., 3:5] = add assert torch.allclose(a[..., 3:5].torch(), b[..., 3:5]) a = tn.rand((10, 5, 6), ranks_tt=3) c = torch.zeros_like(b[2, :, 3:5]) i = torch.rand(c.shape) a[2, :, 3:5] = i b[2, :, 3:5] = i assert torch.allclose(a[2, :, 3:5].torch(), b[2, :, 3:5]) # batch a = tn.rand((10, 5, 6), ranks_tt=3, batch=True) b = a.torch() a[5] = 6 b[5] = 6 assert torch.allclose(a[5].torch(), b[5]) and a[5, 0, 0] == 6 a = tn.rand((10, 5, 6), ranks_tt=3, batch=True) b = a.torch() a[5, 2, 3] = 6 b[5, 2, 3] = 6 assert a[5, 2, 3] == b[5, 2, 3] and b[5, 2, 3] == 6 a[5, 2, :] = 7 b[5, 2, :] = 7 assert torch.allclose(a[5, 2, :].torch(), b[5, 2, :]) and b[5, 2, 0] == 7 a[..., :] = 8 b[..., :] = 8 assert torch.allclose(a[..., :].torch(), b[..., :]) and b[5, 2, 0] == 8 a = tn.rand((10, 5, 6), ranks_tt=3, batch=True) c = torch.zeros_like(b[:, 2, 0]) i = torch.rand(c.shape) a[:, 2, 0] = i b[:, 2, 0] = i assert torch.allclose(a[:, 2, 0], b[:, 2, 0]) c = torch.zeros_like(b[:, :, 0]) add = torch.rand(c.shape) a[:, :, 0] = add b[:, :, 0] = add assert torch.allclose(a[:, :, 0].torch(), b[:, :, 0]) c = torch.zeros_like(b[..., 3:5]) add = torch.rand(c.shape) a[..., 3:5] = add b[..., 3:5] = add assert torch.allclose(a[..., 3:5].torch(), b[..., 3:5]) a = tn.rand((10, 5, 6), ranks_tt=3, batch=True) c = torch.zeros_like(b[2, :, 3:5]) i = torch.rand(c.shape) a[2, :, 3:5] = i b[2, :, 3:5] = i assert torch.allclose(a[2, :, 3:5].torch(), b[2, :, 3:5])
def closed_index(self, variables): return self._get_index(variables, self.cst) def total_index(self, variables): return self._get_index(variables, self.tst) # def directional_covariance(self, variables): # return self._get_index(variables, self.dircov) if __name__ == '__main__': # For example and testing purposes torch.manual_seed(0) I = 32 N = 4 R = 5 t = tn.rand([I] * N, ranks_tt=R) ind = AllIndices(t) s = tn.symbols(N) print('Variance component example:') print(ind.variance_component([0, 1])) print(tn.sobol(t, mask=tn.only(s[0] & s[1])).item()) print() print('Superset index example:') print(ind.superset_index([0, 1])) print(tn.sobol(t, mask=s[0] & s[1]).item()) print() print('Closed index example:')
def test_indexing(): a = tn.rand((10, 5, 6), ranks_tt=3) b = a.torch() assert torch.allclose(a[None].torch(), b[None]) assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) a = tn.rand((10, 5, 6), ranks_cp=3) b = a.torch() assert torch.allclose(a[None].torch(), b[None]) assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) a = tn.rand((10, 5, 6), ranks_tucker=3) b = a.torch() assert torch.allclose(a[None].torch(), b[None]) assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) a = tn.rand((10, 5, 6), ranks_tt=3, ranks_tucker=3) b = a.torch() assert torch.allclose(a[None].torch(), b[None]) assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) a = tn.rand((10, 5, 6), ranks_cp=3, ranks_tucker=3) b = a.torch() assert torch.allclose(a[None].torch(), b[None]) assert torch.allclose(a[None, ..., None].torch(), b[None, ..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[None, ..., 1].torch(), b[None, ..., 1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) assert torch.allclose(a[None, ..., -1].torch(), b[None, ..., -1]) a = tn.rand((10, 5, 6), ranks_tt=3, batch=True) b = a.torch() with pytest.raises(ValueError) as exc_info: a[None].torch(), b[None] assert exc_info.value.args[0] == 'Cannot change batch dimension' assert torch.allclose(a[..., None].torch(), b[..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[..., 1].torch(), b[..., 1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) a = tn.rand((10, 5, 6), ranks_tucker=3, batch=True) b = a.torch() with pytest.raises(ValueError) as exc_info: a[None].torch(), b[None] assert exc_info.value.args[0] == 'Cannot change batch dimension' assert torch.allclose(a[..., None].torch(), b[..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[..., 1].torch(), b[..., 1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) a = tn.rand((10, 5, 6), ranks_cp=3, batch=True) b = a.torch() with pytest.raises(ValueError) as exc_info: a[None].torch(), b[None] assert exc_info.value.args[0] == 'Cannot change batch dimension' assert torch.allclose(a[..., None].torch(), b[..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[..., 1].torch(), b[..., 1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) a = tn.rand((10, 5, 6), ranks_cp=3, ranks_tucker=3, batch=True) b = a.torch() with pytest.raises(ValueError) as exc_info: a[None].torch(), b[None] assert exc_info.value.args[0] == 'Cannot change batch dimension' assert torch.allclose(a[..., None].torch(), b[..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[..., 1].torch(), b[..., 1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) a = tn.rand((10, 5, 6), ranks_tt=3, ranks_tucker=3, batch=True) b = a.torch() with pytest.raises(ValueError) as exc_info: a[None].torch(), b[None] assert exc_info.value.args[0] == 'Cannot change batch dimension' assert torch.allclose(a[..., None].torch(), b[..., None]) assert torch.allclose(a[0, ..., 1].torch(), b[0, ..., 1]) assert torch.allclose(a[..., 1].torch(), b[..., 1]) assert torch.allclose(a[..., -1].torch(), b[..., -1]) assert torch.allclose(a[..., -1].torch(), b[..., -1])
def test_mul(): a = tn.rand((10, 5, 6), ranks_tt=3) b = tn.rand((10, 5, 6), ranks_tt=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_cp=3) b = tn.rand((10, 5, 6), ranks_cp=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tucker=3) b = tn.rand((10, 5, 6), ranks_tucker=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tucker=3, ranks_cp=3) b = tn.rand((10, 5, 6), ranks_tucker=3, ranks_cp=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tucker=3, ranks_tt=3) b = tn.rand((10, 5, 6), ranks_tucker=3, ranks_tt=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tt=3) b = tn.rand((10, 5, 6), ranks_tucker=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tt=3) b = tn.rand((10, 5, 6), ranks_cp=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tucker=3) b = tn.rand((10, 5, 6), ranks_cp=3) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) with pytest.raises(ValueError) as exc_info: a = tn.rand((10, 5, 6), ranks_cp=3, ranks_tt=3) assert exc_info.value.args[ 0] == 'The ranks_tt and ranks_cp provided are incompatible' a = tn.rand((10, 5, 6), ranks_tt=3, batch=True) b = tn.rand((10, 5, 6), ranks_tt=3, batch=True) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_cp=3, batch=True) b = tn.rand((10, 5, 6), ranks_cp=3, batch=True) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tucker=3, batch=True) b = tn.rand((10, 5, 6), ranks_tucker=3, batch=True) assert torch.allclose((a * b).torch(), a.torch() * b.torch()) a = tn.rand((10, 5, 6), ranks_tt=3, batch=True) b = 5 assert torch.allclose((a + b).torch(), a.torch() + b)
def als_completion(X, y, ranks_tt, shape=None, ws=None, x0=None, niter=10, verbose=True): """ Complete an N-dimensional TT from P samples using alternating least squares (ALS). We assume only low-rank structure, and no smoothness/spatial locality. Such assumption requires that there is at least one sample for each tensor hyperslice. This is usually the case for categorical variables, since it is meaningless to have a class or label for which no instances exist. Note that this method may not converge (or be extremely slow to do so) if the number of available samples is below or near a certain proportion of the overall tensor. Such proportion, unfortunately, depends on the data set and its true rank structure ("Riemannian optimization for high-dimensional tensor completion", M. Steinlechner, 2015) :param X: a P X N matrix of integers (tensor indices) :param y: a vector with P elements :param ranks_tt: an integer (or list). Ignored if x0 is given :param shape: list of N integers. If None, the smallest shape that accommodates `X` will be chosen :param ws: a vector with P elements, with the weight of each sample (if None, 1 is assumed) :param x0: initial solution (a TT tensor). If None, a random tensor will be used :param niter: number of ALS sweeps. Default is 10 :param verbose: :return: a `tntorch.Tensor' """ assert not X.dtype.is_floating_point assert X.dim() == 2 assert y.dim() == 1 if ws is None: ws = torch.ones(len(y)) X = X.long() if shape is None: shape = [val.item() for val in torch.max(X, dim=0)[0] + 1] y = y.to(torch.get_default_dtype()) P = X.shape[0] N = X.shape[1] if x0 is None: x0 = tn.rand(shape, ranks_tt=ranks_tt) # All tensor slices must contain at least one sample point for dim in range(N): if torch.unique(X[:, dim]).numel() != x0.shape[dim]: raise ValueError( 'One groundtruth sample is needed for every tensor slice') if verbose: print('Completing a {}D tensor of size {} using {} samples...'.format( N, list(shape), P)) normy = torch.norm(y) x0.orthogonalize(0) cores = x0.cores # Memoized product chains for all groundtruth points # lefts will be initialized on the go lefts = [torch.ones(1, P, x0.cores[n].shape[0]) for n in range(N)] # rights, however, needs to be initialized now rights = [None] * N rights[-1] = torch.ones(1, P, 1) for dim in range(N - 2, -1, -1): rights[dim] = torch.einsum( 'ijk,kjl->ijl', (cores[dim + 1][:, X[:, dim + 1], :], rights[dim + 1])) def optimize_core(cores, mu, direction): sse = 0 for index in range(cores[mu].shape[1]): idx = torch.where(X[:, mu] == index)[0] leftside = lefts[mu][0, idx, :] rightside = rights[mu][:, idx, 0] lhs = rightside.t()[:, :, None] rhs = leftside[:, None, :] A = torch.reshape(lhs * rhs, [len(idx), -1]) * ws[idx, None] b = y[idx] * ws[idx] sol = torch.lstsq(b, A)[0][:A.shape[1], :] residuals = torch.norm(A.matmul(sol)[:, 0] - b)**2 cores[mu][:, index, :] = torch.reshape( sol, cores[mu][:, index, :].shape) #.t() sse += residuals # Update product chains for next core if direction == 'right': x0.left_orthogonalize(mu) lefts[mu + 1] = torch.einsum( 'ijk,kjl->ijl', (lefts[mu], cores[mu][:, X[:, mu], :])) else: x0.right_orthogonalize(mu) rights[mu - 1] = torch.einsum( 'ijk,kjl->ijl', (cores[mu][:, X[:, mu], :], rights[mu])) return sse start = time.time() for swp in range(niter): # Sweep: left-to-right for mu in range(N - 1): optimize_core(cores, mu, direction="right") # Sweep: right-to-left for mu in range(N - 1, 0, -1): sse = optimize_core(cores, mu, direction="left") eps = torch.sqrt(sse) / normy if verbose: print('iter: {: <{}}'.format(swp, len('{}'.format(niter)) + 1), end='') print('| eps: {:.3e}'.format(eps), end='') print(' | time: {:8.4f}'.format(time.time() - start)) return x0