Esempio n. 1
0
def gradients(func, x, name='t3f_gradients', runtime_check=True):
    """Riemannian autodiff: returns gradient projected on tangent space of TT.

  Computes projection of the gradient df/dx onto the tangent space of TT tensor
  at point x.

  Warning: this is experimental feature and it may not work for some function,
  e.g. ones that include QR or SVD decomposition (t3f.project, t3f.round) or
  for functions that work with TT-cores directly (in contrast to working with
  TT-object only via t3f functions). In this cases this function can silently
  return wrong results!

  Example:
      # Scalar product with some predefined tensor squared 0.5 * <x, t>**2.
      # It's gradient is <x, t> t and it's Riemannian gradient is
      #     t3f.project(<x, t> * t, x)
      f = lambda x: 0.5 * t3f.flat_inner(x, t)**2
      projected_grad = t3f.gradients(f, x) # t3f.project(t3f.flat_inner(x, t) * t, x)

  Args:
      func: function that takes TensorTrain object as input and outputs a number.
      x: point at which to compute the gradient and on which tangent space to
        project the gradient.
      name: string, name of the Op.
      runtime_check: [True] whether to do a sanity check that the passed
        function is invariant to different TT representations (otherwise
        the Rieamnnian gradient doesn't even exist). It makes things slower,
        but helps catching bugs, so turn it off during production deployment.

  Returns:
      `TensorTrain`, projection of the gradient df/dx onto the tangent space at
      point x.

  See also:
      t3f.hessian_vector_product
  """
    with tf.name_scope(name):
        left = decompositions.orthogonalize_tt_cores(x)
        right = decompositions.orthogonalize_tt_cores(left,
                                                      left_to_right=False)
        deltas = [right.tt_cores[0]]
        deltas += [tf.zeros_like(cc) for cc in right.tt_cores[1:]]

        def augmented_func(d):
            x_projection = riemannian.deltas_to_tangent_space(
                d, x, left, right)
            return func(x_projection)

        function_value, cores_grad = value_and_grad(augmented_func, deltas)
        if runtime_check:
            assert_op = _is_invariant_to_input_transforms(
                function_value, func(x))
        else:
            assert_op = tf.no_op()
        with tf.control_dependencies([assert_op]):
            deltas = _enforce_gauge_conditions(cores_grad, left)
        return riemannian.deltas_to_tangent_space(deltas, x, left, right)
Esempio n. 2
0
 def testOrthogonalizeLeftToRight(self):
     shape = (2, 4, 3, 3)
     tt_ranks = (1, 5, 2, 17, 1)
     updated_tt_ranks = (1, 2, 2, 6, 1)
     tens = initializers.random_tensor_batch(shape,
                                             tt_rank=tt_ranks,
                                             batch_size=2)
     orthogonal = decompositions.orthogonalize_tt_cores(tens)
     with self.test_session() as sess:
         tens_val, orthogonal_val = sess.run(
             [ops.full(tens), ops.full(orthogonal)])
         self.assertAllClose(tens_val, orthogonal_val, atol=1e-5, rtol=1e-5)
         dynamic_tt_ranks = shapes.tt_ranks(orthogonal).eval()
         self.assertAllEqual(updated_tt_ranks, dynamic_tt_ranks)
         # Check that the TT-cores are orthogonal.
         for core_idx in range(4 - 1):
             core_shape = (updated_tt_ranks[core_idx] * shape[core_idx],
                           updated_tt_ranks[core_idx + 1])
             for i in range(2):
                 core = tf.reshape(orthogonal.tt_cores[core_idx][i],
                                   core_shape)
                 should_be_eye = tf.matmul(tf.transpose(core), core)
                 should_be_eye_val = sess.run(should_be_eye)
                 self.assertAllClose(np.eye(updated_tt_ranks[core_idx + 1]),
                                     should_be_eye_val)
