def gradients(func, x, name='t3f_gradients', runtime_check=True): """Riemannian autodiff: returns gradient projected on tangent space of TT. Computes projection of the gradient df/dx onto the tangent space of TT tensor at point x. Warning: this is experimental feature and it may not work for some function, e.g. ones that include QR or SVD decomposition (t3f.project, t3f.round) or for functions that work with TT-cores directly (in contrast to working with TT-object only via t3f functions). In this cases this function can silently return wrong results! Example: # Scalar product with some predefined tensor squared 0.5 * <x, t>**2. # It's gradient is <x, t> t and it's Riemannian gradient is # t3f.project(<x, t> * t, x) f = lambda x: 0.5 * t3f.flat_inner(x, t)**2 projected_grad = t3f.gradients(f, x) # t3f.project(t3f.flat_inner(x, t) * t, x) Args: func: function that takes TensorTrain object as input and outputs a number. x: point at which to compute the gradient and on which tangent space to project the gradient. name: string, name of the Op. runtime_check: [True] whether to do a sanity check that the passed function is invariant to different TT representations (otherwise the Rieamnnian gradient doesn't even exist). It makes things slower, but helps catching bugs, so turn it off during production deployment. Returns: `TensorTrain`, projection of the gradient df/dx onto the tangent space at point x. See also: t3f.hessian_vector_product """ with tf.name_scope(name): left = decompositions.orthogonalize_tt_cores(x) right = decompositions.orthogonalize_tt_cores(left, left_to_right=False) deltas = [right.tt_cores[0]] deltas += [tf.zeros_like(cc) for cc in right.tt_cores[1:]] def augmented_func(d): x_projection = riemannian.deltas_to_tangent_space( d, x, left, right) return func(x_projection) function_value, cores_grad = value_and_grad(augmented_func, deltas) if runtime_check: assert_op = _is_invariant_to_input_transforms( function_value, func(x)) else: assert_op = tf.no_op() with tf.control_dependencies([assert_op]): deltas = _enforce_gauge_conditions(cores_grad, left) return riemannian.deltas_to_tangent_space(deltas, x, left, right)
def testOrthogonalizeLeftToRight(self): shape = (2, 4, 3, 3) tt_ranks = (1, 5, 2, 17, 1) updated_tt_ranks = (1, 2, 2, 6, 1) tens = initializers.random_tensor_batch(shape, tt_rank=tt_ranks, batch_size=2) orthogonal = decompositions.orthogonalize_tt_cores(tens) with self.test_session() as sess: tens_val, orthogonal_val = sess.run( [ops.full(tens), ops.full(orthogonal)]) self.assertAllClose(tens_val, orthogonal_val, atol=1e-5, rtol=1e-5) dynamic_tt_ranks = shapes.tt_ranks(orthogonal).eval() self.assertAllEqual(updated_tt_ranks, dynamic_tt_ranks) # Check that the TT-cores are orthogonal. for core_idx in range(4 - 1): core_shape = (updated_tt_ranks[core_idx] * shape[core_idx], updated_tt_ranks[core_idx + 1]) for i in range(2): core = tf.reshape(orthogonal.tt_cores[core_idx][i], core_shape) should_be_eye = tf.matmul(tf.transpose(core), core) should_be_eye_val = sess.run(should_be_eye) self.assertAllClose(np.eye(updated_tt_ranks[core_idx + 1]), should_be_eye_val)
def testOrthogonalizeRightToLeft(self): shape = (2, 4, 3, 3) tt_ranks = (1, 5, 2, 17, 1) updated_tt_ranks = (1, 5, 2, 3, 1) tens = initializers.random_tensor(shape, tt_rank=tt_ranks, dtype=self.dtype) orthogonal = decompositions.orthogonalize_tt_cores(tens, left_to_right=False) with self.test_session() as sess: tens_val, orthogonal_val = sess.run( [ops.full(tens), ops.full(orthogonal)]) self.assertAllClose(tens_val, orthogonal_val, atol=1e-5, rtol=1e-5) dynamic_tt_ranks = shapes.tt_ranks(orthogonal).eval() self.assertAllEqual(updated_tt_ranks, dynamic_tt_ranks) # Check that the TT-cores are orthogonal. for core_idx in range(1, 4): core = orthogonal.tt_cores[core_idx] core = tf.reshape( core, (updated_tt_ranks[core_idx], shape[core_idx] * updated_tt_ranks[core_idx + 1])) should_be_eye = tf.matmul(core, tf.transpose(core)) should_be_eye_val = sess.run(should_be_eye) self.assertAllClose(np.eye(updated_tt_ranks[core_idx]), should_be_eye_val)
def frobenius_norm_squared(tt, differentiable=False, name='t3f_frobenius_norm_squared'): """Frobenius norm squared of `TensorTrain` or of each TT in `TensorTrainBatch`. Frobenius norm squared is the sum of squares of all elements in a tensor. Args: tt: `TensorTrain` or `TensorTrainBatch` object differentiable: bool, whether to use a differentiable implementation or a fast and stable implementation based on QR decomposition. name: string, name of the Op. Returns a number which is the Frobenius norm squared of `tt`, if it is `TensorTrain` OR a Tensor of size tt.batch_size, consisting of the Frobenius norms squared of each TensorTrain in `tt`, if it is `TensorTrainBatch` """ with tf.name_scope(name, values=tt.tt_cores): if differentiable: if hasattr(tt, 'batch_size'): bs_str = 'n' else: bs_str = '' if tt.is_tt_matrix(): running_prod = tf.einsum('{0}aijb,{0}cijd->{0}bd'.format(bs_str), tt.tt_cores[0], tt.tt_cores[0]) else: running_prod = tf.einsum('{0}aib,{0}cid->{0}bd'.format(bs_str), tt.tt_cores[0], tt.tt_cores[0]) for core_idx in range(1, tt.ndims()): curr_core = tt.tt_cores[core_idx] if tt.is_tt_matrix(): running_prod = tf.einsum('{0}ac,{0}aijb,{0}cijd->{0}bd'.format(bs_str), running_prod, curr_core, curr_core) else: running_prod = tf.einsum('{0}ac,{0}aib,{0}cid->{0}bd'.format(bs_str), running_prod, curr_core, curr_core) return tf.squeeze(running_prod, [-1, -2]) else: orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True) # All the cores of orth_tt except the last one are orthogonal, hence # the Frobenius norm of orth_tt equals to the norm of the last core. if hasattr(tt, 'batch_size'): batch_size = shapes.lazy_batch_size(tt) last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1)) return tf.norm(last_core, axis=1) ** 2 else: return tf.norm(orth_tt.tt_cores[-1]) ** 2
def testOrthogonalizeLeftToRight(self): shape = (2, 4, 3, 3) tt_ranks = (1, 5, 2, 17, 1) updated_tt_ranks = (1, 2, 2, 6, 1) tens = initializers.random_tensor(shape, tt_rank=tt_ranks, dtype=self.dtype) orthogonal = decompositions.orthogonalize_tt_cores(tens) tens_val, orthogonal_val = self.evaluate([ops.full(tens), ops.full(orthogonal)]) self.assertAllClose(tens_val, orthogonal_val, atol=1e-5, rtol=1e-5) dynamic_tt_ranks = self.evaluate(shapes.tt_ranks(orthogonal)) self.assertAllEqual(updated_tt_ranks, dynamic_tt_ranks) # Check that the TT-cores are orthogonal. for core_idx in range(4 - 1): core = orthogonal.tt_cores[core_idx] core = tf.reshape(core, (updated_tt_ranks[core_idx] * shape[core_idx], updated_tt_ranks[core_idx + 1])) should_be_eye = tf.matmul(tf.transpose(core), core) should_be_eye_val = self.evaluate(should_be_eye) self.assertAllClose(np.eye(updated_tt_ranks[core_idx + 1]), should_be_eye_val)
def frobenius_norm_squared(tt, differentiable=False): """Frobenius norm squared of a TensorTrain (sum of squares of all elements). Args: tt: `TensorTrain` object differentiable: bool, whether to use a differentiable implementation or a fast and stable implementation based on QR decomposition. Returns a number sum of squares of all elements in `tt` """ if differentiable: if tt.is_tt_matrix(): running_prod = tf.einsum('aijb,cijd->bd', tt.tt_cores[0], tt.tt_cores[0]) else: running_prod = tf.einsum('aib,cid->bd', tt.tt_cores[0], tt.tt_cores[0]) for core_idx in range(1, tt.ndims()): curr_core = tt.tt_cores[core_idx] if tt.is_tt_matrix(): running_prod = tf.einsum('ac,aijb,cijd->bd', running_prod, curr_core, curr_core) else: running_prod = tf.einsum('ac,aib,cid->bd', running_prod, curr_core, curr_core) return running_prod[0, 0] else: orth_tt = decompositions.orthogonalize_tt_cores(tt, left_to_right=True) # All the cores of orth_tt except the last one are orthogonal, hence # the Frobenius norm of orth_tt equals to the norm of the last core. if hasattr(tt, 'batch_size'): batch_size = shapes.lazy_batch_size(tt) last_core = tf.reshape(orth_tt.tt_cores[-1], (batch_size, -1)) return tf.norm(last_core, axis=1)**2 else: return tf.norm(orth_tt.tt_cores[-1])**2
def hessian_vector_product(func, x, vector, name='t3f_hessian_vector_product', runtime_check=True): """P_x [d^2f/dx^2] P_x vector, i.e. Riemannian hessian by vector product. Computes P_x [d^2f/dx^2] P_x vector where P_x is projection onto the tangent space of TT at point x and d^2f/dx^2 is the Hessian of the function. Note that the true Riemannian hessian also includes the manifold curvature term which is ignored here. Warning: this is experimental feature and it may not work for some function, e.g. ones that include QR or SVD decomposition (t3f.project, t3f.round) or for functions that work with TT-cores directly (in contrast to working with TT-object only via t3f functions). In this cases this function can silently return wrong results! Example: # Quadratic form with matrix A: <x, A x>. # It's gradient is (A + A.T) x, it's Hessian is (A + A.T) # It's Riemannian Hessian by vector product is # proj_vec = t3f.project(vector, x) # t3f.project(t3f.matmul(A + t3f.transpose(A), proj_vec), x) f = lambda x: t3f.bilinear_form(A, x, x) res = t3f.hessian_vector_product(f, x, vector) Args: func: function that takes TensorTrain object as input and outputs a number. x: point at which to compute the Hessian and on which tangent space to project the gradient. vector: `TensorTrain` object which to multiply be the Hessian. name: string, name of the Op. runtime_check: [True] whether to do a sanity check that the passed function is invariant to different TT representations (otherwise the Rieamnnian gradient doesn't even exist). It makes things slower, but helps catching bugs, so turn it off during production deployment. Returns: `TensorTrain`, result of the Riemannian hessian by vector product. See also: t3f.gradients """ all_cores = list(x.tt_cores) + list(vector.tt_cores) with tf.name_scope(name, values=all_cores): left = decompositions.orthogonalize_tt_cores(x) right = decompositions.orthogonalize_tt_cores(left, left_to_right=False) deltas = [right.tt_cores[0]] deltas += [tf.zeros_like(cc) for cc in right.tt_cores[1:]] x_projection = riemannian.deltas_to_tangent_space( deltas, x, left, right) function_value = func(x_projection) if runtime_check: assert_op = _is_invariant_to_input_transforms( function_value, func(x)) else: assert_op = tf.no_op() with tf.control_dependencies([assert_op]): vector_projected = riemannian.project(vector, x) cores_grad = tf.gradients(function_value, deltas) vec_deltas = riemannian.tangent_space_to_deltas(vector_projected) products = [ tf.reduce_sum(a * b) for a, b in zip(cores_grad, vec_deltas) ] grad_times_vec = tf.add_n(products) second_cores_grad = tf.gradients(grad_times_vec, deltas) final_deltas = _enforce_gauge_conditions(second_cores_grad, left) return riemannian.deltas_to_tangent_space(final_deltas, x, left, right)
def deltas_to_tangent_space(deltas, tt, left=None, right=None, name='t3f_deltas_to_tangent_space'): """Converts deltas representation of tangent space vector to TT object. Takes as input a list of [dP1, ..., dPd] and returns dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd. This function is hard to use correctly because deltas should abey the so called gauge conditions. If the don't, the function will silently return incorrect result. This is why this function is not imported in __init__. Args: deltas: a list of deltas (essentially TT-cores) obeying the gauge conditions. tt: `TensorTrain` object on which the tangent space tensor represented by delta is projected. left: t3f.orthogonilize_tt_cores(tt). If you have it already compute, you may pass it as argument to avoid recomputing. right: t3f.orthogonilize_tt_cores(left, left_to_right=False). If you have it already compute, you may pass it as argument to avoid recomputing. name: string, name of the Op. Returns: `TensorTrain` object constructed from deltas, that is from the tangent space at point `tt`. """ cores = [] dtype = tt.dtype num_dims = tt.ndims() # TODO: add cache instead of mannually pasisng precomputed stuff? input_tensors = list(tt.tt_cores) + list(deltas) if left is not None: input_tensors += list(left.tt_cores) if right is not None: input_tensors += list(right.tt_cores) with tf.name_scope(name): if left is None: left = decompositions.orthogonalize_tt_cores(tt) if right is None: right = decompositions.orthogonalize_tt_cores(left, left_to_right=False) left_tangent_tt_ranks = shapes.lazy_tt_ranks(left) right_tangent_tt_ranks = shapes.lazy_tt_ranks(left) raw_shape = shapes.lazy_raw_shape(left) right_rank_dim = left.right_tt_rank_dim left_rank_dim = left.left_tt_rank_dim is_batch_case = len(deltas[0].shape) > len(tt.tt_cores[0].shape) if is_batch_case: right_rank_dim += 1 left_rank_dim += 1 batch_size = deltas[0].shape.as_list()[0] for i in range(num_dims): left_tt_core = left.tt_cores[i] right_tt_core = right.tt_cores[i] if is_batch_case: tile = [1] * len(left_tt_core.shape) tile = [batch_size] + tile left_tt_core = tf.tile(left_tt_core[None, ...], tile) right_tt_core = tf.tile(right_tt_core[None, ...], tile) if i == 0: tangent_core = tf.concat((deltas[i], left_tt_core), axis=right_rank_dim) elif i == num_dims - 1: tangent_core = tf.concat((right_tt_core, deltas[i]), axis=left_rank_dim) else: rank_1 = right_tangent_tt_ranks[i] rank_2 = left_tangent_tt_ranks[i + 1] if tt.is_tt_matrix(): mode_size_n = raw_shape[0][i] mode_size_m = raw_shape[1][i] shape = [rank_1, mode_size_n, mode_size_m, rank_2] else: mode_size_n = raw_shape[0][i] shape = [rank_1, mode_size_n, rank_2] if is_batch_case: shape = [batch_size] + shape zeros = tf.zeros(shape, dtype=dtype) upper = tf.concat((right_tt_core, zeros), axis=right_rank_dim) lower = tf.concat((deltas[i], left_tt_core), axis=right_rank_dim) tangent_core = tf.concat((upper, lower), axis=left_rank_dim) cores.append(tangent_core) if is_batch_case: tangent = TensorTrainBatch(cores, batch_size=batch_size) else: tangent = TensorTrain(cores) tangent.projection_on = tt return tangent
def project_sum(what, where, weights=None): """Project sum of `what` TTs on the tangent space of `where` TT. project_sum(what, x) = P_x(what) project_sum(batch_what, x) = P_x(\sum_i batch_what[i]) project_sum(batch_what, x, weights) = P_x(\sum_j weights[j] * batch_what[j]) This function implements the algorithm from the paper [1], theorem 3.1. [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of Tensor Trains. Args: what: TensorTrain or TensorTrainBatch. In the case of batch returns projection of the sum of elements in the batch. where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project weights: python list or tf.Tensor of numbers or None, weights of the sum Returns: a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() Complexity: O(d r_where^3 m) for orthogonalizing the TT-cores of where +O(batch_size d r_what r_where n (r_what + r_where)) d is the number of TT-cores (what.ndims()); r_what is the largest TT-rank of what max(what.get_tt_rank()) r_where is the largest TT-rank of where n is the size of the axis dimension of what and where e.g. for a tensor of size 4 x 4 x 4, n is 4; for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12 """ # Always work with batch of TT objects for simplicity. what = shapes.expand_batch_dim(what) if weights is not None: weights = tf.convert_to_tensor(weights, dtype=where.dtype) if not isinstance(where, TensorTrain): raise ValueError( 'The first argument should be a TensorTrain object, got ' '"%s".' % where) if where.get_raw_shape() != what.get_raw_shape(): raise ValueError( 'The shapes of the tensor we want to project and of the ' 'tensor on which tangent space we want to project should ' 'match, got %s and %s.' % (where.get_raw_shape(), what.get_raw_shape())) dtypes_compatible = (where.dtype.is_compatible_with(what.dtype) or what.dtype.is_compatible_with(where.dtype)) if not dtypes_compatible: raise ValueError( 'Dtypes of the arguments should coincide, got %s and %s.' % (where.dtype, what.dtype)) left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where) right_tangent_space_tens = decompositions.orthogonalize_tt_cores( left_tangent_space_tens, left_to_right=False) ndims = where.ndims() dtype = where.dtype raw_shape = shapes.lazy_raw_shape(where) batch_size = shapes.lazy_batch_size(what) right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens) left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens) # For einsum notation. mode_str = 'ij' if where.is_tt_matrix() else 'i' right_rank_dim = where.right_tt_rank_dim left_rank_dim = where.left_tt_rank_dim if weights is not None: weights_shape = weights.get_shape() output_is_batch = len(weights_shape) > 1 and weights_shape[1] > 1 else: output_is_batch = False output_batch_str = 'o' if output_is_batch else '' if output_is_batch: right_rank_dim += 1 left_rank_dim += 1 output_batch_size = weights.get_shape().as_list()[1] # Prepare rhs vectors. # rhs[core_idx] is of size # batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx] rhs = [None] * (ndims + 1) rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1, 0, -1): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str) rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1], right_tang_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] lhs = [None] * (ndims + 1) lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str) lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core, tens_core) # Left to right sweep. res_cores_list = [] for core_idx in range(ndims): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] if core_idx < ndims - 1: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str) proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) if weights is None: einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) else: einsum_str = 'sa{0}b,sbc->sa{0}c'.format( mode_str, output_batch_str) proj_core_s = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) einsum_str = 's{1},sa{0}c->{1}a{0}c'.format( mode_str, output_batch_str) proj_core = tf.einsum(einsum_str, weights, proj_core_s) if core_idx == ndims - 1: if weights is None: einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) else: einsum_str = 'sab,sb{0}c->sa{0}c'.format( mode_str, output_batch_str) proj_core_s = tf.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 's{1},sa{0}c->{1}a{0}c'.format( mode_str, output_batch_str) proj_core = tf.einsum(einsum_str, weights, proj_core_s) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and # right_tang_core extended_left_tang_core = tf.expand_dims(left_tang_core, 0) extended_right_tang_core = tf.expand_dims(right_tang_core, 0) if where.is_tt_matrix(): extended_left_tang_core = tf.tile( extended_left_tang_core, [output_batch_size, 1, 1, 1, 1]) extended_right_tang_core = tf.tile( extended_right_tang_core, [output_batch_size, 1, 1, 1, 1]) else: extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1]) extended_right_tang_core = tf.tile( extended_right_tang_core, [output_batch_size, 1, 1, 1]) else: extended_left_tang_core = left_tang_core extended_right_tang_core = right_tang_core if core_idx == 0: res_core = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) elif core_idx == ndims - 1: res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim) else: rank_1 = right_tangent_tt_ranks[core_idx] rank_2 = left_tangent_tt_ranks[core_idx + 1] if where.is_tt_matrix(): mode_size_n = raw_shape[0][core_idx] mode_size_m = raw_shape[1][core_idx] shape = [rank_1, mode_size_n, mode_size_m, rank_2] else: mode_size = raw_shape[0][core_idx] shape = [rank_1, mode_size, rank_2] if output_is_batch: shape = [output_batch_size] + shape zeros = tf.zeros(shape, dtype) upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim) lower = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) res_core = tf.concat((upper, lower), axis=left_rank_dim) res_cores_list.append(res_core) # TODO: TT-ranks. if output_is_batch: res = TensorTrainBatch(res_cores_list, where.get_raw_shape(), batch_size=output_batch_size) else: res = TensorTrain(res_cores_list, where.get_raw_shape()) res.projection_on = where return res
def project_matmul(what, where, matrix): """Project `matrix` * `what` TTs on the tangent space of `where` TT. project(what, x) = P_x(what) project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N])) This function implements the algorithm from the paper [1], theorem 3.1. [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of Tensor Trains. Args: what: TensorTrain or TensorTrainBatch. In the case of batch returns batch with projection of each individual tensor. where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project matrix: TensorTrain, TT-matrix to multiply by what Returns: a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() Complexity: O(d r_where^3 m) for orthogonalizing the TT-cores of where +O(batch_size d R r_what r_where (n r_what + n m R + m r_where)) d is the number of TT-cores (what.ndims()); r_what is the largest TT-rank of what max(what.get_tt_rank()) r_where is the largest TT-rank of where matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n). """ if not isinstance(where, TensorTrain): raise ValueError( 'The first argument should be a TensorTrain object, got ' '"%s".' % where) if where.get_raw_shape() != what.get_raw_shape(): raise ValueError( 'The shapes of the tensor we want to project and of the ' 'tensor on which tangent space we want to project should ' 'match, got %s and %s.' % (where.get_raw_shape(), what.get_raw_shape())) dtypes_compatible = (where.dtype.is_compatible_with(what.dtype) or what.dtype.is_compatible_with(where.dtype)) if not dtypes_compatible: raise ValueError( 'Dtypes of the arguments should coincide, got %s and %s.' % (where.dtype, what.dtype)) left_tangent_space_tens = decompositions.orthogonalize_tt_cores(where) right_tangent_space_tens = decompositions.orthogonalize_tt_cores( left_tangent_space_tens, left_to_right=False) ndims = where.ndims() dtype = where.dtype raw_shape = shapes.lazy_raw_shape(where) batch_size = shapes.lazy_batch_size(what) right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens) left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens) # For einsum notation. right_rank_dim = what.right_tt_rank_dim left_rank_dim = what.left_tt_rank_dim output_is_batch = isinstance(what, TensorTrainBatch) if output_is_batch: output_batch_size = what.batch_size # Always work with batch of TT objects for simplicity. what = shapes.expand_batch_dim(what) # Prepare rhs vectors. # rhs[core_idx] is of size # batch_size x tensor_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tangent_tt_ranks[core_idx] rhs = [None] * (ndims + 1) rhs[ndims] = tf.ones((batch_size, 1, 1, 1), dtype=dtype) for core_idx in range(ndims - 1, 0, -1): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core, right_tang_core, rhs[core_idx + 1], tens_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] lhs = [None] * (ndims + 1) lhs[0] = tf.ones((batch_size, 1, 1, 1), dtype=dtype) for core_idx in range(ndims - 1): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] # TODO: brutforce order of indices in lhs?? lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef', matrix_core, left_tang_core, lhs[core_idx], tens_core) # Left to right sweep. res_cores_list = [] for core_idx in range(ndims): tens_core = what.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] if core_idx < ndims - 1: proj_core = tf.einsum('scjke,sabc,bijd->saikde', tens_core, lhs[core_idx], matrix_core) proj_core -= tf.einsum('aikb,sbcd->saikcd', left_tang_core, lhs[core_idx + 1]) proj_core = tf.einsum('saikcb,sbcd->saikd', proj_core, rhs[core_idx + 1]) if core_idx == ndims - 1: # d and e dimensions take 1 value, since its the last rank. # To make the result shape (?, ?, ?, 1), we are summing d and leaving e, # but we could have done the opposite -- sum e and leave d. proj_core = tf.einsum('sabc,bijd,scjke->saike', lhs[core_idx], matrix_core, tens_core) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and # right_tang_core extended_left_tang_core = tf.expand_dims(left_tang_core, 0) extended_right_tang_core = tf.expand_dims(right_tang_core, 0) extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1, 1]) extended_right_tang_core = tf.tile(extended_right_tang_core, [output_batch_size, 1, 1, 1, 1]) else: extended_left_tang_core = left_tang_core extended_right_tang_core = right_tang_core if core_idx == 0: res_core = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) elif core_idx == ndims - 1: res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim) else: rank_1 = right_tangent_tt_ranks[core_idx] rank_2 = left_tangent_tt_ranks[core_idx + 1] mode_size_n = raw_shape[0][core_idx] mode_size_m = raw_shape[1][core_idx] shape = [rank_1, mode_size_n, mode_size_m, rank_2] if output_is_batch: shape = [output_batch_size] + shape zeros = tf.zeros(shape, dtype) upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim) lower = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) res_core = tf.concat((upper, lower), axis=left_rank_dim) res_cores_list.append(res_core) # TODO: TT-ranks. if output_is_batch: res = TensorTrainBatch(res_cores_list, where.get_raw_shape(), batch_size=output_batch_size) else: res = TensorTrain(res_cores_list, where.get_raw_shape()) res.projection_on = where return res
def project(what, where): """Project `what` TTs on the tangent space of `where` TT. project(what, x) = P_x(what) project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N])) This function implements the algorithm from the paper [1], theorem 3.1. [1] C. Lubich, I. Oseledets and B. Vandereycken, Time integration of Tensor Trains. Args: what: TensorTrain or TensorTrainBatch. In the case of batch returns batch with projection of each individual tensor. where: TensorTrain, TT-tensor or TT-matrix on which tangent space to project Returns: a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() """ if not isinstance(where, TensorTrain): raise ValueError('The first argument should be a TensorTrain object, got ' '"%s".' % where) if where.get_raw_shape() != what.get_raw_shape(): raise ValueError('The shapes of the tensor we want to project and of the ' 'tensor on which tangent space we want to project should ' 'match, got %s and %s.' % (where.get_raw_shape(), what.get_raw_shape())) if not where.dtype.is_compatible_with(what.dtype): raise ValueError('Dtypes of the arguments should coincide, got %s and %s.' % (where.dtype, what.dtype)) left_tangent_space_tens = decompositions.orthogonalize_tt_cores( where) right_tangent_space_tens = decompositions.orthogonalize_tt_cores( left_tangent_space_tens, left_to_right=False) ndims = where.ndims() dtype = where.dtype raw_shape = shapes.lazy_raw_shape(where) batch_size = shapes.lazy_batch_size(what) right_tangent_tt_ranks = shapes.lazy_tt_ranks(right_tangent_space_tens) left_tangent_tt_ranks = shapes.lazy_tt_ranks(left_tangent_space_tens) # For einsum notation. mode_str = 'ij' if where.is_tt_matrix() else 'i' right_rank_dim = 3 if where.is_tt_matrix() else 2 left_rank_dim = 0 output_is_batch = isinstance(what, TensorTrainBatch) if output_is_batch: right_rank_dim += 1 left_rank_dim = 1 output_batch_size = what.batch_size # Always work with batch of TT objects for simplicity. what = shapes.expand_batch_dim(what) # Prepare rhs vectors. # rhs[core_idx] is of size # batch_size x tensor_tt_ranks[core_idx] x tangent_tt_ranks[core_idx] rhs = [None] * (ndims + 1) rhs[ndims] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1, 0, -1): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str) rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1], right_tang_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] lhs = [None] * (ndims + 1) lhs[0] = tf.ones((batch_size, 1, 1), dtype=dtype) for core_idx in range(ndims - 1): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str) lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core, tens_core) # Left to right sweep. res_cores_list = [] for core_idx in range(ndims): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] if core_idx < ndims - 1: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str) proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) if output_is_batch: einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str) else: einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) if core_idx == ndims - 1: if output_is_batch: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) else: einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str) proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and # right_tang_core extended_left_tang_core = tf.expand_dims(left_tang_core, 0) extended_right_tang_core = tf.expand_dims(right_tang_core, 0) if where.is_tt_matrix(): extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1, 1]) extended_right_tang_core = tf.tile(extended_right_tang_core, [output_batch_size, 1, 1, 1, 1]) else: extended_left_tang_core = tf.tile(extended_left_tang_core, [output_batch_size, 1, 1, 1]) extended_right_tang_core = tf.tile(extended_right_tang_core, [output_batch_size, 1, 1, 1]) else: extended_left_tang_core = left_tang_core extended_right_tang_core = right_tang_core if core_idx == 0: res_core = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) elif core_idx == ndims - 1: res_core = tf.concat((extended_right_tang_core, proj_core), axis=left_rank_dim) else: rank_1 = right_tangent_tt_ranks[core_idx] rank_2 = left_tangent_tt_ranks[core_idx + 1] if where.is_tt_matrix(): mode_size_n = raw_shape[0][core_idx] mode_size_m = raw_shape[1][core_idx] shape = [rank_1, mode_size_n, mode_size_m, rank_2] else: mode_size = raw_shape[0][core_idx] shape = [rank_1, mode_size, rank_2] if output_is_batch: shape = [output_batch_size] + shape zeros = tf.zeros(shape, dtype) upper = tf.concat((extended_right_tang_core, zeros), axis=right_rank_dim) lower = tf.concat((proj_core, extended_left_tang_core), axis=right_rank_dim) res_core = tf.concat((upper, lower), axis=left_rank_dim) res_cores_list.append(res_core) # TODO: TT-ranks. if output_is_batch: res = TensorTrainBatch(res_cores_list, where.get_raw_shape(), batch_size=output_batch_size) else: res = TensorTrain(res_cores_list, where.get_raw_shape()) res.projection_on = where return res