Ejemplo n.º 1
0
def sample_top_p(logits, top_p):
    """Chooses most probable logits with cumulative probabilities upto top_p.

  Sets the remaining logits to negative infinity.

  Args:
    logits: Input logits for next token.
    top_p: Float tensor with a value >=0 and < 1.0

  Returns:
    Logits with top_p filtering applied.
  """
    sorted_indices = tf.argsort(logits, direction="DESCENDING")
    # Flatten logits as tf.gather on TPU needs axis to be compile time constant.
    logits_shape = decoding_module.shape_list(logits)
    range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1)
    range_for_gather = tf.tile(range_for_gather * logits_shape[1],
                               [1, logits_shape[1]]) + sorted_indices
    flattened_logits = tf.reshape(logits, [-1])
    flattened_sorted_indices = tf.reshape(range_for_gather, [-1])
    sorted_logits = tf.reshape(
        tf.gather(flattened_logits, flattened_sorted_indices),
        [logits_shape[0], logits_shape[1]])
    cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1),
                                 axis=-1)

    # Remove tokens with cumulative probability above the threshold.
    sorted_indices_to_remove = cumulative_probs > top_p

    # Shift the indices to the right to keep the first token above threshold.
    sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
    sorted_indices_to_remove = tf.concat([
        tf.zeros_like(sorted_indices_to_remove[:, :1]),
        sorted_indices_to_remove[:, 1:]
    ], -1)

    # Scatter sorted indices to original indexes.
    indices_to_remove = scatter_values_on_batch_indices(
        sorted_indices_to_remove, sorted_indices)
    top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove,
                                                  np.NINF)
    return top_p_logits
Ejemplo n.º 2
0
def scatter_values_on_batch_indices(values, batch_indices):
    """Scatter `values` into a tensor using `batch_indices`.

  Args:
    values: tensor of shape [batch_size, vocab_size] containing the values to
      scatter
    batch_indices: tensor of shape [batch_size, vocab_size] containing the
      indices to insert (should be a permutation in range(0, n))

  Returns:
    Tensor of shape [batch_size, vocab_size] with values inserted at
    batch_indices
  """
    tensor_shape = decoding_module.shape_list(batch_indices)
    broad_casted_batch_dims = tf.reshape(
        tf.broadcast_to(tf.expand_dims(tf.range(tensor_shape[0]), axis=-1),
                        tensor_shape), [1, -1])
    pair_indices = tf.transpose(
        tf.concat(
            [broad_casted_batch_dims,
             tf.reshape(batch_indices, [1, -1])], 0))
    return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape)
 def test_shape_list(self):
     x = tf.ones([7, 1])
     shape = decoding_module.shape_list(x)
     self.assertAllEqual([7, 1], shape)