예제 #1
0
파일: ops.py 프로젝트: vseledkin/t3f
def _add_matrix_cores(tt_a, tt_b):
    """Internal function to be called from add for two TT-matrices.

  Does the actual assembling of the TT-cores to add two TT-matrices.
  """
    ndims = tt_a.ndims()
    dtype = tt_a.dtype
    shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    b_ranks = shapes.lazy_tt_ranks(tt_b)
    tt_cores = []
    for core_idx in range(ndims):
        a_core = tt_a.tt_cores[core_idx]
        b_core = tt_b.tt_cores[core_idx]
        if core_idx == 0:
            curr_core = tf.concat((a_core, b_core), axis=3)
        elif core_idx == ndims - 1:
            curr_core = tf.concat((a_core, b_core), axis=0)
        else:
            upper_zeros = tf.zeros((a_ranks[core_idx], shape[0][core_idx],
                                    shape[1][core_idx], b_ranks[core_idx + 1]),
                                   dtype)
            lower_zeros = tf.zeros((b_ranks[core_idx], shape[0][core_idx],
                                    shape[1][core_idx], a_ranks[core_idx + 1]),
                                   dtype)
            upper = tf.concat((a_core, upper_zeros), axis=3)
            lower = tf.concat((lower_zeros, b_core), axis=3)
            curr_core = tf.concat((upper, lower), axis=0)
        tt_cores.append(curr_core)
    return tt_cores
예제 #2
0
파일: ops.py 프로젝트: vseledkin/t3f
def multiply(tt_left, right):
    """Returns a TensorTrain corresponding to element-wise product tt_left * right.

  The shapes of tt_left and right should coincide.

  Args:
    tt_left: `TensorTrain`, TT-tensor or TT-matrix
    right: `TensorTrain`, TT-tensor or TT-matrix, OR a number.

  Returns
    a `TensorTrain` object corresponding to the element-wise product of the
    arguments.

  Raises
    ValueError if the arguments shapes do not coincide.
  """
    if not isinstance(right, TensorTrainBase):
        # Assume right is a number, not TensorTrain.
        tt_cores = list(tt_left.tt_cores)
        tt_cores[0] = right * tt_cores[0]
        out_ranks = tt_left.get_tt_ranks()
    else:
        ndims = tt_left.ndims()
        if tt_left.is_tt_matrix() != right.is_tt_matrix():
            raise ValueError('The arguments should be both TT-tensors or both '
                             'TT-matrices')

        if tt_left.get_shape() != right.get_shape():
            raise ValueError('The arguments should have the same shape.')

        a_ranks = shapes.lazy_tt_ranks(tt_left)
        b_ranks = shapes.lazy_tt_ranks(right)
        shape = shapes.lazy_raw_shape(tt_left)

        is_matrix = tt_left.is_tt_matrix()
        tt_cores = []
        for core_idx in range(ndims):
            a_core = tt_left.tt_cores[core_idx]
            b_core = right.tt_cores[core_idx]
            left_rank = a_ranks[core_idx] * b_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
            if is_matrix:
                curr_core = tf.einsum('aijb,cijd->acijbd', a_core, b_core)
                curr_core = tf.reshape(curr_core,
                                       (left_rank, shape[0][core_idx],
                                        shape[1][core_idx], right_rank))
            else:
                curr_core = tf.einsum('aib,cid->acibd', a_core, b_core)
                curr_core = tf.reshape(
                    curr_core, (left_rank, shape[0][core_idx], right_rank))
            tt_cores.append(curr_core)

        combined_ranks = zip(tt_left.get_tt_ranks(), right.get_tt_ranks())
        out_ranks = [a * b for a, b in combined_ranks]

    if isinstance(tt_left, TensorTrain):
        return TensorTrain(tt_cores, tt_left.get_raw_shape(), out_ranks)
    else:
        return TensorTrainBatch(tt_cores, tt_left.get_raw_shape(), out_ranks,
                                tt_left.batch_size)
예제 #3
0
def _orthogonalize_tt_cores_right_to_left(tt):
    """Orthogonalize TT-cores of a TT-object in the right to left order.

  Args:
    tt: TenosorTrain or a TensorTrainBatch.

  Returns:
    The same type as the input `tt` (TenosorTrain or a TensorTrainBatch).
  """
    # Left to right orthogonalization.
    ndims = tt.ndims()
    raw_shape = shapes.lazy_raw_shape(tt)
    tt_ranks = shapes.lazy_tt_ranks(tt)
    prev_rank = tt_ranks[ndims]
    # Copy cores references so we can change the cores.
    tt_cores = list(tt.tt_cores)
    for core_idx in range(ndims - 1, 0, -1):
        curr_core = tt_cores[core_idx]
        # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
        # be outdated for the current TT-rank, but should be valid for the next
        # TT-rank.
        curr_rank = prev_rank
        prev_rank = tt_ranks[core_idx]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        qr_shape = (prev_rank, curr_mode * curr_rank)
        curr_core = tf.reshape(curr_core, qr_shape)
        curr_core, triang = tf.qr(tf.transpose(curr_core))
        curr_core = tf.transpose(curr_core)
        triang = tf.transpose(triang)
        if triang.get_shape().is_fully_defined():
            triang_shape = triang.get_shape().as_list()
        else:
            triang_shape = tf.shape(triang)
        # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
        # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
        # should be changed to 4.
        prev_rank = triang_shape[1]
        if tt.is_tt_matrix():
            new_core_shape = (prev_rank, curr_mode_left, curr_mode_right,
                              curr_rank)
        else:
            new_core_shape = (prev_rank, curr_mode, curr_rank)
        tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape)

        prev_core = tf.reshape(tt_cores[core_idx - 1], (-1, triang_shape[0]))
        tt_cores[core_idx - 1] = tf.matmul(prev_core, triang)

    if tt.is_tt_matrix():
        first_core_shape = (1, raw_shape[0][0], raw_shape[1][0], prev_rank)
    else:
        first_core_shape = (1, raw_shape[0][0], prev_rank)
    tt_cores[0] = tf.reshape(tt_cores[0], first_core_shape)
    # TODO: infer the tt_ranks.
    return TensorTrain(tt_cores, tt.get_raw_shape())
