def _full_tt_batch(tt): """Converts a TensorTrainBatch into a regular tensor or matrix (tf.Tensor). Args: tt: `TensorTrainBatch` object. Returns: tf.Tensor. """ num_dims = tt.ndims() ranks = shapes.lazy_tt_ranks(tt) shape = shapes.lazy_shape(tt) raw_shape = shapes.lazy_raw_shape(tt) res = tt.tt_cores[0] batch_size = shapes.lazy_batch_size(tt) for i in range(1, num_dims): res = tf.reshape(res, (batch_size, -1, ranks[i])) curr_core = tf.reshape(tt.tt_cores[i], (batch_size, ranks[i], -1)) res = tf.einsum('oqb,obw->oqw', res, curr_core) if tt.is_tt_matrix(): intermediate_shape = [batch_size] for i in range(num_dims): intermediate_shape.append(raw_shape[0][i]) intermediate_shape.append(raw_shape[1][i]) res = tf.reshape(res, intermediate_shape) transpose = [0] for i in range(0, 2 * num_dims, 2): transpose.append(i + 1) for i in range(1, 2 * num_dims, 2): transpose.append(i + 1) res = tf.transpose(res, transpose) return tf.reshape(res, shape) else: return tf.reshape(res, shape)
def _add_batch_matrix_cores(tt_a, tt_b): """Internal function to be called from add for two batches of TT-matrices. Does the actual assembling of the TT-cores to add two batches of TT-matrices. """ ndims = tt_a.ndims() dtype = tt_a.dtype shape = shapes.lazy_raw_shape(tt_a) a_ranks = shapes.lazy_tt_ranks(tt_a) b_ranks = shapes.lazy_tt_ranks(tt_b) if isinstance(tt_a, TensorTrainBatch) and tt_a.batch_size == 1: # We add 1 element batch tt_a to a batch_size element batch tt_b to get # the answer TensorTrainBatch of batch_size == tt_b.batch_size. batch_size = shapes.lazy_batch_size(tt_b) else: batch_size = shapes.lazy_batch_size(tt_a) tt_a = shapes.expand_batch_dim(tt_a) tt_b = shapes.expand_batch_dim(tt_b) tt_cores = [] for core_idx in range(ndims): a_core = tt_a.tt_cores[core_idx] if tt_a.batch_size == 1: a_core = tf.tile(a_core, (batch_size, 1, 1, 1, 1)) b_core = tt_b.tt_cores[core_idx] if tt_b.batch_size == 1: b_core = tf.tile(b_core, (batch_size, 1, 1, 1, 1)) if core_idx == 0: curr_core = tf.concat((a_core, b_core), axis=4) elif core_idx == ndims - 1: curr_core = tf.concat((a_core, b_core), axis=1) else: upper_zeros = tf.zeros( (batch_size, a_ranks[core_idx], shape[0][core_idx], shape[1][core_idx], b_ranks[core_idx + 1]), dtype) lower_zeros = tf.zeros( (batch_size, b_ranks[core_idx], shape[0][core_idx], shape[1][core_idx], a_ranks[core_idx + 1]), dtype) upper = tf.concat((a_core, upper_zeros), axis=4) lower = tf.concat((lower_zeros, b_core), axis=4) curr_core = tf.concat((upper, lower), axis=1) tt_cores.append(curr_core) return tt_cores, batch_size
def frobenius_norm_squared(tt, differentiable=False, name='t3f_frobenius_norm_squared'): """Frobenius norm squared of `TensorTrain` or of each TT in `TensorTrainBatch`. Frobenius norm squared is the sum of squares of all elements in a tensor. Args: tt: `TensorTrain` or `TensorTrainBatch` object differentiable: bool, whether to use a differentiable implementation or a fast and stable implementation based on QR decomposition. name: string, name of the Op. Returns a number which is the Frobenius norm squared of `tt`, if it is `TensorTrain` OR a Tensor of size tt.batch_size, consisting of the Frobenius norms squared of each TensorTrain in `tt`, if it is `TensorTrainBatch` """ with tf.name_scope(name, values=tt.tt_cores): if differentiable: if hasattr(tt, 'batch_size'): bs_str = 'n' else: bs_str = '' if tt.is_tt_matrix(): running_prod = tf.einsum('{0}aijb,{0}cijd->{0}bd'.format(bs_str), tt.tt_cores[0], tt.tt_cores[0]) else: running_prod = tf.einsum('{0}aib,{0}cid->{0}bd'.format(bs_str), tt.tt_cores[0], tt.tt_cores[0]) for core_idx in range(1, tt.ndims()): curr_core = tt.tt_cores[core_idx] if tt.is_tt_matrix(): running_prod = tf.einsum('{0}ac,{0}aijb,{0}cijd->{0}bd'.format(bs_str), running_prod, curr_core, curr_core) else: running_prod = tf.einsum('{0}ac,{0}aib,{0}cid->{0}bd'.format(bs_str), running_prod, curr_core, curr_core) return tf.squeeze(running_prod, [-1, -2]) else: orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True) # All the cores of orth_tt except the last one are orthogonal, hence # the Frobenius norm of orth_tt equals to the norm of the last core. if hasattr(tt, 'batch_size'): batch_size = shapes.lazy_batch_size(tt) last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1)) return tf.norm(last_core, axis=1) ** 2 else: return tf.norm(orth_tt.tt_cores[-1]) ** 2
def frobenius_norm_squared(tt, differentiable=False): """Frobenius norm squared of a TensorTrain (sum of squares of all elements). Args: tt: `TensorTrain` object differentiable: bool, whether to use a differentiable implementation or a fast and stable implementation based on QR decomposition. Returns a number sum of squares of all elements in `tt` """ if differentiable: if tt.is_tt_matrix(): running_prod = tf.einsum('aijb,cijd->bd', tt.tt_cores[0], tt.tt_cores[0]) else: running_prod = tf.einsum('aib,cid->bd', tt.tt_cores[0], tt.tt_cores[0]) for core_idx in range(1, tt.ndims()): curr_core = tt.tt_cores[core_idx] if tt.is_tt_matrix(): running_prod = tf.einsum('ac,aijb,cijd->bd', running_prod, curr_core, curr_core) else: running_prod = tf.einsum('ac,aib,cid->bd', running_prod, curr_core, curr_core) return running_prod[0, 0] else: orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True) # All the cores of orth_tt except the last one are orthogonal, hence # the Frobenius norm of orth_tt equals to the norm of the last core. if hasattr(tt, 'batch_size'): batch_size = shapes.lazy_batch_size(tt) last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1)) return tf.norm(last_core, axis=1)**2 else: return tf.norm(orth_tt.tt_cores[-1])**2
def _orthogonalize_batch_tt_cores_left_to_right(tt): """Orthogonalize TT-cores of a batch TT-object in the left to right order. Args: tt: TensorTrainBatch. Returns: TensorTrainBatch """ # Left to right orthogonalization. ndims = tt.ndims() raw_shape = shapes.lazy_raw_shape(tt) tt_ranks = shapes.lazy_tt_ranks(tt) next_rank = tt_ranks[0] batch_size = shapes.lazy_batch_size(tt) # Copy cores references so we can change the cores. tt_cores = list(tt.tt_cores) for core_idx in range(ndims - 1): curr_core = tt_cores[core_idx] # TT-ranks could have changed on the previous iteration, so `tt_ranks` can # be outdated for the current TT-rank, but should be valid for the next # TT-rank. curr_rank = next_rank next_rank = tt_ranks[core_idx + 1] if tt.is_tt_matrix(): curr_mode_left = raw_shape[0][core_idx] curr_mode_right = raw_shape[1][core_idx] curr_mode = curr_mode_left * curr_mode_right else: curr_mode = raw_shape[0][core_idx] qr_shape = (batch_size, curr_rank * curr_mode, next_rank) curr_core = tf.reshape(curr_core, qr_shape) curr_core, triang = tf.qr(curr_core) if triang.get_shape().is_fully_defined(): triang_shape = triang.get_shape().as_list() else: triang_shape = tf.shape(triang) # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would # be of size 4 x 4 and r would be 4 x 10, which means that the next rank # should be changed to 4. next_rank = triang_shape[1] if tt.is_tt_matrix(): new_core_shape = (batch_size, curr_rank, curr_mode_left, curr_mode_right, next_rank) else: new_core_shape = (batch_size, curr_rank, curr_mode, next_rank) tt_cores[core_idx] = tf.reshape(curr_core, new_core_shape) next_core = tf.reshape(tt_cores[core_idx + 1], (batch_size, triang_shape[2], -1)) tt_cores[core_idx + 1] = tf.matmul(triang, next_core) if tt.is_tt_matrix(): last_core_shape = (batch_size, next_rank, raw_shape[0][-1], raw_shape[1][-1], 1) else: last_core_shape = (batch_size, next_rank, raw_shape[0][-1], 1) tt_cores[-1] = tf.reshape(tt_cores[-1], last_core_shape) # TODO: infer the tt_ranks. return TensorTrainBatch(tt_cores, tt.get_raw_shape(), batch_size=batch_size)
def _round_batch_tt(tt, max_tt_rank, epsilon): """Internal function that rounds a TensorTrainBatch. See t3f.round for details. """ ndims = tt.ndims() max_tt_rank = np.array(max_tt_rank).astype(np.int32) if max_tt_rank < 1: raise ValueError('Maximum TT-rank should be greater or equal to 1.') if epsilon is not None and epsilon < 0: raise ValueError('Epsilon should be non-negative.') if max_tt_rank.size == 1: max_tt_rank = (max_tt_rank * np.ones(ndims + 1)).astype(np.int32) elif max_tt_rank.size != ndims + 1: raise ValueError( 'max_tt_rank should be a number or a vector of size (d+1) ' 'where d is the number of dimensions (rank) of the tensor.') raw_shape = shapes.lazy_raw_shape(tt) batch_size = shapes.lazy_batch_size(tt) tt_cores = orthogonalize_tt_cores(tt).tt_cores # Copy cores references so we can change the cores. tt_cores = list(tt_cores) ranks = [1] * (ndims + 1) are_tt_ranks_defined = True # Right to left SVD compression. for core_idx in range(ndims - 1, 0, -1): curr_core = tt_cores[core_idx] if tt.is_tt_matrix(): curr_mode_left = raw_shape[0][core_idx] curr_mode_right = raw_shape[1][core_idx] curr_mode = curr_mode_left * curr_mode_right else: curr_mode = raw_shape[0][core_idx] columns = curr_mode * ranks[core_idx + 1] curr_core = tf.reshape(curr_core, (batch_size, -1, columns)) rows = curr_core.get_shape()[1].value if rows is None: rows = tf.shape(curr_core)[1] if max_tt_rank[core_idx] == 1: ranks[core_idx] = 1 else: try: ranks[core_idx] = min(max_tt_rank[core_idx], rows, columns) except TypeError: # Some of the values are undefined on the compilation stage and thus # they are tf.tensors instead of values. min_dim = tf.minimum(rows, columns) ranks[core_idx] = tf.minimum(max_tt_rank[core_idx], min_dim) are_tt_ranks_defined = False s, u, v = tf.svd(curr_core, full_matrices=False) u = u[:, :, 0:ranks[core_idx]] s = s[:, 0:ranks[core_idx]] v = v[:, :, 0:ranks[core_idx]] if tt.is_tt_matrix(): core_shape = (batch_size, ranks[core_idx], curr_mode_left, curr_mode_right, ranks[core_idx + 1]) else: core_shape = (batch_size, ranks[core_idx], curr_mode, ranks[core_idx + 1]) tt_cores[core_idx] = tf.reshape(tf.transpose(v, (0, 2, 1)), core_shape) prev_core_shape = (batch_size, -1, rows) tt_cores[core_idx - 1] = tf.reshape(tt_cores[core_idx - 1], prev_core_shape) tt_cores[core_idx - 1] = tf.matmul(tt_cores[core_idx - 1], u) tt_cores[core_idx - 1] = tf.matmul(tt_cores[core_idx - 1], tf.matrix_diag(s)) if tt.is_tt_matrix(): core_shape = (batch_size, ranks[0], raw_shape[0][0], raw_shape[1][0], ranks[1]) else: core_shape = (batch_size, ranks[0], raw_shape[0][0], ranks[1]) tt_cores[0] = tf.reshape(tt_cores[0], core_shape) if not are_tt_ranks_defined: ranks = None return TensorTrainBatch(tt_cores, tt.get_raw_shape(), ranks, batch_size=tt.batch_size)
def tt_tt_matmul(tt_matrix_a, tt_matrix_b): """Multiplies two TT-matrices and returns the TT-matrix of the result. Args: tt_matrix_a: `TensorTrain` or `TensorTrainBatch` object containing a TT-matrix (a batch of TT-matrices) of size M x N tt_matrix_b: `TensorTrain` or `TensorTrainBatch` object containing a TT-matrix (a batch of TT-matrices) of size N x P Returns `TensorTrain` object containing a TT-matrix of size M x P if both arguments are `TensorTrain`s `TensorTrainBatch` if any of the arguments is a `TensorTrainBatch` Raises: ValueError is the arguments are not TT matrices or if their sizes are not appropriate for a matrix-by-matrix multiplication. """ # Both TensorTrain and TensorTrainBatch are inherited from TensorTrainBase. if not isinstance(tt_matrix_a, TensorTrainBase) or \ not isinstance(tt_matrix_b, TensorTrainBase) or \ not tt_matrix_a.is_tt_matrix() or \ not tt_matrix_b.is_tt_matrix(): raise ValueError('Arguments should be TT-matrices') if not shapes.is_batch_broadcasting_possible(tt_matrix_a, tt_matrix_b): raise ValueError( 'The batch sizes are different and not 1, broadcasting is ' 'not available.') ndims = tt_matrix_a.ndims() if tt_matrix_b.ndims() != ndims: raise ValueError( 'Arguments should have the same number of dimensions, ' 'got %d and %d instead.' % (ndims, tt_matrix_b.ndims())) # Convert BatchSize 1 batch into TT object to simplify broadcasting. tt_matrix_a = shapes.squeeze_batch_dim(tt_matrix_a) tt_matrix_b = shapes.squeeze_batch_dim(tt_matrix_b) is_a_batch = isinstance(tt_matrix_a, TensorTrainBatch) is_b_batch = isinstance(tt_matrix_b, TensorTrainBatch) is_res_batch = is_a_batch or is_b_batch a_batch_str = 'o' if is_a_batch else '' b_batch_str = 'o' if is_b_batch else '' res_batch_str = 'o' if is_res_batch else '' einsum_str = '{}aijb,{}cjkd->{}acikbd'.format(a_batch_str, b_batch_str, res_batch_str) result_cores = [] # TODO: name the operation and the resulting tensor. a_shape = shapes.lazy_raw_shape(tt_matrix_a) a_ranks = shapes.lazy_tt_ranks(tt_matrix_a) b_shape = shapes.lazy_raw_shape(tt_matrix_b) b_ranks = shapes.lazy_tt_ranks(tt_matrix_b) if is_res_batch: if is_a_batch: batch_size = shapes.lazy_batch_size(tt_matrix_a) if is_b_batch: batch_size = shapes.lazy_batch_size(tt_matrix_b) for core_idx in range(ndims): a_core = tt_matrix_a.tt_cores[core_idx] b_core = tt_matrix_b.tt_cores[core_idx] curr_res_core = tf.einsum(einsum_str, a_core, b_core) res_left_rank = a_ranks[core_idx] * b_ranks[core_idx] res_right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1] left_mode = a_shape[0][core_idx] right_mode = b_shape[1][core_idx] if is_res_batch: core_shape = (batch_size, res_left_rank, left_mode, right_mode, res_right_rank) else: core_shape = (res_left_rank, left_mode, right_mode, res_right_rank) curr_res_core = tf.reshape(curr_res_core, core_shape) result_cores.append(curr_res_core) res_shape = (tt_matrix_a.get_raw_shape()[0], tt_matrix_b.get_raw_shape()[1]) static_a_ranks = tt_matrix_a.get_tt_ranks() static_b_ranks = tt_matrix_b.get_tt_ranks() out_ranks = [a_r * b_r for a_r, b_r in zip(static_a_ranks, static_b_ranks)] if is_res_batch: return TensorTrainBatch(result_cores, res_shape, out_ranks, batch_size) else: return TensorTrain(result_cores, res_shape, out_ranks)
def project_sum(what, where, weights=None): """Project sum of `what` TTs on the tangent space of `where` TT. project_sum(what, x) = P_x(what) project_sum(batch_what, x) = P_x(\sum_i batch_what[i]) project_sum(batch_what, x, weights) = P_x(\sum_j weights[j] * batch_what[j]) This function implements the algorithm from the paper [1], theorem 3.1. [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of Tensor Trains. Args: what: TensorTrain or TensorTrainBatch. In the case of batch returns projection of the sum of elements in the batch. where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project weights: python list or tf.Tensor of numbers or None, weights of the sum Returns: a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() Complexity: O(d r_where^3 m) for orthogonalizing the TT-cores of where +O(batch_size d r_what r_where n (r_what + r_where)) d is the number of TT-cores (what.ndims()); r_what is the largest TT-rank of what max(what.get_tt_rank()) r_where is the largest TT-rank of where n is the size of the axis dimension of what and where e.g. for a tensor of size 4 x 4 x 4, n is 4; for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12 """ # Always work with batch of TT objects for simplicity. what = shapes.expand_batch_dim(what) if weights is not None: weights = tf.convert_to_tensor(weights, dtype=where.dtype) if not isinstance(where, TensorTrain): raise ValueError( 'The first argument should be a TensorTrain object, got ' '"%s".' % where) if where.get_raw_shape() != what.get_raw_shape(): raise ValueError( 'The shapes of the tensor we want to project and of the ' 'tensor on which tangent space we want to project should ' 'match, got %s and %s.' % (where.get_raw_shape(), what.get_raw_shape())) dtypes_compatible = (where.dtype.is_compatible_with(what.dtype) or what.dtype.is_compatible_with(where.dtype)) if not dtypes_compatible: raise ValueError( 'Dtypes of the arguments should coincide, got %s and %s.' % (where.dtype, what.dtype)) left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where) right_tangent_space_tens = decompositions.orthogonalize_tt_cores( left_tangent_space_tens, left_to_right=False) ndims = where.ndims() dtype = where.dtype raw_shape = shapes.lazy_raw_shape(where) batch_size = shapes.lazy_batch_size(what) right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens) left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens) # For einsum notation. mode_str = 'ij' if where.is_tt_matrix() else 'i' right_rank_dim = where.right_tt_rank_dim left_rank_dim = where.left_tt_rank_dim if weights is not None: weights_shape = weights.get_shape() output_is_batch = len(weights_shape) > 1 and weights_shape[1] > 1 else: output_is_batch = False output_batch_str = 'o' if output_is_batch else '' if output_is_batch: right_rank_dim += 1 left_rank_dim += 1 output_batch_size = weights.get_shape().as_list()[1] # Prepare rhs vectors. # rhs[core_idx] is of size # batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx] rhs = [None] * (ndims + 1) rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1, 0, -1): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str) rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1], right_tang_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] lhs = [None] * (ndims + 1) lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str) lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core, tens_core) # Left to right sweep. res_cores_list = [] for core_idx in range(ndims): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] if core_idx < ndims - 1: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str) proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) if weights is None: einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) else: einsum_str = 'sa{0}b,sbc->sa{0}c'.format( mode_str, output_batch_str) proj_core_s = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) einsum_str = 's{1},sa{0}c->{1}a{0}c'.format( mode_str, output_batch_str) proj_core = tf.einsum(einsum_str, weights, proj_core_s) if core_idx == ndims - 1: if weights is None: einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) else: einsum_str = 'sab,sb{0}c->sa{0}c'.format( mode_str, output_batch_str) proj_core_s = tf.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 's{1},sa{0}c->{1}a{0}c'.format( mode_str, output_batch_str) proj_core = tf.einsum(einsum_str, weights, proj_core_s) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and # right_tang_core extended_left_tang_core = tf.expand_dims(left_tang_core, 0) extended_right_tang_core = tf.expand_dims(right_tang_core, 0) if where.is_tt_matrix(): extended_left_tang_core = tf.tile( extended_left_tang_core, [output_batch_size, 1, 1, 1, 1]) extended_right_tang_core = tf.tile( extended_right_tang_core, [output_batch_size, 1, 1, 1, 1]) else: extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1]) extended_right_tang_core = tf.tile( extended_right_tang_core, [output_batch_size, 1, 1, 1]) else: extended_left_tang_core = left_tang_core extended_right_tang_core = right_tang_core if core_idx == 0: res_core = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) elif core_idx == ndims - 1: res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim) else: rank_1 = right_tangent_tt_ranks[core_idx] rank_2 = left_tangent_tt_ranks[core_idx + 1] if where.is_tt_matrix(): mode_size_n = raw_shape[0][core_idx] mode_size_m = raw_shape[1][core_idx] shape = [rank_1, mode_size_n, mode_size_m, rank_2] else: mode_size = raw_shape[0][core_idx] shape = [rank_1, mode_size, rank_2] if output_is_batch: shape = [output_batch_size] + shape zeros = tf.zeros(shape, dtype) upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim) lower = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) res_core = tf.concat((upper, lower), axis=left_rank_dim) res_cores_list.append(res_core) # TODO: TT-ranks. if output_is_batch: res = TensorTrainBatch(res_cores_list, where.get_raw_shape(), batch_size=output_batch_size) else: res = TensorTrain(res_cores_list, where.get_raw_shape()) res.projection_on = where return res
def project_matmul(what, where, matrix): """Project `matrix` * `what` TTs on the tangent space of `where` TT. project(what, x) = P_x(what) project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N])) This function implements the algorithm from the paper [1], theorem 3.1. [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of Tensor Trains. Args: what: TensorTrain or TensorTrainBatch. In the case of batch returns batch with projection of each individual tensor. where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project matrix: TensorTrain, TT-matrix to multiply by what Returns: a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() Complexity: O(d r_where^3 m) for orthogonalizing the TT-cores of where +O(batch_size d R r_what r_where (n r_what + n m R + m r_where)) d is the number of TT-cores (what.ndims()); r_what is the largest TT-rank of what max(what.get_tt_rank()) r_where is the largest TT-rank of where matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n). """ if not isinstance(where, TensorTrain): raise ValueError( 'The first argument should be a TensorTrain object, got ' '"%s".' % where) if where.get_raw_shape() != what.get_raw_shape(): raise ValueError( 'The shapes of the tensor we want to project and of the ' 'tensor on which tangent space we want to project should ' 'match, got %s and %s.' % (where.get_raw_shape(), what.get_raw_shape())) dtypes_compatible = (where.dtype.is_compatible_with(what.dtype) or what.dtype.is_compatible_with(where.dtype)) if not dtypes_compatible: raise ValueError( 'Dtypes of the arguments should coincide, got %s and %s.' % (where.dtype, what.dtype)) left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where) right_tangent_space_tens = decompositions.orthogonalize_tt_cores( left_tangent_space_tens, left_to_right=False) ndims = where.ndims() dtype = where.dtype raw_shape = shapes.lazy_raw_shape(where) batch_size = shapes.lazy_batch_size(what) right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens) left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens) # For einsum notation. right_rank_dim = what.right_tt_rank_dim left_rank_dim = what.left_tt_rank_dim output_is_batch = isinstance(what, TensorTrainBatch) if output_is_batch: output_batch_size = what.batch_size # Always work with batch of TT objects for simplicity. what = shapes.expand_batch_dim(what) # Prepare rhs vectors. # rhs[core_idx] is of size # batch_size x tensor_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tangent_tt_ranks[core_idx] rhs = [None] * (ndims + 1) rhs[ndims] = tf.ones((batch_size, 1, 1, 1), dtype=dtype) for core_idx in range(ndims - 1, 0, -1): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core, right_tang_core, rhs[core_idx + 1], tens_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] lhs = [None] * (ndims + 1) lhs[0] = tf.ones((batch_size, 1, 1, 1), dtype=dtype) for core_idx in range(ndims - 1): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] # TODO: brutforce order of indices in lhs?? lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef', matrix_core, left_tang_core, lhs[core_idx], tens_core) # Left to right sweep. res_cores_list = [] for core_idx in range(ndims): tens_core = what.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] if core_idx < ndims - 1: proj_core = tf.einsum('scjke,sabc,bijd->saikde', tens_core, lhs[core_idx], matrix_core) proj_core -= tf.einsum('aikb,sbcd->saikcd', left_tang_core, lhs[core_idx + 1]) proj_core = tf.einsum('saikcb,sbcd->saikd', proj_core, rhs[core_idx + 1]) if core_idx == ndims - 1: # d and e dimensions take 1 value, since its the last rank. # To make the result shape (?, ?, ?, 1), we are summing d and leaving e, # but we could have done the opposite -- sum e and leave d. proj_core = tf.einsum('sabc,bijd,scjke->saike', lhs[core_idx], matrix_core, tens_core) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and # right_tang_core extended_left_tang_core = tf.expand_dims(left_tang_core, 0) extended_right_tang_core = tf.expand_dims(right_tang_core, 0) extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1, 1]) extended_right_tang_core = tf.tile(extended_right_tang_core, [output_batch_size, 1, 1, 1, 1]) else: extended_left_tang_core = left_tang_core extended_right_tang_core = right_tang_core if core_idx == 0: res_core = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) elif core_idx == ndims - 1: res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim) else: rank_1 = right_tangent_tt_ranks[core_idx] rank_2 = left_tangent_tt_ranks[core_idx + 1] mode_size_n = raw_shape[0][core_idx] mode_size_m = raw_shape[1][core_idx] shape = [rank_1, mode_size_n, mode_size_m, rank_2] if output_is_batch: shape = [output_batch_size] + shape zeros = tf.zeros(shape, dtype) upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim) lower = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) res_core = tf.concat((upper, lower), axis=left_rank_dim) res_cores_list.append(res_core) # TODO: TT-ranks. if output_is_batch: res = TensorTrainBatch(res_cores_list, where.get_raw_shape(), batch_size=output_batch_size) else: res = TensorTrain(res_cores_list, where.get_raw_shape()) res.projection_on = where return res
def project(what, where): """Project `what` TTs on the tangent space of `where` TT. project(what, x) = P_x(what) project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N])) This function implements the algorithm from the paper [1], theorem 3.1. [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of Tensor Trains. Args: what: TensorTrain or TensorTrainBatch. In the case of batch returns batch with projection of each individual tensor. where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project Returns: a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() """ if not isinstance(where, TensorTrain): raise ValueError('The first argument should be a TensorTrain object, got ' '"%s".' % where) if where.get_raw_shape() != what.get_raw_shape(): raise ValueError('The shapes of the tensor we want to project and of the ' 'tensor on which tangent space we want to project should ' 'match, got %s and %s.' % (where.get_raw_shape(), what.get_raw_shape())) if not where.dtype.is_compatible_with(what.dtype): raise ValueError('Dtypes of the arguments should coincide, got %s and %s.' % (where.dtype, what.dtype)) left_tangent_space_tens = decompositions.orthogonalize_tt_cores( where) right_tangent_space_tens = decompositions.orthogonalize_tt_cores( left_tangent_space_tens, left_to_right=False) ndims = where.ndims() dtype = where.dtype raw_shape = shapes.lazy_raw_shape(where) batch_size = shapes.lazy_batch_size(what) right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens) left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens) # For einsum notation. mode_str = 'ij' if where.is_tt_matrix() else 'i' right_rank_dim = 3 if where.is_tt_matrix() else 2 left_rank_dim = 0 output_is_batch = isinstance(what, TensorTrainBatch) if output_is_batch: right_rank_dim += 1 left_rank_dim = 1 output_batch_size = what.batch_size # Always work with batch of TT objects for simplicity. what = shapes.expand_batch_dim(what) # Prepare rhs vectors. # rhs[core_idx] is of size # batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx] rhs = [None] * (ndims + 1) rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1, 0, -1): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str) rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1], right_tang_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] lhs = [None] * (ndims + 1) lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str) lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core, tens_core) # Left to right sweep. res_cores_list = [] for core_idx in range(ndims): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] if core_idx < ndims - 1: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str) proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) if output_is_batch: einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str) else: einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) if core_idx == ndims - 1: if output_is_batch: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) else: einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and # right_tang_core extended_left_tang_core = tf.expand_dims(left_tang_core, 0) extended_right_tang_core = tf.expand_dims(right_tang_core, 0) if where.is_tt_matrix(): extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1, 1]) extended_right_tang_core = tf.tile(extended_right_tang_core, [output_batch_size, 1, 1, 1, 1]) else: extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1]) extended_right_tang_core = tf.tile(extended_right_tang_core, [output_batch_size, 1, 1, 1]) else: extended_left_tang_core = left_tang_core extended_right_tang_core = right_tang_core if core_idx == 0: res_core = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) elif core_idx == ndims - 1: res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim) else: rank_1 = right_tangent_tt_ranks[core_idx] rank_2 = left_tangent_tt_ranks[core_idx + 1] if where.is_tt_matrix(): mode_size_n = raw_shape[0][core_idx] mode_size_m = raw_shape[1][core_idx] shape = [rank_1, mode_size_n, mode_size_m, rank_2] else: mode_size = raw_shape[0][core_idx] shape = [rank_1, mode_size, rank_2] if output_is_batch: shape = [output_batch_size] + shape zeros = tf.zeros(shape, dtype) upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim) lower = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) res_core = tf.concat((upper, lower), axis=left_rank_dim) res_cores_list.append(res_core) # TODO: TT-ranks. if output_is_batch: res = TensorTrainBatch(res_cores_list, where.get_raw_shape(), batch_size=output_batch_size) else: res = TensorTrain(res_cores_list, where.get_raw_shape()) res.projection_on = where return res
def multiply(tt_left, right): """Returns a TensorTrain corresponding to element-wise product tt_left * right. Supports broadcasting: multiply(TensorTrainBatch, TensorTrain) returns TensorTrainBatch consisting of element-wise products of TT in TensorTrainBatch and TensorTrain multiply(TensorTrainBatch_a, TensorTrainBatch_b) returns TensorTrainBatch consisting of element-wise products of TT in TensorTrainBatch_a and TT in TensorTrainBatch_b Batch sizes should support broadcasting Args: tt_left: `TensorTrain` OR `TensorTrainBatch` right: `TensorTrain` OR `TensorTrainBatch` OR a number. Returns a `TensorTrain` or `TensorTrainBatch` object corresponding to the element-wise product of the arguments. Raises ValueError if the arguments shapes do not coincide or broadcasting is not possible. """ is_left_batch = isinstance(tt_left, TensorTrainBatch) is_right_batch = isinstance(right, TensorTrainBatch) is_batch_case = is_left_batch or is_right_batch ndims = tt_left.ndims() if not isinstance(right, TensorTrainBase): # Assume right is a number, not TensorTrain. # To squash right uniformly across TT-cores we pull its absolute value # and raise to the power 1/ndims. First TT-core is multiplied by the sign # of right. tt_cores = list(tt_left.tt_cores) fact = tf.pow(tf.cast(tf.abs(right), tt_left.dtype), 1.0 / ndims) sign = tf.cast(tf.sign(right), tt_left.dtype) for i in range(len(tt_cores)): tt_cores[i] = fact * tt_cores[i] tt_cores[0] = tt_cores[0] * sign out_ranks = tt_left.get_tt_ranks() if is_left_batch: out_batch_size = tt_left.batch_size else: if tt_left.is_tt_matrix() != right.is_tt_matrix(): raise ValueError('The arguments should be both TT-tensors or both ' 'TT-matrices') if tt_left.get_raw_shape() != right.get_raw_shape(): raise ValueError('The arguments should have the same shape.') out_batch_size = 1 dependencies = [] can_determine_if_broadcast = True if is_left_batch and is_right_batch: if tt_left.batch_size is None and right.batch_size is None: can_determine_if_broadcast = False elif tt_left.batch_size is None and right.batch_size is not None: if right.batch_size > 1: can_determine_if_broadcast = False elif tt_left.batch_size is not None and right.batch_size is None: if tt_left.batch_size > 1: can_determine_if_broadcast = False if not can_determine_if_broadcast: # Cannot determine if broadcasting is needed. Avoid broadcasting and # assume elementwise multiplication AND add execution time assert to print # a better error message if the batch sizes turn out to be different. message = ( 'The batch sizes were unknown on compilation stage, so ' 'assumed elementwise multiplication (i.e. no broadcasting). ' 'Now it seems that they are different after all :') data = [ message, shapes.lazy_batch_size(tt_left), ' x ', shapes.lazy_batch_size(right) ] bs_eq = tf.assert_equal(shapes.lazy_batch_size(tt_left), shapes.lazy_batch_size(right), data=data) dependencies.append(bs_eq) do_broadcast = shapes.is_batch_broadcasting_possible(tt_left, right) if not can_determine_if_broadcast: # Assume elementwise multiplication if broadcasting cannot be determined # on compilation stage. do_broadcast = False if not do_broadcast and can_determine_if_broadcast: raise ValueError( 'The batch sizes are different and not 1, broadcasting ' 'is not available.') a_ranks = shapes.lazy_tt_ranks(tt_left) b_ranks = shapes.lazy_tt_ranks(right) shape = shapes.lazy_raw_shape(tt_left) output_str = '' bs_str_left = '' bs_str_right = '' if is_batch_case: if is_left_batch and is_right_batch: # Both arguments are batches of equal size. if tt_left.batch_size == right.batch_size or not can_determine_if_broadcast: bs_str_left = 'n' bs_str_right = 'n' output_str = 'n' if not can_determine_if_broadcast: out_batch_size = None else: out_batch_size = tt_left.batch_size else: # Broadcasting (e.g batch_sizes are 1 and n>1). bs_str_left = 'n' bs_str_right = 'm' output_str = 'nm' if tt_left.batch_size is None or tt_left.batch_size > 1: out_batch_size = tt_left.batch_size else: out_batch_size = right.batch_size else: # One of the arguments is TensorTrain. if is_left_batch: bs_str_left = 'n' bs_str_right = '' out_batch_size = tt_left.batch_size else: bs_str_left = '' bs_str_right = 'n' out_batch_size = right.batch_size output_str = 'n' is_matrix = tt_left.is_tt_matrix() tt_cores = [] for core_idx in range(ndims): a_core = tt_left.tt_cores[core_idx] b_core = right.tt_cores[core_idx] left_rank = a_ranks[core_idx] * b_ranks[core_idx] right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1] if is_matrix: with tf.control_dependencies(dependencies): curr_core = tf.einsum( '{0}aijb,{1}cijd->{2}acijbd'.format( bs_str_left, bs_str_right, output_str), a_core, b_core) curr_core = tf.reshape(curr_core, (-1, left_rank, shape[0][core_idx], shape[1][core_idx], right_rank)) if not is_batch_case: curr_core = tf.squeeze(curr_core, axis=0) else: with tf.control_dependencies(dependencies): curr_core = tf.einsum( '{0}aib,{1}cid->{2}acibd'.format( bs_str_left, bs_str_right, output_str), a_core, b_core) curr_core = tf.reshape( curr_core, (-1, left_rank, shape[0][core_idx], right_rank)) if not is_batch_case: curr_core = tf.squeeze(curr_core, axis=0) tt_cores.append(curr_core) combined_ranks = zip(tt_left.get_tt_ranks(), right.get_tt_ranks()) out_ranks = [a * b for a, b in combined_ranks] if not is_batch_case: return TensorTrain(tt_cores, tt_left.get_raw_shape(), out_ranks) else: return TensorTrainBatch(tt_cores, tt_left.get_raw_shape(), out_ranks, batch_size=out_batch_size)