Exemple #1
0
def zeros_like(tt, name='t3f_zeros_like'):
    """Constructs t3f.zeros with the shape of `tt`.

  In the case when `tt` is a TensorTrainBatch constructs t3f.zeros with
  the shape of a TensorTrain in `tt`.

  Args:
    tt: TensorTrain object
    name: string, name of the Op.

  Returns:
    TensorTrain object of the same shape as `tt` but with all entries equal to
    0.

  """
    if not isinstance(tt, TensorTrainBase):
        raise ValueError("`tt` has to be a Tensor Train object")
    else:
        shape = shapes.lazy_raw_shape(tt)
        # I guess variables=tt.tt_cores is not needed here since the output of
        # the function doesn't depend on the values of the TT-cores, only on their
        # shapes etc. But I'm not 100% sure.
        with tf.name_scope(name):
            if tt.is_tt_matrix():
                return matrix_zeros(shape, dtype=tt.dtype)
            else:
                return tensor_zeros(shape[0, :], dtype=tt.dtype)
Exemple #2
0
def _orthogonalize_tt_cores_right_to_left(tt):
    """Orthogonalize TT-cores of a TT-object in the right to left order.

  Args:
    tt: TenosorTrain or a TensorTrainBatch.

  Returns:
    The same type as the input `tt` (TenosorTrain or a TensorTrainBatch).
  """
    # Left to right orthogonalization.
    ndims = tt.ndims()
    raw_shape = shapes.lazy_raw_shape(tt)
    tt_ranks = shapes.lazy_tt_ranks(tt)
    prev_rank = tt_ranks[ndims]
    # Copy cores references so we can change the cores.
    tt_cores = list(tt.tt_cores)
    for core_idx in range(ndims - 1, 0, -1):
        curr_core = tt_cores[core_idx]
        # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
        # be outdated for the current TT-rank, but should be valid for the next
        # TT-rank.
        curr_rank = prev_rank
        prev_rank = tt_ranks[core_idx]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        qr_shape = (prev_rank, curr_mode * curr_rank)
        curr_core = tf.reshape(curr_core, qr_shape)
        curr_core, triang = tf.qr(tf.transpose(curr_core))
        curr_core = tf.transpose(curr_core)
        triang = tf.transpose(triang)
        if triang.get_shape().is_fully_defined():
            triang_shape = triang.get_shape().as_list()
        else:
            triang_shape = tf.shape(triang)
        # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
        # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
        # should be changed to 4.
        prev_rank = triang_shape[1]
        if tt.is_tt_matrix():
            new_core_shape = (prev_rank, curr_mode_left, curr_mode_right,
                              curr_rank)
        else:
            new_core_shape = (prev_rank, curr_mode, curr_rank)
        tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape)

        prev_core = tf.reshape(tt_cores[core_idx - 1], (-1, triang_shape[0]))
        tt_cores[core_idx - 1] = tf.matmul(prev_core, triang)

    if tt.is_tt_matrix():
        first_core_shape = (1, raw_shape[0][0], raw_shape[1][0], prev_rank)
    else:
        first_core_shape = (1, raw_shape[0][0], prev_rank)
    tt_cores[0] = tf.reshape(tt_cores[0], first_core_shape)
    # TODO: infer the tt_ranks.
    return TensorTrain(tt_cores, tt.get_raw_shape())
