Пример #1
0
def test_ops():

    x, y, z, w = tn.meshgrid([32] * 4)
    t = x + y + z + w + 1
    assert tn.relative_error(1 / t.torch(), 1 / t) < 1e-4
    assert tn.relative_error(torch.cos(t.torch()), tn.cos(t)) < 1e-4
    assert tn.relative_error(torch.exp(t.torch()), tn.exp(t)) < 1e-4
Пример #2
0
def cross_forward(
        info,
        function=lambda x: x,
        domain=None,
        tensors=None,
        function_arg='vectors',
        return_info=False):
    """
    Given TT-cross indices and a black-box function (to be evaluated on an arbitrary grid), computes a differentiable TT tensor as given by the TT-cross interpolation formula.
    Reference: I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_
    :param info: dictionary with the indices returned by `tntorch.cross()`
    :param function: a function $\mathbb{R}^M \to \mathbb{R}$, as in `tntorch.cross()`
    :param domain: domain where `function` will be evaluated on, as in `tntorch.cross()`
    :param tensors: list of $M$ TT tensors where `function` will be evaluated on
    :param function_arg: type of argument accepted by `function`. See `tntorch.cross()`
    :param return_info: Boolean, if True, will also return a dictionary with informative metrics about the algorithm's outcome
    :return: a TT :class:`Tensor`(if `return_info`=True, also a dictionary)
    """

    assert domain is not None or tensors is not None
    assert function_arg in ('vectors', 'matrix')
    device = None
    if function_arg == 'matrix':
        def f(*args):
            return function(torch.cat([arg[:, None] for arg in args], dim=1))
    else:
        f = function
    if tensors is None:
        tensors = tn.meshgrid(domain)
        device = domain[0].device
    if not hasattr(tensors, '__len__'):
        tensors = [tensors]

    Is = list(tensors[0].shape)
    N = len(Is)

    # Load index information from dictionary
    lsets = info['lsets']
    rsets = info['rsets']
    left_locals = info['left_locals']
    Rs = info['Rs']

    if return_info:
        info['Xs'] = torch.zeros(0, N)
        info['shapes'] = []

    assert function_arg in ('vectors', 'matrix')
    if function_arg == 'matrix':
        def f(*args):
            return function(torch.cat([arg[:, None] for arg in args], dim=1))
    else:
        f = function

    t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device)

    def evaluate_function(j):  # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j]
        Xs = []
        for k, t in enumerate(tensors):
            V = torch.einsum('ai,ibj,jc->abc', [t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j]])
            Xs.append(V.flatten())

        evaluation = f(*Xs)

        if return_info:
            info['Xs'] = torch.cat((info['Xs'], torch.cat([x[:, None] for x in Xs], dim=1).detach().cpu()), dim=0)
            info['shapes'].append([Rs[j], Is[j], Rs[j + 1]])

        V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]])
        return V

    cores = []

    # Cross-interpolation formula, left-to-right
    for j in range(0, N-1):

        # Update tensors for current indices
        V = evaluate_function(j)
        V = torch.reshape(V, [-1, V.shape[2]])  # Left unfolding
        A = V[left_locals[j], :]
        X = torch.linalg.lstsq(A.t(), V.t()).solution.t()

        cores.append(X.reshape(Rs[j], Is[j], Rs[j + 1]))

        # Map local indices to global ones
        local_r, local_i = np.unravel_index(left_locals[j], [Rs[j], Is[j]])
        lsets[j + 1] = np.c_[lsets[j][local_r, :], local_i]
        for k, t in enumerate(tensors):
            t_linterfaces[k][j + 1] = torch.einsum('ai,iaj->aj',
                                                       [t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :]])

    # Leave the last core ready
    X = evaluate_function(N-1)
    cores.append(X)

    if return_info:
        return tn.Tensor(cores), info
    else:
        return tn.Tensor(cores)