Esempio n. 3
0
 def testOrthogonalizeRightToLeft(self):
     shape = (2, 4, 3, 3)
     tt_ranks = (1, 5, 2, 17, 1)
     updated_tt_ranks = (1, 5, 2, 3, 1)
     tens = initializers.random_tensor(shape,
                                       tt_rank=tt_ranks,
                                       dtype=self.dtype)
     orthogonal = decompositions.orthogonalize_tt_cores(tens,
                                                        left_to_right=False)
     with self.test_session() as sess:
         tens_val, orthogonal_val = sess.run(
             [ops.full(tens), ops.full(orthogonal)])
         self.assertAllClose(tens_val, orthogonal_val, atol=1e-5, rtol=1e-5)
         dynamic_tt_ranks = shapes.tt_ranks(orthogonal).eval()
         self.assertAllEqual(updated_tt_ranks, dynamic_tt_ranks)
         # Check that the TT-cores are orthogonal.
         for core_idx in range(1, 4):
             core = orthogonal.tt_cores[core_idx]
             core = tf.reshape(
                 core, (updated_tt_ranks[core_idx],
                        shape[core_idx] * updated_tt_ranks[core_idx + 1]))
             should_be_eye = tf.matmul(core, tf.transpose(core))
             should_be_eye_val = sess.run(should_be_eye)
             self.assertAllClose(np.eye(updated_tt_ranks[core_idx]),
                                 should_be_eye_val)
Esempio n. 4
0
File: ops.py Progetto: towadroid/t3f
def frobenius_norm_squared(tt, differentiable=False,
                           name='t3f_frobenius_norm_squared'):
  """Frobenius norm squared of `TensorTrain` or of each TT in `TensorTrainBatch`.

  Frobenius norm squared is the sum of squares of all elements in a tensor.

  Args:
    tt: `TensorTrain` or `TensorTrainBatch` object
    differentiable: bool, whether to use a differentiable implementation
      or a fast and stable implementation based on QR decomposition.
    name: string, name of the Op.

  Returns
    a number which is the Frobenius norm squared of `tt`, if it is `TensorTrain`
    OR
    a Tensor of size tt.batch_size, consisting of the Frobenius norms squared of
    each TensorTrain in `tt`, if it is `TensorTrainBatch`
  """
  with tf.name_scope(name, values=tt.tt_cores):
    if differentiable:
      if hasattr(tt, 'batch_size'):
          bs_str = 'n'
      else:
          bs_str = ''
      if tt.is_tt_matrix():
        running_prod = tf.einsum('{0}aijb,{0}cijd->{0}bd'.format(bs_str),
                                 tt.tt_cores[0], tt.tt_cores[0])
      else:
        running_prod = tf.einsum('{0}aib,{0}cid->{0}bd'.format(bs_str),
                                 tt.tt_cores[0], tt.tt_cores[0])

      for core_idx in range(1, tt.ndims()):
        curr_core = tt.tt_cores[core_idx]
        if tt.is_tt_matrix():
          running_prod = tf.einsum('{0}ac,{0}aijb,{0}cijd->{0}bd'.format(bs_str),
                                   running_prod, curr_core, curr_core)
        else:
          running_prod = tf.einsum('{0}ac,{0}aib,{0}cid->{0}bd'.format(bs_str),
                                   running_prod, curr_core, curr_core)

      return tf.squeeze(running_prod, [-1, -2])

    else:
      orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True)
      # All the cores of orth_tt except the last one are orthogonal, hence
      # the Frobenius norm of orth_tt equals to the norm of the last core.
      if hasattr(tt, 'batch_size'):
        batch_size = shapes.lazy_batch_size(tt)
        last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1))
        return tf.norm(last_core, axis=1) ** 2
      else:
        return tf.norm(orth_tt.tt_cores[-1]) ** 2
Esempio n. 5
0
 def testOrthogonalizeLeftToRight(self):
   shape = (2, 4, 3, 3)
   tt_ranks = (1, 5, 2, 17, 1)
   updated_tt_ranks = (1, 2, 2, 6, 1)
   tens = initializers.random_tensor(shape, tt_rank=tt_ranks,
                                     dtype=self.dtype)
   orthogonal = decompositions.orthogonalize_tt_cores(tens)
   tens_val, orthogonal_val = self.evaluate([ops.full(tens), ops.full(orthogonal)])
   self.assertAllClose(tens_val, orthogonal_val, atol=1e-5, rtol=1e-5)
   dynamic_tt_ranks = self.evaluate(shapes.tt_ranks(orthogonal))
   self.assertAllEqual(updated_tt_ranks, dynamic_tt_ranks)
   # Check that the TT-cores are orthogonal.
   for core_idx in range(4 - 1):
     core = orthogonal.tt_cores[core_idx]
     core = tf.reshape(core, (updated_tt_ranks[core_idx] * shape[core_idx],
                              updated_tt_ranks[core_idx + 1]))
     should_be_eye = tf.matmul(tf.transpose(core), core)
     should_be_eye_val = self.evaluate(should_be_eye)
     self.assertAllClose(np.eye(updated_tt_ranks[core_idx + 1]),
                         should_be_eye_val)