Exemple #3
0
def _full_tt_batch(tt):
    """Converts a TensorTrainBatch into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrainBatch` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    res = tt.tt_cores[0]
    batch_size = shapes.lazy_batch_size(tt)
    for i in range(1, num_dims):
        res = tf.reshape(res, (batch_size, -1, ranks[i]))
        curr_core = tf.reshape(tt.tt_cores[i], (batch_size, ranks[i], -1))
        res = tf.einsum('oqb,obw->oqw', res, curr_core)
    if tt.is_tt_matrix():
        intermediate_shape = [batch_size]
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = [0]
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i + 1)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i + 1)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
Exemple #4
0
def multiply(tt_left, right):
    """Returns a TensorTrain corresponding to element-wise product tt_left * right.

  The shapes of tt_left and right should coincide.

  Args:
    tt_left: `TensorTrain`, TT-tensor or TT-matrix
    right: `TensorTrain`, TT-tensor or TT-matrix, OR a number.

  Returns
    a `TensorTrain` object corresponding to the element-wise product of the
    arguments.

  Raises
    ValueError if the arguments shapes do not coincide.
  """
    if not isinstance(right, TensorTrainBase):
        # Assume right is a number, not TensorTrain.
        tt_cores = list(tt_left.tt_cores)
        tt_cores[0] = right * tt_cores[0]
        out_ranks = tt_left.get_tt_ranks()
    else:
        ndims = tt_left.ndims()
        if tt_left.is_tt_matrix() != right.is_tt_matrix():
            raise ValueError('The arguments should be both TT-tensors or both '
                             'TT-matrices')

        if tt_left.get_shape() != right.get_shape():
            raise ValueError('The arguments should have the same shape.')

        a_ranks = shapes.lazy_tt_ranks(tt_left)
        b_ranks = shapes.lazy_tt_ranks(right)
        shape = shapes.lazy_raw_shape(tt_left)

        is_matrix = tt_left.is_tt_matrix()
        tt_cores = []
        for core_idx in range(ndims):
            a_core = tt_left.tt_cores[core_idx]
            b_core = right.tt_cores[core_idx]
            left_rank = a_ranks[core_idx] * b_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
            if is_matrix:
                curr_core = tf.einsum('aijb,cijd->acijbd', a_core, b_core)
                curr_core = tf.reshape(curr_core,
                                       (left_rank, shape[0][core_idx],
                                        shape[1][core_idx], right_rank))
            else:
                curr_core = tf.einsum('aib,cid->acibd', a_core, b_core)
                curr_core = tf.reshape(
                    curr_core, (left_rank, shape[0][core_idx], right_rank))
            tt_cores.append(curr_core)

        combined_ranks = zip(tt_left.get_tt_ranks(), right.get_tt_ranks())
        out_ranks = [a * b for a, b in combined_ranks]

    if isinstance(tt_left, TensorTrain):
        return TensorTrain(tt_cores, tt_left.get_raw_shape(), out_ranks)
    else:
        return TensorTrainBatch(tt_cores, tt_left.get_raw_shape(), out_ranks,
                                tt_left.batch_size)
Exemple #5
0
def _add_matrix_cores(tt_a, tt_b):
    """Internal function to be called from add for two TT-matrices.

  Does the actual assembling of the TT-cores to add two TT-matrices.
  """
    ndims = tt_a.ndims()
    dtype = tt_a.dtype
    shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    b_ranks = shapes.lazy_tt_ranks(tt_b)
    tt_cores = []
    for core_idx in range(ndims):
        a_core = tt_a.tt_cores[core_idx]
        b_core = tt_b.tt_cores[core_idx]
        if core_idx == 0:
            curr_core = tf.concat((a_core, b_core), axis=3)
        elif core_idx == ndims - 1:
            curr_core = tf.concat((a_core, b_core), axis=0)
        else:
            upper_zeros = tf.zeros((a_ranks[core_idx], shape[0][core_idx],
                                    shape[1][core_idx], b_ranks[core_idx + 1]),
                                   dtype)
            lower_zeros = tf.zeros((b_ranks[core_idx], shape[0][core_idx],
                                    shape[1][core_idx], a_ranks[core_idx + 1]),
                                   dtype)
            upper = tf.concat((a_core, upper_zeros), axis=3)
            lower = tf.concat((lower_zeros, b_core), axis=3)
            curr_core = tf.concat((upper, lower), axis=0)
        tt_cores.append(curr_core)
    return tt_cores
Exemple #6
0
def _full_tt(tt):
    """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrain` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    res = tt.tt_cores[0]
    for i in range(1, num_dims):
        res = tf.reshape(res, (-1, ranks[i]))
        curr_core = tf.reshape(tt.tt_cores[i], (ranks[i], -1))
        res = tf.matmul(res, curr_core)
    if tt.is_tt_matrix():
        intermediate_shape = []
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = []
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
Exemple #7
0
def tt_sparse_flat_inner(tt_a, sparse_b):
    """Inner product between a TT-tensor (or TT-matrix) and tf.SparseTensor along all axis.

  The shapes of tt_a and sparse_b should coincide.

  Args:
    tt_a: `TensorTrain` object
    sparse_b: tf.SparseTensor

  Returns
    a number
    sum of products of all the elements of tt_a and sparse_b
  """
    if sparse_b.indices.get_shape().is_fully_defined():
        num_elements = sparse_b.indices.get_shape()[0]
    else:
        num_elements = tf.shape(sparse_b.indices)[0]
    a_shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    if tt_a.is_tt_matrix():
        tt_a_elements = tf.ones((num_elements, 1, 1))
        # TODO: use t3f.shape is safer??
        tensor_shape = tt_a.get_raw_shape()
        row_idx_linear = tf.cast(sparse_b.indices[:, 0], tf.int64)
        row_idx = utils.unravel_index(row_idx_linear,
                                      tf.cast(tensor_shape[0], tf.int64))
        col_idx_linear = tf.cast(sparse_b.indices[:, 1], tf.int64)
        col_idx = utils.unravel_index(col_idx_linear,
                                      tf.cast(tensor_shape[1], tf.int64))
        for core_idx in range(tt_a.ndims()):
            curr_core = tt_a.tt_cores[core_idx]
            left_rank = a_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1]
            curr_core = tf.transpose(curr_core, (1, 2, 0, 3))
            curr_core_shape = (a_shape[0][core_idx] * a_shape[1][core_idx],
                               left_rank, right_rank)
            curr_core = tf.reshape(curr_core, curr_core_shape)
            # Ravel multiindex (row_idx[:, core_idx], col_idx[:, core_idx]) into
            # a linear index to use tf.gather that supports only first dimensional
            # gather.
            # TODO: use gather_nd instead.
            curr_elements_idx = row_idx[:,
                                        core_idx] * tensor_shape[1][core_idx]
            curr_elements_idx += col_idx[:, core_idx]
            core_slices = tf.gather(curr_core, curr_elements_idx)
            tt_a_elements = tf.matmul(tt_a_elements, core_slices)
    else:
        tt_a_elements = gather_nd(tt_a, sparse_b.indices)
    tt_a_elements = tf.reshape(tt_a_elements, (1, -1))
    sparse_b_elements = tf.reshape(sparse_b.values, (-1, 1))
    result = tf.matmul(tt_a_elements, sparse_b_elements)
    # Convert a 1x1 matrix into a number.
    result = result[0, 0]
    return result
Exemple #8
0
def tt_dense_matmul(tt_matrix_a, matrix_b):
    """Multiplies a TT-matrix by a regular matrix, returns a regular matrix.

  Args:
    tt_matrix_a: `TensorTrain` object containing a TT-matrix of size M x N
    matrix_b: tf.Tensor of size N x P

  Returns
    tf.Tensor of size M x P
  """
    if not isinstance(tt_matrix_a,
                      TensorTrain) or not tt_matrix_a.is_tt_matrix():
        raise ValueError('The first argument should be a TT-matrix')

    ndims = tt_matrix_a.ndims()
    a_columns = tt_matrix_a.get_shape()[1].value
    b_rows = matrix_b.get_shape()[0].value
    if a_columns is not None and b_rows is not None:
        if a_columns != b_rows:
            raise ValueError(
                'Arguments shapes should align got %d and %d instead.' %
                (tt_matrix_a.get_shape(), matrix_b.get_shape()))

    a_shape = shapes.lazy_shape(tt_matrix_a)
    a_raw_shape = shapes.lazy_raw_shape(tt_matrix_a)
    if matrix_b.get_shape().is_fully_defined():
        b_shape = matrix_b.get_shape().as_list()
    else:
        b_shape = tf.shape(matrix_b)
    a_ranks = shapes.lazy_tt_ranks(tt_matrix_a)
    # If A is (i0, ..., id-1) x (j0, ..., jd-1) and B is (j0, ..., jd-1) x K,
    # data is (K, j0, ..., jd-2) x jd-1 x 1
    data = tf.transpose(matrix_b)
    data = tf.reshape(data, (-1, a_raw_shape[1][-1], 1))
    for core_idx in reversed(range(ndims)):
        curr_core = tt_matrix_a.tt_cores[core_idx]
        # On the k = core_idx iteration, after applying einsum the shape of data
        # becomes ik x (ik-1..., id-1, K, j0, ..., jk-1) x rank_k
        data = tf.einsum('aijb,rjb->ira', curr_core, data)
        if core_idx > 0:
            # After reshape the shape of data becomes
            # (ik, ..., id-1, K, j0, ..., jk-2) x jk-1 x rank_k
            new_data_shape = (-1, a_raw_shape[1][core_idx - 1],
                              a_ranks[core_idx])
            data = tf.reshape(data, new_data_shape)
    # At the end the shape of the data is (i0, ..., id-1) x K
    return tf.reshape(data, (a_shape[0], b_shape[1]))
Exemple #9
0
def _full_tt(tt):
    """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrain` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    quan_core_list = []
    for i in range(num_dims):
        print('FP core:', tt.tt_cores[i])
        if i == 0 or i == num_dims - 1:
            quan_core_list.append(fw(tt.tt_cores[i]))
        else:
            quan_core_list.append(fw(tt.tt_cores[i]))
        print('8bit core:', quan_core_list[i])
    res = quan_core_list[0]
    # Quan first core
    for i in range(1, num_dims):
        res = tf.reshape(res, (-1, ranks[i]))
        curr_core = tf.reshape(quan_core_list[i], (ranks[i], -1))
        res = tf.matmul(res, curr_core)
        print('core multi FP: ', res)
        # Quan mult cores
        res = fw(res)
        print('core multi 8bit: ', res)
    if tt.is_tt_matrix():
        intermediate_shape = []
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = []
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
Exemple #10
0
def _add_batch_matrix_cores(tt_a, tt_b):
    """Internal function to be called from add for two batches of TT-matrices.

  Does the actual assembling of the TT-cores to add two batches of TT-matrices.
  """
    ndims = tt_a.ndims()
    dtype = tt_a.dtype
    shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    b_ranks = shapes.lazy_tt_ranks(tt_b)
    if isinstance(tt_a, TensorTrainBatch) and tt_a.batch_size == 1:
        # We add 1 element batch tt_a to a batch_size element batch tt_b to get
        # the answer TensorTrainBatch of batch_size == tt_b.batch_size.
        batch_size = shapes.lazy_batch_size(tt_b)
    else:
        batch_size = shapes.lazy_batch_size(tt_a)
    tt_a = shapes.expand_batch_dim(tt_a)
    tt_b = shapes.expand_batch_dim(tt_b)
    tt_cores = []
    for core_idx in range(ndims):
        a_core = tt_a.tt_cores[core_idx]
        if tt_a.batch_size == 1:
            a_core = tf.tile(a_core, (batch_size, 1, 1, 1, 1))
        b_core = tt_b.tt_cores[core_idx]
        if tt_b.batch_size == 1:
            b_core = tf.tile(b_core, (batch_size, 1, 1, 1, 1))
        if core_idx == 0:
            curr_core = tf.concat((a_core, b_core), axis=4)
        elif core_idx == ndims - 1:
            curr_core = tf.concat((a_core, b_core), axis=1)
        else:
            upper_zeros = tf.zeros(
                (batch_size, a_ranks[core_idx], shape[0][core_idx],
                 shape[1][core_idx], b_ranks[core_idx + 1]), dtype)
            lower_zeros = tf.zeros(
                (batch_size, b_ranks[core_idx], shape[0][core_idx],
                 shape[1][core_idx], a_ranks[core_idx + 1]), dtype)
            upper = tf.concat((a_core, upper_zeros), axis=4)
            lower = tf.concat((lower_zeros, b_core), axis=4)
            curr_core = tf.concat((upper, lower), axis=1)
        tt_cores.append(curr_core)
    return tt_cores, batch_size
