def sample_top_p(logits, top_p): """Chooses most probable logits with cumulative probabilities upto top_p. Sets the remaining logits to negative infinity. Args: logits: Input logits for next token. top_p: Float tensor with a value >=0 and < 1.0 Returns: Logits with top_p filtering applied. """ sorted_indices = tf.argsort(logits, direction="DESCENDING") # Flatten logits as tf.gather on TPU needs axis to be compile time constant. logits_shape = decoding_module.shape_list(logits) range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1) range_for_gather = tf.tile(range_for_gather * logits_shape[1], [1, logits_shape[1]]) + sorted_indices flattened_logits = tf.reshape(logits, [-1]) flattened_sorted_indices = tf.reshape(range_for_gather, [-1]) sorted_logits = tf.reshape( tf.gather(flattened_logits, flattened_sorted_indices), [logits_shape[0], logits_shape[1]]) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) # Remove tokens with cumulative probability above the threshold. sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep the first token above threshold. sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) sorted_indices_to_remove = tf.concat([ tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:] ], -1) # Scatter sorted indices to original indexes. indices_to_remove = scatter_values_on_batch_indices( sorted_indices_to_remove, sorted_indices) top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove, np.NINF) return top_p_logits
def scatter_values_on_batch_indices(values, batch_indices): """Scatter `values` into a tensor using `batch_indices`. Args: values: tensor of shape [batch_size, vocab_size] containing the values to scatter batch_indices: tensor of shape [batch_size, vocab_size] containing the indices to insert (should be a permutation in range(0, n)) Returns: Tensor of shape [batch_size, vocab_size] with values inserted at batch_indices """ tensor_shape = decoding_module.shape_list(batch_indices) broad_casted_batch_dims = tf.reshape( tf.broadcast_to(tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape), [1, -1]) pair_indices = tf.transpose( tf.concat( [broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0)) return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape)
def test_shape_list(self): x = tf.ones([7, 1]) shape = decoding_module.shape_list(x) self.assertAllEqual([7, 1], shape)