def _sampling_step(time, func, state, min_length, max_length, pad_id, eos_id): # Compute log probabilities seqs = state.inputs seq_probs = state.seq_probs # [batch_size * num_samples, vocab_size] step_log_probs, next_state = func(seqs, state.state) # Suppress <eos> if needed batch_size = tf.shape(step_log_probs)[0] vocab_size = step_log_probs.shape[-1].value or tf.shape(step_log_probs)[1] add_mask = tf.one_hot(eos_id, vocab_size, dtype=step_log_probs.dtype, on_value=step_log_probs.dtype.min, off_value=0.0) add_mask = tf.tile(tf.reshape(add_mask, [1, -1]), [batch_size, 1]) add_mask = tf.where(time < min_length, add_mask, tf.zeros_like(add_mask)) step_log_probs = step_log_probs + add_mask # sample from distribution symbol_indices = tf.multinomial(step_log_probs, 1, output_dtype=tf.int32) symbol_scores = tf.squeeze(utils.gather_2d(step_log_probs, symbol_indices)) curr_flags = tf.squeeze(tf.equal(symbol_indices, eos_id), axis=1) curr_flags = tf.logical_or(state.flags, curr_flags) # Append <pad> to finished samples symbol_indices = tf.where(state.flags, tf.fill([batch_size, 1], pad_id), symbol_indices) symbol_scores = tf.where(state.flags, tf.zeros([batch_size]), symbol_scores) # Force sampler to generate <eos> if length exceed max_length eos_flags = tf.where(time > max_length, tf.ones([batch_size], tf.bool), tf.zeros([batch_size], tf.bool)) eos_scores = tf.squeeze( utils.gather_2d(step_log_probs, tf.fill([batch_size, 1], eos_id))) eos_indices = tf.fill([batch_size, 1], eos_id) cond = tf.logical_and(tf.logical_not(curr_flags), eos_flags) curr_flags = tf.logical_or(curr_flags, eos_flags) symbol_indices = tf.where(cond, eos_indices, symbol_indices) symbol_scores = tf.where(cond, eos_scores, symbol_scores) step_symbol_scores = tf.expand_dims(symbol_scores, -1) new_state = SamplerState(inputs=tf.concat([seqs, symbol_indices], axis=1), state=next_state, scores=state.scores + symbol_scores, flags=curr_flags, seq_probs=tf.concat( [seq_probs, step_symbol_scores], axis=1)) return time + 1, new_state
def _beam_search_step(time, func, state, batch_size, beam_size, alpha, pad_id, eos_id): # Compute log probabilities seqs, log_probs = state.inputs[:2] flat_seqs = utils.merge_first_two_dims(seqs) flat_state = nest.map_structure(lambda x: utils.merge_first_two_dims(x), state.state) step_log_probs, next_state = func(flat_seqs, flat_state) step_log_probs = utils.split_first_two_dims(step_log_probs, batch_size, beam_size) next_state = nest.map_structure( lambda x: utils.split_first_two_dims(x, batch_size, beam_size), next_state) curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs # Apply length penalty length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha) curr_scores = curr_log_probs / length_penalty vocab_size = curr_scores.shape[-1].value or tf.shape(curr_scores)[-1] # Select top-k candidates # [batch_size, beam_size * vocab_size] curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) # [batch_size, 2 * beam_size] top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size) # Shape: [batch_size, 2 * beam_size] beam_indices = top_indices // vocab_size symbol_indices = top_indices % vocab_size # Expand sequences # [batch_size, 2 * beam_size, time] candidate_seqs = utils.gather_2d(seqs, beam_indices) candidate_seqs = tf.concat( [candidate_seqs, tf.expand_dims(symbol_indices, 2)], 2) # Expand sequences # Suppress finished sequences flags = tf.equal(symbol_indices, eos_id) # [batch, 2 * beam_size] alive_scores = top_scores + tf.to_float(flags) * tf.float32.min # [batch, beam_size] alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size) alive_symbols = utils.gather_2d(symbol_indices, alive_indices) alive_indices = utils.gather_2d(beam_indices, alive_indices) alive_seqs = utils.gather_2d(seqs, alive_indices) # [batch_size, beam_size, time + 1] alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2) alive_state = nest.map_structure( lambda x: utils.gather_2d(x, alive_indices), next_state) alive_log_probs = alive_scores * length_penalty # Select finished sequences prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish # [batch, 2 * beam_size] step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min # [batch, 3 * beam_size] fin_flags = tf.concat([prev_fin_flags, flags], axis=1) fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1) # [batch, beam_size] fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size) fin_flags = utils.gather_2d(fin_flags, fin_indices) pad_seqs = tf.fill([batch_size, beam_size, 1], tf.constant(pad_id, tf.int32)) prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2) fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1) fin_seqs = utils.gather_2d(fin_seqs, fin_indices) new_state = BeamSearchState( inputs=(alive_seqs, alive_log_probs, alive_scores), state=alive_state, finish=(fin_flags, fin_seqs, fin_scores), ) return time + 1, new_state