Exemple #11
0
def zeros_like(tt):
    """Constructs t3f.zeros with the shape of `tt`.

  In the case when `tt` is a TensorTrainBatch constructs t3f.zeros with
  the shape of a TensorTrain in `tt`.

  Args:
    tt: TensorTrain object

  Returns:
    TensorTrain object of the same shape as `tt` but with all entries equal to
    0.

  """
    if not isinstance(tt, TensorTrainBase):
        raise ValueError("`tt` has to be a Tensor Train object")
    else:
        shape = shapes.lazy_raw_shape(tt)
        if tt.is_tt_matrix():
            return matrix_zeros(shape)
        else:
            return tensor_zeros(shape[0, :])
Exemple #12
0
def _round_batch_tt(tt, max_tt_rank, epsilon):
    """Internal function that rounds a TensorTrainBatch.

  See t3f.round for details.
  """
    ndims = tt.ndims()
    max_tt_rank = np.array(max_tt_rank).astype(np.int32)
    if max_tt_rank < 1:
        raise ValueError('Maximum TT-rank should be greater or equal to 1.')
    if epsilon is not None and epsilon < 0:
        raise ValueError('Epsilon should be non-negative.')
    if max_tt_rank.size == 1:
        max_tt_rank = (max_tt_rank * np.ones(ndims + 1)).astype(np.int32)
    elif max_tt_rank.size != ndims + 1:
        raise ValueError(
            'max_tt_rank should be a number or a vector of size (d+1) '
            'where d is the number of dimensions (rank) of the tensor.')
    raw_shape = shapes.lazy_raw_shape(tt)
    batch_size = shapes.lazy_batch_size(tt)

    tt_cores = orthogonalize_tt_cores(tt).tt_cores
    # Copy cores references so we can change the cores.
    tt_cores = list(tt_cores)

    ranks = [1] * (ndims + 1)
    are_tt_ranks_defined = True
    # Right to left SVD compression.
    for core_idx in range(ndims - 1, 0, -1):
        curr_core = tt_cores[core_idx]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        columns = curr_mode * ranks[core_idx + 1]
        curr_core = tf.reshape(curr_core, (batch_size, -1, columns))
        rows = curr_core.get_shape()[1].value
        if rows is None:
            rows = tf.shape(curr_core)[1]
        if max_tt_rank[core_idx] == 1:
            ranks[core_idx] = 1
        else:
            try:
                ranks[core_idx] = min(max_tt_rank[core_idx], rows, columns)
            except TypeError:
                # Some of the values are undefined on the compilation stage and thus
                # they are tf.tensors instead of values.
                min_dim = tf.minimum(rows, columns)
                ranks[core_idx] = tf.minimum(max_tt_rank[core_idx], min_dim)
                are_tt_ranks_defined = False
        s, u, v = tf.svd(curr_core, full_matrices=False)
        u = u[:, :, 0:ranks[core_idx]]
        s = s[:, 0:ranks[core_idx]]
        v = v[:, :, 0:ranks[core_idx]]
        if tt.is_tt_matrix():
            core_shape = (batch_size, ranks[core_idx], curr_mode_left,
                          curr_mode_right, ranks[core_idx + 1])
        else:
            core_shape = (batch_size, ranks[core_idx], curr_mode,
                          ranks[core_idx + 1])
        tt_cores[core_idx] = tf.reshape(tf.transpose(v, (0, 2, 1)), core_shape)
        prev_core_shape = (batch_size, -1, rows)
        tt_cores[core_idx - 1] = tf.reshape(tt_cores[core_idx - 1],
                                            prev_core_shape)
        tt_cores[core_idx - 1] = tf.matmul(tt_cores[core_idx - 1], u)
        tt_cores[core_idx - 1] = tf.matmul(tt_cores[core_idx - 1],
                                           tf.matrix_diag(s))

    if tt.is_tt_matrix():
        core_shape = (batch_size, ranks[0], raw_shape[0][0], raw_shape[1][0],
                      ranks[1])
    else:
        core_shape = (batch_size, ranks[0], raw_shape[0][0], ranks[1])
    tt_cores[0] = tf.reshape(tt_cores[0], core_shape)
    if not are_tt_ranks_defined:
        ranks = None
    return TensorTrainBatch(tt_cores,
                            tt.get_raw_shape(),
                            ranks,
                            batch_size=tt.batch_size)
