예제 #1
0
파일: ops.py 프로젝트: vseledkin/t3f
def _full_tt_batch(tt):
    """Converts a TensorTrainBatch into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrainBatch` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    res = tt.tt_cores[0]
    batch_size = shapes.lazy_batch_size(tt)
    for i in range(1, num_dims):
        res = tf.reshape(res, (batch_size, -1, ranks[i]))
        curr_core = tf.reshape(tt.tt_cores[i], (batch_size, ranks[i], -1))
        res = tf.einsum('oqb,obw->oqw', res, curr_core)
    if tt.is_tt_matrix():
        intermediate_shape = [batch_size]
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = [0]
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i + 1)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i + 1)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
예제 #2
0
파일: ops.py 프로젝트: vseledkin/t3f
def _add_batch_matrix_cores(tt_a, tt_b):
    """Internal function to be called from add for two batches of TT-matrices.

  Does the actual assembling of the TT-cores to add two batches of TT-matrices.
  """
    ndims = tt_a.ndims()
    dtype = tt_a.dtype
    shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    b_ranks = shapes.lazy_tt_ranks(tt_b)
    if isinstance(tt_a, TensorTrainBatch) and tt_a.batch_size == 1:
        # We add 1 element batch tt_a to a batch_size element batch tt_b to get
        # the answer TensorTrainBatch of batch_size == tt_b.batch_size.
        batch_size = shapes.lazy_batch_size(tt_b)
    else:
        batch_size = shapes.lazy_batch_size(tt_a)
    tt_a = shapes.expand_batch_dim(tt_a)
    tt_b = shapes.expand_batch_dim(tt_b)
    tt_cores = []
    for core_idx in range(ndims):
        a_core = tt_a.tt_cores[core_idx]
        if tt_a.batch_size == 1:
            a_core = tf.tile(a_core, (batch_size, 1, 1, 1, 1))
        b_core = tt_b.tt_cores[core_idx]
        if tt_b.batch_size == 1:
            b_core = tf.tile(b_core, (batch_size, 1, 1, 1, 1))
        if core_idx == 0:
            curr_core = tf.concat((a_core, b_core), axis=4)
        elif core_idx == ndims - 1:
            curr_core = tf.concat((a_core, b_core), axis=1)
        else:
            upper_zeros = tf.zeros(
                (batch_size, a_ranks[core_idx], shape[0][core_idx],
                 shape[1][core_idx], b_ranks[core_idx + 1]), dtype)
            lower_zeros = tf.zeros(
                (batch_size, b_ranks[core_idx], shape[0][core_idx],
                 shape[1][core_idx], a_ranks[core_idx + 1]), dtype)
            upper = tf.concat((a_core, upper_zeros), axis=4)
            lower = tf.concat((lower_zeros, b_core), axis=4)
            curr_core = tf.concat((upper, lower), axis=1)
        tt_cores.append(curr_core)
    return tt_cores, batch_size
예제 #3
0
파일: ops.py 프로젝트: towadroid/t3f
def frobenius_norm_squared(tt, differentiable=False,
                           name='t3f_frobenius_norm_squared'):
  """Frobenius norm squared of `TensorTrain` or of each TT in `TensorTrainBatch`.

  Frobenius norm squared is the sum of squares of all elements in a tensor.

  Args:
    tt: `TensorTrain` or `TensorTrainBatch` object
    differentiable: bool, whether to use a differentiable implementation
      or a fast and stable implementation based on QR decomposition.
    name: string, name of the Op.

  Returns
    a number which is the Frobenius norm squared of `tt`, if it is `TensorTrain`
    OR
    a Tensor of size tt.batch_size, consisting of the Frobenius norms squared of
    each TensorTrain in `tt`, if it is `TensorTrainBatch`
  """
  with tf.name_scope(name, values=tt.tt_cores):
    if differentiable:
      if hasattr(tt, 'batch_size'):
          bs_str = 'n'
      else:
          bs_str = ''
      if tt.is_tt_matrix():
        running_prod = tf.einsum('{0}aijb,{0}cijd->{0}bd'.format(bs_str),
                                 tt.tt_cores[0], tt.tt_cores[0])
      else:
        running_prod = tf.einsum('{0}aib,{0}cid->{0}bd'.format(bs_str),
                                 tt.tt_cores[0], tt.tt_cores[0])

      for core_idx in range(1, tt.ndims()):
        curr_core = tt.tt_cores[core_idx]
        if tt.is_tt_matrix():
          running_prod = tf.einsum('{0}ac,{0}aijb,{0}cijd->{0}bd'.format(bs_str),
                                   running_prod, curr_core, curr_core)
        else:
          running_prod = tf.einsum('{0}ac,{0}aib,{0}cid->{0}bd'.format(bs_str),
                                   running_prod, curr_core, curr_core)

      return tf.squeeze(running_prod, [-1, -2])

    else:
      orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True)
      # All the cores of orth_tt except the last one are orthogonal, hence
      # the Frobenius norm of orth_tt equals to the norm of the last core.
      if hasattr(tt, 'batch_size'):
        batch_size = shapes.lazy_batch_size(tt)
        last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1))
        return tf.norm(last_core, axis=1) ** 2
      else:
        return tf.norm(orth_tt.tt_cores[-1]) ** 2
예제 #4
0
파일: ops.py 프로젝트: vseledkin/t3f
def frobenius_norm_squared(tt, differentiable=False):
    """Frobenius norm squared of a TensorTrain (sum of squares of all elements).

  Args:
    tt: `TensorTrain` object
    differentiable: bool, whether to use a differentiable implementation
      or a fast and stable implementation based on QR decomposition.

  Returns
    a number
    sum of squares of all elements in `tt`
  """
    if differentiable:
        if tt.is_tt_matrix():
            running_prod = tf.einsum('aijb,cijd->bd', tt.tt_cores[0],
                                     tt.tt_cores[0])
        else:
            running_prod = tf.einsum('aib,cid->bd', tt.tt_cores[0],
                                     tt.tt_cores[0])

        for core_idx in range(1, tt.ndims()):
            curr_core = tt.tt_cores[core_idx]
            if tt.is_tt_matrix():
                running_prod = tf.einsum('ac,aijb,cijd->bd', running_prod,
                                         curr_core, curr_core)
            else:
                running_prod = tf.einsum('ac,aib,cid->bd', running_prod,
                                         curr_core, curr_core)
        return running_prod[0, 0]
    else:
        orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True)
        # All the cores of orth_tt except the last one are orthogonal, hence
        # the Frobenius norm of orth_tt equals to the norm of the last core.
        if hasattr(tt, 'batch_size'):
            batch_size = shapes.lazy_batch_size(tt)
            last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1))
            return tf.norm(last_core, axis=1)**2
        else:
            return tf.norm(orth_tt.tt_cores[-1])**2
예제 #5
0
def _orthogonalize_batch_tt_cores_left_to_right(tt):
    """Orthogonalize TT-cores of a batch TT-object in the left to right order.

  Args:
    tt: TensorTrainBatch.

  Returns:
    TensorTrainBatch
  """
    # Left to right orthogonalization.
    ndims = tt.ndims()
    raw_shape = shapes.lazy_raw_shape(tt)
    tt_ranks = shapes.lazy_tt_ranks(tt)
    next_rank = tt_ranks[0]
    batch_size = shapes.lazy_batch_size(tt)

    # Copy cores references so we can change the cores.
    tt_cores = list(tt.tt_cores)
    for core_idx in range(ndims - 1):
        curr_core = tt_cores[core_idx]
        # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
        # be outdated for the current TT-rank, but should be valid for the next
        # TT-rank.
        curr_rank = next_rank
        next_rank = tt_ranks[core_idx + 1]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        qr_shape = (batch_size, curr_rank * curr_mode, next_rank)
        curr_core = tf.reshape(curr_core, qr_shape)
        curr_core, triang = tf.qr(curr_core)
        if triang.get_shape().is_fully_defined():
            triang_shape = triang.get_shape().as_list()
        else:
            triang_shape = tf.shape(triang)
        # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
        # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
        # should be changed to 4.
        next_rank = triang_shape[1]
        if tt.is_tt_matrix():
            new_core_shape = (batch_size, curr_rank, curr_mode_left,
                              curr_mode_right, next_rank)
        else:
            new_core_shape = (batch_size, curr_rank, curr_mode, next_rank)

        tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape)

        next_core = tf.reshape(tt_cores[core_idx + 1],
                               (batch_size, triang_shape[2], -1))
        tt_cores[core_idx + 1] = tf.matmul(triang, next_core)

    if tt.is_tt_matrix():
        last_core_shape = (batch_size, next_rank, raw_shape[0][-1],
                           raw_shape[1][-1], 1)
    else:
        last_core_shape = (batch_size, next_rank, raw_shape[0][-1], 1)
    tt_cores[-1] = tf.reshape(tt_cores[-1], last_core_shape)
    # TODO: infer the tt_ranks.
    return TensorTrainBatch(tt_cores,
                            tt.get_raw_shape(),
                            batch_size=batch_size)
예제 #6
0
def _round_batch_tt(tt, max_tt_rank, epsilon):
    """Internal function that rounds a TensorTrainBatch.

  See t3f.round for details.
  """
    ndims = tt.ndims()
    max_tt_rank = np.array(max_tt_rank).astype(np.int32)
    if max_tt_rank < 1:
        raise ValueError('Maximum TT-rank should be greater or equal to 1.')
    if epsilon is not None and epsilon < 0:
        raise ValueError('Epsilon should be non-negative.')
    if max_tt_rank.size == 1:
        max_tt_rank = (max_tt_rank * np.ones(ndims + 1)).astype(np.int32)
    elif max_tt_rank.size != ndims + 1:
        raise ValueError(
            'max_tt_rank should be a number or a vector of size (d+1) '
            'where d is the number of dimensions (rank) of the tensor.')
    raw_shape = shapes.lazy_raw_shape(tt)
    batch_size = shapes.lazy_batch_size(tt)

    tt_cores = orthogonalize_tt_cores(tt).tt_cores
    # Copy cores references so we can change the cores.
    tt_cores = list(tt_cores)

    ranks = [1] * (ndims + 1)
    are_tt_ranks_defined = True
    # Right to left SVD compression.
    for core_idx in range(ndims - 1, 0, -1):
        curr_core = tt_cores[core_idx]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        columns = curr_mode * ranks[core_idx + 1]
        curr_core = tf.reshape(curr_core, (batch_size, -1, columns))
        rows = curr_core.get_shape()[1].value
        if rows is None:
            rows = tf.shape(curr_core)[1]
        if max_tt_rank[core_idx] == 1:
            ranks[core_idx] = 1
        else:
            try:
                ranks[core_idx] = min(max_tt_rank[core_idx], rows, columns)
            except TypeError:
                # Some of the values are undefined on the compilation stage and thus
                # they are tf.tensors instead of values.
                min_dim = tf.minimum(rows, columns)
                ranks[core_idx] = tf.minimum(max_tt_rank[core_idx], min_dim)
                are_tt_ranks_defined = False
        s, u, v = tf.svd(curr_core, full_matrices=False)
        u = u[:, :, 0:ranks[core_idx]]
        s = s[:, 0:ranks[core_idx]]
        v = v[:, :, 0:ranks[core_idx]]
        if tt.is_tt_matrix():
            core_shape = (batch_size, ranks[core_idx], curr_mode_left,
                          curr_mode_right, ranks[core_idx + 1])
        else:
            core_shape = (batch_size, ranks[core_idx], curr_mode,
                          ranks[core_idx + 1])
        tt_cores[core_idx] = tf.reshape(tf.transpose(v, (0, 2, 1)), core_shape)
        prev_core_shape = (batch_size, -1, rows)
        tt_cores[core_idx - 1] = tf.reshape(tt_cores[core_idx - 1],
                                            prev_core_shape)
        tt_cores[core_idx - 1] = tf.matmul(tt_cores[core_idx - 1], u)
        tt_cores[core_idx - 1] = tf.matmul(tt_cores[core_idx - 1],
                                           tf.matrix_diag(s))

    if tt.is_tt_matrix():
        core_shape = (batch_size, ranks[0], raw_shape[0][0], raw_shape[1][0],
                      ranks[1])
    else:
        core_shape = (batch_size, ranks[0], raw_shape[0][0], ranks[1])
    tt_cores[0] = tf.reshape(tt_cores[0], core_shape)
    if not are_tt_ranks_defined:
        ranks = None
    return TensorTrainBatch(tt_cores,
                            tt.get_raw_shape(),
                            ranks,
                            batch_size=tt.batch_size)
예제 #7
0
파일: ops.py 프로젝트: vseledkin/t3f
def tt_tt_matmul(tt_matrix_a, tt_matrix_b):
    """Multiplies two TT-matrices and returns the TT-matrix of the result.

  Args:
    tt_matrix_a: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size M x N
    tt_matrix_b: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size N x P

  Returns
    `TensorTrain` object containing a TT-matrix of size M x P if both arguments
      are `TensorTrain`s
    `TensorTrainBatch` if any of the arguments is a `TensorTrainBatch`

  Raises:
    ValueError is the arguments are not TT matrices or if their sizes are not
    appropriate for a matrix-by-matrix multiplication.
  """
    # Both TensorTrain and TensorTrainBatch are inherited from TensorTrainBase.
    if not isinstance(tt_matrix_a, TensorTrainBase) or \
        not isinstance(tt_matrix_b, TensorTrainBase) or \
        not tt_matrix_a.is_tt_matrix() or \
        not tt_matrix_b.is_tt_matrix():
        raise ValueError('Arguments should be TT-matrices')

    if not shapes.is_batch_broadcasting_possible(tt_matrix_a, tt_matrix_b):
        raise ValueError(
            'The batch sizes are different and not 1, broadcasting is '
            'not available.')

    ndims = tt_matrix_a.ndims()
    if tt_matrix_b.ndims() != ndims:
        raise ValueError(
            'Arguments should have the same number of dimensions, '
            'got %d and %d instead.' % (ndims, tt_matrix_b.ndims()))

    # Convert BatchSize 1 batch into TT object to simplify broadcasting.
    tt_matrix_a = shapes.squeeze_batch_dim(tt_matrix_a)
    tt_matrix_b = shapes.squeeze_batch_dim(tt_matrix_b)
    is_a_batch = isinstance(tt_matrix_a, TensorTrainBatch)
    is_b_batch = isinstance(tt_matrix_b, TensorTrainBatch)
    is_res_batch = is_a_batch or is_b_batch
    a_batch_str = 'o' if is_a_batch else ''
    b_batch_str = 'o' if is_b_batch else ''
    res_batch_str = 'o' if is_res_batch else ''
    einsum_str = '{}aijb,{}cjkd->{}acikbd'.format(a_batch_str, b_batch_str,
                                                  res_batch_str)
    result_cores = []
    # TODO: name the operation and the resulting tensor.
    a_shape = shapes.lazy_raw_shape(tt_matrix_a)
    a_ranks = shapes.lazy_tt_ranks(tt_matrix_a)
    b_shape = shapes.lazy_raw_shape(tt_matrix_b)
    b_ranks = shapes.lazy_tt_ranks(tt_matrix_b)
    if is_res_batch:
        if is_a_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_a)
        if is_b_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_b)
    for core_idx in range(ndims):
        a_core = tt_matrix_a.tt_cores[core_idx]
        b_core = tt_matrix_b.tt_cores[core_idx]
        curr_res_core = tf.einsum(einsum_str, a_core, b_core)

        res_left_rank = a_ranks[core_idx] * b_ranks[core_idx]
        res_right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
        left_mode = a_shape[0][core_idx]
        right_mode = b_shape[1][core_idx]
        if is_res_batch:
            core_shape = (batch_size, res_left_rank, left_mode, right_mode,
                          res_right_rank)
        else:
            core_shape = (res_left_rank, left_mode, right_mode, res_right_rank)
        curr_res_core = tf.reshape(curr_res_core, core_shape)
        result_cores.append(curr_res_core)

    res_shape = (tt_matrix_a.get_raw_shape()[0],
                 tt_matrix_b.get_raw_shape()[1])
    static_a_ranks = tt_matrix_a.get_tt_ranks()
    static_b_ranks = tt_matrix_b.get_tt_ranks()
    out_ranks = [a_r * b_r for a_r, b_r in zip(static_a_ranks, static_b_ranks)]
    if is_res_batch:
        return TensorTrainBatch(result_cores, res_shape, out_ranks, batch_size)
    else:
        return TensorTrain(result_cores, res_shape, out_ranks)
예제 #8
0
def project_sum(what, where, weights=None):
    """Project sum of `what` TTs on the tangent space of `where` TT.

  project_sum(what, x) = P_x(what)
  project_sum(batch_what, x) = P_x(\sum_i batch_what[i])
  project_sum(batch_what, x, weights) = P_x(\sum_j weights[j] * batch_what[j])

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      projection of the sum of elements in the batch.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    weights: python list or tf.Tensor of numbers or None, weights of the sum

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()

  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d r_what r_where n (r_what + r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    n is the size of the axis dimension of what and where e.g.
      for a tensor of size 4 x 4 x 4, n is 4;
      for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12
  """
    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    if weights is not None:
        weights = tf.convert_to_tensor(weights, dtype=where.dtype)

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    mode_str = 'ij' if where.is_tt_matrix() else 'i'
    right_rank_dim = where.right_tt_rank_dim
    left_rank_dim = where.left_tt_rank_dim
    if weights is not None:
        weights_shape = weights.get_shape()
        output_is_batch = len(weights_shape) > 1 and weights_shape[1] > 1
    else:
        output_is_batch = False
    output_batch_str = 'o' if output_is_batch else ''
    if output_is_batch:
        right_rank_dim += 1
        left_rank_dim += 1
        output_batch_size = weights.get_shape().as_list()[1]

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
        rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                                  right_tang_core)

    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
        lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx],
                                      left_tang_core, tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
            proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
            proj_core -= tf.einsum(einsum_str, left_tang_core,
                                   lhs[core_idx + 1])
            if weights is None:
                einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])
            else:
                einsum_str = 'sa{0}b,sbc->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, proj_core,
                                        rhs[core_idx + 1])
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if core_idx == ndims - 1:
            if weights is None:
                einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            else:
                einsum_str = 'sab,sb{0}c->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, lhs[core_idx], tens_core)
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            if where.is_tt_matrix():
                extended_left_tang_core = tf.tile(
                    extended_left_tang_core, [output_batch_size, 1, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1, 1])
            else:
                extended_left_tang_core = tf.tile(extended_left_tang_core,
                                                  [output_batch_size, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            if where.is_tt_matrix():
                mode_size_n = raw_shape[0][core_idx]
                mode_size_m = raw_shape[1][core_idx]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size = raw_shape[0][core_idx]
                shape = [rank_1, mode_size, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)
    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
예제 #9
0
def project_matmul(what, where, matrix):
    """Project `matrix` * `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    matrix: TensorTrain, TT-matrix to multiply by what

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
      
  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d R r_what r_where (n r_what + n m R + m r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n).
  """

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    right_rank_dim = what.right_tt_rank_dim
    left_rank_dim = what.left_tt_rank_dim
    output_is_batch = isinstance(what, TensorTrainBatch)
    if output_is_batch:
        output_batch_size = what.batch_size

    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core,
                                  right_tang_core, rhs[core_idx + 1],
                                  tens_core)
    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        # TODO: brutforce order of indices in lhs??
        lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef',
                                      matrix_core, left_tang_core,
                                      lhs[core_idx], tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            proj_core = tf.einsum('scjke,sabc,bijd->saikde', tens_core,
                                  lhs[core_idx], matrix_core)
            proj_core -= tf.einsum('aikb,sbcd->saikcd', left_tang_core,
                                   lhs[core_idx + 1])
            proj_core = tf.einsum('saikcb,sbcd->saikd', proj_core,
                                  rhs[core_idx + 1])

        if core_idx == ndims - 1:
            # d and e dimensions take 1 value, since its the last rank.
            # To make the result shape (?, ?, ?, 1), we are summing d and leaving e,
            # but we could have done the opposite -- sum e and leave d.
            proj_core = tf.einsum('sabc,bijd,scjke->saike', lhs[core_idx],
                                  matrix_core, tens_core)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            extended_left_tang_core = tf.tile(extended_left_tang_core,
                                              [output_batch_size, 1, 1, 1, 1])
            extended_right_tang_core = tf.tile(extended_right_tang_core,
                                               [output_batch_size, 1, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            mode_size_n = raw_shape[0][core_idx]
            mode_size_m = raw_shape[1][core_idx]
            shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)

    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
예제 #10
0
파일: riemannian.py 프로젝트: vseledkin/t3f
def project(what, where):
  """Project `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
  """

  if not isinstance(where, TensorTrain):
    raise ValueError('The first argument should be a TensorTrain object, got '
                     '"%s".' % where)

  if where.get_raw_shape() != what.get_raw_shape():
    raise ValueError('The shapes of the tensor we want to project and of the '
                     'tensor on which tangent space we want to project should '
                     'match, got %s and %s.' %
                     (where.get_raw_shape(),
                      what.get_raw_shape()))

  if not where.dtype.is_compatible_with(what.dtype):
    raise ValueError('Dtypes of the arguments should coincide, got %s and %s.' %
                     (where.dtype,
                      what.dtype))

  left_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    where)
  right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    left_tangent_space_tens, left_to_right=False)

  ndims = where.ndims()
  dtype = where.dtype
  raw_shape = shapes.lazy_raw_shape(where)
  batch_size = shapes.lazy_batch_size(what)
  right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
  left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

  # For einsum notation.
  mode_str = 'ij' if where.is_tt_matrix() else 'i'
  right_rank_dim = 3 if where.is_tt_matrix() else 2
  left_rank_dim = 0
  output_is_batch = isinstance(what, TensorTrainBatch)
  if output_is_batch:
    right_rank_dim += 1
    left_rank_dim = 1
    output_batch_size = what.batch_size

  # Always work with batch of TT objects for simplicity.
  what = shapes.expand_batch_dim(what)

  # Prepare rhs vectors.
  # rhs[core_idx] is of size
  #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
  rhs = [None] * (ndims + 1)
  rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1, 0, -1):
    tens_core = what.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
    rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                              right_tang_core)

  # Prepare lhs vectors.
  # lhs[core_idx] is of size
  #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
  lhs = [None] * (ndims + 1)
  lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
    lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core,
                                  tens_core)

  # Left to right sweep.
  res_cores_list = []
  for core_idx in range(ndims):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

    if core_idx < ndims - 1:
      einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
      einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
      proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1])
      if output_is_batch:
        einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])

    if core_idx == ndims - 1:
      if output_is_batch:
        einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)

    if output_is_batch:
      # Add batch dimension of size output_batch_size to left_tang_core and
      # right_tang_core
      extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
      extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
      if where.is_tt_matrix():
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1, 1])
      else:
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1])
    else:
      extended_left_tang_core = left_tang_core
      extended_right_tang_core = right_tang_core

    if core_idx == 0:
      res_core = tf.concat((proj_core, extended_left_tang_core),
                           axis=right_rank_dim)
    elif core_idx == ndims - 1:
      res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim)
    else:
      rank_1 = right_tangent_tt_ranks[core_idx]
      rank_2 = left_tangent_tt_ranks[core_idx + 1]
      if where.is_tt_matrix():
        mode_size_n = raw_shape[0][core_idx]
        mode_size_m = raw_shape[1][core_idx]
        shape = [rank_1, mode_size_n, mode_size_m, rank_2]
      else:
        mode_size = raw_shape[0][core_idx]
        shape = [rank_1, mode_size, rank_2]
      if output_is_batch:
        shape = [output_batch_size] + shape
      zeros = tf.zeros(shape, dtype)
      upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim)
      lower = tf.concat((proj_core, extended_left_tang_core),
                        axis=right_rank_dim)
      res_core = tf.concat((upper, lower), axis=left_rank_dim)
    res_cores_list.append(res_core)
  # TODO: TT-ranks.
  if output_is_batch:
    res = TensorTrainBatch(res_cores_list, where.get_raw_shape(),
                            batch_size=output_batch_size)
  else:
    res = TensorTrain(res_cores_list, where.get_raw_shape())

  res.projection_on = where
  return res
예제 #11
0
파일: ops.py 프로젝트: zhanglang1860/t3f
def multiply(tt_left, right):
    """Returns a TensorTrain corresponding to element-wise product tt_left * right.

  Supports broadcasting:
    multiply(TensorTrainBatch, TensorTrain) returns TensorTrainBatch consisting
    of element-wise products of TT in TensorTrainBatch and TensorTrain

    multiply(TensorTrainBatch_a, TensorTrainBatch_b) returns TensorTrainBatch
    consisting of element-wise products of TT in TensorTrainBatch_a and
    TT in TensorTrainBatch_b

    Batch sizes should support broadcasting
  Args:
    tt_left: `TensorTrain` OR `TensorTrainBatch`
    right: `TensorTrain` OR `TensorTrainBatch` OR a number.

  Returns
    a `TensorTrain` or `TensorTrainBatch` object corresponding to the
    element-wise product of the arguments.

  Raises
    ValueError if the arguments shapes do not coincide or broadcasting is not
    possible.
  """

    is_left_batch = isinstance(tt_left, TensorTrainBatch)
    is_right_batch = isinstance(right, TensorTrainBatch)

    is_batch_case = is_left_batch or is_right_batch
    ndims = tt_left.ndims()
    if not isinstance(right, TensorTrainBase):
        # Assume right is a number, not TensorTrain.
        # To squash right uniformly across TT-cores we pull its absolute value
        # and raise to the power 1/ndims. First TT-core is multiplied by the sign
        # of right.
        tt_cores = list(tt_left.tt_cores)
        fact = tf.pow(tf.cast(tf.abs(right), tt_left.dtype), 1.0 / ndims)
        sign = tf.cast(tf.sign(right), tt_left.dtype)
        for i in range(len(tt_cores)):
            tt_cores[i] = fact * tt_cores[i]

        tt_cores[0] = tt_cores[0] * sign
        out_ranks = tt_left.get_tt_ranks()
        if is_left_batch:
            out_batch_size = tt_left.batch_size
    else:

        if tt_left.is_tt_matrix() != right.is_tt_matrix():
            raise ValueError('The arguments should be both TT-tensors or both '
                             'TT-matrices')

        if tt_left.get_raw_shape() != right.get_raw_shape():
            raise ValueError('The arguments should have the same shape.')

        out_batch_size = 1
        dependencies = []
        can_determine_if_broadcast = True
        if is_left_batch and is_right_batch:
            if tt_left.batch_size is None and right.batch_size is None:
                can_determine_if_broadcast = False
            elif tt_left.batch_size is None and right.batch_size is not None:
                if right.batch_size > 1:
                    can_determine_if_broadcast = False
            elif tt_left.batch_size is not None and right.batch_size is None:
                if tt_left.batch_size > 1:
                    can_determine_if_broadcast = False

        if not can_determine_if_broadcast:
            # Cannot determine if broadcasting is needed. Avoid broadcasting and
            # assume elementwise multiplication AND add execution time assert to print
            # a better error message if the batch sizes turn out to be different.

            message = (
                'The batch sizes were unknown on compilation stage, so '
                'assumed elementwise multiplication (i.e. no broadcasting). '
                'Now it seems that they are different after all :')

            data = [
                message,
                shapes.lazy_batch_size(tt_left), ' x ',
                shapes.lazy_batch_size(right)
            ]
            bs_eq = tf.assert_equal(shapes.lazy_batch_size(tt_left),
                                    shapes.lazy_batch_size(right),
                                    data=data)

            dependencies.append(bs_eq)

        do_broadcast = shapes.is_batch_broadcasting_possible(tt_left, right)
        if not can_determine_if_broadcast:
            # Assume elementwise multiplication if broadcasting cannot be determined
            # on compilation stage.
            do_broadcast = False
        if not do_broadcast and can_determine_if_broadcast:
            raise ValueError(
                'The batch sizes are different and not 1, broadcasting '
                'is not available.')

        a_ranks = shapes.lazy_tt_ranks(tt_left)
        b_ranks = shapes.lazy_tt_ranks(right)
        shape = shapes.lazy_raw_shape(tt_left)

        output_str = ''
        bs_str_left = ''
        bs_str_right = ''

        if is_batch_case:
            if is_left_batch and is_right_batch:
                # Both arguments are batches of equal size.
                if tt_left.batch_size == right.batch_size or not can_determine_if_broadcast:
                    bs_str_left = 'n'
                    bs_str_right = 'n'
                    output_str = 'n'
                    if not can_determine_if_broadcast:
                        out_batch_size = None
                    else:
                        out_batch_size = tt_left.batch_size
                else:
                    # Broadcasting (e.g batch_sizes are 1 and n>1).
                    bs_str_left = 'n'
                    bs_str_right = 'm'
                    output_str = 'nm'
                    if tt_left.batch_size is None or tt_left.batch_size > 1:
                        out_batch_size = tt_left.batch_size
                    else:
                        out_batch_size = right.batch_size
            else:
                # One of the arguments is TensorTrain.
                if is_left_batch:
                    bs_str_left = 'n'
                    bs_str_right = ''
                    out_batch_size = tt_left.batch_size
                else:
                    bs_str_left = ''
                    bs_str_right = 'n'
                    out_batch_size = right.batch_size
                output_str = 'n'

        is_matrix = tt_left.is_tt_matrix()
        tt_cores = []

        for core_idx in range(ndims):
            a_core = tt_left.tt_cores[core_idx]
            b_core = right.tt_cores[core_idx]
            left_rank = a_ranks[core_idx] * b_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
            if is_matrix:
                with tf.control_dependencies(dependencies):
                    curr_core = tf.einsum(
                        '{0}aijb,{1}cijd->{2}acijbd'.format(
                            bs_str_left, bs_str_right, output_str), a_core,
                        b_core)
                    curr_core = tf.reshape(curr_core,
                                           (-1, left_rank, shape[0][core_idx],
                                            shape[1][core_idx], right_rank))
                    if not is_batch_case:
                        curr_core = tf.squeeze(curr_core, axis=0)
            else:
                with tf.control_dependencies(dependencies):
                    curr_core = tf.einsum(
                        '{0}aib,{1}cid->{2}acibd'.format(
                            bs_str_left, bs_str_right, output_str), a_core,
                        b_core)
                    curr_core = tf.reshape(
                        curr_core,
                        (-1, left_rank, shape[0][core_idx], right_rank))
                    if not is_batch_case:
                        curr_core = tf.squeeze(curr_core, axis=0)

            tt_cores.append(curr_core)

        combined_ranks = zip(tt_left.get_tt_ranks(), right.get_tt_ranks())
        out_ranks = [a * b for a, b in combined_ranks]

    if not is_batch_case:
        return TensorTrain(tt_cores, tt_left.get_raw_shape(), out_ranks)
    else:
        return TensorTrainBatch(tt_cores,
                                tt_left.get_raw_shape(),
                                out_ranks,
                                batch_size=out_batch_size)