Esempio n. 6
0
File: ops.py Progetto: vseledkin/t3f
def frobenius_norm_squared(tt, differentiable=False):
    """Frobenius norm squared of a TensorTrain (sum of squares of all elements).

  Args:
    tt: `TensorTrain` object
    differentiable: bool, whether to use a differentiable implementation
      or a fast and stable implementation based on QR decomposition.

  Returns
    a number
    sum of squares of all elements in `tt`
  """
    if differentiable:
        if tt.is_tt_matrix():
            running_prod = tf.einsum('aijb,cijd->bd', tt.tt_cores[0],
                                     tt.tt_cores[0])
        else:
            running_prod = tf.einsum('aib,cid->bd', tt.tt_cores[0],
                                     tt.tt_cores[0])

        for core_idx in range(1, tt.ndims()):
            curr_core = tt.tt_cores[core_idx]
            if tt.is_tt_matrix():
                running_prod = tf.einsum('ac,aijb,cijd->bd', running_prod,
                                         curr_core, curr_core)
            else:
                running_prod = tf.einsum('ac,aib,cid->bd', running_prod,
                                         curr_core, curr_core)
        return running_prod[0, 0]
    else:
        orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True)
        # All the cores of orth_tt except the last one are orthogonal, hence
        # the Frobenius norm of orth_tt equals to the norm of the last core.
        if hasattr(tt, 'batch_size'):
            batch_size = shapes.lazy_batch_size(tt)
            last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1))
            return tf.norm(last_core, axis=1)**2
        else:
            return tf.norm(orth_tt.tt_cores[-1])**2
Esempio n. 7
0
def hessian_vector_product(func,
                           x,
                           vector,
                           name='t3f_hessian_vector_product',
                           runtime_check=True):
    """P_x [d^2f/dx^2] P_x vector, i.e. Riemannian hessian by vector product.

    Computes
      P_x [d^2f/dx^2] P_x vector
    where P_x is projection onto the tangent space of TT at point x and
    d^2f/dx^2 is the Hessian of the function.

    Note that the true Riemannian hessian also includes the manifold curvature
    term which is ignored here.

    Warning: this is experimental feature and it may not work for some function,
    e.g. ones that include QR or SVD decomposition (t3f.project, t3f.round) or
    for functions that work with TT-cores directly (in contrast to working with
    TT-object only via t3f functions). In this cases this function can silently
    return wrong results!

    Example:
        # Quadratic form with matrix A: <x, A x>.
        # It's gradient is (A + A.T) x, it's Hessian is (A + A.T)
        # It's Riemannian Hessian by vector product is
        #     proj_vec = t3f.project(vector, x)
        #     t3f.project(t3f.matmul(A + t3f.transpose(A), proj_vec), x)
        f = lambda x: t3f.bilinear_form(A, x, x)
        res = t3f.hessian_vector_product(f, x, vector)

    Args:
        func: function that takes TensorTrain object as input and outputs a number.
        x: point at which to compute the Hessian and on which tangent space to
          project the gradient.
      vector: `TensorTrain` object which to multiply be the Hessian.
      name: string, name of the Op.
      runtime_check: [True] whether to do a sanity check that the passed
        function is invariant to different TT representations (otherwise
        the Rieamnnian gradient doesn't even exist). It makes things slower,
        but helps catching bugs, so turn it off during production deployment.

    Returns:
        `TensorTrain`, result of the Riemannian hessian by vector product.

    See also:
        t3f.gradients
    """
    all_cores = list(x.tt_cores) + list(vector.tt_cores)
    with tf.name_scope(name, values=all_cores):
        left = decompositions.orthogonalize_tt_cores(x)
        right = decompositions.orthogonalize_tt_cores(left,
                                                      left_to_right=False)
        deltas = [right.tt_cores[0]]
        deltas += [tf.zeros_like(cc) for cc in right.tt_cores[1:]]
        x_projection = riemannian.deltas_to_tangent_space(
            deltas, x, left, right)
        function_value = func(x_projection)
        if runtime_check:
            assert_op = _is_invariant_to_input_transforms(
                function_value, func(x))
        else:
            assert_op = tf.no_op()
        with tf.control_dependencies([assert_op]):
            vector_projected = riemannian.project(vector, x)
        cores_grad = tf.gradients(function_value, deltas)
        vec_deltas = riemannian.tangent_space_to_deltas(vector_projected)
        products = [
            tf.reduce_sum(a * b) for a, b in zip(cores_grad, vec_deltas)
        ]
        grad_times_vec = tf.add_n(products)
        second_cores_grad = tf.gradients(grad_times_vec, deltas)
        final_deltas = _enforce_gauge_conditions(second_cores_grad, left)
        return riemannian.deltas_to_tangent_space(final_deltas, x, left, right)