Exemple #13
0
def _orthogonalize_batch_tt_cores_left_to_right(tt):
    """Orthogonalize TT-cores of a batch TT-object in the left to right order.

  Args:
    tt: TensorTrainBatch.

  Returns:
    TensorTrainBatch
  """
    # Left to right orthogonalization.
    ndims = tt.ndims()
    raw_shape = shapes.lazy_raw_shape(tt)
    tt_ranks = shapes.lazy_tt_ranks(tt)
    next_rank = tt_ranks[0]
    batch_size = shapes.lazy_batch_size(tt)

    # Copy cores references so we can change the cores.
    tt_cores = list(tt.tt_cores)
    for core_idx in range(ndims - 1):
        curr_core = tt_cores[core_idx]
        # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
        # be outdated for the current TT-rank, but should be valid for the next
        # TT-rank.
        curr_rank = next_rank
        next_rank = tt_ranks[core_idx + 1]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        qr_shape = (batch_size, curr_rank * curr_mode, next_rank)
        curr_core = tf.reshape(curr_core, qr_shape)
        curr_core, triang = tf.qr(curr_core)
        if triang.get_shape().is_fully_defined():
            triang_shape = triang.get_shape().as_list()
        else:
            triang_shape = tf.shape(triang)
        # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
        # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
        # should be changed to 4.
        next_rank = triang_shape[1]
        if tt.is_tt_matrix():
            new_core_shape = (batch_size, curr_rank, curr_mode_left,
                              curr_mode_right, next_rank)
        else:
            new_core_shape = (batch_size, curr_rank, curr_mode, next_rank)

        tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape)

        next_core = tf.reshape(tt_cores[core_idx + 1],
                               (batch_size, triang_shape[2], -1))
        tt_cores[core_idx + 1] = tf.matmul(triang, next_core)

    if tt.is_tt_matrix():
        last_core_shape = (batch_size, next_rank, raw_shape[0][-1],
                           raw_shape[1][-1], 1)
    else:
        last_core_shape = (batch_size, next_rank, raw_shape[0][-1], 1)
    tt_cores[-1] = tf.reshape(tt_cores[-1], last_core_shape)
    # TODO: infer the tt_ranks.
    return TensorTrainBatch(tt_cores,
                            tt.get_raw_shape(),
                            batch_size=batch_size)
Exemple #14
0
def deltas_to_tangent_space(deltas,
                            tt,
                            left=None,
                            right=None,
                            name='t3f_deltas_to_tangent_space'):
    """Converts deltas representation of tangent space vector to TT object.

  Takes as input a list of [dP1, ..., dPd] and returns
    dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd.

  This function is hard to use correctly because deltas should abey the
  so called gauge conditions. If the don't, the function will silently return
  incorrect result. This is why this function is not imported in __init__.

  Args:
      deltas: a list of deltas (essentially TT-cores) obeying the gauge
        conditions.
      tt: `TensorTrain` object on which the tangent space tensor represented by
        delta is projected.
      left: t3f.orthogonilize_tt_cores(tt). If you have it already compute, you
        may pass it as argument to avoid recomputing.
      right: t3f.orthogonilize_tt_cores(left, left_to_right=False). If you have
        it already compute, you may pass it as argument to avoid recomputing.
      name: string, name of the Op.

  Returns:
      `TensorTrain` object constructed from deltas, that is from the tangent
        space at point `tt`.
  """
    cores = []
    dtype = tt.dtype
    num_dims = tt.ndims()
    # TODO: add cache instead of mannually pasisng precomputed stuff?
    input_tensors = list(tt.tt_cores) + list(deltas)
    if left is not None:
        input_tensors += list(left.tt_cores)
    if right is not None:
        input_tensors += list(right.tt_cores)
    with tf.name_scope(name):
        if left is None:
            left = decompositions.orthogonalize_tt_cores(tt)
        if right is None:
            right = decompositions.orthogonalize_tt_cores(left,
                                                          left_to_right=False)
        left_tangent_tt_ranks = shapes.lazy_tt_ranks(left)
        right_tangent_tt_ranks = shapes.lazy_tt_ranks(left)
        raw_shape = shapes.lazy_raw_shape(left)
        right_rank_dim = left.right_tt_rank_dim
        left_rank_dim = left.left_tt_rank_dim
        is_batch_case = len(deltas[0].shape) > len(tt.tt_cores[0].shape)
        if is_batch_case:
            right_rank_dim += 1
            left_rank_dim += 1
            batch_size = deltas[0].shape.as_list()[0]
        for i in range(num_dims):
            left_tt_core = left.tt_cores[i]
            right_tt_core = right.tt_cores[i]
            if is_batch_case:
                tile = [1] * len(left_tt_core.shape)
                tile = [batch_size] + tile
                left_tt_core = tf.tile(left_tt_core[None, ...], tile)
                right_tt_core = tf.tile(right_tt_core[None, ...], tile)

            if i == 0:
                tangent_core = tf.concat((deltas[i], left_tt_core),
                                         axis=right_rank_dim)
            elif i == num_dims - 1:
                tangent_core = tf.concat((right_tt_core, deltas[i]),
                                         axis=left_rank_dim)
            else:
                rank_1 = right_tangent_tt_ranks[i]
                rank_2 = left_tangent_tt_ranks[i + 1]
                if tt.is_tt_matrix():
                    mode_size_n = raw_shape[0][i]
                    mode_size_m = raw_shape[1][i]
                    shape = [rank_1, mode_size_n, mode_size_m, rank_2]
                else:
                    mode_size_n = raw_shape[0][i]
                    shape = [rank_1, mode_size_n, rank_2]
                if is_batch_case:
                    shape = [batch_size] + shape
                zeros = tf.zeros(shape, dtype=dtype)
                upper = tf.concat((right_tt_core, zeros), axis=right_rank_dim)
                lower = tf.concat((deltas[i], left_tt_core),
                                  axis=right_rank_dim)
                tangent_core = tf.concat((upper, lower), axis=left_rank_dim)
            cores.append(tangent_core)
        if is_batch_case:
            tangent = TensorTrainBatch(cores, batch_size=batch_size)
        else:
            tangent = TensorTrain(cores)
        tangent.projection_on = tt
        return tangent