예제 #4
0
파일: ops.py 프로젝트: vseledkin/t3f
def _full_tt(tt):
    """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrain` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    res = tt.tt_cores[0]
    for i in range(1, num_dims):
        res = tf.reshape(res, (-1, ranks[i]))
        curr_core = tf.reshape(tt.tt_cores[i], (ranks[i], -1))
        res = tf.matmul(res, curr_core)
    if tt.is_tt_matrix():
        intermediate_shape = []
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = []
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
예제 #5
0
파일: ops.py 프로젝트: vseledkin/t3f
def _full_tt_batch(tt):
    """Converts a TensorTrainBatch into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrainBatch` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    res = tt.tt_cores[0]
    batch_size = shapes.lazy_batch_size(tt)
    for i in range(1, num_dims):
        res = tf.reshape(res, (batch_size, -1, ranks[i]))
        curr_core = tf.reshape(tt.tt_cores[i], (batch_size, ranks[i], -1))
        res = tf.einsum('oqb,obw->oqw', res, curr_core)
    if tt.is_tt_matrix():
        intermediate_shape = [batch_size]
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = [0]
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i + 1)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i + 1)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
예제 #6
0
파일: ops.py 프로젝트: vseledkin/t3f
def _add_batch_matrix_cores(tt_a, tt_b):
    """Internal function to be called from add for two batches of TT-matrices.

  Does the actual assembling of the TT-cores to add two batches of TT-matrices.
  """
    ndims = tt_a.ndims()
    dtype = tt_a.dtype
    shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    b_ranks = shapes.lazy_tt_ranks(tt_b)
    if isinstance(tt_a, TensorTrainBatch) and tt_a.batch_size == 1:
        # We add 1 element batch tt_a to a batch_size element batch tt_b to get
        # the answer TensorTrainBatch of batch_size == tt_b.batch_size.
        batch_size = shapes.lazy_batch_size(tt_b)
    else:
        batch_size = shapes.lazy_batch_size(tt_a)
    tt_a = shapes.expand_batch_dim(tt_a)
    tt_b = shapes.expand_batch_dim(tt_b)
    tt_cores = []
    for core_idx in range(ndims):
        a_core = tt_a.tt_cores[core_idx]
        if tt_a.batch_size == 1:
            a_core = tf.tile(a_core, (batch_size, 1, 1, 1, 1))
        b_core = tt_b.tt_cores[core_idx]
        if tt_b.batch_size == 1:
            b_core = tf.tile(b_core, (batch_size, 1, 1, 1, 1))
        if core_idx == 0:
            curr_core = tf.concat((a_core, b_core), axis=4)
        elif core_idx == ndims - 1:
            curr_core = tf.concat((a_core, b_core), axis=1)
        else:
            upper_zeros = tf.zeros(
                (batch_size, a_ranks[core_idx], shape[0][core_idx],
                 shape[1][core_idx], b_ranks[core_idx + 1]), dtype)
            lower_zeros = tf.zeros(
                (batch_size, b_ranks[core_idx], shape[0][core_idx],
                 shape[1][core_idx], a_ranks[core_idx + 1]), dtype)
            upper = tf.concat((a_core, upper_zeros), axis=4)
            lower = tf.concat((lower_zeros, b_core), axis=4)
            curr_core = tf.concat((upper, lower), axis=1)
        tt_cores.append(curr_core)
    return tt_cores, batch_size
예제 #7
0
파일: ops.py 프로젝트: zhanglang1860/t3f
def tt_sparse_flat_inner(tt_a, sparse_b):
    """Inner product between a TT-tensor (or TT-matrix) and tf.SparseTensor along all axis.

  The shapes of tt_a and sparse_b should coincide.

  Args:
    tt_a: `TensorTrain` object
    sparse_b: tf.SparseTensor

  Returns
    a number
    sum of products of all the elements of tt_a and sparse_b
  """
    if sparse_b.indices.get_shape().is_fully_defined():
        num_elements = sparse_b.indices.get_shape()[0]
    else:
        num_elements = tf.shape(sparse_b.indices)[0]
    a_shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    if tt_a.is_tt_matrix():
        tt_a_elements = tf.ones((num_elements, 1, 1))
        # TODO: use t3f.shape is safer??
        tensor_shape = tt_a.get_raw_shape()
        row_idx_linear = tf.cast(sparse_b.indices[:, 0], tf.int64)
        row_idx = utils.unravel_index(row_idx_linear,
                                      tf.cast(tensor_shape[0], tf.int64))
        col_idx_linear = tf.cast(sparse_b.indices[:, 1], tf.int64)
        col_idx = utils.unravel_index(col_idx_linear,
                                      tf.cast(tensor_shape[1], tf.int64))
        for core_idx in range(tt_a.ndims()):
            curr_core = tt_a.tt_cores[core_idx]
            left_rank = a_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1]
            curr_core = tf.transpose(curr_core, (1, 2, 0, 3))
            curr_core_shape = (a_shape[0][core_idx] * a_shape[1][core_idx],
                               left_rank, right_rank)
            curr_core = tf.reshape(curr_core, curr_core_shape)
            # Ravel multiindex (row_idx[:, core_idx], col_idx[:, core_idx]) into
            # a linear index to use tf.gather that supports only first dimensional
            # gather.
            # TODO: use gather_nd instead.
            curr_elements_idx = row_idx[:,
                                        core_idx] * tensor_shape[1][core_idx]
            curr_elements_idx += col_idx[:, core_idx]
            core_slices = tf.gather(curr_core, curr_elements_idx)
            tt_a_elements = tf.matmul(tt_a_elements, core_slices)
    else:
        tt_a_elements = gather_nd(tt_a, sparse_b.indices)
    tt_a_elements = tf.reshape(tt_a_elements, (1, -1))
    sparse_b_elements = tf.reshape(sparse_b.values, (-1, 1))
    result = tf.matmul(tt_a_elements, sparse_b_elements)
    # Convert a 1x1 matrix into a number.
    result = result[0, 0]
    return result
예제 #8
0
파일: autodiff.py 프로젝트: towadroid/t3f
def _enforce_gauge_conditions(deltas, left):
    """Project deltas that define tangent space vec onto the gauge conditions."""
    proj_deltas = []
    tt_ranks = shapes.lazy_tt_ranks(left)
    for i in range(left.ndims()):
        right_r = tt_ranks[i + 1]
        q = tf.reshape(left.tt_cores[i], (-1, right_r))
        if i < left.ndims() - 1:
            proj_delta = deltas[i]
            proj_delta = tf.reshape(proj_delta, (-1, right_r))
            proj_delta -= tf.matmul(q, tf.matmul(tf.transpose(q), proj_delta))
            proj_delta = tf.reshape(proj_delta, left.tt_cores[i].shape)
        else:
            proj_delta = deltas[i]
        proj_deltas.append(proj_delta)
    return proj_deltas
예제 #9
0
파일: ops.py 프로젝트: vseledkin/t3f
def tt_dense_matmul(tt_matrix_a, matrix_b):
    """Multiplies a TT-matrix by a regular matrix, returns a regular matrix.

  Args:
    tt_matrix_a: `TensorTrain` object containing a TT-matrix of size M x N
    matrix_b: tf.Tensor of size N x P

  Returns
    tf.Tensor of size M x P
  """
    if not isinstance(tt_matrix_a,
                      TensorTrain) or not tt_matrix_a.is_tt_matrix():
        raise ValueError('The first argument should be a TT-matrix')

    ndims = tt_matrix_a.ndims()
    a_columns = tt_matrix_a.get_shape()[1].value
    b_rows = matrix_b.get_shape()[0].value
    if a_columns is not None and b_rows is not None:
        if a_columns != b_rows:
            raise ValueError(
                'Arguments shapes should align got %d and %d instead.' %
                (tt_matrix_a.get_shape(), matrix_b.get_shape()))

    a_shape = shapes.lazy_shape(tt_matrix_a)
    a_raw_shape = shapes.lazy_raw_shape(tt_matrix_a)
    if matrix_b.get_shape().is_fully_defined():
        b_shape = matrix_b.get_shape().as_list()
    else:
        b_shape = tf.shape(matrix_b)
    a_ranks = shapes.lazy_tt_ranks(tt_matrix_a)
    # If A is (i0, ..., id-1) x (j0, ..., jd-1) and B is (j0, ..., jd-1) x K,
    # data is (K, j0, ..., jd-2) x jd-1 x 1
    data = tf.transpose(matrix_b)
    data = tf.reshape(data, (-1, a_raw_shape[1][-1], 1))
    for core_idx in reversed(range(ndims)):
        curr_core = tt_matrix_a.tt_cores[core_idx]
        # On the k = core_idx iteration, after applying einsum the shape of data
        # becomes ik x (ik-1..., id-1, K, j0, ..., jk-1) x rank_k
        data = tf.einsum('aijb,rjb->ira', curr_core, data)
        if core_idx > 0:
            # After reshape the shape of data becomes
            # (ik, ..., id-1, K, j0, ..., jk-2) x jk-1 x rank_k
            new_data_shape = (-1, a_raw_shape[1][core_idx - 1],
                              a_ranks[core_idx])
            data = tf.reshape(data, new_data_shape)
    # At the end the shape of the data is (i0, ..., id-1) x K
    return tf.reshape(data, (a_shape[0], b_shape[1]))
예제 #10
0
파일: utils.py 프로젝트: Dean-Go-kr/QTTNet
def _full_tt(tt):
    """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrain` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    quan_core_list = []
    for i in range(num_dims):
        print('FP core:', tt.tt_cores[i])
        if i == 0 or i == num_dims - 1:
            quan_core_list.append(fw(tt.tt_cores[i]))
        else:
            quan_core_list.append(fw(tt.tt_cores[i]))
        print('8bit core:', quan_core_list[i])
    res = quan_core_list[0]
    # Quan first core
    for i in range(1, num_dims):
        res = tf.reshape(res, (-1, ranks[i]))
        curr_core = tf.reshape(quan_core_list[i], (ranks[i], -1))
        res = tf.matmul(res, curr_core)
        print('core multi FP: ', res)
        # Quan mult cores
        res = fw(res)
        print('core multi 8bit: ', res)
    if tt.is_tt_matrix():
        intermediate_shape = []
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = []
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
예제 #11
0
def tangent_space_to_deltas(tt, name='t3f_tangent_space_to_deltas'):
    """Convert an element of the tangent space to deltas representation.

  Tangent space elements (outputs of t3f.project) look like:
    dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd.

  This function takes as input an element of the tangent space and converts
  it to the list of deltas [dP1, ..., dPd].

  Args:
      tt: `TensorTrain` or `TensorTrainBatch` that is a result of t3f.project,
        t3f.project_matmul, or other similar functions.
      name: string, name of the Op.

  Returns:
      A list of delta-cores (tf.Tensors).
  """
    if not hasattr(tt, 'projection_on') or tt.projection_on is None:
        raise ValueError('tt argument is supposed to be a projection, but it '
                         'lacks projection_on field')
    num_dims = tt.ndims()
    left_tt_rank_dim = tt.left_tt_rank_dim
    right_tt_rank_dim = tt.right_tt_rank_dim
    deltas = [None] * num_dims
    tt_ranks = shapes.lazy_tt_ranks(tt)
    for i in range(1, num_dims - 1):
        if int(tt_ranks[i] / 2) != tt_ranks[i] / 2:
            raise ValueError(
                'tt argument is supposed to be a projection, but its '
                'ranks are not even.')
    with tf.name_scope(name):
        for i in range(1, num_dims - 1):
            r1, r2 = tt_ranks[i], tt_ranks[i + 1]
            curr_core = tt.tt_cores[i]
            slc = [slice(None)] * len(curr_core.shape)
            slc[left_tt_rank_dim] = slice(int(r1 / 2), None)
            slc[right_tt_rank_dim] = slice(0, int(r2 / 2))
            deltas[i] = curr_core[slc]
        slc = [slice(None)] * len(tt.tt_cores[0].shape)
        slc[right_tt_rank_dim] = slice(0, int(tt_ranks[1] / 2))
        deltas[0] = tt.tt_cores[0][slc]
        slc = [slice(None)] * len(tt.tt_cores[0].shape)
        slc[left_tt_rank_dim] = slice(int(tt_ranks[-2] / 2), None)
        deltas[num_dims - 1] = tt.tt_cores[num_dims - 1][slc]
    return deltas
