def test_ops(): x, y, z, w = tn.meshgrid([32] * 4) t = x + y + z + w + 1 assert tn.relative_error(1 / t.torch(), 1 / t) < 1e-4 assert tn.relative_error(torch.cos(t.torch()), tn.cos(t)) < 1e-4 assert tn.relative_error(torch.exp(t.torch()), tn.exp(t)) < 1e-4
def cross_forward( info, function=lambda x: x, domain=None, tensors=None, function_arg='vectors', return_info=False): """ Given TT-cross indices and a black-box function (to be evaluated on an arbitrary grid), computes a differentiable TT tensor as given by the TT-cross interpolation formula. Reference: I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_ :param info: dictionary with the indices returned by `tntorch.cross()` :param function: a function $\mathbb{R}^M \to \mathbb{R}$, as in `tntorch.cross()` :param domain: domain where `function` will be evaluated on, as in `tntorch.cross()` :param tensors: list of $M$ TT tensors where `function` will be evaluated on :param function_arg: type of argument accepted by `function`. See `tntorch.cross()` :param return_info: Boolean, if True, will also return a dictionary with informative metrics about the algorithm's outcome :return: a TT :class:`Tensor`(if `return_info`=True, also a dictionary) """ assert domain is not None or tensors is not None assert function_arg in ('vectors', 'matrix') device = None if function_arg == 'matrix': def f(*args): return function(torch.cat([arg[:, None] for arg in args], dim=1)) else: f = function if tensors is None: tensors = tn.meshgrid(domain) device = domain[0].device if not hasattr(tensors, '__len__'): tensors = [tensors] Is = list(tensors[0].shape) N = len(Is) # Load index information from dictionary lsets = info['lsets'] rsets = info['rsets'] left_locals = info['left_locals'] Rs = info['Rs'] if return_info: info['Xs'] = torch.zeros(0, N) info['shapes'] = [] assert function_arg in ('vectors', 'matrix') if function_arg == 'matrix': def f(*args): return function(torch.cat([arg[:, None] for arg in args], dim=1)) else: f = function t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device) def evaluate_function(j): # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j] Xs = [] for k, t in enumerate(tensors): V = torch.einsum('ai,ibj,jc->abc', [t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j]]) Xs.append(V.flatten()) evaluation = f(*Xs) if return_info: info['Xs'] = torch.cat((info['Xs'], torch.cat([x[:, None] for x in Xs], dim=1).detach().cpu()), dim=0) info['shapes'].append([Rs[j], Is[j], Rs[j + 1]]) V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]]) return V cores = [] # Cross-interpolation formula, left-to-right for j in range(0, N-1): # Update tensors for current indices V = evaluate_function(j) V = torch.reshape(V, [-1, V.shape[2]]) # Left unfolding A = V[left_locals[j], :] X = torch.linalg.lstsq(A.t(), V.t()).solution.t() cores.append(X.reshape(Rs[j], Is[j], Rs[j + 1])) # Map local indices to global ones local_r, local_i = np.unravel_index(left_locals[j], [Rs[j], Is[j]]) lsets[j + 1] = np.c_[lsets[j][local_r, :], local_i] for k, t in enumerate(tensors): t_linterfaces[k][j + 1] = torch.einsum('ai,iaj->aj', [t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :]]) # Leave the last core ready X = evaluate_function(N-1) cores.append(X) if return_info: return tn.Tensor(cores), info else: return tn.Tensor(cores)
def cross(function, domain=None, tensors=None, function_arg='vectors', ranks_tt=None, kickrank=3, rmax=100, eps=1e-6, max_iter=25, val_size=1000, verbose=True, return_info=False, _minimize=False): """ Cross-approximation routine that samples a black-box function and returns an N-dimensional tensor train approximating it. It accepts either: - A domain (tensor product of :math:`N` given arrays) and a function :math:`\\mathbb{R}^N \\to \\mathbb{R}` - A list of :math:`K` tensors of dimension :math:`N` and equal shape and a function :math:`\\mathbb{R}^K \\to \\mathbb{R}` :Examples: >>> tn.cross(function=lambda x: x**2, tensors=[t]) # Compute the element-wise square of `t` using 5 TT-ranks >>> domain = [torch.linspace(-1, 1, 32)]*5 >>> tn.cross(function=lambda x, y, z, t, w: x**2 + y*z + torch.cos(t + w), domain=domain) # Approximate a function over the rectangle :math:`[-1, 1]^5` >>> tn.cross(function=lambda x: torch.sum(x**2, dim=1), domain=domain, function_arg='matrix') # An example where the function accepts a matrix References: - I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_ - D. Savostyanov, I. Oseledets: `"Fast Adaptive Interpolation of Multi-dimensional Arrays in Tensor Train Format" (2011) <https://ieeexplore.ieee.org/document/6076873>`_ - S. Dolgov, R. Scheichl: `"A Hybrid Alternating Least Squares - TT Cross Algorithm for Parametric PDEs" (2018) <https://arxiv.org/pdf/1707.04562.pdf>`_ - A. Mikhalev's `maxvolpy package <https://bitbucket.org/muxas/maxvolpy>`_ - I. Oseledets (and others)'s `ttpy package <https://github.com/oseledets/ttpy>`_ :param function: should produce a vector of :math:`P` elements. Accepts either :math:`N` comma-separated vectors, or a matrix (see `function_arg`) :param domain: a list of :math:`N` vectors (incompatible with `tensors`) :param tensors: a :class:`Tensor` or list thereof (incompatible with `domain`) :param function_arg: if 'vectors', `function` accepts :math:`N` vectors of length :math:`P` each. If 'matrix', a matrix of shape :math:`P \\times N`. :param ranks_tt: int or list of :math:`N-1` ints. If None, will be determined adaptively :param kickrank: when adaptively found, ranks will be increased by this amount after every iteration (full sweep left-to-right and right-to-left) :param rmax: this rank will not be surpassed :param eps: the procedure will stop after this validation error is met (as measured after each iteration) :param max_iter: int :param val_size: size of the validation set :param verbose: default is True :param return_info: if True, will also return a dictionary with informative metrics about the algorithm's outcome :return: an N-dimensional TT :class:`Tensor` (if `return_info`=True, also a dictionary) """ try: import maxvolpy.maxvol except ModuleNotFoundError: raise ModuleNotFoundError( "Functions that require cross-approximation require the optional maxvolpy package, which can be installed by 'pip install maxvolpy'. More info is available at https://bitbucket.org/muxas/maxvolpy" ) assert domain is not None or tensors is not None assert function_arg in ('vectors', 'matrix') if function_arg == 'matrix': def f(*args): return function(torch.cat([arg[:, None] for arg in args], dim=1)) else: f = function if tensors is None: tensors = tn.meshgrid(domain) if not hasattr(tensors, '__len__'): tensors = [tensors] tensors = [t.decompress_tucker_factors(_clone=False) for t in tensors] Is = list(tensors[0].shape) N = len(Is) # Process ranks and cap them, if needed if ranks_tt is None: ranks_tt = 1 else: kickrank = None if not hasattr(ranks_tt, '__len__'): ranks_tt = [ranks_tt] * (N - 1) ranks_tt = [1] + list(ranks_tt) + [1] Rs = np.array(ranks_tt) for n in list(range(1, N)) + list(range(N - 1, -1, -1)): Rs[n] = min(Rs[n - 1] * Is[n - 1], Rs[n], Is[n] * Rs[n + 1]) # Initialize cores at random cores = [torch.randn(Rs[n], Is[n], Rs[n + 1]) for n in range(N)] # Prepare left and right sets lsets = [np.array([[0]])] + [None] * (N - 1) randint = np.hstack( [np.random.randint(0, Is[n + 1], [max(Rs), 1]) for n in range(N - 1)] + [np.zeros([max(Rs), 1], dtype=np.int)]) rsets = [randint[:Rs[n + 1], n:] for n in range(N - 1)] + [np.array([[0]])] # Initialize left and right interfaces for `tensors` def init_interfaces(): t_linterfaces = [] t_rinterfaces = [] for t in tensors: linterfaces = [torch.ones(1, t.ranks_tt[0])] + [None] * (N - 1) rinterfaces = [None] * (N - 1) + [ torch.ones(t.ranks_tt[t.dim()], 1) ] for j in range(N - 1): M = torch.ones(t.cores[-1].shape[-1], len(rsets[j])) for n in range(N - 1, j, -1): if t.cores[n].dim() == 3: # TT core M = torch.einsum( 'iaj,ja->ia', (t.cores[n][:, rsets[j][:, n - 1 - j], :], M)) else: # CP factor M = torch.einsum( 'ai,ia->ia', (t.cores[n][rsets[j][:, n - 1 - j], :], M)) rinterfaces[j] = M t_linterfaces.append(linterfaces) t_rinterfaces.append(rinterfaces) return t_linterfaces, t_rinterfaces t_linterfaces, t_rinterfaces = init_interfaces() # Create a validation set Xs_val = [torch.as_tensor(np.random.choice(I, val_size)) for I in Is] ys_val = f(*[t[Xs_val].torch() for t in tensors]) if ys_val.dim() > 1: assert ys_val.dim() == 2 assert ys_val.shape[1] == 1 ys_val = ys_val[:, 0] assert len(ys_val) == val_size norm_ys_val = torch.norm(ys_val) if verbose: print( 'Cross-approximation over a {}D domain containing {:g} grid points:' .format(N, tensors[0].numel())) start = time.time() converged = False info = { 'nsamples': 0, 'eval_time': 0, 'val_epss': [], 'min': 0, 'argmin': None } def evaluate_function( j ): # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j] Xs = [] for k, t in enumerate(tensors): if tensors[k].cores[j].dim() == 3: # TT core V = torch.einsum('ai,ibj,jc->abc', (t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j])) else: # CP factor V = torch.einsum('ai,bi,ic->abc', (t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j])) Xs.append(V.flatten()) eval_start = time.time() evaluation = f(*Xs) info['eval_time'] += time.time() - eval_start if _minimize: evaluation = np.pi / 2 - torch.atan( evaluation - info['min'] ) # Function used by I. Oseledets for TT minimization in ttpy evaluation_argmax = torch.argmax(evaluation) eval_min = torch.tan(np.pi / 2 - evaluation[evaluation_argmax]) + info['min'] if info['min'] == 0 or eval_min < info['min']: coords = np.unravel_index(evaluation_argmax, [Rs[j], Is[j], Rs[j + 1]]) info['min'] = eval_min info['argmin'] = tuple(lsets[j][coords[0]][1:]) + tuple( [coords[1]]) + tuple(rsets[j][coords[2]][:-1]) # Check for nan/inf values if evaluation.dim() == 2: evaluation = evaluation[:, 0] invalid = (torch.isnan(evaluation) | torch.isinf(evaluation)).nonzero() if len(invalid) > 0: invalid = invalid[0].item() raise ValueError( 'Invalid return value for function {}: f({}) = {}'.format( function, ', '.join('{:g}'.format(x[invalid].numpy()) for x in Xs), f(*[x[invalid:invalid + 1][:, None] for x in Xs]).item())) V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]]) info['nsamples'] += V.numel() return V # Sweeps for i in range(max_iter): if verbose: print('iter: {: <{}}'.format(i, len('{}'.format(max_iter)) + 1), end='') sys.stdout.flush() left_locals = [] # Left-to-right for j in range(0, N - 1): # Update tensors for current indices V = evaluate_function(j) # QR + maxvol towards the right V = torch.reshape(V, [-1, V.shape[2]]) # Left unfolding Q, R = torch.qr(V) if _minimize: local, _ = maxvolpy.maxvol.rect_maxvol(Q.detach().numpy(), maxK=Q.shape[1]) else: local, _ = maxvolpy.maxvol.maxvol(Q.detach().numpy()) V = torch.gels(Q.t(), Q[local, :].t())[0].t() cores[j] = torch.reshape(V, [Rs[j], Is[j], Rs[j + 1]]) left_locals.append(local) # Map local indices to global ones local_r, local_i = np.unravel_index(local, [Rs[j], Is[j]]) lsets[j + 1] = np.c_[lsets[j][local_r, :], local_i] for k, t in enumerate(tensors): if t.cores[j].dim() == 3: # TT core t_linterfaces[k][j + 1] = torch.einsum( 'ai,iaj->aj', (t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :])) else: # CP factor t_linterfaces[k][j + 1] = torch.einsum( 'ai,ai->ai', (t_linterfaces[k][j][local_r, :], t.cores[j][local_i, :])) # Right-to-left sweep for j in range(N - 1, 0, -1): # Update tensors for current indices V = evaluate_function(j) # QR + maxvol towards the left V = torch.reshape(V, [Rs[j], -1]) # Right unfolding Q, R = torch.qr(V.t()) if _minimize: local, _ = maxvolpy.maxvol.rect_maxvol(Q.detach().numpy(), maxK=Q.shape[1]) else: local, _ = maxvolpy.maxvol.maxvol(Q.detach().numpy()) V = torch.gels(Q.t(), Q[local, :].t())[0] cores[j] = torch.reshape(torch.as_tensor(V), [Rs[j], Is[j], Rs[j + 1]]) # Map local indices to global ones local_i, local_r = np.unravel_index(local, [Is[j], Rs[j + 1]]) rsets[j - 1] = np.c_[local_i, rsets[j][local_r, :]] for k, t in enumerate(tensors): if t.cores[j].dim() == 3: # TT core t_rinterfaces[k][j - 1] = torch.einsum( 'iaj,ja->ia', (t.cores[j][:, local_i, :], t_rinterfaces[k][j][:, local_r])) else: # CP factor t_rinterfaces[k][j - 1] = torch.einsum( 'ai,ia->ia', (t.cores[j][local_i, :], t_rinterfaces[k][j][:, local_r])) # Leave the first core ready V = evaluate_function(0) cores[0] = V # Evaluate validation error val_eps = torch.norm(ys_val - tn.Tensor(cores)[Xs_val].torch()) / norm_ys_val info['val_epss'].append(val_eps) if val_eps < eps: converged = True if verbose: # Print status print('| eps: {:.3e}'.format(val_eps), end='') print(' | total time: {:8.4f} | largest rank: {:3d}'.format( time.time() - start, max(Rs)), end='') if converged: print(' <- converged: eps < {}'.format(eps)) elif i == max_iter - 1: print(' <- max_iter was reached: {}'.format(max_iter)) else: print() if converged: break elif i < max_iter - 1 and kickrank is not None: # Augment ranks newRs = Rs.copy() newRs[1:-1] = np.minimum(rmax, newRs[1:-1] + kickrank) for n in list(range(1, N)) + list(range(N - 1, 0, -1)): newRs[n] = min(newRs[n - 1] * Is[n - 1], newRs[n], Is[n] * newRs[n + 1]) extra = np.hstack([ np.random.randint(0, Is[n + 1], [max(newRs), 1]) for n in range(N - 1) ] + [np.zeros([max(newRs), 1], dtype=np.int)]) for n in range(N - 1): if newRs[n + 1] > Rs[n + 1]: rsets[n] = np.vstack( [rsets[n], extra[:newRs[n + 1] - Rs[n + 1], n:]]) Rs = newRs t_linterfaces, t_rinterfaces = init_interfaces( ) # Recompute interfaces if val_eps > eps and not _minimize: logging.warning( 'eps={:g} (larger than {}) when cross-approximating {}'.format( val_eps, eps, function)) if verbose: print( 'Did {} function evaluations, which took {:.4g}s ({:.4g} evals/s)'. format(info['nsamples'], info['eval_time'], info['nsamples'] / info['eval_time'])) print() if return_info: info['lsets'] = lsets info['rsets'] = rsets info['left_locals'] = left_locals info['total_time'] = time.time() - start info['val_eps'] = val_eps return tn.Tensor([torch.Tensor(c) for c in cores]), info else: return tn.Tensor([torch.Tensor(c) for c in cores])
self.result = result def index(self, variables): """ Compute the covariance index w.r.t. given variables. :param variables: a list of integers :return: a real number """ if len(variables) == 0: return 0 if not all(np.unique(variables) == np.array(variables)): raise ValueError('There are repeated variables') # Read out the desired index idx = np.zeros(self.N, dtype=np.int) idx[np.array(variables)] = 1 return self.result[tuple(idx)].item() # Examples of use if __name__ == '__main__': x, y = tn.meshgrid(32, 32) model = x + y # This is the function f(x, y) = x+y dc = DirectionalCovariance(model) print('Index with respect to {x}:', dc.index([0])) print('Index with respect to {y}:', dc.index([1])) print('Index with respect to {x, y}:', dc.index([0, 1]))