Esempio n. 8
0
def deltas_to_tangent_space(deltas,
                            tt,
                            left=None,
                            right=None,
                            name='t3f_deltas_to_tangent_space'):
    """Converts deltas representation of tangent space vector to TT object.

  Takes as input a list of [dP1, ..., dPd] and returns
    dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd.

  This function is hard to use correctly because deltas should abey the
  so called gauge conditions. If the don't, the function will silently return
  incorrect result. This is why this function is not imported in __init__.

  Args:
      deltas: a list of deltas (essentially TT-cores) obeying the gauge
        conditions.
      tt: `TensorTrain` object on which the tangent space tensor represented by
        delta is projected.
      left: t3f.orthogonilize_tt_cores(tt). If you have it already compute, you
        may pass it as argument to avoid recomputing.
      right: t3f.orthogonilize_tt_cores(left, left_to_right=False). If you have
        it already compute, you may pass it as argument to avoid recomputing.
      name: string, name of the Op.

  Returns:
      `TensorTrain` object constructed from deltas, that is from the tangent
        space at point `tt`.
  """
    cores = []
    dtype = tt.dtype
    num_dims = tt.ndims()
    # TODO: add cache instead of mannually pasisng precomputed stuff?
    input_tensors = list(tt.tt_cores) + list(deltas)
    if left is not None:
        input_tensors += list(left.tt_cores)
    if right is not None:
        input_tensors += list(right.tt_cores)
    with tf.name_scope(name):
        if left is None:
            left = decompositions.orthogonalize_tt_cores(tt)
        if right is None:
            right = decompositions.orthogonalize_tt_cores(left,
                                                          left_to_right=False)
        left_tangent_tt_ranks = shapes.lazy_tt_ranks(left)
        right_tangent_tt_ranks = shapes.lazy_tt_ranks(left)
        raw_shape = shapes.lazy_raw_shape(left)
        right_rank_dim = left.right_tt_rank_dim
        left_rank_dim = left.left_tt_rank_dim
        is_batch_case = len(deltas[0].shape) > len(tt.tt_cores[0].shape)
        if is_batch_case:
            right_rank_dim += 1
            left_rank_dim += 1
            batch_size = deltas[0].shape.as_list()[0]
        for i in range(num_dims):
            left_tt_core = left.tt_cores[i]
            right_tt_core = right.tt_cores[i]
            if is_batch_case:
                tile = [1] * len(left_tt_core.shape)
                tile = [batch_size] + tile
                left_tt_core = tf.tile(left_tt_core[None, ...], tile)
                right_tt_core = tf.tile(right_tt_core[None, ...], tile)

            if i == 0:
                tangent_core = tf.concat((deltas[i], left_tt_core),
                                         axis=right_rank_dim)
            elif i == num_dims - 1:
                tangent_core = tf.concat((right_tt_core, deltas[i]),
                                         axis=left_rank_dim)
            else:
                rank_1 = right_tangent_tt_ranks[i]
                rank_2 = left_tangent_tt_ranks[i + 1]
                if tt.is_tt_matrix():
                    mode_size_n = raw_shape[0][i]
                    mode_size_m = raw_shape[1][i]
                    shape = [rank_1, mode_size_n, mode_size_m, rank_2]
                else:
                    mode_size_n = raw_shape[0][i]
                    shape = [rank_1, mode_size_n, rank_2]
                if is_batch_case:
                    shape = [batch_size] + shape
                zeros = tf.zeros(shape, dtype=dtype)
                upper = tf.concat((right_tt_core, zeros), axis=right_rank_dim)
                lower = tf.concat((deltas[i], left_tt_core),
                                  axis=right_rank_dim)
                tangent_core = tf.concat((upper, lower), axis=left_rank_dim)
            cores.append(tangent_core)
        if is_batch_case:
            tangent = TensorTrainBatch(cores, batch_size=batch_size)
        else:
            tangent = TensorTrain(cores)
        tangent.projection_on = tt
        return tangent