예제 #12
0
def project_sum(what, where, weights=None):
    """Project sum of `what` TTs on the tangent space of `where` TT.

  project_sum(what, x) = P_x(what)
  project_sum(batch_what, x) = P_x(\sum_i batch_what[i])
  project_sum(batch_what, x, weights) = P_x(\sum_j weights[j] * batch_what[j])

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      projection of the sum of elements in the batch.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    weights: python list or tf.Tensor of numbers or None, weights of the sum

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()

  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d r_what r_where n (r_what + r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    n is the size of the axis dimension of what and where e.g.
      for a tensor of size 4 x 4 x 4, n is 4;
      for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12
  """
    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    if weights is not None:
        weights = tf.convert_to_tensor(weights, dtype=where.dtype)

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    mode_str = 'ij' if where.is_tt_matrix() else 'i'
    right_rank_dim = where.right_tt_rank_dim
    left_rank_dim = where.left_tt_rank_dim
    if weights is not None:
        weights_shape = weights.get_shape()
        output_is_batch = len(weights_shape) > 1 and weights_shape[1] > 1
    else:
        output_is_batch = False
    output_batch_str = 'o' if output_is_batch else ''
    if output_is_batch:
        right_rank_dim += 1
        left_rank_dim += 1
        output_batch_size = weights.get_shape().as_list()[1]

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
        rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                                  right_tang_core)

    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
        lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx],
                                      left_tang_core, tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
            proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
            proj_core -= tf.einsum(einsum_str, left_tang_core,
                                   lhs[core_idx + 1])
            if weights is None:
                einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])
            else:
                einsum_str = 'sa{0}b,sbc->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, proj_core,
                                        rhs[core_idx + 1])
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if core_idx == ndims - 1:
            if weights is None:
                einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            else:
                einsum_str = 'sab,sb{0}c->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, lhs[core_idx], tens_core)
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            if where.is_tt_matrix():
                extended_left_tang_core = tf.tile(
                    extended_left_tang_core, [output_batch_size, 1, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1, 1])
            else:
                extended_left_tang_core = tf.tile(extended_left_tang_core,
                                                  [output_batch_size, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            if where.is_tt_matrix():
                mode_size_n = raw_shape[0][core_idx]
                mode_size_m = raw_shape[1][core_idx]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size = raw_shape[0][core_idx]
                shape = [rank_1, mode_size, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)
    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
예제 #13
0
파일: ops.py 프로젝트: vseledkin/t3f
def tt_tt_matmul(tt_matrix_a, tt_matrix_b):
    """Multiplies two TT-matrices and returns the TT-matrix of the result.

  Args:
    tt_matrix_a: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size M x N
    tt_matrix_b: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size N x P

  Returns
    `TensorTrain` object containing a TT-matrix of size M x P if both arguments
      are `TensorTrain`s
    `TensorTrainBatch` if any of the arguments is a `TensorTrainBatch`

  Raises:
    ValueError is the arguments are not TT matrices or if their sizes are not
    appropriate for a matrix-by-matrix multiplication.
  """
    # Both TensorTrain and TensorTrainBatch are inherited from TensorTrainBase.
    if not isinstance(tt_matrix_a, TensorTrainBase) or \
        not isinstance(tt_matrix_b, TensorTrainBase) or \
        not tt_matrix_a.is_tt_matrix() or \
        not tt_matrix_b.is_tt_matrix():
        raise ValueError('Arguments should be TT-matrices')

    if not shapes.is_batch_broadcasting_possible(tt_matrix_a, tt_matrix_b):
        raise ValueError(
            'The batch sizes are different and not 1, broadcasting is '
            'not available.')

    ndims = tt_matrix_a.ndims()
    if tt_matrix_b.ndims() != ndims:
        raise ValueError(
            'Arguments should have the same number of dimensions, '
            'got %d and %d instead.' % (ndims, tt_matrix_b.ndims()))

    # Convert BatchSize 1 batch into TT object to simplify broadcasting.
    tt_matrix_a = shapes.squeeze_batch_dim(tt_matrix_a)
    tt_matrix_b = shapes.squeeze_batch_dim(tt_matrix_b)
    is_a_batch = isinstance(tt_matrix_a, TensorTrainBatch)
    is_b_batch = isinstance(tt_matrix_b, TensorTrainBatch)
    is_res_batch = is_a_batch or is_b_batch
    a_batch_str = 'o' if is_a_batch else ''
    b_batch_str = 'o' if is_b_batch else ''
    res_batch_str = 'o' if is_res_batch else ''
    einsum_str = '{}aijb,{}cjkd->{}acikbd'.format(a_batch_str, b_batch_str,
                                                  res_batch_str)
    result_cores = []
    # TODO: name the operation and the resulting tensor.
    a_shape = shapes.lazy_raw_shape(tt_matrix_a)
    a_ranks = shapes.lazy_tt_ranks(tt_matrix_a)
    b_shape = shapes.lazy_raw_shape(tt_matrix_b)
    b_ranks = shapes.lazy_tt_ranks(tt_matrix_b)
    if is_res_batch:
        if is_a_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_a)
        if is_b_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_b)
    for core_idx in range(ndims):
        a_core = tt_matrix_a.tt_cores[core_idx]
        b_core = tt_matrix_b.tt_cores[core_idx]
        curr_res_core = tf.einsum(einsum_str, a_core, b_core)

        res_left_rank = a_ranks[core_idx] * b_ranks[core_idx]
        res_right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
        left_mode = a_shape[0][core_idx]
        right_mode = b_shape[1][core_idx]
        if is_res_batch:
            core_shape = (batch_size, res_left_rank, left_mode, right_mode,
                          res_right_rank)
        else:
            core_shape = (res_left_rank, left_mode, right_mode, res_right_rank)
        curr_res_core = tf.reshape(curr_res_core, core_shape)
        result_cores.append(curr_res_core)

    res_shape = (tt_matrix_a.get_raw_shape()[0],
                 tt_matrix_b.get_raw_shape()[1])
    static_a_ranks = tt_matrix_a.get_tt_ranks()
    static_b_ranks = tt_matrix_b.get_tt_ranks()
    out_ranks = [a_r * b_r for a_r, b_r in zip(static_a_ranks, static_b_ranks)]
    if is_res_batch:
        return TensorTrainBatch(result_cores, res_shape, out_ranks, batch_size)
    else:
        return TensorTrain(result_cores, res_shape, out_ranks)