Exemple #15
0
def project_sum(what, where, weights=None):
    """Project sum of `what` TTs on the tangent space of `where` TT.

  project_sum(what, x) = P_x(what)
  project_sum(batch_what, x) = P_x(\sum_i batch_what[i])
  project_sum(batch_what, x, weights) = P_x(\sum_j weights[j] * batch_what[j])

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      projection of the sum of elements in the batch.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    weights: python list or tf.Tensor of numbers or None, weights of the sum

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()

  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d r_what r_where n (r_what + r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    n is the size of the axis dimension of what and where e.g.
      for a tensor of size 4 x 4 x 4, n is 4;
      for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12
  """
    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    if weights is not None:
        weights = tf.convert_to_tensor(weights, dtype=where.dtype)

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    mode_str = 'ij' if where.is_tt_matrix() else 'i'
    right_rank_dim = where.right_tt_rank_dim
    left_rank_dim = where.left_tt_rank_dim
    if weights is not None:
        weights_shape = weights.get_shape()
        output_is_batch = len(weights_shape) > 1 and weights_shape[1] > 1
    else:
        output_is_batch = False
    output_batch_str = 'o' if output_is_batch else ''
    if output_is_batch:
        right_rank_dim += 1
        left_rank_dim += 1
        output_batch_size = weights.get_shape().as_list()[1]

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
        rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                                  right_tang_core)

    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
        lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx],
                                      left_tang_core, tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
            proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
            proj_core -= tf.einsum(einsum_str, left_tang_core,
                                   lhs[core_idx + 1])
            if weights is None:
                einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])
            else:
                einsum_str = 'sa{0}b,sbc->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, proj_core,
                                        rhs[core_idx + 1])
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if core_idx == ndims - 1:
            if weights is None:
                einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
                proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
            else:
                einsum_str = 'sab,sb{0}c->sa{0}c'.format(
                    mode_str, output_batch_str)
                proj_core_s = tf.einsum(einsum_str, lhs[core_idx], tens_core)
                einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(
                    mode_str, output_batch_str)
                proj_core = tf.einsum(einsum_str, weights, proj_core_s)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            if where.is_tt_matrix():
                extended_left_tang_core = tf.tile(
                    extended_left_tang_core, [output_batch_size, 1, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1, 1])
            else:
                extended_left_tang_core = tf.tile(extended_left_tang_core,
                                                  [output_batch_size, 1, 1, 1])
                extended_right_tang_core = tf.tile(
                    extended_right_tang_core, [output_batch_size, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            if where.is_tt_matrix():
                mode_size_n = raw_shape[0][core_idx]
                mode_size_m = raw_shape[1][core_idx]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size = raw_shape[0][core_idx]
                shape = [rank_1, mode_size, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)
    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
Exemple #16
0
def project_matmul(what, where, matrix):
    """Project `matrix` * `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project
    matrix: TensorTrain, TT-matrix to multiply by what

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
      
  Complexity:
       O(d r_where^3 m) for orthogonalizing the TT-cores of where
      +O(batch_size d R r_what r_where (n r_what + n m R + m r_where))
    d is the number of TT-cores (what.ndims());
    r_what is the largest TT-rank of what max(what.get_tt_rank())
    r_where is the largest TT-rank of where
    matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n).
  """

    if not isinstance(where, TensorTrain):
        raise ValueError(
            'The first argument should be a TensorTrain object, got '
            '"%s".' % where)

    if where.get_raw_shape() != what.get_raw_shape():
        raise ValueError(
            'The shapes of the tensor we want to project and of the '
            'tensor on which tangent space we want to project should '
            'match, got %s and %s.' %
            (where.get_raw_shape(), what.get_raw_shape()))

    dtypes_compatible = (where.dtype.is_compatible_with(what.dtype)
                         or what.dtype.is_compatible_with(where.dtype))
    if not dtypes_compatible:
        raise ValueError(
            'Dtypes of the arguments should coincide, got %s and %s.' %
            (where.dtype, what.dtype))

    left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where)
    right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
        left_tangent_space_tens, left_to_right=False)

    ndims = where.ndims()
    dtype = where.dtype
    raw_shape = shapes.lazy_raw_shape(where)
    batch_size = shapes.lazy_batch_size(what)
    right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
    left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

    # For einsum notation.
    right_rank_dim = what.right_tt_rank_dim
    left_rank_dim = what.left_tt_rank_dim
    output_is_batch = isinstance(what, TensorTrainBatch)
    if output_is_batch:
        output_batch_size = what.batch_size

    # Always work with batch of TT objects for simplicity.
    what = shapes.expand_batch_dim(what)

    # Prepare rhs vectors.
    # rhs[core_idx] is of size
    #   batch_size x tensor_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
    rhs = [None] * (ndims + 1)
    rhs[ndims] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1, 0, -1):
        tens_core = what.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core,
                                  right_tang_core, rhs[core_idx + 1],
                                  tens_core)
    # Prepare lhs vectors.
    # lhs[core_idx] is of size
    #   batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
    lhs = [None] * (ndims + 1)
    lhs[0] = tf.ones((batch_size, 1, 1, 1), dtype=dtype)
    for core_idx in range(ndims - 1):
        tens_core = what.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        # TODO: brutforce order of indices in lhs??
        lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef',
                                      matrix_core, left_tang_core,
                                      lhs[core_idx], tens_core)

    # Left to right sweep.
    res_cores_list = []
    for core_idx in range(ndims):
        tens_core = what.tt_cores[core_idx]
        matrix_core = matrix.tt_cores[core_idx]
        left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
        right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

        if core_idx < ndims - 1:
            proj_core = tf.einsum('scjke,sabc,bijd->saikde', tens_core,
                                  lhs[core_idx], matrix_core)
            proj_core -= tf.einsum('aikb,sbcd->saikcd', left_tang_core,
                                   lhs[core_idx + 1])
            proj_core = tf.einsum('saikcb,sbcd->saikd', proj_core,
                                  rhs[core_idx + 1])

        if core_idx == ndims - 1:
            # d and e dimensions take 1 value, since its the last rank.
            # To make the result shape (?, ?, ?, 1), we are summing d and leaving e,
            # but we could have done the opposite -- sum e and leave d.
            proj_core = tf.einsum('sabc,bijd,scjke->saike', lhs[core_idx],
                                  matrix_core, tens_core)

        if output_is_batch:
            # Add batch dimension of size output_batch_size to left_tang_core and
            # right_tang_core
            extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
            extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
            extended_left_tang_core = tf.tile(extended_left_tang_core,
                                              [output_batch_size, 1, 1, 1, 1])
            extended_right_tang_core = tf.tile(extended_right_tang_core,
                                               [output_batch_size, 1, 1, 1, 1])
        else:
            extended_left_tang_core = left_tang_core
            extended_right_tang_core = right_tang_core

        if core_idx == 0:
            res_core = tf.concat((proj_core, extended_left_tang_core),
                                 axis=right_rank_dim)
        elif core_idx == ndims - 1:
            res_core = tf.concat((extended_right_tang_core, proj_core),
                                 axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[core_idx]
            rank_2 = left_tangent_tt_ranks[core_idx + 1]
            mode_size_n = raw_shape[0][core_idx]
            mode_size_m = raw_shape[1][core_idx]
            shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            if output_is_batch:
                shape = [output_batch_size] + shape
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((extended_right_tang_core, zeros),
                              axis=right_rank_dim)
            lower = tf.concat((proj_core, extended_left_tang_core),
                              axis=right_rank_dim)
            res_core = tf.concat((upper, lower), axis=left_rank_dim)
        res_cores_list.append(res_core)

    # TODO: TT-ranks.
    if output_is_batch:
        res = TensorTrainBatch(res_cores_list,
                               where.get_raw_shape(),
                               batch_size=output_batch_size)
    else:
        res = TensorTrain(res_cores_list, where.get_raw_shape())

    res.projection_on = where
    return res
Exemple #17
0
def project(what, where):
  """Project `what` TTs on the tangent space of `where` TT.

  project(what, x) = P_x(what)
  project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

  This function implements the algorithm from the paper [1], theorem 3.1.

  [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of
    Tensor Trains.

  Args:
    what: TensorTrain or TensorTrainBatch. In the case of batch returns
      batch with projection of each individual tensor.
    where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project

  Returns:
     a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks()
  """

  if not isinstance(where, TensorTrain):
    raise ValueError('The first argument should be a TensorTrain object, got '
                     '"%s".' % where)

  if where.get_raw_shape() != what.get_raw_shape():
    raise ValueError('The shapes of the tensor we want to project and of the '
                     'tensor on which tangent space we want to project should '
                     'match, got %s and %s.' %
                     (where.get_raw_shape(),
                      what.get_raw_shape()))

  if not where.dtype.is_compatible_with(what.dtype):
    raise ValueError('Dtypes of the arguments should coincide, got %s and %s.' %
                     (where.dtype,
                      what.dtype))

  left_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    where)
  right_tangent_space_tens = decompositions.orthogonalize_tt_cores(
    left_tangent_space_tens, left_to_right=False)

  ndims = where.ndims()
  dtype = where.dtype
  raw_shape = shapes.lazy_raw_shape(where)
  batch_size = shapes.lazy_batch_size(what)
  right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens)
  left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens)

  # For einsum notation.
  mode_str = 'ij' if where.is_tt_matrix() else 'i'
  right_rank_dim = 3 if where.is_tt_matrix() else 2
  left_rank_dim = 0
  output_is_batch = isinstance(what, TensorTrainBatch)
  if output_is_batch:
    right_rank_dim += 1
    left_rank_dim = 1
    output_batch_size = what.batch_size

  # Always work with batch of TT objects for simplicity.
  what = shapes.expand_batch_dim(what)

  # Prepare rhs vectors.
  # rhs[core_idx] is of size
  #   batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx]
  rhs = [None] * (ndims + 1)
  rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1, 0, -1):
    tens_core = what.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str)
    rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1],
                              right_tang_core)

  # Prepare lhs vectors.
  # lhs[core_idx] is of size
  #   batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx]
  lhs = [None] * (ndims + 1)
  lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype)
  for core_idx in range(ndims - 1):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str)
    lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core,
                                  tens_core)

  # Left to right sweep.
  res_cores_list = []
  for core_idx in range(ndims):
    tens_core = what.tt_cores[core_idx]
    left_tang_core = left_tangent_space_tens.tt_cores[core_idx]
    right_tang_core = right_tangent_space_tens.tt_cores[core_idx]

    if core_idx < ndims - 1:
      einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)
      einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str)
      proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1])
      if output_is_batch:
        einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1])

    if core_idx == ndims - 1:
      if output_is_batch:
        einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str)
      else:
        einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str)
      proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core)

    if output_is_batch:
      # Add batch dimension of size output_batch_size to left_tang_core and
      # right_tang_core
      extended_left_tang_core = tf.expand_dims(left_tang_core, 0)
      extended_right_tang_core = tf.expand_dims(right_tang_core, 0)
      if where.is_tt_matrix():
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1, 1])
      else:
        extended_left_tang_core = tf.tile(extended_left_tang_core,
                                          [output_batch_size, 1, 1, 1])
        extended_right_tang_core = tf.tile(extended_right_tang_core,
                                           [output_batch_size, 1, 1, 1])
    else:
      extended_left_tang_core = left_tang_core
      extended_right_tang_core = right_tang_core

    if core_idx == 0:
      res_core = tf.concat((proj_core, extended_left_tang_core),
                           axis=right_rank_dim)
    elif core_idx == ndims - 1:
      res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim)
    else:
      rank_1 = right_tangent_tt_ranks[core_idx]
      rank_2 = left_tangent_tt_ranks[core_idx + 1]
      if where.is_tt_matrix():
        mode_size_n = raw_shape[0][core_idx]
        mode_size_m = raw_shape[1][core_idx]
        shape = [rank_1, mode_size_n, mode_size_m, rank_2]
      else:
        mode_size = raw_shape[0][core_idx]
        shape = [rank_1, mode_size, rank_2]
      if output_is_batch:
        shape = [output_batch_size] + shape
      zeros = tf.zeros(shape, dtype)
      upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim)
      lower = tf.concat((proj_core, extended_left_tang_core),
                        axis=right_rank_dim)
      res_core = tf.concat((upper, lower), axis=left_rank_dim)
    res_cores_list.append(res_core)
  # TODO: TT-ranks.
  if output_is_batch:
    res = TensorTrainBatch(res_cores_list, where.get_raw_shape(),
                            batch_size=output_batch_size)
  else:
    res = TensorTrain(res_cores_list, where.get_raw_shape())

  res.projection_on = where
  return res
