def test_beam_search(self, start, alpha, min_len, max_len, targets, expected): batch_size = 2 beam_size = 3 max_decode_len = len(targets) vocab_size = 7 targets = tf.one_hot(tf.constant(targets), vocab_size, dtype=tf.float32) length_norm_fn = beam_search.length_normalization( start, alpha, min_len, max_len, -1e3) def symbols_to_logits_fn(unused_decodes, unused_states, i): # scales to ensure logits choice not biased by length penalty. logits = targets[i:i + 1, :] * 1e2 return tf.tile(logits, [batch_size * beam_size, 1]), unused_states states = {"empty": tf.ones([batch_size, 1], tf.float32)} beams, _ = beam_search.beam_search( symbols_to_logits_fn, tf.zeros([batch_size, max_decode_len], dtype=tf.int32), states, vocab_size, beam_size, length_norm_fn) self.assertAllEqual(expected, beams[0, 0, :])
def left2right_decode(symbols_to_logits_fn, context_BxU_dict, batch_size, max_decode_len, vocab_size, beam_size=1, beam_start=5, beam_alpha=0.6, beam_min=0, beam_max=-1, temperature=0.0, top_k=0, top_p=0.0, eos_id=EOS_ID): """left to right decode. Notations: B: batch_size, V: vocab_size, T: decode_len, U: undefined dimensions Args: symbols_to_logits_fn: logits = fn(decodes, context, i). Shoud take [batch_size, decoded_ids] and return [batch_size, vocab_size]. context_BxU_dict: dict of Tensors. batch_size: int, decode batch size. max_decode_len: int, maximum number of steps to decode. vocab_size: int, output vocab size. beam_size: Number of beams to decode. beam_start: start length for scaling, default to 5. beam_alpha: Length penalty for decoding. Should be between 0 (shorter) and 1 (longer), default to 0.6. beam_min: Minimum beam search lengths. beam_max: Maximum beam search lengths. Set -1 to use unlimited. temperature: Sampling temp for next token (0 for argmax), default to 0.0. top_k: Number of top symbols to consider at each time step, default to 0 (consider all symbols). top_p: Nucleus sampling probability. eos_id: end of token id, default to 1. Returns: decodes: Tensor[batch, decode_len] """ dtype = tf.int64 # When beam_size=1, beam_search does not behave exactly like greedy. # This is due to using 2 * beam_size in grow_topk, and keep the top beam_size # ones that haven't reached EOS into alive. # In this case, alpha value for length penalty will take effect. if beam_size == 1: def decode_loop(i, decodes_BxT, cache_BxU_dict): logits_BxV = symbols_to_logits_fn(decodes_BxT, cache_BxU_dict, i) logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature) decodes_BxT = inplace_update_i(decodes_BxT, tf.argmax(input=logits_BxV, axis=-1), i) return i + 1, decodes_BxT, cache_BxU_dict def loop_cond(i, decodes_BxT, unused_cache_BxU_dict): finished_B = tf.reduce_any(input_tensor=tf.equal(decodes_BxT, EOS_ID), axis=1) return tf.logical_and(i < max_decode_len, tf.logical_not(tf.reduce_all(input_tensor=finished_B))) init_dec_BxT = tf.zeros([batch_size, max_decode_len], dtype=dtype) _, decodes, _ = tf.while_loop( cond=loop_cond, body=decode_loop, loop_vars=[tf.constant(0, dtype=dtype), init_dec_BxT, context_BxU_dict]) return decodes else: def symbols_to_logits_fn_with_sampling(decodes_BxT, states_BxU_dict, i): logits_BxV = symbols_to_logits_fn(decodes_BxT, states_BxU_dict, i) logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature) return logits_BxV, states_BxU_dict length_norm_fn = beam_search.length_normalization(beam_start, beam_alpha, beam_min, beam_max, -1e3) beams, _ = beam_search.beam_search( symbols_to_logits_fn_with_sampling, tf.zeros([batch_size, max_decode_len], dtype=tf.int32), context_BxU_dict, vocab_size, beam_size, length_norm_fn, eos_id) return tf.cast(beams[:, 0, :], dtype)
def left2right_decode(symbols_to_logits_fn, context_BxU_dict, batch_size, max_decode_len, vocab_size, beam_size=1, beam_start=5, beam_alpha=0.6, beam_min=0, beam_max=-1, temperature=0.0, top_k=0, top_p=0.0, eos_id=EOS_ID, training=False): """left to right decode. Notations: B: batch_size, V: vocab_size, T: decode_len, U: undefined dimensions Args: symbols_to_logits_fn: logits = fn(decodes, context, i). Shoud take [batch_size, decoded_ids] and return [batch_size, vocab_size]. context_BxU_dict: dict of Tensors. batch_size: int, decode batch size. max_decode_len: int, maximum number of steps to decode. vocab_size: int, output vocab size. beam_size: Number of beams to decode. beam_start: start length for scaling, default to 5. beam_alpha: Length penalty for decoding. Should be between 0 (shorter) and 1 (longer), default to 0.6. beam_min: Minimum beam search lengths. beam_max: Maximum beam search lengths. Set -1 to use unlimited. temperature: Sampling temp for next token (0 for argmax), default to 0.0. top_k: Number of top symbols to consider at each time step, default to 0 (consider all symbols). top_p: Nucleus sampling probability. eos_id: end of token id, default to 1. training: for sampling during training, default to False for predictions Returns: decodes: Tensor[batch, decode_len] """ dtype = tf.int64 # When beam_size=1, beam_search does not behave exactly like greedy. # This is due to using 2 * beam_size in grow_topk, and keep the top beam_size # ones that haven't reached EOS into alive. # In this case, alpha value for length penalty will take effect. if beam_size == 1: def decode_loop(i, decodes_BxT, cache_BxU_dict, logits_BxTxV): logits_BxV = symbols_to_logits_fn(decodes_BxT, cache_BxU_dict, i) logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature) # returns z decodes_BxT = inplace_update_i(decodes_BxT, tf.argmax(logits_BxV, -1), i) # ids of argmax(logits) if training: decodes_BxT = tf.cast(tf.stop_gradient(decodes_BxT), dtype) # remove from graph # logp_BxV = tf.log(tf.clip_by_value(tf.math.softmax(logits_BxV, axis=1), 1e-8, 1.0)) # logits -> logp # logp_BxT = inplace_update_i(logp_BxT, tf.broadcast_to(tf.reduce_max(logp_BxV), [1, ]), i) # logp sequence # logp_BxTxV = inplace_update_i(logp_BxTxV, logp_BxV, i) # logp sequence x vocab logits_BxTxV = inplace_update_i2(logits_BxTxV, logits_BxV, i) # logits sequence x vocab return i + 1, decodes_BxT, cache_BxU_dict, logits_BxTxV def loop_cond(i, decodes_BxT, unused_cache_BxU_dict, unused_logits_BxTxV): finished_B = tf.reduce_any(tf.equal(decodes_BxT, EOS_ID), axis=1) return tf.logical_and(i < max_decode_len, tf.logical_not(tf.reduce_all(finished_B))) dtype = tf.int32 if training else dtype init_dec_BxT = tf.zeros([batch_size, max_decode_len], dtype=dtype) # added placeholder tensors to append values to # init_logp_BxT = tf.zeros([batch_size, max_decode_len], dtype=tf.float32) # logp sequence # init_logp_BxTxV = tf.zeros([batch_size, max_decode_len, vocab_size], dtype=tf.float32) # logp sequence x vocab init_logits_BxTxV = tf.zeros([batch_size, max_decode_len, vocab_size], dtype=tf.float32) # logits sequence x vocab # swap_mem = True if training else False _, decodes, _, logits_BxTxV = tf.while_loop( loop_cond, decode_loop, [tf.constant(0, dtype=dtype), init_dec_BxT, context_BxU_dict, init_logits_BxTxV]) # {ids of argmax(logits), logp of sequence where argmax, dict(entire logp of beam, entire logits of beam)} return decodes, None, logits_BxTxV else: def symbols_to_logits_fn_with_sampling(decodes_BxT, states_BxU_dict, i): logits_BxV = symbols_to_logits_fn(decodes_BxT, states_BxU_dict, i) logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature) return logits_BxV, states_BxU_dict length_norm_fn = beam_search.length_normalization(beam_start, beam_alpha, beam_min, beam_max, -1e3) beams, beam_scores, beam_dict = beam_search.beam_search( symbols_to_logits_fn_with_sampling, tf.zeros([batch_size, max_decode_len], dtype=tf.int32), context_BxU_dict, vocab_size, beam_size, length_norm_fn, eos_id, training) final_beams = {} for i in range(beam_size): final_beams[i] = tf.cast(beams[:, i, :], dtype) return final_beams, beam_scores, beam_dict