예제 #14
0
def _orthogonalize_batch_tt_cores_left_to_right(tt):
    """Orthogonalize TT-cores of a batch TT-object in the left to right order.

  Args:
    tt: TensorTrainBatch.

  Returns:
    TensorTrainBatch
  """
    # Left to right orthogonalization.
    ndims = tt.ndims()
    raw_shape = shapes.lazy_raw_shape(tt)
    tt_ranks = shapes.lazy_tt_ranks(tt)
    next_rank = tt_ranks[0]
    batch_size = shapes.lazy_batch_size(tt)

    # Copy cores references so we can change the cores.
    tt_cores = list(tt.tt_cores)
    for core_idx in range(ndims - 1):
        curr_core = tt_cores[core_idx]
        # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
        # be outdated for the current TT-rank, but should be valid for the next
        # TT-rank.
        curr_rank = next_rank
        next_rank = tt_ranks[core_idx + 1]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        qr_shape = (batch_size, curr_rank * curr_mode, next_rank)
        curr_core = tf.reshape(curr_core, qr_shape)
        curr_core, triang = tf.qr(curr_core)
        if triang.get_shape().is_fully_defined():
            triang_shape = triang.get_shape().as_list()
        else:
            triang_shape = tf.shape(triang)
        # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
        # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
        # should be changed to 4.
        next_rank = triang_shape[1]
        if tt.is_tt_matrix():
            new_core_shape = (batch_size, curr_rank, curr_mode_left,
                              curr_mode_right, next_rank)
        else:
            new_core_shape = (batch_size, curr_rank, curr_mode, next_rank)

        tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape)

        next_core = tf.reshape(tt_cores[core_idx + 1],
                               (batch_size, triang_shape[2], -1))
        tt_cores[core_idx + 1] = tf.matmul(triang, next_core)

    if tt.is_tt_matrix():
        last_core_shape = (batch_size, next_rank, raw_shape[0][-1],
                           raw_shape[1][-1], 1)
    else:
        last_core_shape = (batch_size, next_rank, raw_shape[0][-1], 1)
    tt_cores[-1] = tf.reshape(tt_cores[-1], last_core_shape)
    # TODO: infer the tt_ranks.
    return TensorTrainBatch(tt_cores,
                            tt.get_raw_shape(),
                            batch_size=batch_size)