Exemple #18
0
def multiply(tt_left, right):
    """Returns a TensorTrain corresponding to element-wise product tt_left * right.

  Supports broadcasting:
    multiply(TensorTrainBatch, TensorTrain) returns TensorTrainBatch consisting
    of element-wise products of TT in TensorTrainBatch and TensorTrain

    multiply(TensorTrainBatch_a, TensorTrainBatch_b) returns TensorTrainBatch
    consisting of element-wise products of TT in TensorTrainBatch_a and
    TT in TensorTrainBatch_b

    Batch sizes should support broadcasting
  Args:
    tt_left: `TensorTrain` OR `TensorTrainBatch`
    right: `TensorTrain` OR `TensorTrainBatch` OR a number.

  Returns
    a `TensorTrain` or `TensorTrainBatch` object corresponding to the
    element-wise product of the arguments.

  Raises
    ValueError if the arguments shapes do not coincide or broadcasting is not
    possible.
  """

    is_left_batch = isinstance(tt_left, TensorTrainBatch)
    is_right_batch = isinstance(right, TensorTrainBatch)

    is_batch_case = is_left_batch or is_right_batch
    ndims = tt_left.ndims()
    if not isinstance(right, TensorTrainBase):
        # Assume right is a number, not TensorTrain.
        # To squash right uniformly across TT-cores we pull its absolute value
        # and raise to the power 1/ndims. First TT-core is multiplied by the sign
        # of right.
        tt_cores = list(tt_left.tt_cores)
        fact = tf.pow(tf.cast(tf.abs(right), tt_left.dtype), 1.0 / ndims)
        sign = tf.cast(tf.sign(right), tt_left.dtype)
        for i in range(len(tt_cores)):
            tt_cores[i] = fact * tt_cores[i]

        tt_cores[0] = tt_cores[0] * sign
        out_ranks = tt_left.get_tt_ranks()
        if is_left_batch:
            out_batch_size = tt_left.batch_size
    else:

        if tt_left.is_tt_matrix() != right.is_tt_matrix():
            raise ValueError('The arguments should be both TT-tensors or both '
                             'TT-matrices')

        if tt_left.get_raw_shape() != right.get_raw_shape():
            raise ValueError('The arguments should have the same shape.')

        out_batch_size = 1
        dependencies = []
        can_determine_if_broadcast = True
        if is_left_batch and is_right_batch:
            if tt_left.batch_size is None and right.batch_size is None:
                can_determine_if_broadcast = False
            elif tt_left.batch_size is None and right.batch_size is not None:
                if right.batch_size > 1:
                    can_determine_if_broadcast = False
            elif tt_left.batch_size is not None and right.batch_size is None:
                if tt_left.batch_size > 1:
                    can_determine_if_broadcast = False

        if not can_determine_if_broadcast:
            # Cannot determine if broadcasting is needed. Avoid broadcasting and
            # assume elementwise multiplication AND add execution time assert to print
            # a better error message if the batch sizes turn out to be different.

            message = (
                'The batch sizes were unknown on compilation stage, so '
                'assumed elementwise multiplication (i.e. no broadcasting). '
                'Now it seems that they are different after all :')

            data = [
                message,
                shapes.lazy_batch_size(tt_left), ' x ',
                shapes.lazy_batch_size(right)
            ]
            bs_eq = tf.assert_equal(shapes.lazy_batch_size(tt_left),
                                    shapes.lazy_batch_size(right),
                                    data=data)

            dependencies.append(bs_eq)

        do_broadcast = shapes.is_batch_broadcasting_possible(tt_left, right)
        if not can_determine_if_broadcast:
            # Assume elementwise multiplication if broadcasting cannot be determined
            # on compilation stage.
            do_broadcast = False
        if not do_broadcast and can_determine_if_broadcast:
            raise ValueError(
                'The batch sizes are different and not 1, broadcasting '
                'is not available.')

        a_ranks = shapes.lazy_tt_ranks(tt_left)
        b_ranks = shapes.lazy_tt_ranks(right)
        shape = shapes.lazy_raw_shape(tt_left)

        output_str = ''
        bs_str_left = ''
        bs_str_right = ''

        if is_batch_case:
            if is_left_batch and is_right_batch:
                # Both arguments are batches of equal size.
                if tt_left.batch_size == right.batch_size or not can_determine_if_broadcast:
                    bs_str_left = 'n'
                    bs_str_right = 'n'
                    output_str = 'n'
                    if not can_determine_if_broadcast:
                        out_batch_size = None
                    else:
                        out_batch_size = tt_left.batch_size
                else:
                    # Broadcasting (e.g batch_sizes are 1 and n>1).
                    bs_str_left = 'n'
                    bs_str_right = 'm'
                    output_str = 'nm'
                    if tt_left.batch_size is None or tt_left.batch_size > 1:
                        out_batch_size = tt_left.batch_size
                    else:
                        out_batch_size = right.batch_size
            else:
                # One of the arguments is TensorTrain.
                if is_left_batch:
                    bs_str_left = 'n'
                    bs_str_right = ''
                    out_batch_size = tt_left.batch_size
                else:
                    bs_str_left = ''
                    bs_str_right = 'n'
                    out_batch_size = right.batch_size
                output_str = 'n'

        is_matrix = tt_left.is_tt_matrix()
        tt_cores = []

        for core_idx in range(ndims):
            a_core = tt_left.tt_cores[core_idx]
            b_core = right.tt_cores[core_idx]
            left_rank = a_ranks[core_idx] * b_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
            if is_matrix:
                with tf.control_dependencies(dependencies):
                    curr_core = tf.einsum(
                        '{0}aijb,{1}cijd->{2}acijbd'.format(
                            bs_str_left, bs_str_right, output_str), a_core,
                        b_core)
                    curr_core = tf.reshape(curr_core,
                                           (-1, left_rank, shape[0][core_idx],
                                            shape[1][core_idx], right_rank))
                    if not is_batch_case:
                        curr_core = tf.squeeze(curr_core, axis=0)
            else:
                with tf.control_dependencies(dependencies):
                    curr_core = tf.einsum(
                        '{0}aib,{1}cid->{2}acibd'.format(
                            bs_str_left, bs_str_right, output_str), a_core,
                        b_core)
                    curr_core = tf.reshape(
                        curr_core,
                        (-1, left_rank, shape[0][core_idx], right_rank))
                    if not is_batch_case:
                        curr_core = tf.squeeze(curr_core, axis=0)

            tt_cores.append(curr_core)

        combined_ranks = zip(tt_left.get_tt_ranks(), right.get_tt_ranks())
        out_ranks = [a * b for a, b in combined_ranks]

    if not is_batch_case:
        return TensorTrain(tt_cores, tt_left.get_raw_shape(), out_ranks)
    else:
        return TensorTrainBatch(tt_cores,
                                tt_left.get_raw_shape(),
                                out_ranks,
                                batch_size=out_batch_size)
