Exemplo n.º 1
0
    def test_beam_search(self, start, alpha, min_len, max_len, targets,
                         expected):
        batch_size = 2
        beam_size = 3
        max_decode_len = len(targets)
        vocab_size = 7
        targets = tf.one_hot(tf.constant(targets),
                             vocab_size,
                             dtype=tf.float32)
        length_norm_fn = beam_search.length_normalization(
            start, alpha, min_len, max_len, -1e3)

        def symbols_to_logits_fn(unused_decodes, unused_states, i):
            # scales to ensure logits choice not biased by length penalty.
            logits = targets[i:i + 1, :] * 1e2
            return tf.tile(logits, [batch_size * beam_size, 1]), unused_states

        states = {"empty": tf.ones([batch_size, 1], tf.float32)}

        beams, _ = beam_search.beam_search(
            symbols_to_logits_fn,
            tf.zeros([batch_size, max_decode_len], dtype=tf.int32), states,
            vocab_size, beam_size, length_norm_fn)
        self.assertAllEqual(expected, beams[0, 0, :])
Exemplo n.º 2
0
def left2right_decode(symbols_to_logits_fn,
                      context_BxU_dict,
                      batch_size,
                      max_decode_len,
                      vocab_size,
                      beam_size=1,
                      beam_start=5,
                      beam_alpha=0.6,
                      beam_min=0,
                      beam_max=-1,
                      temperature=0.0,
                      top_k=0,
                      top_p=0.0,
                      eos_id=EOS_ID):
  """left to right decode.

  Notations:
    B: batch_size, V: vocab_size, T: decode_len, U: undefined dimensions

  Args:
    symbols_to_logits_fn: logits = fn(decodes, context, i). Shoud take
      [batch_size, decoded_ids] and return [batch_size, vocab_size].
    context_BxU_dict: dict of Tensors.
    batch_size: int, decode batch size.
    max_decode_len: int, maximum number of steps to decode.
    vocab_size: int, output vocab size.
    beam_size: Number of beams to decode.
    beam_start: start length for scaling, default to 5.
    beam_alpha: Length penalty for decoding. Should be between 0 (shorter) and 1
      (longer), default to 0.6.
    beam_min: Minimum beam search lengths.
    beam_max: Maximum beam search lengths. Set -1 to use unlimited.
    temperature: Sampling temp for next token (0 for argmax), default to 0.0.
    top_k: Number of top symbols to consider at each time step, default to 0
      (consider all symbols).
    top_p: Nucleus sampling probability.
    eos_id: end of token id, default to 1.

  Returns:
    decodes: Tensor[batch, decode_len]
  """
  dtype = tf.int64
  # When beam_size=1, beam_search does not behave exactly like greedy.
  # This is due to using 2 * beam_size in grow_topk, and keep the top beam_size
  # ones that haven't reached EOS into alive.
  # In this case, alpha value for length penalty will take effect.
  if beam_size == 1:

    def decode_loop(i, decodes_BxT, cache_BxU_dict):
      logits_BxV = symbols_to_logits_fn(decodes_BxT, cache_BxU_dict, i)
      logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature)
      decodes_BxT = inplace_update_i(decodes_BxT, tf.argmax(input=logits_BxV, axis=-1), i)
      return i + 1, decodes_BxT, cache_BxU_dict

    def loop_cond(i, decodes_BxT, unused_cache_BxU_dict):
      finished_B = tf.reduce_any(input_tensor=tf.equal(decodes_BxT, EOS_ID), axis=1)
      return tf.logical_and(i < max_decode_len,
                            tf.logical_not(tf.reduce_all(input_tensor=finished_B)))

    init_dec_BxT = tf.zeros([batch_size, max_decode_len], dtype=dtype)
    _, decodes, _ = tf.while_loop(
        cond=loop_cond, body=decode_loop,
        loop_vars=[tf.constant(0, dtype=dtype), init_dec_BxT, context_BxU_dict])
    return decodes

  else:

    def symbols_to_logits_fn_with_sampling(decodes_BxT, states_BxU_dict, i):
      logits_BxV = symbols_to_logits_fn(decodes_BxT, states_BxU_dict, i)
      logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature)
      return logits_BxV, states_BxU_dict

    length_norm_fn = beam_search.length_normalization(beam_start, beam_alpha,
                                                      beam_min, beam_max, -1e3)
    beams, _ = beam_search.beam_search(
        symbols_to_logits_fn_with_sampling,
        tf.zeros([batch_size, max_decode_len], dtype=tf.int32),
        context_BxU_dict, vocab_size, beam_size, length_norm_fn, eos_id)
    return tf.cast(beams[:, 0, :], dtype)