예제 #15
0
def add_n_projected(tt_objects, coef=None):
    """Adds all input TT-objects that are projections on the same tangent space.

    add_projected((a, b)) is equivalent add(a, b) for a and b that are from the
    same tangent space, but doesn't increase the TT-ranks.

  Args:
    tt_objects: a list of TT-objects that are projections on the same tangent
      space.
    coef: a list of numbers or anything else convertable to tf.Tensor.
      If provided, computes weighted sum. The size of this array should be
        len(tt_objects) x tt_objects[0].batch_size

  Returns:
    TT-objects representing the sum of the tt_objects (weighted sum if coef is
    provided). The TT-rank of the result equals to the TT-ranks of the arguments.
  """
    for tt in tt_objects:
        if not hasattr(tt, 'projection_on'):
            raise ValueError(
                'Both arguments should be projections on the tangent '
                'space of some other TT-object. All projection* functions '
                'leave .projection_on field in the resulting TT-object '
                'which is not present in the argument you\'ve provided.')

    projection_on = tt_objects[0].projection_on
    for tt in tt_objects[1:]:
        if tt.projection_on != projection_on:
            raise ValueError(
                'All tt_objects should be projections on the tangent '
                'space of the same TT-object. The provided arguments are '
                'projections on different TT-objects (%s and %s). Or at '
                'least the pointers are different.' %
                (tt.projection_on, projection_on))
    if coef is not None:
        coef = tf.convert_to_tensor(coef, dtype=tt_objects[0].dtype)
        if coef.get_shape().ndims > 1:
            # In batch case we will need to multiply each core by this coefficients
            # along the first axis. To do it need to reshape the coefs to match
            # the TT-cores number of dimensions.
            some_core = tt_objects[0].tt_cores[0]
            dim_array = [1] * (some_core.get_shape().ndims + 1)
            dim_array[0] = coef.get_shape().as_list()[0]
            dim_array[1] = coef.get_shape().as_list()[1]
            coef = tf.reshape(coef, dim_array)

    ndims = tt_objects[0].ndims()
    tt_ranks = shapes.lazy_tt_ranks(tt_objects[0])
    left_rank_dim = tt_objects[0].left_tt_rank_dim
    right_rank_dim = tt_objects[0].right_tt_rank_dim
    res_cores = []

    def slice_tt_core(tt_core, left_idx, right_idx):
        num_tt_core_dims = len(tt_core.get_shape())
        idx = [slice(None)] * num_tt_core_dims
        idx[left_rank_dim] = left_idx
        idx[right_rank_dim] = right_idx
        return tt_core[idx]

    right_half_rank = tt_ranks[1] // 2
    left_chunks = []
    for obj_idx, tt in enumerate(tt_objects):
        curr_core = slice_tt_core(tt.tt_cores[0], slice(None),
                                  slice(0, right_half_rank))
        if coef is not None:
            curr_core *= coef[obj_idx]
        left_chunks.append(curr_core)
    left_part = tf.add_n(left_chunks)
    first_obj_core = tt_objects[0].tt_cores[0]
    right_part = slice_tt_core(first_obj_core, slice(None),
                               slice(right_half_rank, None))
    first_core = tf.concat((left_part, right_part), axis=right_rank_dim)
    res_cores.append(first_core)

    for core_idx in range(1, ndims - 1):
        first_obj_core = tt_objects[0].tt_cores[core_idx]
        left_half_rank = tt_ranks[core_idx] // 2
        right_half_rank = tt_ranks[core_idx + 1] // 2

        upper_part = slice_tt_core(tt.tt_cores[core_idx],
                                   slice(0, left_half_rank), slice(None))
        lower_right_part = slice_tt_core(first_obj_core,
                                         slice(left_half_rank, None),
                                         slice(right_half_rank, None))

        lower_left_chunks = []
        for obj_idx, tt in enumerate(tt_objects):
            curr_core = slice_tt_core(tt.tt_cores[core_idx],
                                      slice(left_half_rank, None),
                                      slice(0, right_half_rank))
            if coef is not None:
                curr_core *= coef[obj_idx]
            lower_left_chunks.append(curr_core)
        lower_left_part = tf.add_n(lower_left_chunks)
        lower_part = tf.concat((lower_left_part, lower_right_part),
                               axis=right_rank_dim)
        curr_core = tf.concat((upper_part, lower_part), axis=left_rank_dim)
        res_cores.append(curr_core)

    left_half_rank = tt_ranks[ndims - 1] // 2
    upper_part = slice_tt_core(tt.tt_cores[-1], slice(0, left_half_rank),
                               slice(None))
    lower_chunks = []
    for obj_idx, tt in enumerate(tt_objects):
        curr_core = slice_tt_core(tt.tt_cores[-1], slice(left_half_rank, None),
                                  slice(None))
        if coef is not None:
            curr_core *= coef[obj_idx]
        lower_chunks.append(curr_core)
    lower_part = tf.add_n(lower_chunks)
    last_core = tf.concat((upper_part, lower_part), axis=left_rank_dim)
    res_cores.append(last_core)

    raw_shape = tt_objects[0].get_raw_shape()
    static_tt_ranks = tt_objects[0].get_tt_ranks()
    if isinstance(tt_objects[0], TensorTrain):
        res = TensorTrain(res_cores, raw_shape, static_tt_ranks)
    elif isinstance(tt_objects[0], TensorTrainBatch):
        res = TensorTrainBatch(res_cores, raw_shape, static_tt_ranks,
                               tt_objects[0].batch_size)
    # Maintain the projection_on property.
    res.projection_on = tt_objects[0].projection_on
    return res
예제 #16
0
def pairwise_flat_inner_projected(projected_tt_vectors_1,
                                  projected_tt_vectors_2):
    """Scalar products between two batches of TTs from the same tangent space.

    res[i, j] = t3f.flat_inner(projected_tt_vectors_1[i], projected_tt_vectors_1[j]).

  pairwise_flat_inner_projected(projected_tt_vectors_1, projected_tt_vectors_2)
  is equivalent to
    pairwise_flat_inner(projected_tt_vectors_1, projected_tt_vectors_2)
  , but works only on objects from the same tangent space and is much faster
  than general pairwise_flat_inner.

  Args:
    projected_tt_vectors_1: TensorTrainBatch of tensors projected on the same
      tangent space as projected_tt_vectors_2.
    projected_tt_vectors_2: TensorTrainBatch.

  Returns:
    tf.tensor with the scalar product matrix.
      
  Complexity:
      O(batch_size^2 d r^2 n), where
    d is the number of TT-cores (projected_tt_vectors_1.ndims());
    r is the largest TT-rank max(projected_tt_vectors_1.get_tt_rank())
      (i.e. 2 * {the TT-rank of the object we projected vectors onto}.
    and n is the size of the axis dimension, e.g.
      for a tensor of size 4 x 4 x 4, n is 4;
      for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12.
  """
    if not hasattr(projected_tt_vectors_1, 'projection_on') or \
        not hasattr(projected_tt_vectors_2, 'projection_on'):
        raise ValueError(
            'Both arguments should be projections on the tangent '
            'space of some other TT-object. All projection* functions '
            'leave .projection_on field in the resulting TT-object '
            'which is not present in the arguments you\'ve provided')

    if projected_tt_vectors_1.projection_on != projected_tt_vectors_2.projection_on:
        raise ValueError(
            'Both arguments should be projections on the tangent '
            'space of the same TT-object. The provided arguments are '
            'projections on different TT-objects (%s and %s). Or at '
            'least the pointers are different.' %
            (projected_tt_vectors_1.projection_on,
             projected_tt_vectors_2.projection_on))

    # Always work with batches of objects for simplicity.
    projected_tt_vectors_1 = shapes.expand_batch_dim(projected_tt_vectors_1)
    projected_tt_vectors_2 = shapes.expand_batch_dim(projected_tt_vectors_2)

    ndims = projected_tt_vectors_1.ndims()
    tt_ranks = shapes.lazy_tt_ranks(projected_tt_vectors_1)

    if projected_tt_vectors_1.is_tt_matrix():
        right_size = tt_ranks[1] // 2
        curr_core_1 = projected_tt_vectors_1.tt_cores[0]
        curr_core_2 = projected_tt_vectors_2.tt_cores[0]
        curr_du_1 = curr_core_1[:, :, :, :, :right_size]
        curr_du_2 = curr_core_2[:, :, :, :, :right_size]
        res = tf.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2)
        for core_idx in range(1, ndims):
            left_size = tt_ranks[core_idx] // 2
            right_size = tt_ranks[core_idx + 1] // 2
            curr_core_1 = projected_tt_vectors_1.tt_cores[core_idx]
            curr_core_2 = projected_tt_vectors_2.tt_cores[core_idx]
            curr_du_1 = curr_core_1[:, left_size:, :, :, :right_size]
            curr_du_2 = curr_core_2[:, left_size:, :, :, :right_size]
            res += tf.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2)

        left_size = tt_ranks[-2] // 2
        curr_core_1 = projected_tt_vectors_1.tt_cores[-1]
        curr_core_2 = projected_tt_vectors_2.tt_cores[-1]
        curr_du_1 = curr_core_1[:, left_size:, :, :, :]
        curr_du_2 = curr_core_2[:, left_size:, :, :, :]
        res += tf.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2)
    else:
        # Working with TT-tensor, not TT-matrix.
        right_size = tt_ranks[1] // 2
        curr_core_1 = projected_tt_vectors_1.tt_cores[0]
        curr_core_2 = projected_tt_vectors_2.tt_cores[0]
        curr_du_1 = curr_core_1[:, :, :, :right_size]
        curr_du_2 = curr_core_2[:, :, :, :right_size]
        res = tf.einsum('paib,qaib->pq', curr_du_1, curr_du_2)
        for core_idx in range(1, ndims):
            left_size = tt_ranks[core_idx] // 2
            right_size = tt_ranks[core_idx + 1] // 2
            curr_core_1 = projected_tt_vectors_1.tt_cores[core_idx]
            curr_core_2 = projected_tt_vectors_2.tt_cores[core_idx]
            curr_du_1 = curr_core_1[:, left_size:, :, :right_size]
            curr_du_2 = curr_core_2[:, left_size:, :, :right_size]
            res += tf.einsum('paib,qaib->pq', curr_du_1, curr_du_2)

        left_size = tt_ranks[-2] // 2
        curr_core_1 = projected_tt_vectors_1.tt_cores[-1]
        curr_core_2 = projected_tt_vectors_2.tt_cores[-1]
        curr_du_1 = curr_core_1[:, left_size:, :, :]
        curr_du_2 = curr_core_2[:, left_size:, :, :]
        res += tf.einsum('paib,qaib->pq', curr_du_1, curr_du_2)
    return res
