def _allocation(self, usage): r"""Computes allocation by sorting `usage`. This corresponds to the value a = a_t[\phi_t[j]] in the paper. Args: usage: tensor of shape `[batch_size, memory_size]` indicating current memory usage. This is equal to u_t in the paper when we only have one write head, but for multiple write heads, one should update the usage while iterating through the write heads to take into account the allocation returned by this function. Returns: Tensor of shape `[batch_size, memory_size]` corresponding to allocation. """ with tf.name_scope('allocation'): # Ensure values are not too small prior to cumprod. usage = _EPSILON + (1 - _EPSILON) * usage nonusage = 1 - usage sorted_nonusage, indices = tf.nn.top_k(nonusage, k=self._memory_size, name='sort') sorted_usage = 1 - sorted_nonusage prod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True) sorted_allocation = sorted_nonusage * prod_sorted_usage inverse_indices = util.batch_invert_permutation(indices) # This final line "unsorts" sorted_allocation, so that the indexing # corresponds to the original indexing of `usage`. return util.batch_gather(sorted_allocation, inverse_indices)
def test(self): # Tests that the _batch_invert_permutation function correctly inverts a # batch of permutations. batch_size = 5 length = 7 permutations = np.empty([batch_size, length], dtype=int) for i in xrange(batch_size): permutations[i] = np.random.permutation(length) inverse = util.batch_invert_permutation( tf.constant(permutations, tf.int32)) with self.test_session(): inverse = inverse.eval() for i in xrange(batch_size): for j in xrange(length): self.assertEqual(permutations[i][inverse[i][j]], j)