def tt_sparse_flat_inner(tt_a, sparse_b): """Inner product between a TT-tensor (or TT-matrix) and tf.SparseTensor along all axis. The shapes of tt_a and sparse_b should coincide. Args: tt_a: `TensorTrain` object sparse_b: tf.SparseTensor Returns a number sum of products of all the elements of tt_a and sparse_b """ if sparse_b.indices.get_shape().is_fully_defined(): num_elements = sparse_b.indices.get_shape()[0] else: num_elements = tf.shape(sparse_b.indices)[0] a_shape = shapes.lazy_raw_shape(tt_a) a_ranks = shapes.lazy_tt_ranks(tt_a) if tt_a.is_tt_matrix(): tt_a_elements = tf.ones((num_elements, 1, 1)) # TODO: use t3f.shape is safer?? tensor_shape = tt_a.get_raw_shape() row_idx_linear = tf.cast(sparse_b.indices[:, 0], tf.int64) row_idx = utils.unravel_index(row_idx_linear, tf.cast(tensor_shape[0], tf.int64)) col_idx_linear = tf.cast(sparse_b.indices[:, 1], tf.int64) col_idx = utils.unravel_index(col_idx_linear, tf.cast(tensor_shape[1], tf.int64)) for core_idx in range(tt_a.ndims()): curr_core = tt_a.tt_cores[core_idx] left_rank = a_ranks[core_idx] right_rank = a_ranks[core_idx + 1] curr_core = tf.transpose(curr_core, (1, 2, 0, 3)) curr_core_shape = (a_shape[0][core_idx] * a_shape[1][core_idx], left_rank, right_rank) curr_core = tf.reshape(curr_core, curr_core_shape) # Ravel multiindex (row_idx[:, core_idx], col_idx[:, core_idx]) into # a linear index to use tf.gather that supports only first dimensional # gather. # TODO: use gather_nd instead. curr_elements_idx = row_idx[:, core_idx] * tensor_shape[1][core_idx] curr_elements_idx += col_idx[:, core_idx] core_slices = tf.gather(curr_core, curr_elements_idx) tt_a_elements = tf.matmul(tt_a_elements, core_slices) else: tt_a_elements = gather_nd(tt_a, sparse_b.indices) tt_a_elements = tf.reshape(tt_a_elements, (1, -1)) sparse_b_elements = tf.reshape(sparse_b.values, (-1, 1)) result = tf.matmul(tt_a_elements, sparse_b_elements) # Convert a 1x1 matrix into a number. result = result[0, 0] return result
def testUnravelIndex(self): # 2D. shape = (7, 6) linear_idx = [22, 41, 37] desired = [[3, 4], [6, 5], [6, 1]] actual = utils.unravel_index(linear_idx, shape) self.assertAllEqual(desired, self.evaluate(actual)) # 3D. shape = (2, 3, 4) linear_idx = [19, 17, 0, 23] desired = [[1, 1, 3], [1, 1, 1], [0, 0, 0], [1, 2, 3]] actual = utils.unravel_index(linear_idx, shape) self.assertAllEqual(desired, self.evaluate(actual))