Пример #3
0
def cross(function,
          domain=None,
          tensors=None,
          function_arg='vectors',
          ranks_tt=None,
          kickrank=3,
          rmax=100,
          eps=1e-6,
          max_iter=25,
          val_size=1000,
          verbose=True,
          return_info=False,
          _minimize=False):
    """
    Cross-approximation routine that samples a black-box function and returns an N-dimensional tensor train approximating it. It accepts either:

    - A domain (tensor product of :math:`N` given arrays) and a function :math:`\\mathbb{R}^N \\to \\mathbb{R}`
    - A list of :math:`K` tensors of dimension :math:`N` and equal shape and a function :math:`\\mathbb{R}^K \\to \\mathbb{R}`

    :Examples:

    >>> tn.cross(function=lambda x: x**2, tensors=[t])  # Compute the element-wise square of `t` using 5 TT-ranks

    >>> domain = [torch.linspace(-1, 1, 32)]*5
    >>> tn.cross(function=lambda x, y, z, t, w: x**2 + y*z + torch.cos(t + w), domain=domain)  # Approximate a function over the rectangle :math:`[-1, 1]^5`

    >>> tn.cross(function=lambda x: torch.sum(x**2, dim=1), domain=domain, function_arg='matrix')  # An example where the function accepts a matrix

    References:

    - I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_
    - D. Savostyanov, I. Oseledets: `"Fast Adaptive Interpolation of Multi-dimensional Arrays in Tensor Train Format" (2011) <https://ieeexplore.ieee.org/document/6076873>`_
    - S. Dolgov, R. Scheichl: `"A Hybrid Alternating Least Squares - TT Cross Algorithm for Parametric PDEs" (2018) <https://arxiv.org/pdf/1707.04562.pdf>`_
    - A. Mikhalev's `maxvolpy package <https://bitbucket.org/muxas/maxvolpy>`_
    - I. Oseledets (and others)'s `ttpy package <https://github.com/oseledets/ttpy>`_

    :param function: should produce a vector of :math:`P` elements. Accepts either :math:`N` comma-separated vectors, or a matrix (see `function_arg`)
    :param domain: a list of :math:`N` vectors (incompatible with `tensors`)
    :param tensors: a :class:`Tensor` or list thereof (incompatible with `domain`)
    :param function_arg: if 'vectors', `function` accepts :math:`N` vectors of length :math:`P` each. If 'matrix', a matrix of shape :math:`P \\times N`.
    :param ranks_tt: int or list of :math:`N-1` ints. If None, will be determined adaptively
    :param kickrank: when adaptively found, ranks will be increased by this amount after every iteration (full sweep left-to-right and right-to-left)
    :param rmax: this rank will not be surpassed
    :param eps: the procedure will stop after this validation error is met (as measured after each iteration)
    :param max_iter: int
    :param val_size: size of the validation set
    :param verbose: default is True
    :param return_info: if True, will also return a dictionary with informative metrics about the algorithm's outcome

    :return: an N-dimensional TT :class:`Tensor` (if `return_info`=True, also a dictionary)
    """

    try:
        import maxvolpy.maxvol
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "Functions that require cross-approximation require the optional maxvolpy package, which can be installed by 'pip install maxvolpy'. More info is available at https://bitbucket.org/muxas/maxvolpy"
        )

    assert domain is not None or tensors is not None
    assert function_arg in ('vectors', 'matrix')
    if function_arg == 'matrix':

        def f(*args):
            return function(torch.cat([arg[:, None] for arg in args], dim=1))
    else:
        f = function
    if tensors is None:
        tensors = tn.meshgrid(domain)
    if not hasattr(tensors, '__len__'):
        tensors = [tensors]
    tensors = [t.decompress_tucker_factors(_clone=False) for t in tensors]
    Is = list(tensors[0].shape)
    N = len(Is)

    # Process ranks and cap them, if needed
    if ranks_tt is None:
        ranks_tt = 1
    else:
        kickrank = None
    if not hasattr(ranks_tt, '__len__'):
        ranks_tt = [ranks_tt] * (N - 1)
    ranks_tt = [1] + list(ranks_tt) + [1]
    Rs = np.array(ranks_tt)
    for n in list(range(1, N)) + list(range(N - 1, -1, -1)):
        Rs[n] = min(Rs[n - 1] * Is[n - 1], Rs[n], Is[n] * Rs[n + 1])

    # Initialize cores at random
    cores = [torch.randn(Rs[n], Is[n], Rs[n + 1]) for n in range(N)]

    # Prepare left and right sets
    lsets = [np.array([[0]])] + [None] * (N - 1)
    randint = np.hstack(
        [np.random.randint(0, Is[n + 1], [max(Rs), 1])
         for n in range(N - 1)] + [np.zeros([max(Rs), 1], dtype=np.int)])
    rsets = [randint[:Rs[n + 1], n:] for n in range(N - 1)] + [np.array([[0]])]

    # Initialize left and right interfaces for `tensors`
    def init_interfaces():
        t_linterfaces = []
        t_rinterfaces = []
        for t in tensors:
            linterfaces = [torch.ones(1, t.ranks_tt[0])] + [None] * (N - 1)
            rinterfaces = [None] * (N - 1) + [
                torch.ones(t.ranks_tt[t.dim()], 1)
            ]
            for j in range(N - 1):
                M = torch.ones(t.cores[-1].shape[-1], len(rsets[j]))
                for n in range(N - 1, j, -1):
                    if t.cores[n].dim() == 3:  # TT core
                        M = torch.einsum(
                            'iaj,ja->ia',
                            (t.cores[n][:, rsets[j][:, n - 1 - j], :], M))
                    else:  # CP factor
                        M = torch.einsum(
                            'ai,ia->ia',
                            (t.cores[n][rsets[j][:, n - 1 - j], :], M))
                rinterfaces[j] = M
            t_linterfaces.append(linterfaces)
            t_rinterfaces.append(rinterfaces)
        return t_linterfaces, t_rinterfaces

    t_linterfaces, t_rinterfaces = init_interfaces()

    # Create a validation set
    Xs_val = [torch.as_tensor(np.random.choice(I, val_size)) for I in Is]
    ys_val = f(*[t[Xs_val].torch() for t in tensors])
    if ys_val.dim() > 1:
        assert ys_val.dim() == 2
        assert ys_val.shape[1] == 1
        ys_val = ys_val[:, 0]
    assert len(ys_val) == val_size
    norm_ys_val = torch.norm(ys_val)

    if verbose:
        print(
            'Cross-approximation over a {}D domain containing {:g} grid points:'
            .format(N, tensors[0].numel()))
    start = time.time()
    converged = False

    info = {
        'nsamples': 0,
        'eval_time': 0,
        'val_epss': [],
        'min': 0,
        'argmin': None
    }

    def evaluate_function(
            j
    ):  # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j]
        Xs = []
        for k, t in enumerate(tensors):
            if tensors[k].cores[j].dim() == 3:  # TT core
                V = torch.einsum('ai,ibj,jc->abc',
                                 (t_linterfaces[k][j], tensors[k].cores[j],
                                  t_rinterfaces[k][j]))
            else:  # CP factor
                V = torch.einsum('ai,bi,ic->abc',
                                 (t_linterfaces[k][j], tensors[k].cores[j],
                                  t_rinterfaces[k][j]))
            Xs.append(V.flatten())

        eval_start = time.time()
        evaluation = f(*Xs)
        info['eval_time'] += time.time() - eval_start
        if _minimize:
            evaluation = np.pi / 2 - torch.atan(
                evaluation - info['min']
            )  # Function used by I. Oseledets for TT minimization in ttpy
            evaluation_argmax = torch.argmax(evaluation)
            eval_min = torch.tan(np.pi / 2 -
                                 evaluation[evaluation_argmax]) + info['min']
            if info['min'] == 0 or eval_min < info['min']:
                coords = np.unravel_index(evaluation_argmax,
                                          [Rs[j], Is[j], Rs[j + 1]])
                info['min'] = eval_min
                info['argmin'] = tuple(lsets[j][coords[0]][1:]) + tuple(
                    [coords[1]]) + tuple(rsets[j][coords[2]][:-1])

        # Check for nan/inf values
        if evaluation.dim() == 2:
            evaluation = evaluation[:, 0]
        invalid = (torch.isnan(evaluation) | torch.isinf(evaluation)).nonzero()
        if len(invalid) > 0:
            invalid = invalid[0].item()
            raise ValueError(
                'Invalid return value for function {}: f({}) = {}'.format(
                    function,
                    ', '.join('{:g}'.format(x[invalid].numpy()) for x in Xs),
                    f(*[x[invalid:invalid + 1][:, None] for x in Xs]).item()))

        V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]])
        info['nsamples'] += V.numel()
        return V

    # Sweeps
    for i in range(max_iter):

        if verbose:
            print('iter: {: <{}}'.format(i,
                                         len('{}'.format(max_iter)) + 1),
                  end='')
            sys.stdout.flush()

        left_locals = []

        # Left-to-right
        for j in range(0, N - 1):

            # Update tensors for current indices
            V = evaluate_function(j)

            # QR + maxvol towards the right
            V = torch.reshape(V, [-1, V.shape[2]])  # Left unfolding
            Q, R = torch.qr(V)
            if _minimize:
                local, _ = maxvolpy.maxvol.rect_maxvol(Q.detach().numpy(),
                                                       maxK=Q.shape[1])
            else:
                local, _ = maxvolpy.maxvol.maxvol(Q.detach().numpy())
            V = torch.gels(Q.t(), Q[local, :].t())[0].t()
            cores[j] = torch.reshape(V, [Rs[j], Is[j], Rs[j + 1]])
            left_locals.append(local)

            # Map local indices to global ones
            local_r, local_i = np.unravel_index(local, [Rs[j], Is[j]])
            lsets[j + 1] = np.c_[lsets[j][local_r, :], local_i]
            for k, t in enumerate(tensors):
                if t.cores[j].dim() == 3:  # TT core
                    t_linterfaces[k][j + 1] = torch.einsum(
                        'ai,iaj->aj', (t_linterfaces[k][j][local_r, :],
                                       t.cores[j][:, local_i, :]))
                else:  # CP factor
                    t_linterfaces[k][j + 1] = torch.einsum(
                        'ai,ai->ai', (t_linterfaces[k][j][local_r, :],
                                      t.cores[j][local_i, :]))

        # Right-to-left sweep
        for j in range(N - 1, 0, -1):

            # Update tensors for current indices
            V = evaluate_function(j)

            # QR + maxvol towards the left
            V = torch.reshape(V, [Rs[j], -1])  # Right unfolding
            Q, R = torch.qr(V.t())
            if _minimize:
                local, _ = maxvolpy.maxvol.rect_maxvol(Q.detach().numpy(),
                                                       maxK=Q.shape[1])
            else:
                local, _ = maxvolpy.maxvol.maxvol(Q.detach().numpy())
            V = torch.gels(Q.t(), Q[local, :].t())[0]
            cores[j] = torch.reshape(torch.as_tensor(V),
                                     [Rs[j], Is[j], Rs[j + 1]])

            # Map local indices to global ones
            local_i, local_r = np.unravel_index(local, [Is[j], Rs[j + 1]])
            rsets[j - 1] = np.c_[local_i, rsets[j][local_r, :]]
            for k, t in enumerate(tensors):
                if t.cores[j].dim() == 3:  # TT core
                    t_rinterfaces[k][j - 1] = torch.einsum(
                        'iaj,ja->ia', (t.cores[j][:, local_i, :],
                                       t_rinterfaces[k][j][:, local_r]))
                else:  # CP factor
                    t_rinterfaces[k][j - 1] = torch.einsum(
                        'ai,ia->ia',
                        (t.cores[j][local_i, :], t_rinterfaces[k][j][:,
                                                                     local_r]))

        # Leave the first core ready
        V = evaluate_function(0)
        cores[0] = V

        # Evaluate validation error
        val_eps = torch.norm(ys_val -
                             tn.Tensor(cores)[Xs_val].torch()) / norm_ys_val
        info['val_epss'].append(val_eps)
        if val_eps < eps:
            converged = True

        if verbose:  # Print status
            print('| eps: {:.3e}'.format(val_eps), end='')
            print(' | total time: {:8.4f} | largest rank: {:3d}'.format(
                time.time() - start, max(Rs)),
                  end='')
            if converged:
                print(' <- converged: eps < {}'.format(eps))
            elif i == max_iter - 1:
                print(' <- max_iter was reached: {}'.format(max_iter))
            else:
                print()
        if converged:
            break
        elif i < max_iter - 1 and kickrank is not None:  # Augment ranks
            newRs = Rs.copy()
            newRs[1:-1] = np.minimum(rmax, newRs[1:-1] + kickrank)
            for n in list(range(1, N)) + list(range(N - 1, 0, -1)):
                newRs[n] = min(newRs[n - 1] * Is[n - 1], newRs[n],
                               Is[n] * newRs[n + 1])
            extra = np.hstack([
                np.random.randint(0, Is[n + 1], [max(newRs), 1])
                for n in range(N - 1)
            ] + [np.zeros([max(newRs), 1], dtype=np.int)])
            for n in range(N - 1):
                if newRs[n + 1] > Rs[n + 1]:
                    rsets[n] = np.vstack(
                        [rsets[n], extra[:newRs[n + 1] - Rs[n + 1], n:]])
            Rs = newRs
            t_linterfaces, t_rinterfaces = init_interfaces(
            )  # Recompute interfaces

    if val_eps > eps and not _minimize:
        logging.warning(
            'eps={:g} (larger than {}) when cross-approximating {}'.format(
                val_eps, eps, function))

    if verbose:
        print(
            'Did {} function evaluations, which took {:.4g}s ({:.4g} evals/s)'.
            format(info['nsamples'], info['eval_time'],
                   info['nsamples'] / info['eval_time']))
        print()

    if return_info:
        info['lsets'] = lsets
        info['rsets'] = rsets
        info['left_locals'] = left_locals
        info['total_time'] = time.time() - start
        info['val_eps'] = val_eps
        return tn.Tensor([torch.Tensor(c) for c in cores]), info
    else:
        return tn.Tensor([torch.Tensor(c) for c in cores])
Пример #4
0
        self.result = result

    def index(self, variables):
        """
        Compute the covariance index w.r.t. given variables.

        :param variables: a list of integers
        :return: a real number
        """

        if len(variables) == 0:
            return 0
        if not all(np.unique(variables) == np.array(variables)):
            raise ValueError('There are repeated variables')

        # Read out the desired index
        idx = np.zeros(self.N, dtype=np.int)
        idx[np.array(variables)] = 1
        return self.result[tuple(idx)].item()


# Examples of use
if __name__ == '__main__':
    x, y = tn.meshgrid(32, 32)
    model = x + y  # This is the function f(x, y) = x+y
    dc = DirectionalCovariance(model)
    print('Index with respect to {x}:', dc.index([0]))
    print('Index with respect to {y}:', dc.index([1]))
    print('Index with respect to {x, y}:', dc.index([0, 1]))