Esempio n. 9
0
def project_sum(what, where, weights=None):
    """Project sum of `what` TTs on the tangent space of `where` TT.

  project_sum(what, x) = P_x(what)
  project_sum(batch_what, x) = P_x(\sum_i batch_what[i])
  project_sum(batch_what, x, weights) = P_x(\sum_j weights[j] * batch_what[j])

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      projection of the sum of elements in the batch.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    weights: python list or tf.Tensor of numbers or None, weights of the sum

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()

  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d r_what r_where n (r_what + r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    n is the size of the axis dimension of what and where e.g.
      for a tensor of size 4 x 4 x 4, n is 4;
      for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12
  """
    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    if weights is not None:
        weights = tf.convert_to_tensor(weights, dtype=where.dtype)

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    mode_str = 'ij' if where.is_tt_matrix() else 'i'
    right_rank_dim = where.right_tt_rank_dim
    left_rank_dim = where.left_tt_rank_dim
    if weights is not None:
        weights_shape = weights.get_shape()
        output_is_batch = len(weights_shape) > 1 and weights_shape[1] > 1
    else:
        output_is_batch = False
    output_batch_str = 'o' if output_is_batch else ''
    if output_is_batch:
        right_rank_dim += 1
        left_rank_dim += 1
        output_batch_size = weights.get_shape().as_list()[1]

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
        rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                                  right_tang_core)

    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
        lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx],
                                      left_tang_core, tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
            proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
            proj_core -= tf.einsum(einsum_str, left_tang_core,
                                   lhs[core_idx + 1])
            if weights is None:
                einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])
            else:
                einsum_str = 'sa{0}b,sbc->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, proj_core,
                                        rhs[core_idx + 1])
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if core_idx == ndims - 1:
            if weights is None:
                einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            else:
                einsum_str = 'sab,sb{0}c->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, lhs[core_idx], tens_core)
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            if where.is_tt_matrix():
                extended_left_tang_core = tf.tile(
                    extended_left_tang_core, [output_batch_size, 1, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1, 1])
            else:
                extended_left_tang_core = tf.tile(extended_left_tang_core,
                                                  [output_batch_size, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            if where.is_tt_matrix():
                mode_size_n = raw_shape[0][core_idx]
                mode_size_m = raw_shape[1][core_idx]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size = raw_shape[0][core_idx]
                shape = [rank_1, mode_size, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)
    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
Esempio n. 10
0
def project_matmul(what, where, matrix):
    """Project `matrix` * `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    matrix: TensorTrain, TT-matrix to multiply by what

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
      
  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d R r_what r_where (n r_what + n m R + m r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n).
  """

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    right_rank_dim = what.right_tt_rank_dim
    left_rank_dim = what.left_tt_rank_dim
    output_is_batch = isinstance(what, TensorTrainBatch)
    if output_is_batch:
        output_batch_size = what.batch_size

    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core,
                                  right_tang_core, rhs[core_idx + 1],
                                  tens_core)
    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        # TODO: brutforce order of indices in lhs??
        lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef',
                                      matrix_core, left_tang_core,
                                      lhs[core_idx], tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            proj_core = tf.einsum('scjke,sabc,bijd->saikde', tens_core,
                                  lhs[core_idx], matrix_core)
            proj_core -= tf.einsum('aikb,sbcd->saikcd', left_tang_core,
                                   lhs[core_idx + 1])
            proj_core = tf.einsum('saikcb,sbcd->saikd', proj_core,
                                  rhs[core_idx + 1])

        if core_idx == ndims - 1:
            # d and e dimensions take 1 value, since its the last rank.
            # To make the result shape (?, ?, ?, 1), we are summing d and leaving e,
            # but we could have done the opposite -- sum e and leave d.
            proj_core = tf.einsum('sabc,bijd,scjke->saike', lhs[core_idx],
                                  matrix_core, tens_core)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            extended_left_tang_core = tf.tile(extended_left_tang_core,
                                              [output_batch_size, 1, 1, 1, 1])
            extended_right_tang_core = tf.tile(extended_right_tang_core,
                                               [output_batch_size, 1, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            mode_size_n = raw_shape[0][core_idx]
            mode_size_m = raw_shape[1][core_idx]
            shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)

    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
Esempio n. 11
0
def project(what, where):
  """Project `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
  """

  if not isinstance(where, TensorTrain):
    raise ValueError('The first argument should be a TensorTrain object, got '
                     '"%s".' % where)

  if where.get_raw_shape() != what.get_raw_shape():
    raise ValueError('The shapes of the tensor we want to project and of the '
                     'tensor on which tangent space we want to project should '
                     'match, got %s and %s.' %
                     (where.get_raw_shape(),
                      what.get_raw_shape()))

  if not where.dtype.is_compatible_with(what.dtype):
    raise ValueError('Dtypes of the arguments should coincide, got %s and %s.' %
                     (where.dtype,
                      what.dtype))

  left_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    where)
  right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    left_tangent_space_tens, left_to_right=False)

  ndims = where.ndims()
  dtype = where.dtype
  raw_shape = shapes.lazy_raw_shape(where)
  batch_size = shapes.lazy_batch_size(what)
  right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
  left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

  # For einsum notation.
  mode_str = 'ij' if where.is_tt_matrix() else 'i'
  right_rank_dim = 3 if where.is_tt_matrix() else 2
  left_rank_dim = 0
  output_is_batch = isinstance(what, TensorTrainBatch)
  if output_is_batch:
    right_rank_dim += 1
    left_rank_dim = 1
    output_batch_size = what.batch_size

  # Always work with batch of TT objects for simplicity.
  what = shapes.expand_batch_dim(what)

  # Prepare rhs vectors.
  # rhs[core_idx] is of size
  #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
  rhs = [None] * (ndims + 1)
  rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1, 0, -1):
    tens_core = what.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
    rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                              right_tang_core)

  # Prepare lhs vectors.
  # lhs[core_idx] is of size
  #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
  lhs = [None] * (ndims + 1)
  lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
    lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core,
                                  tens_core)

  # Left to right sweep.
  res_cores_list = []
  for core_idx in range(ndims):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

    if core_idx < ndims - 1:
      einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
      einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
      proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1])
      if output_is_batch:
        einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])

    if core_idx == ndims - 1:
      if output_is_batch:
        einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)

    if output_is_batch:
      # Add batch dimension of size output_batch_size to left_tang_core and
      # right_tang_core
      extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
      extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
      if where.is_tt_matrix():
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1, 1])
      else:
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1])
    else:
      extended_left_tang_core = left_tang_core
      extended_right_tang_core = right_tang_core

    if core_idx == 0:
      res_core = tf.concat((proj_core, extended_left_tang_core),
                           axis=right_rank_dim)
    elif core_idx == ndims - 1:
      res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim)
    else:
      rank_1 = right_tangent_tt_ranks[core_idx]
      rank_2 = left_tangent_tt_ranks[core_idx + 1]
      if where.is_tt_matrix():
        mode_size_n = raw_shape[0][core_idx]
        mode_size_m = raw_shape[1][core_idx]
        shape = [rank_1, mode_size_n, mode_size_m, rank_2]
      else:
        mode_size = raw_shape[0][core_idx]
        shape = [rank_1, mode_size, rank_2]
      if output_is_batch:
        shape = [output_batch_size] + shape
      zeros = tf.zeros(shape, dtype)
      upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim)
      lower = tf.concat((proj_core, extended_left_tang_core),
                        axis=right_rank_dim)
      res_core = tf.concat((upper, lower), axis=left_rank_dim)
    res_cores_list.append(res_core)
  # TODO: TT-ranks.
  if output_is_batch:
    res = TensorTrainBatch(res_cores_list, where.get_raw_shape(),
                            batch_size=output_batch_size)
  else:
    res = TensorTrain(res_cores_list, where.get_raw_shape())

  res.projection_on = where
  return res