def tt_tt_flat_inner(tt_a, tt_b): """Inner product between two TT-tensors or TT-matrices along all axis. The shapes of tt_a and tt_b should coincide. Args: tt_a: `TensorTrain` or `TensorTrainBatch` object tt_b: `TensorTrain` or `TensorTrainBatch` object Returns a number or a Tensor with numbers for each element in the batch. sum of products of all the elements of tt_a and tt_b Raises: ValueError if the arguments are not `TensorTrain` objects, have different number of TT-cores, different underlying shape, or if you are trying to compute inner product between a TT-matrix and a TT-tensor. """ if not isinstance(tt_a, TensorTrainBase) or not isinstance( tt_b, TensorTrainBase): raise ValueError('Arguments should be TensorTrains') if tt_a.is_tt_matrix() != tt_b.is_tt_matrix(): raise ValueError('One of the arguments is a TT-tensor, the other is ' 'a TT-matrix, disallowed') are_both_matrices = tt_a.is_tt_matrix() and tt_b.is_tt_matrix() if not shapes.is_batch_broadcasting_possible(tt_a, tt_b): raise ValueError( 'The batch sizes are different and not 1, broadcasting is ' 'not available.') # TODO: compare shapes and raise if not consistent. ndims = tt_a.ndims() if tt_b.ndims() != ndims: raise ValueError( 'Arguments should have the same number of dimensions, ' 'got %d and %d instead.' % (ndims, tt_b.ndims())) axes_str = 'ij' if are_both_matrices else 'i' # Convert BatchSize 1 batch into TT object to simplify broadcasting. tt_a = shapes.squeeze_batch_dim(tt_a) tt_b = shapes.squeeze_batch_dim(tt_b) is_a_batch = isinstance(tt_a, TensorTrainBatch) is_b_batch = isinstance(tt_b, TensorTrainBatch) is_res_batch = is_a_batch or is_b_batch a_batch_str = 'o' if is_a_batch else '' b_batch_str = 'o' if is_b_batch else '' res_batch_str = 'o' if is_res_batch else '' init_einsum_str = '{1}a{0}b,{2}c{0}d->{3}bd'.format( axes_str, a_batch_str, b_batch_str, res_batch_str) a_core = tt_a.tt_cores[0] b_core = tt_b.tt_cores[0] # Simplest example of this operation: # if both arguments are TT-tensors, then it is # res = tf.einsum('aib,cid->bd', a_core, b_core) res = tf.einsum(init_einsum_str, a_core, b_core) # TODO: name the operation and the resulting tensor. einsum_str = '{3}ac,{1}a{0}b,{2}c{0}d->{3}bd'.format( axes_str, a_batch_str, b_batch_str, res_batch_str) for core_idx in range(1, ndims): a_core = tt_a.tt_cores[core_idx] b_core = tt_b.tt_cores[core_idx] # Simplest example of this operation: # if both arguments are TT-tensors, then it is # res = tf.einsum('ac,aib,cid->bd', res, a_core, b_core) res = tf.einsum(einsum_str, res, a_core, b_core) return tf.squeeze(res)
def tt_tt_matmul(tt_matrix_a, tt_matrix_b): """Multiplies two TT-matrices and returns the TT-matrix of the result. Args: tt_matrix_a: `TensorTrain` or `TensorTrainBatch` object containing a TT-matrix (a batch of TT-matrices) of size M x N tt_matrix_b: `TensorTrain` or `TensorTrainBatch` object containing a TT-matrix (a batch of TT-matrices) of size N x P Returns `TensorTrain` object containing a TT-matrix of size M x P if both arguments are `TensorTrain`s `TensorTrainBatch` if any of the arguments is a `TensorTrainBatch` Raises: ValueError is the arguments are not TT matrices or if their sizes are not appropriate for a matrix-by-matrix multiplication. """ # Both TensorTrain and TensorTrainBatch are inherited from TensorTrainBase. if not isinstance(tt_matrix_a, TensorTrainBase) or \ not isinstance(tt_matrix_b, TensorTrainBase) or \ not tt_matrix_a.is_tt_matrix() or \ not tt_matrix_b.is_tt_matrix(): raise ValueError('Arguments should be TT-matrices') if not shapes.is_batch_broadcasting_possible(tt_matrix_a, tt_matrix_b): raise ValueError( 'The batch sizes are different and not 1, broadcasting is ' 'not available.') ndims = tt_matrix_a.ndims() if tt_matrix_b.ndims() != ndims: raise ValueError( 'Arguments should have the same number of dimensions, ' 'got %d and %d instead.' % (ndims, tt_matrix_b.ndims())) # Convert BatchSize 1 batch into TT object to simplify broadcasting. tt_matrix_a = shapes.squeeze_batch_dim(tt_matrix_a) tt_matrix_b = shapes.squeeze_batch_dim(tt_matrix_b) is_a_batch = isinstance(tt_matrix_a, TensorTrainBatch) is_b_batch = isinstance(tt_matrix_b, TensorTrainBatch) is_res_batch = is_a_batch or is_b_batch a_batch_str = 'o' if is_a_batch else '' b_batch_str = 'o' if is_b_batch else '' res_batch_str = 'o' if is_res_batch else '' einsum_str = '{}aijb,{}cjkd->{}acikbd'.format(a_batch_str, b_batch_str, res_batch_str) result_cores = [] # TODO: name the operation and the resulting tensor. a_shape = shapes.lazy_raw_shape(tt_matrix_a) a_ranks = shapes.lazy_tt_ranks(tt_matrix_a) b_shape = shapes.lazy_raw_shape(tt_matrix_b) b_ranks = shapes.lazy_tt_ranks(tt_matrix_b) if is_res_batch: if is_a_batch: batch_size = shapes.lazy_batch_size(tt_matrix_a) if is_b_batch: batch_size = shapes.lazy_batch_size(tt_matrix_b) for core_idx in range(ndims): a_core = tt_matrix_a.tt_cores[core_idx] b_core = tt_matrix_b.tt_cores[core_idx] curr_res_core = tf.einsum(einsum_str, a_core, b_core) res_left_rank = a_ranks[core_idx] * b_ranks[core_idx] res_right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1] left_mode = a_shape[0][core_idx] right_mode = b_shape[1][core_idx] if is_res_batch: core_shape = (batch_size, res_left_rank, left_mode, right_mode, res_right_rank) else: core_shape = (res_left_rank, left_mode, right_mode, res_right_rank) curr_res_core = tf.reshape(curr_res_core, core_shape) result_cores.append(curr_res_core) res_shape = (tt_matrix_a.get_raw_shape()[0], tt_matrix_b.get_raw_shape()[1]) static_a_ranks = tt_matrix_a.get_tt_ranks() static_b_ranks = tt_matrix_b.get_tt_ranks() out_ranks = [a_r * b_r for a_r, b_r in zip(static_a_ranks, static_b_ranks)] if is_res_batch: return TensorTrainBatch(result_cores, res_shape, out_ranks, batch_size) else: return TensorTrain(result_cores, res_shape, out_ranks)