Exemple #19
0
def tt_tt_matmul(tt_matrix_a, tt_matrix_b):
    """Multiplies two TT-matrices and returns the TT-matrix of the result.

  Args:
    tt_matrix_a: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size M x N
    tt_matrix_b: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size N x P

  Returns
    `TensorTrain` object containing a TT-matrix of size M x P if both arguments
      are `TensorTrain`s
    `TensorTrainBatch` if any of the arguments is a `TensorTrainBatch`

  Raises:
    ValueError is the arguments are not TT matrices or if their sizes are not
    appropriate for a matrix-by-matrix multiplication.
  """
    # Both TensorTrain and TensorTrainBatch are inherited from TensorTrainBase.
    if not isinstance(tt_matrix_a, TensorTrainBase) or \
        not isinstance(tt_matrix_b, TensorTrainBase) or \
        not tt_matrix_a.is_tt_matrix() or \
        not tt_matrix_b.is_tt_matrix():
        raise ValueError('Arguments should be TT-matrices')

    if not shapes.is_batch_broadcasting_possible(tt_matrix_a, tt_matrix_b):
        raise ValueError(
            'The batch sizes are different and not 1, broadcasting is '
            'not available.')

    ndims = tt_matrix_a.ndims()
    if tt_matrix_b.ndims() != ndims:
        raise ValueError(
            'Arguments should have the same number of dimensions, '
            'got %d and %d instead.' % (ndims, tt_matrix_b.ndims()))

    # Convert BatchSize 1 batch into TT object to simplify broadcasting.
    tt_matrix_a = shapes.squeeze_batch_dim(tt_matrix_a)
    tt_matrix_b = shapes.squeeze_batch_dim(tt_matrix_b)
    is_a_batch = isinstance(tt_matrix_a, TensorTrainBatch)
    is_b_batch = isinstance(tt_matrix_b, TensorTrainBatch)
    is_res_batch = is_a_batch or is_b_batch
    a_batch_str = 'o' if is_a_batch else ''
    b_batch_str = 'o' if is_b_batch else ''
    res_batch_str = 'o' if is_res_batch else ''
    einsum_str = '{}aijb,{}cjkd->{}acikbd'.format(a_batch_str, b_batch_str,
                                                  res_batch_str)
    result_cores = []
    # TODO: name the operation and the resulting tensor.
    a_shape = shapes.lazy_raw_shape(tt_matrix_a)
    a_ranks = shapes.lazy_tt_ranks(tt_matrix_a)
    b_shape = shapes.lazy_raw_shape(tt_matrix_b)
    b_ranks = shapes.lazy_tt_ranks(tt_matrix_b)
    if is_res_batch:
        if is_a_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_a)
        if is_b_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_b)
    for core_idx in range(ndims):
        a_core = tt_matrix_a.tt_cores[core_idx]
        b_core = tt_matrix_b.tt_cores[core_idx]
        curr_res_core = tf.einsum(einsum_str, a_core, b_core)

        res_left_rank = a_ranks[core_idx] * b_ranks[core_idx]
        res_right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
        left_mode = a_shape[0][core_idx]
        right_mode = b_shape[1][core_idx]
        if is_res_batch:
            core_shape = (batch_size, res_left_rank, left_mode, right_mode,
                          res_right_rank)
        else:
            core_shape = (res_left_rank, left_mode, right_mode, res_right_rank)
        curr_res_core = tf.reshape(curr_res_core, core_shape)
        result_cores.append(curr_res_core)

    res_shape = (tt_matrix_a.get_raw_shape()[0],
                 tt_matrix_b.get_raw_shape()[1])
    static_a_ranks = tt_matrix_a.get_tt_ranks()
    static_b_ranks = tt_matrix_b.get_tt_ranks()
    out_ranks = [a_r * b_r for a_r, b_r in zip(static_a_ranks, static_b_ranks)]
    if is_res_batch:
        return TensorTrainBatch(result_cores, res_shape, out_ranks, batch_size)
    else:
        return TensorTrain(result_cores, res_shape, out_ranks)