예제 #17
0
def project_matmul(what, where, matrix):
    """Project `matrix` * `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    matrix: TensorTrain, TT-matrix to multiply by what

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
      
  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d R r_what r_where (n r_what + n m R + m r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n).
  """

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    right_rank_dim = what.right_tt_rank_dim
    left_rank_dim = what.left_tt_rank_dim
    output_is_batch = isinstance(what, TensorTrainBatch)
    if output_is_batch:
        output_batch_size = what.batch_size

    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core,
                                  right_tang_core, rhs[core_idx + 1],
                                  tens_core)
    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        # TODO: brutforce order of indices in lhs??
        lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef',
                                      matrix_core, left_tang_core,
                                      lhs[core_idx], tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            proj_core = tf.einsum('scjke,sabc,bijd->saikde', tens_core,
                                  lhs[core_idx], matrix_core)
            proj_core -= tf.einsum('aikb,sbcd->saikcd', left_tang_core,
                                   lhs[core_idx + 1])
            proj_core = tf.einsum('saikcb,sbcd->saikd', proj_core,
                                  rhs[core_idx + 1])

        if core_idx == ndims - 1:
            # d and e dimensions take 1 value, since its the last rank.
            # To make the result shape (?, ?, ?, 1), we are summing d and leaving e,
            # but we could have done the opposite -- sum e and leave d.
            proj_core = tf.einsum('sabc,bijd,scjke->saike', lhs[core_idx],
                                  matrix_core, tens_core)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            extended_left_tang_core = tf.tile(extended_left_tang_core,
                                              [output_batch_size, 1, 1, 1, 1])
            extended_right_tang_core = tf.tile(extended_right_tang_core,
                                               [output_batch_size, 1, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            mode_size_n = raw_shape[0][core_idx]
            mode_size_m = raw_shape[1][core_idx]
            shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)

    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
예제 #18
0
파일: riemannian.py 프로젝트: vseledkin/t3f
def project(what, where):
  """Project `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
  """

  if not isinstance(where, TensorTrain):
    raise ValueError('The first argument should be a TensorTrain object, got '
                     '"%s".' % where)

  if where.get_raw_shape() != what.get_raw_shape():
    raise ValueError('The shapes of the tensor we want to project and of the '
                     'tensor on which tangent space we want to project should '
                     'match, got %s and %s.' %
                     (where.get_raw_shape(),
                      what.get_raw_shape()))

  if not where.dtype.is_compatible_with(what.dtype):
    raise ValueError('Dtypes of the arguments should coincide, got %s and %s.' %
                     (where.dtype,
                      what.dtype))

  left_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    where)
  right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    left_tangent_space_tens, left_to_right=False)

  ndims = where.ndims()
  dtype = where.dtype
  raw_shape = shapes.lazy_raw_shape(where)
  batch_size = shapes.lazy_batch_size(what)
  right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
  left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

  # For einsum notation.
  mode_str = 'ij' if where.is_tt_matrix() else 'i'
  right_rank_dim = 3 if where.is_tt_matrix() else 2
  left_rank_dim = 0
  output_is_batch = isinstance(what, TensorTrainBatch)
  if output_is_batch:
    right_rank_dim += 1
    left_rank_dim = 1
    output_batch_size = what.batch_size

  # Always work with batch of TT objects for simplicity.
  what = shapes.expand_batch_dim(what)

  # Prepare rhs vectors.
  # rhs[core_idx] is of size
  #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
  rhs = [None] * (ndims + 1)
  rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1, 0, -1):
    tens_core = what.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
    rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                              right_tang_core)

  # Prepare lhs vectors.
  # lhs[core_idx] is of size
  #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
  lhs = [None] * (ndims + 1)
  lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
    lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core,
                                  tens_core)

  # Left to right sweep.
  res_cores_list = []
  for core_idx in range(ndims):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

    if core_idx < ndims - 1:
      einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
      einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
      proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1])
      if output_is_batch:
        einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])

    if core_idx == ndims - 1:
      if output_is_batch:
        einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)

    if output_is_batch:
      # Add batch dimension of size output_batch_size to left_tang_core and
      # right_tang_core
      extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
      extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
      if where.is_tt_matrix():
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1, 1])
      else:
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1])
    else:
      extended_left_tang_core = left_tang_core
      extended_right_tang_core = right_tang_core

    if core_idx == 0:
      res_core = tf.concat((proj_core, extended_left_tang_core),
                           axis=right_rank_dim)
    elif core_idx == ndims - 1:
      res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim)
    else:
      rank_1 = right_tangent_tt_ranks[core_idx]
      rank_2 = left_tangent_tt_ranks[core_idx + 1]
      if where.is_tt_matrix():
        mode_size_n = raw_shape[0][core_idx]
        mode_size_m = raw_shape[1][core_idx]
        shape = [rank_1, mode_size_n, mode_size_m, rank_2]
      else:
        mode_size = raw_shape[0][core_idx]
        shape = [rank_1, mode_size, rank_2]
      if output_is_batch:
        shape = [output_batch_size] + shape
      zeros = tf.zeros(shape, dtype)
      upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim)
      lower = tf.concat((proj_core, extended_left_tang_core),
                        axis=right_rank_dim)
      res_core = tf.concat((upper, lower), axis=left_rank_dim)
    res_cores_list.append(res_core)
  # TODO: TT-ranks.
  if output_is_batch:
    res = TensorTrainBatch(res_cores_list, where.get_raw_shape(),
                            batch_size=output_batch_size)
  else:
    res = TensorTrain(res_cores_list, where.get_raw_shape())

  res.projection_on = where
  return res
