Beispiel #1
0
def tt_sparse_flat_inner(tt_a, sparse_b):
    """Inner product between a TT-tensor (or TT-matrix) and tf.SparseTensor along all axis.

  The shapes of tt_a and sparse_b should coincide.

  Args:
    tt_a: `TensorTrain` object
    sparse_b: tf.SparseTensor

  Returns
    a number
    sum of products of all the elements of tt_a and sparse_b
  """
    if sparse_b.indices.get_shape().is_fully_defined():
        num_elements = sparse_b.indices.get_shape()[0]
    else:
        num_elements = tf.shape(sparse_b.indices)[0]
    a_shape = shapes.lazy_raw_shape(tt_a)
    a_ranks = shapes.lazy_tt_ranks(tt_a)
    if tt_a.is_tt_matrix():
        tt_a_elements = tf.ones((num_elements, 1, 1))
        # TODO: use t3f.shape is safer??
        tensor_shape = tt_a.get_raw_shape()
        row_idx_linear = tf.cast(sparse_b.indices[:, 0], tf.int64)
        row_idx = utils.unravel_index(row_idx_linear,
                                      tf.cast(tensor_shape[0], tf.int64))
        col_idx_linear = tf.cast(sparse_b.indices[:, 1], tf.int64)
        col_idx = utils.unravel_index(col_idx_linear,
                                      tf.cast(tensor_shape[1], tf.int64))
        for core_idx in range(tt_a.ndims()):
            curr_core = tt_a.tt_cores[core_idx]
            left_rank = a_ranks[core_idx]
            right_rank = a_ranks[core_idx + 1]
            curr_core = tf.transpose(curr_core, (1, 2, 0, 3))
            curr_core_shape = (a_shape[0][core_idx] * a_shape[1][core_idx],
                               left_rank, right_rank)
            curr_core = tf.reshape(curr_core, curr_core_shape)
            # Ravel multiindex (row_idx[:, core_idx], col_idx[:, core_idx]) into
            # a linear index to use tf.gather that supports only first dimensional
            # gather.
            # TODO: use gather_nd instead.
            curr_elements_idx = row_idx[:,
                                        core_idx] * tensor_shape[1][core_idx]
            curr_elements_idx += col_idx[:, core_idx]
            core_slices = tf.gather(curr_core, curr_elements_idx)
            tt_a_elements = tf.matmul(tt_a_elements, core_slices)
    else:
        tt_a_elements = gather_nd(tt_a, sparse_b.indices)
    tt_a_elements = tf.reshape(tt_a_elements, (1, -1))
    sparse_b_elements = tf.reshape(sparse_b.values, (-1, 1))
    result = tf.matmul(tt_a_elements, sparse_b_elements)
    # Convert a 1x1 matrix into a number.
    result = result[0, 0]
    return result
Beispiel #2
0
 def testUnravelIndex(self):
     # 2D.
     shape = (7, 6)
     linear_idx = [22, 41, 37]
     desired = [[3, 4], [6, 5], [6, 1]]
     actual = utils.unravel_index(linear_idx, shape)
     self.assertAllEqual(desired, self.evaluate(actual))
     # 3D.
     shape = (2, 3, 4)
     linear_idx = [19, 17, 0, 23]
     desired = [[1, 1, 3], [1, 1, 1], [0, 0, 0], [1, 2, 3]]
     actual = utils.unravel_index(linear_idx, shape)
     self.assertAllEqual(desired, self.evaluate(actual))