Exemple #20
0
def _orthogonalize_tt_cores_left_to_right(tt):
    """Orthogonalize TT-cores of a TT-object in the left to right order.
  Args:
    tt: TenosorTrain or a TensorTrainBatch.
  Returns:
    The same type as the input `tt` (TenosorTrain or a TensorTrainBatch).
  
  Complexity:
    for a single TT-object:
      O(d r^3 n)
    for a batch of TT-objects:
      O(batch_size d r^3 n)
    where
      d is the number of TT-cores (tt.ndims());
      r is the largest TT-rank of tt max(tt.get_tt_rank())
      n is the size of the axis dimension, e.g.
        for a tensor of size 4 x 4 x 4, n is 4;
        for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12 
  """
    # Left to right orthogonalization.
    ndims = tt.ndims()
    raw_shape = shapes.lazy_raw_shape(tt)
    tt_ranks = shapes.lazy_tt_ranks(tt)
    next_rank = tt_ranks[0]
    # Copy cores references so we can change the cores.
    tt_cores = list(tt.tt_cores)
    for core_idx in range(ndims - 1):
        curr_core = tt_cores[core_idx]
        # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
        # be outdated for the current TT-rank, but should be valid for the next
        # TT-rank.
        curr_rank = next_rank
        next_rank = tt_ranks[core_idx + 1]
        if tt.is_tt_matrix():
            curr_mode_left = raw_shape[0][core_idx]
            curr_mode_right = raw_shape[1][core_idx]
            curr_mode = curr_mode_left * curr_mode_right
        else:
            curr_mode = raw_shape[0][core_idx]

        qr_shape = (curr_rank * curr_mode, next_rank)
        curr_core = tf.reshape(curr_core, qr_shape)
        curr_core, triang = tf.qr(curr_core)
        if triang.get_shape().is_fully_defined():
            triang_shape = triang.get_shape().as_list()
        else:
            triang_shape = tf.shape(triang)
        # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
        # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
        # should be changed to 4.
        next_rank = triang_shape[0]
        if tt.is_tt_matrix():
            new_core_shape = (curr_rank, curr_mode_left, curr_mode_right,
                              next_rank)
        else:
            new_core_shape = (curr_rank, curr_mode, next_rank)
        tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape)

        next_core = tf.reshape(tt_cores[core_idx + 1], (triang_shape[1], -1))
        tt_cores[core_idx + 1] = tf.matmul(triang, next_core)

    if tt.is_tt_matrix():
        last_core_shape = (next_rank, raw_shape[0][-1], raw_shape[1][-1], 1)
    else:
        last_core_shape = (next_rank, raw_shape[0][-1], 1)
    tt_cores[-1] = tf.reshape(tt_cores[-1], last_core_shape)
    # TODO: infer the tt_ranks.
    return TensorTrain(tt_cores, tt.get_raw_shape())