Exemplo n.º 1
0
    def _allocation(self, usage):
        r"""Computes allocation by sorting `usage`.

    This corresponds to the value a = a_t[\phi_t[j]] in the paper.

    Args:
      usage: tensor of shape `[batch_size, memory_size]` indicating current
          memory usage. This is equal to u_t in the paper when we only have one
          write head, but for multiple write heads, one should update the usage
          while iterating through the write heads to take into account the
          allocation returned by this function.

    Returns:
      Tensor of shape `[batch_size, memory_size]` corresponding to allocation.
    """
        with tf.name_scope('allocation'):
            # Ensure values are not too small prior to cumprod.
            usage = _EPSILON + (1 - _EPSILON) * usage

            nonusage = 1 - usage
            sorted_nonusage, indices = tf.nn.top_k(nonusage,
                                                   k=self._memory_size,
                                                   name='sort')
            sorted_usage = 1 - sorted_nonusage
            prod_sorted_usage = tf.cumprod(sorted_usage,
                                           axis=1,
                                           exclusive=True)
            sorted_allocation = sorted_nonusage * prod_sorted_usage
            inverse_indices = util.batch_invert_permutation(indices)

            # This final line "unsorts" sorted_allocation, so that the indexing
            # corresponds to the original indexing of `usage`.
            return util.batch_gather(sorted_allocation, inverse_indices)
Exemplo n.º 2
0
    def test(self):
        # Tests that the _batch_invert_permutation function correctly inverts a
        # batch of permutations.
        batch_size = 5
        length = 7

        permutations = np.empty([batch_size, length], dtype=int)
        for i in xrange(batch_size):
            permutations[i] = np.random.permutation(length)

        inverse = util.batch_invert_permutation(
            tf.constant(permutations, tf.int32))
        with self.test_session():
            inverse = inverse.eval()

        for i in xrange(batch_size):
            for j in xrange(length):
                self.assertEqual(permutations[i][inverse[i][j]], j)