예제 #19
0
파일: ops.py 프로젝트: zhanglang1860/t3f
def multiply(tt_left, right):
    """Returns a TensorTrain corresponding to element-wise product tt_left * right.

  Supports broadcasting:
    multiply(TensorTrainBatch, TensorTrain) returns TensorTrainBatch consisting
    of element-wise products of TT in TensorTrainBatch and TensorTrain

    multiply(TensorTrainBatch_a, TensorTrainBatch_b) returns TensorTrainBatch
    consisting of element-wise products of TT in TensorTrainBatch_a and
    TT in TensorTrainBatch_b

    Batch sizes should support broadcasting
  Args:
    tt_left: `TensorTrain` OR `TensorTrainBatch`
    right: `TensorTrain` OR `TensorTrainBatch` OR a number.

  Returns
    a `TensorTrain` or `TensorTrainBatch` object corresponding to the
    element-wise product of the arguments.

  Raises
    ValueError if the arguments shapes do not coincide or broadcasting is not
    possible.
  """

    is_left_batch = isinstance(tt_left, TensorTrainBatch)
    is_right_batch = isinstance(right, TensorTrainBatch)

    is_batch_case = is_left_batch or is_right_batch
    ndims = tt_left.ndims()
    if not isinstance(right, TensorTrainBase):
        # Assume right is a number, not TensorTrain.
        # To squash right uniformly across TT-cores we pull its absolute value
        # and raise to the power 1/ndims. First TT-core is multiplied by the sign
        # of right.
        tt_cores = list(tt_left.tt_cores)
        fact = tf.pow(tf.cast(tf.abs(right), tt_left.dtype), 1.0 / ndims)
        sign = tf.cast(tf.sign(right), tt_left.dtype)
        for i in range(len(tt_cores)):
            tt_cores[i] = fact * tt_cores[i]

        tt_cores[0] = tt_cores[0] * sign
        out_ranks = tt_left.get_tt_ranks()
        if is_left_batch:
            out_batch_size = tt_left.batch_size
    else:

        if tt_left.is_tt_matrix() != right.is_tt_matrix():
            raise ValueError('The arguments should be both TT-tensors or both '
                             'TT-matrices')

        if tt_left.get_raw_shape() != right.get_raw_shape():
            raise ValueError('The arguments should have the same shape.')

        out_batch_size = 1
        dependencies = []
        can_determine_if_broadcast = True
        if is_left_batch and is_right_batch:
            if tt_left.batch_size is None and right.batch_size is None:
                can_determine_if_broadcast = False
            elif tt_left.batch_size is None and right.batch_size is not None:
                if right.batch_size > 1:
                    can_determine_if_broadcast = False
            elif tt_left.batch_size is not None and right.batch_size is None:
                if tt_left.batch_size > 1:
                    can_determine_if_broadcast = False

        if not can_determine_if_broadcast:
            # Cannot determine if broadcasting is needed. Avoid broadcasting and
            # assume elementwise multiplication AND add execution time assert to print
            # a better error message if the batch sizes turn out to be different.

            message = (
                'The batch sizes were unknown on compilation stage, so '
                'assumed elementwise multiplication (i.e. no broadcasting). '
                'Now it seems that they are different after all :')

            data = [
                message,
                shapes.lazy_batch_size(tt_left), ' x ',
                shapes.lazy_batch_size(right)
            ]
            bs_eq = tf.assert_equal(shapes.lazy_batch_size(tt_left),
                                    shapes.lazy_batch_size(right),
                                    data=data)

            dependencies.append(bs_eq)

        do_broadcast = shapes.is_batch_broadcasting_possible(tt_left, right)
        if not can_determine_if_broadcast:
            # Assume elementwise multiplication if broadcasting cannot be determined
            # on compilation stage.
            do_broadcast = False
        if not do_broadcast and can_determine_if_broadcast:
            raise ValueError(
                'The batch sizes are different and not 1, broadcasting '
                'is not available.')

        a_ranks = shapes.lazy_tt_ranks(tt_left)
        b_ranks = shapes.lazy_tt_ranks(right)
        shape = shapes.lazy_raw_shape(tt_left)

        output_str = ''
        bs_str_left = ''
        bs_str_right = ''

        if is_batch_case:
            if is_left_batch and is_right_batch:
                # Both arguments are batches of equal size.
                if tt_left.batch_size == right.batch_size or not can_determine_if_broadcast:
                    bs_str_left = 'n'
                    bs_str_right = 'n'
                    output_str = 'n'
                    if not can_determine_if_broadcast:
                        out_batch_size = None
                    else:
                        out_batch_size = tt_left.batch_size
                else:
                    # Broadcasting (e.g batch_sizes are 1 and n>1).
                    bs_str_left = 'n'
                    bs_str_right = 'm'
                    output_str = 'nm'
                    if tt_left.batch_size is None or tt_left.batch_size > 1:
                        out_batch_size = tt_left.batch_size
                    else:
                        out_batch_size = right.batch_size
            else:
                # One of the arguments is TensorTrain.
                if is_left_batch:
                    bs_str_left = 'n'
                    bs_str_right = ''
                    out_batch_size = tt_left.batch_size
                else:
                    bs_str_left = ''
                    bs_str_right = 'n'
                    out_batch_size = right.batch_size
                output_str = 'n'

        is_matrix = tt_left.is_tt_matrix()
        tt_cores = []

        for core_idx in range(ndims):
            a_core = tt_left.tt_cores[core_idx]
            b_core = right.tt_cores[core_idx]
            left_rank = a_ranks[core_idx] * b_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
            if is_matrix:
                with tf.control_dependencies(dependencies):
                    curr_core = tf.einsum(
                        '{0}aijb,{1}cijd->{2}acijbd'.format(
                            bs_str_left, bs_str_right, output_str), a_core,
                        b_core)
                    curr_core = tf.reshape(curr_core,
                                           (-1, left_rank, shape[0][core_idx],
                                            shape[1][core_idx], right_rank))
                    if not is_batch_case:
                        curr_core = tf.squeeze(curr_core, axis=0)
            else:
                with tf.control_dependencies(dependencies):
                    curr_core = tf.einsum(
                        '{0}aib,{1}cid->{2}acibd'.format(
                            bs_str_left, bs_str_right, output_str), a_core,
                        b_core)
                    curr_core = tf.reshape(
                        curr_core,
                        (-1, left_rank, shape[0][core_idx], right_rank))
                    if not is_batch_case:
                        curr_core = tf.squeeze(curr_core, axis=0)

            tt_cores.append(curr_core)

        combined_ranks = zip(tt_left.get_tt_ranks(), right.get_tt_ranks())
        out_ranks = [a * b for a, b in combined_ranks]

    if not is_batch_case:
        return TensorTrain(tt_cores, tt_left.get_raw_shape(), out_ranks)
    else:
        return TensorTrainBatch(tt_cores,
                                tt_left.get_raw_shape(),
                                out_ranks,
                                batch_size=out_batch_size)