Exemplo n.º 3
0
def left2right_decode(symbols_to_logits_fn,
                      context_BxU_dict,
                      batch_size,
                      max_decode_len,
                      vocab_size,
                      beam_size=1,
                      beam_start=5,
                      beam_alpha=0.6,
                      beam_min=0,
                      beam_max=-1,
                      temperature=0.0,
                      top_k=0,
                      top_p=0.0,
                      eos_id=EOS_ID,
                      training=False):
  """left to right decode.

  Notations:
    B: batch_size, V: vocab_size, T: decode_len, U: undefined dimensions

  Args:
    symbols_to_logits_fn: logits = fn(decodes, context, i). Shoud take
      [batch_size, decoded_ids] and return [batch_size, vocab_size].
    context_BxU_dict: dict of Tensors.
    batch_size: int, decode batch size.
    max_decode_len: int, maximum number of steps to decode.
    vocab_size: int, output vocab size.
    beam_size: Number of beams to decode.
    beam_start: start length for scaling, default to 5.
    beam_alpha: Length penalty for decoding. Should be between 0 (shorter) and 1
      (longer), default to 0.6.
    beam_min: Minimum beam search lengths.
    beam_max: Maximum beam search lengths. Set -1 to use unlimited.
    temperature: Sampling temp for next token (0 for argmax), default to 0.0.
    top_k: Number of top symbols to consider at each time step, default to 0
      (consider all symbols).
    top_p: Nucleus sampling probability.
    eos_id: end of token id, default to 1.
    training: for sampling during training, default to False for predictions

  Returns:
    decodes: Tensor[batch, decode_len]
  """
  dtype = tf.int64
  # When beam_size=1, beam_search does not behave exactly like greedy.
  # This is due to using 2 * beam_size in grow_topk, and keep the top beam_size
  # ones that haven't reached EOS into alive.
  # In this case, alpha value for length penalty will take effect.
  if beam_size == 1:

    def decode_loop(i, decodes_BxT, cache_BxU_dict, logits_BxTxV):
      logits_BxV = symbols_to_logits_fn(decodes_BxT, cache_BxU_dict, i)
      logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature)  # returns z
      decodes_BxT = inplace_update_i(decodes_BxT, tf.argmax(logits_BxV, -1), i)  # ids of argmax(logits)
      if training:
        decodes_BxT = tf.cast(tf.stop_gradient(decodes_BxT), dtype)  # remove from graph
        # logp_BxV = tf.log(tf.clip_by_value(tf.math.softmax(logits_BxV, axis=1), 1e-8, 1.0))  # logits -> logp
        # logp_BxT = inplace_update_i(logp_BxT, tf.broadcast_to(tf.reduce_max(logp_BxV), [1, ]), i)  # logp sequence
        # logp_BxTxV = inplace_update_i(logp_BxTxV, logp_BxV, i)  # logp sequence x vocab
        logits_BxTxV = inplace_update_i2(logits_BxTxV, logits_BxV, i)  # logits sequence x vocab

      return i + 1, decodes_BxT, cache_BxU_dict, logits_BxTxV

    def loop_cond(i, decodes_BxT, unused_cache_BxU_dict, unused_logits_BxTxV):
      finished_B = tf.reduce_any(tf.equal(decodes_BxT, EOS_ID), axis=1)
      return tf.logical_and(i < max_decode_len,
                            tf.logical_not(tf.reduce_all(finished_B)))

    dtype = tf.int32 if training else dtype
    init_dec_BxT = tf.zeros([batch_size, max_decode_len], dtype=dtype)

    # added placeholder tensors to append values to
    # init_logp_BxT = tf.zeros([batch_size, max_decode_len], dtype=tf.float32)  # logp sequence
    # init_logp_BxTxV = tf.zeros([batch_size, max_decode_len, vocab_size], dtype=tf.float32)  # logp sequence x vocab
    init_logits_BxTxV = tf.zeros([batch_size, max_decode_len, vocab_size], dtype=tf.float32)  # logits sequence x vocab
    # swap_mem = True if training else False

    _, decodes, _, logits_BxTxV = tf.while_loop(
        loop_cond, decode_loop,
        [tf.constant(0, dtype=dtype), init_dec_BxT, context_BxU_dict,
         init_logits_BxTxV])

    # {ids of argmax(logits), logp of sequence where argmax, dict(entire logp of beam, entire logits of beam)}
    return decodes, None, logits_BxTxV

  else:

      def symbols_to_logits_fn_with_sampling(decodes_BxT, states_BxU_dict, i):
          logits_BxV = symbols_to_logits_fn(decodes_BxT, states_BxU_dict, i)
          logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature)
          return logits_BxV, states_BxU_dict

      length_norm_fn = beam_search.length_normalization(beam_start, beam_alpha,
                                                        beam_min, beam_max, -1e3)
      beams, beam_scores, beam_dict = beam_search.beam_search(
          symbols_to_logits_fn_with_sampling,
          tf.zeros([batch_size, max_decode_len], dtype=tf.int32),
          context_BxU_dict, vocab_size, beam_size, length_norm_fn, eos_id, training)

      final_beams = {}
      for i in range(beam_size):
          final_beams[i] = tf.cast(beams[:, i, :], dtype)

      return final_beams, beam_scores, beam_dict