Esempio n. 1
0
File: ops.py Progetto: vseledkin/t3f
def tt_tt_flat_inner(tt_a, tt_b):
    """Inner product between two TT-tensors or TT-matrices along all axis.

  The shapes of tt_a and tt_b should coincide.

  Args:
    tt_a: `TensorTrain` or `TensorTrainBatch` object
    tt_b: `TensorTrain` or `TensorTrainBatch` object

  Returns
    a number or a Tensor with numbers for each element in the batch.
    sum of products of all the elements of tt_a and tt_b

  Raises:
    ValueError if the arguments are not `TensorTrain` objects, have different
      number of TT-cores, different underlying shape, or if you are trying to
      compute inner product between a TT-matrix and a TT-tensor.
  """
    if not isinstance(tt_a, TensorTrainBase) or not isinstance(
            tt_b, TensorTrainBase):
        raise ValueError('Arguments should be TensorTrains')

    if tt_a.is_tt_matrix() != tt_b.is_tt_matrix():
        raise ValueError('One of the arguments is a TT-tensor, the other is '
                         'a TT-matrix, disallowed')
    are_both_matrices = tt_a.is_tt_matrix() and tt_b.is_tt_matrix()

    if not shapes.is_batch_broadcasting_possible(tt_a, tt_b):
        raise ValueError(
            'The batch sizes are different and not 1, broadcasting is '
            'not available.')

    # TODO: compare shapes and raise if not consistent.

    ndims = tt_a.ndims()
    if tt_b.ndims() != ndims:
        raise ValueError(
            'Arguments should have the same number of dimensions, '
            'got %d and %d instead.' % (ndims, tt_b.ndims()))

    axes_str = 'ij' if are_both_matrices else 'i'
    # Convert BatchSize 1 batch into TT object to simplify broadcasting.
    tt_a = shapes.squeeze_batch_dim(tt_a)
    tt_b = shapes.squeeze_batch_dim(tt_b)
    is_a_batch = isinstance(tt_a, TensorTrainBatch)
    is_b_batch = isinstance(tt_b, TensorTrainBatch)
    is_res_batch = is_a_batch or is_b_batch
    a_batch_str = 'o' if is_a_batch else ''
    b_batch_str = 'o' if is_b_batch else ''
    res_batch_str = 'o' if is_res_batch else ''
    init_einsum_str = '{1}a{0}b,{2}c{0}d->{3}bd'.format(
        axes_str, a_batch_str, b_batch_str, res_batch_str)
    a_core = tt_a.tt_cores[0]
    b_core = tt_b.tt_cores[0]
    # Simplest example of this operation:
    # if both arguments are TT-tensors, then it is
    # res = tf.einsum('aib,cid->bd', a_core, b_core)
    res = tf.einsum(init_einsum_str, a_core, b_core)
    # TODO: name the operation and the resulting tensor.

    einsum_str = '{3}ac,{1}a{0}b,{2}c{0}d->{3}bd'.format(
        axes_str, a_batch_str, b_batch_str, res_batch_str)
    for core_idx in range(1, ndims):
        a_core = tt_a.tt_cores[core_idx]
        b_core = tt_b.tt_cores[core_idx]
        # Simplest example of this operation:
        # if both arguments are TT-tensors, then it is
        # res = tf.einsum('ac,aib,cid->bd', res, a_core, b_core)
        res = tf.einsum(einsum_str, res, a_core, b_core)
    return tf.squeeze(res)
Esempio n. 2
0
File: ops.py Progetto: vseledkin/t3f
def tt_tt_matmul(tt_matrix_a, tt_matrix_b):
    """Multiplies two TT-matrices and returns the TT-matrix of the result.

  Args:
    tt_matrix_a: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size M x N
    tt_matrix_b: `TensorTrain` or `TensorTrainBatch` object containing
      a TT-matrix (a batch of TT-matrices) of size N x P

  Returns
    `TensorTrain` object containing a TT-matrix of size M x P if both arguments
      are `TensorTrain`s
    `TensorTrainBatch` if any of the arguments is a `TensorTrainBatch`

  Raises:
    ValueError is the arguments are not TT matrices or if their sizes are not
    appropriate for a matrix-by-matrix multiplication.
  """
    # Both TensorTrain and TensorTrainBatch are inherited from TensorTrainBase.
    if not isinstance(tt_matrix_a, TensorTrainBase) or \
        not isinstance(tt_matrix_b, TensorTrainBase) or \
        not tt_matrix_a.is_tt_matrix() or \
        not tt_matrix_b.is_tt_matrix():
        raise ValueError('Arguments should be TT-matrices')

    if not shapes.is_batch_broadcasting_possible(tt_matrix_a, tt_matrix_b):
        raise ValueError(
            'The batch sizes are different and not 1, broadcasting is '
            'not available.')

    ndims = tt_matrix_a.ndims()
    if tt_matrix_b.ndims() != ndims:
        raise ValueError(
            'Arguments should have the same number of dimensions, '
            'got %d and %d instead.' % (ndims, tt_matrix_b.ndims()))

    # Convert BatchSize 1 batch into TT object to simplify broadcasting.
    tt_matrix_a = shapes.squeeze_batch_dim(tt_matrix_a)
    tt_matrix_b = shapes.squeeze_batch_dim(tt_matrix_b)
    is_a_batch = isinstance(tt_matrix_a, TensorTrainBatch)
    is_b_batch = isinstance(tt_matrix_b, TensorTrainBatch)
    is_res_batch = is_a_batch or is_b_batch
    a_batch_str = 'o' if is_a_batch else ''
    b_batch_str = 'o' if is_b_batch else ''
    res_batch_str = 'o' if is_res_batch else ''
    einsum_str = '{}aijb,{}cjkd->{}acikbd'.format(a_batch_str, b_batch_str,
                                                  res_batch_str)
    result_cores = []
    # TODO: name the operation and the resulting tensor.
    a_shape = shapes.lazy_raw_shape(tt_matrix_a)
    a_ranks = shapes.lazy_tt_ranks(tt_matrix_a)
    b_shape = shapes.lazy_raw_shape(tt_matrix_b)
    b_ranks = shapes.lazy_tt_ranks(tt_matrix_b)
    if is_res_batch:
        if is_a_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_a)
        if is_b_batch:
            batch_size = shapes.lazy_batch_size(tt_matrix_b)
    for core_idx in range(ndims):
        a_core = tt_matrix_a.tt_cores[core_idx]
        b_core = tt_matrix_b.tt_cores[core_idx]
        curr_res_core = tf.einsum(einsum_str, a_core, b_core)

        res_left_rank = a_ranks[core_idx] * b_ranks[core_idx]
        res_right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1]
        left_mode = a_shape[0][core_idx]
        right_mode = b_shape[1][core_idx]
        if is_res_batch:
            core_shape = (batch_size, res_left_rank, left_mode, right_mode,
                          res_right_rank)
        else:
            core_shape = (res_left_rank, left_mode, right_mode, res_right_rank)
        curr_res_core = tf.reshape(curr_res_core, core_shape)
        result_cores.append(curr_res_core)

    res_shape = (tt_matrix_a.get_raw_shape()[0],
                 tt_matrix_b.get_raw_shape()[1])
    static_a_ranks = tt_matrix_a.get_tt_ranks()
    static_b_ranks = tt_matrix_b.get_tt_ranks()
    out_ranks = [a_r * b_r for a_r, b_r in zip(static_a_ranks, static_b_ranks)]
    if is_res_batch:
        return TensorTrainBatch(result_cores, res_shape, out_ranks, batch_size)
    else:
        return TensorTrain(result_cores, res_shape, out_ranks)