예제 #20
0
def deltas_to_tangent_space(deltas,
                            tt,
                            left=None,
                            right=None,
                            name='t3f_deltas_to_tangent_space'):
    """Converts deltas representation of tangent space vector to TT object.

  Takes as input a list of [dP1, ..., dPd] and returns
    dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd.

  This function is hard to use correctly because deltas should abey the
  so called gauge conditions. If the don't, the function will silently return
  incorrect result. This is why this function is not imported in __init__.

  Args:
      deltas: a list of deltas (essentially TT-cores) obeying the gauge
        conditions.
      tt: `TensorTrain` object on which the tangent space tensor represented by
        delta is projected.
      left: t3f.orthogonilize_tt_cores(tt). If you have it already compute, you
        may pass it as argument to avoid recomputing.
      right: t3f.orthogonilize_tt_cores(left, left_to_right=False). If you have
        it already compute, you may pass it as argument to avoid recomputing.
      name: string, name of the Op.

  Returns:
      `TensorTrain` object constructed from deltas, that is from the tangent
        space at point `tt`.
  """
    cores = []
    dtype = tt.dtype
    num_dims = tt.ndims()
    # TODO: add cache instead of mannually pasisng precomputed stuff?
    input_tensors = list(tt.tt_cores) + list(deltas)
    if left is not None:
        input_tensors += list(left.tt_cores)
    if right is not None:
        input_tensors += list(right.tt_cores)
    with tf.name_scope(name):
        if left is None:
            left = decompositions.orthogonalize_tt_cores(tt)
        if right is None:
            right = decompositions.orthogonalize_tt_cores(left,
                                                          left_to_right=False)
        left_tangent_tt_ranks = shapes.lazy_tt_ranks(left)
        right_tangent_tt_ranks = shapes.lazy_tt_ranks(left)
        raw_shape = shapes.lazy_raw_shape(left)
        right_rank_dim = left.right_tt_rank_dim
        left_rank_dim = left.left_tt_rank_dim
        is_batch_case = len(deltas[0].shape) > len(tt.tt_cores[0].shape)
        if is_batch_case:
            right_rank_dim += 1
            left_rank_dim += 1
            batch_size = deltas[0].shape.as_list()[0]
        for i in range(num_dims):
            left_tt_core = left.tt_cores[i]
            right_tt_core = right.tt_cores[i]
            if is_batch_case:
                tile = [1] * len(left_tt_core.shape)
                tile = [batch_size] + tile
                left_tt_core = tf.tile(left_tt_core[None, ...], tile)
                right_tt_core = tf.tile(right_tt_core[None, ...], tile)

            if i == 0:
                tangent_core = tf.concat((deltas[i], left_tt_core),
                                         axis=right_rank_dim)
            elif i == num_dims - 1:
                tangent_core = tf.concat((right_tt_core, deltas[i]),
                                         axis=left_rank_dim)
            else:
                rank_1 = right_tangent_tt_ranks[i]
                rank_2 = left_tangent_tt_ranks[i + 1]
                if tt.is_tt_matrix():
                    mode_size_n = raw_shape[0][i]
                    mode_size_m = raw_shape[1][i]
                    shape = [rank_1, mode_size_n, mode_size_m, rank_2]
                else:
                    mode_size_n = raw_shape[0][i]
                    shape = [rank_1, mode_size_n, rank_2]
                if is_batch_case:
                    shape = [batch_size] + shape
                zeros = tf.zeros(shape, dtype=dtype)
                upper = tf.concat((right_tt_core, zeros), axis=right_rank_dim)
                lower = tf.concat((deltas[i], left_tt_core),
                                  axis=right_rank_dim)
                tangent_core = tf.concat((upper, lower), axis=left_rank_dim)
            cores.append(tangent_core)
        if is_batch_case:
            tangent = TensorTrainBatch(cores, batch_size=batch_size)
        else:
            tangent = TensorTrain(cores)
        tangent.projection_on = tt
        return tangent
예제 #21
0
def _orthogonalize_tt_cores_left_to_right(tt):
    """Orthogonalize TT-cores of a TT-object in the left to right order.
  Args:
    tt: TenosorTrain or a TensorTrainBatch.
  Returns:
    The same type as the input `tt` (TenosorTrain or a TensorTrainBatch).
  
  Complexity:
    for a single TT-object:
      O(d r^3 n)
    for a batch of TT-objects:
      O(batch_size d r^3 n)
    where
      d is the number of TT-cores (tt.ndims());
      r is the largest TT-rank of tt max(tt.get_tt_rank())
      n is the size of the axis dimension, e.g.
        for a tensor of size 4 x 4 x 4, n is 4;
        for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12 
  """
    # Left to right orthogonalization.
    ndims = tt.ndims()
    raw_shape = shapes.lazy_raw_shape(tt)
    tt_ranks = shapes.lazy_tt_ranks(tt)
    next_rank = tt_ranks[0]
    # Copy cores references so we can change the cores.
    tt_cores = list(tt.tt_cores)
    for core_idx in range(ndims - 1):
        curr_core = tt_cores[core_idx]
        # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
        # be outdated for the current TT-rank, but should be valid for the next
        # TT-rank.
        curr_rank = next_rank
        next_rank = tt_ranks[core_idx + 1]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        qr_shape = (curr_rank * curr_mode, next_rank)
        curr_core = tf.reshape(curr_core, qr_shape)
        curr_core, triang = tf.qr(curr_core)
        if triang.get_shape().is_fully_defined():
            triang_shape = triang.get_shape().as_list()
        else:
            triang_shape = tf.shape(triang)
        # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
        # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
        # should be changed to 4.
        next_rank = triang_shape[0]
        if tt.is_tt_matrix():
            new_core_shape = (curr_rank, curr_mode_left, curr_mode_right,
                              next_rank)
        else:
            new_core_shape = (curr_rank, curr_mode, next_rank)
        tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape)

        next_core = tf.reshape(tt_cores[core_idx + 1], (triang_shape[1], -1))
        tt_cores[core_idx + 1] = tf.matmul(triang, next_core)

    if tt.is_tt_matrix():
        last_core_shape = (next_rank, raw_shape[0][-1], raw_shape[1][-1], 1)
    else:
        last_core_shape = (next_rank, raw_shape[0][-1], 1)
    tt_cores[-1] = tf.reshape(tt_cores[-1], last_core_shape)
    # TODO: infer the tt_ranks.
    return TensorTrain(tt_cores, tt.get_raw_shape())