def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, bs_stat_ta, predicted_ids): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. bs_stat_ta: structure of TensorArray. predicted_ids: A Tensor. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder def _decoding(_decoder, _input, _cache, _decoder_output_remover, _outputs_ta, _outputs_to_logits_fn): with tf.variable_scope(_decoder.name): _output, _next_cache = _decoder.step(_input, _cache) _decoder_top_features = _decoder.merge_top_features(_output) _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms), _outputs_ta, _decoder_output_remover.apply(_output)) _logit = _outputs_to_logits_fn(_decoder_top_features) return _output, _next_cache, _ta, _logit outputs, next_caches, next_outputs_tas, logits = repeat_n_times( num_models, _decoding, decoders, inputs, caches, decoder_output_removers, outputs_tas, outputs_to_logits_fns) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec( log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out), bs_stat_ta, infer_status) predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids) next_inputs = repeat_n_times(num_models, target_to_embedding_fns, next_input_symbols, time + 1) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, bs_stat_ta, \ next_predicted_ids
def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, bs_stat_ta, predicted_ids): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. bs_stat_ta: structure of TensorArray. predicted_ids: A Tensor. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder def _decoding(_decoder, _input, _cache, _decoder_output_remover, _outputs_ta, _outputs_to_logits_fn): with tf.variable_scope(_decoder.name): _output, _next_cache = _decoder.step(_input, _cache) _decoder_top_features = _decoder.merge_top_features(_output) _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms), _outputs_ta, _decoder_output_remover.apply(_output)) _logit = _outputs_to_logits_fn(_decoder_top_features) return _output, _next_cache, _ta, _logit outputs, next_caches, next_outputs_tas, logits = repeat_n_times( num_models, _decoding, decoders, inputs, caches, decoder_output_removers, outputs_tas, outputs_to_logits_fns) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec( log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out), bs_stat_ta, infer_status) predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids) next_inputs = repeat_n_times(num_models, target_to_embedding_fns, next_input_symbols, time + 1) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, bs_stat_ta, \ next_predicted_ids
def body_traininfer(time, inputs, cache, outputs_ta, finished, *args): """Internal while_loop body. Args: time: scalar int32 Tensor. inputs: The inputs Tensor. cache: The decoder states. outputs_ta: structure of TensorArray. finished: A bool tensor (keeping track of what's finished). args: The log_probs, lengths, infer_status for mode==INFER. Returns: `(time + 1, next_inputs, next_cache, outputs_ta, next_finished, *args)`. """ with tf.variable_scope(decoder.name): outputs, next_cache = decoder.step(inputs, cache) outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, decoder_output_remover.apply(outputs)) inner_loop_vars = [time + 1, None, None, outputs_ta, None] sample_ids = None if decoder.mode == ModeKeys.INFER: log_probs, lengths = args[0], args[1] bs_stat_ta = args[2] predicted_ids = args[3] logits = _compute_logits(decoder, target_modality, outputs) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) predicted_ids = gather_states( tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_cache["decoding_states"] = gather_states( next_cache["decoding_states"], beam_ids) bs_stat = BeamSearchStateSpec(log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure( lambda ta, out: ta.write(time, out), bs_stat_ta, bs_stat) next_predicted_ids = tf.concat( [predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) inner_loop_vars.extend( [next_log_probs, next_lengths, bs_stat_ta, next_predicted_ids]) next_finished, next_input_symbols = helper.next_symbols( time=time, sample_ids=sample_ids) next_inputs = _embed_words(target_modality, next_input_symbols, time + 1) next_finished = tf.logical_or(next_finished, finished) inner_loop_vars[1] = next_inputs inner_loop_vars[2] = next_cache inner_loop_vars[4] = next_finished return inner_loop_vars
def body_traininfer(time, inputs, decoder_states, outputs_ta, finished, *args): """Internal while_loop body. Args: time: scalar int32 Tensor. inputs: The inputs Tensor. decoder_states: The decoder states. outputs_ta: structure of TensorArray. finished: A bool tensor (keeping track of what's finished). args: The log_probs, lengths, infer_status for mode==INFER. Returns: `(time + 1, next_inputs, next_decoder_states, outputs_ta, next_finished, *args)`. """ with tf.variable_scope(decoder.name): inputs = inputs_preprocessing_fn(time, inputs) outputs, next_decoder_states = decoder.step(inputs, decoder_states, decoding_params) outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, decoder_output_remover.apply(outputs)) inner_loop_vars = [time + 1, None, None, outputs_ta, None] sample_ids = None prev_inputs = inputs if decoder.mode == ModeKeys.INFER: log_probs, lengths = args[0], args[1] infer_status_ta = args[2] logits = _compute_logits(decoder, target_modality, outputs) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) next_decoder_states = gather_states(next_decoder_states, beam_ids) prev_inputs = gather_states(inputs, beam_ids) infer_status = BeamSearchStateSpec( log_probs=next_log_probs, predicted_ids=sample_ids, beam_ids=beam_ids, lengths=next_lengths) infer_status_ta = nest.map_structure(lambda ta, out: ta.write(time, out), infer_status_ta, infer_status) inner_loop_vars.extend([next_log_probs, next_lengths, infer_status_ta]) next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids) next_inputs = _embed_words(target_modality, next_input_symbols, time + 1) with tf.variable_scope(decoder.name): next_inputs = inputs_postprocessing_fn(prev_inputs, next_inputs) next_finished = tf.logical_or(next_finished, finished) inner_loop_vars[1] = next_inputs inner_loop_vars[2] = next_decoder_states inner_loop_vars[4] = next_finished return inner_loop_vars
def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, infer_status_ta): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. infer_status_ta: structure of TensorArray. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder outputs = [] next_caches = [] for dec, inp, cache in zip(decoders, inputs, caches): with tf.variable_scope(dec.name): out, next_cache = dec.step(inp, cache) outputs.append(out) next_caches.append(next_cache) next_outputs_tas = [] for out_ta, out, rem in zip(outputs_tas, outputs, decoder_output_removers): ta = nest.map_structure(lambda ta, out: ta.write(time, out), out_ta, rem.apply(out)) next_outputs_tas.append(ta) logits = [] for dec, modality, out in zip(decoders, target_modalities, outputs): logits.append(_compute_logits(dec, modality, out)) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec(log_probs=next_log_probs, predicted_ids=sample_ids, beam_ids=beam_ids, lengths=next_lengths) infer_status_ta = nest.map_structure( lambda ta, out: ta.write(time, out), infer_status_ta, infer_status) next_finished, next_input_symbols = helper.next_symbols( time=time, sample_ids=sample_ids) next_inputs = nest.map_structure( lambda modality: _embed_words(modality, next_input_symbols, time + 1), target_modalities) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, infer_status_ta
def sample_symbols_new(logits, log_probs, finished, lengths, time): """ :param logits: [batch_size * beam_size, target_vocab_size] :param log_probs: [batch_size * beam_size, ] :param finished: [batch_size * beam_size, ] :param lengths: decoding length [batch_size * beam_size, ] :param time: :return: """ # [batch_size * beam_size,] prev_finished_float = math_ops.to_float(finished) # [batch_size * beam_size, ] prev_log_probs = log_probs # [batch_size * beam_size, target_vocab_size] probs = advanced_log_softmax(logits) # negative # mask the finished beam except only one entrance (target_eos_id) # [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0] # this forces the beam with EOS continue to generate EOS finished_beam_bias = finished_beam_one_entry_bias(on_entry=eos_id, num_entries=vocab_size) # [batch_size * beam_size, target_vocab_size]: outer product finished_beam_bias = expand_to_beam_size(finished_beam_bias, beam_size * batch_size, axis=0) finished_beam_bias *= array_ops.expand_dims(prev_finished_float, 1) # compute new probs, with finished flags & mask probs = probs * array_ops.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias # [batch_size * beam_size, target_vocab_size] # compute new log_probs log_probs = probs + array_ops.expand_dims(prev_log_probs, 1) # new decoding length: [batch_size * beam_size] lengths = lengths + 1 - math_ops.to_int32(finished) # compute beam score # length_penalty: [batch_size * beam_size,] length_penalty = math_ops.pow(((5.0 + math_ops.to_float(lengths)) / 6.0), -alpha) scores = log_probs * array_ops.expand_dims(length_penalty, axis=1) # flatten # [batch_size, beam_size * target_vocab_size] scores = array_ops.reshape(array_ops.reshape(scores, [-1]), [batch_size, -1]) ret_log_probs = array_ops.reshape(array_ops.reshape(log_probs, [-1]), [batch_size, -1]) scores_flat = control_flow_ops.cond( ops.convert_to_tensor(time) > 0, lambda: scores, # time > 0: all lambda: array_ops.slice(scores, [0, 0], [-1, vocab_size]) ) # time = 0: first logits in each batch # [batch_size, beam_size] will restore top live_k sample_scores, sample_ids = nn_ops.top_k(scores_flat, k=beam_size) ret_sample_ids = array_ops.reshape(sample_ids, [-1]) # flatten: [batch_size * beam_size,] sample_ids = array_ops.reshape(sample_ids, [-1]) # because we do topk to scores with dim:[batch, beam * vocab] # we need to cover the true word ids word_ids = math_ops.mod(sample_ids, vocab_size) # beam ids should be adjusted according to batch_size # batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], [batch_size,...] ] batch_pos = compute_batch_indices(batch_size, beam_size) # compute new beam_ids, [batch_size * beam_size, ] beam_ids = math_ops.div(sample_ids, vocab_size) \ + array_ops.reshape(batch_pos * beam_size, [-1]) # we need to recover log_probs from score # flatten sample_scores: [batch_size * beam_size,] sample_scores_flatten = array_ops.reshape(sample_scores, [-1]) # gather each length penalty length_penalty = gather_states(length_penalty, beam_ids) # recover log probabilities next_log_probs = sample_scores_flatten / length_penalty # gather states according to beam_ids next_lengths = gather_states(lengths, beam_ids) # [batch_size * beam_size * vocab_size, ] log_probs_flat = array_ops.reshape(log_probs, [-1]) log_probs_index = array_ops.reshape( batch_pos, [-1]) * beam_size * vocab_size + sample_ids next_log_probs = array_ops.gather(log_probs_flat, log_probs_index) return word_ids, beam_ids, next_log_probs, next_lengths, ret_log_probs, ret_sample_ids, length_penalty
def sample_symbols(self, logits, log_probs, finished, lengths, time): """ Samples symbols and returns it. Args: logits: The logits Tensor with shape [beam_size * batch_size, vocab_size], or a list of logits Tensors. log_probs: Accumulated log probabilities, a float32 Tensor with shape [beam_size * batch_size, ]. finished: Finished flag of each beam in each batch, a bool Tensor with shape [beam_size * batch_size, ]. lengths: The length of each beam in each batch, a int32 Tensor with shape [beam_size * batch_size, ]. time: A int32 Scalar, the current time. Returns: A tuple `(word_ids, beam_ids, next_log_probs, next_lengths)`, where `words_ids` is the ids of sampled word symbols; `beam_ids` indicates the index of beam which the symbol at the position is from; `next_log_probs` is the accumulated log probabilities of each beam; `next_lengths` is the decoding lengths of each beam. All of the Tensors have shape [batch_size * beam_size, ]. """ # [batch_size * beam_size,] prev_finished_float = tf.to_float(finished) # [batch_size * beam_size, ] prev_log_probs = log_probs # [batch_size * beam_size, target_vocab_size] probs = self._compute_log_probs(logits) # mask the finished beam except only one entrance (target_eos_id) # [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0] # this forces the beam with EOS continue to generate EOS finished_beam_bias = finished_beam_one_entry_bias( on_entry=self._vocab.eos_id, num_entries=self._vocab.vocab_size) # [batch_size * beam_size, target_vocab_size]: outer product finished_beam_bias = expand_to_beam_size( finished_beam_bias, self._beam_size * self._batch_size, axis=0) finished_beam_bias *= tf.expand_dims(prev_finished_float, 1) # compute new probs, with finished flags & mask probs = probs * tf.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias # [batch_size * beam_size, target_vocab_size] # compute new log_probs log_probs = probs + tf.expand_dims(prev_log_probs, 1) # new decoding length: [batch_size * beam_size] lengths = lengths + 1 - tf.to_int32(finished) # compute beam score # length_penalty: [batch_size * beam_size,] length_penalty = compute_length_penalty(lengths, self._alpha) scores = log_probs * tf.expand_dims(length_penalty, axis=1) # flatten: [batch_size, beam_size * target_vocab_size] scores = tf.reshape(tf.reshape(scores, [-1]), [self._batch_size, -1]) scores_flat = tf.cond( tf.convert_to_tensor(time) > 0, lambda: scores, # time > 0: all lambda: tf.slice(scores, [0, 0], [-1, self._vocab.vocab_size])) # time = 0: first logits in each batch # [batch_size, beam_size] will restore top live_k sample_scores, sample_ids = tf.nn.top_k(scores_flat, k=self._beam_size) # flatten: [batch_size * beam_size,] sample_ids = tf.reshape(sample_ids, [-1]) # because we do topk to scores with dim:[batch, beam * vocab] # we need to cover the true word ids word_ids = tf.mod(sample_ids, self._vocab.vocab_size) # find beam_ids, indicating the current position is from which beam # batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], ..., [batch_size,...] ] batch_pos = compute_batch_indices(self._batch_size, self._beam_size) # beam_base_pos: [batch_size * beam_size,]: [0, 0, ..., beam, beam,..., 2beam, 2beam, ...] beam_base_pos = tf.reshape(batch_pos * self._beam_size, [-1]) # compute new beam_ids, [batch_size * beam_size, ] beam_ids = tf.div(sample_ids, self._vocab.vocab_size) + beam_base_pos # gather states according to beam_ids next_lengths = gather_states(lengths, beam_ids) # we need to recover log_probs according to scores's topk ids # [batch_size * beam_size * vocab_size, ] log_probs_flat = tf.reshape(log_probs, [-1]) log_probs_index = beam_base_pos * self._vocab.vocab_size + sample_ids next_log_probs = tf.gather(log_probs_flat, log_probs_index) return word_ids, beam_ids, next_log_probs, next_lengths
def sample_symbols(self, logits, log_probs, finished, lengths, time): """ Samples symbols and returns it. Args: logits: The logits Tensor with shape [beam_size * batch_size, vocab_size], or a list of logits Tensors. log_probs: Accumulated log probabilities, a float32 Tensor with shape [beam_size * batch_size, ]. finished: Finished flag of each beam in each batch, a bool Tensor with shape [beam_size * batch_size, ]. lengths: The length of each beam in each batch, a int32 Tensor with shape [beam_size * batch_size, ]. time: A int32 Scalar, the current time. Returns: A tuple `(word_ids, beam_ids, next_log_probs, next_lengths)`, where `words_ids` is the ids of sampled word symbols; `beam_ids` indicates the index of beam which the symbol at the position is from; `next_log_probs` is the accumulated log probabilities of each beam; `next_lengths` is the decoding lengths of each beam. All of the Tensors have shape [batch_size * beam_size, ]. """ # [batch_size * beam_size,] prev_finished_float = tf.to_float(finished) # [batch_size * beam_size, ] prev_log_probs = log_probs # [batch_size * beam_size, target_vocab_size] probs = self._compute_log_probs(logits) # mask the finished beam except only one entrance (target_eos_id) # [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0] # this forces the beam with EOS continue to generate EOS finished_beam_bias = finished_beam_one_entry_bias( on_entry=self._vocab.eos_id, num_entries=self._vocab.vocab_size) # [batch_size * beam_size, target_vocab_size]: outer product finished_beam_bias = expand_to_beam_size( finished_beam_bias, self._beam_size * self._batch_size, axis=0) finished_beam_bias *= tf.expand_dims(prev_finished_float, 1) # compute new probs, with finished flags & mask probs = probs * tf.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias # [batch_size * beam_size, target_vocab_size] # compute new log_probs log_probs = probs + tf.expand_dims(prev_log_probs, 1) # new decoding length: [batch_size * beam_size] lengths = lengths + 1 - tf.to_int32(finished) # compute beam score # length_penalty: [batch_size * beam_size,] length_penalty = compute_length_penalty(lengths, self._alpha) scores = log_probs * tf.expand_dims(length_penalty, axis=1) # flatten: [batch_size, beam_size * target_vocab_size] scores = tf.reshape(tf.reshape(scores, [-1]), [self._batch_size, -1]) scores_flat = tf.cond( tf.convert_to_tensor(time) > 0, lambda: scores, # time > 0: all lambda: tf.slice(scores, [0, 0], [-1, self._vocab.vocab_size])) # time = 0: first logits in each batch # [batch_size, beam_size] will restore top live_k sample_scores, sample_ids = tf.nn.top_k(scores_flat, k=self._beam_size) # flatten: [batch_size * beam_size,] sample_ids = tf.reshape(sample_ids, [-1]) # because we do topk to scores with dim:[batch, beam * vocab] # we need to cover the true word ids word_ids = tf.mod(sample_ids, self._vocab.vocab_size) # find beam_ids, indicating the current position is from which beam # batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], ..., [batch_size,...] ] batch_pos = compute_batch_indices(self._batch_size, self._beam_size) # beam_base_pos: [batch_size * beam_size,]: [0, 0, ..., beam, beam,..., 2beam, 2beam, ...] beam_base_pos = tf.reshape(batch_pos * self._beam_size, [-1]) # compute new beam_ids, [batch_size * beam_size, ] beam_ids = tf.div(sample_ids, self._vocab.vocab_size) + beam_base_pos # gather states according to beam_ids next_lengths = gather_states(lengths, beam_ids) # we need to recover log_probs according to scores's topk ids # [batch_size * beam_size * vocab_size, ] log_probs_flat = tf.reshape(log_probs, [-1]) log_probs_index = beam_base_pos * self._vocab.vocab_size + sample_ids next_log_probs = tf.gather(log_probs_flat, log_probs_index) return word_ids, beam_ids, next_log_probs, next_lengths
def sample_symbols_new(logits, log_probs, finished, lengths, time): """ :param logits: [batch_size * beam_size, target_vocab_size] :param log_probs: [batch_size * beam_size, ] :param finished: [batch_size * beam_size, ] :param lengths: decoding length [batch_size * beam_size, ] :param time: :return: """ # [batch_size * beam_size,] prev_finished_float = math_ops.to_float(finished) # [batch_size * beam_size, ] prev_log_probs = log_probs # [batch_size * beam_size, target_vocab_size] probs = advanced_log_softmax(logits) # negative # mask the finished beam except only one entrance (target_eos_id) # [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0] # this forces the beam with EOS continue to generate EOS finished_beam_bias = finished_beam_one_entry_bias( on_entry=eos_id, num_entries=vocab_size) # [batch_size * beam_size, target_vocab_size]: outer product finished_beam_bias = expand_to_beam_size( finished_beam_bias, beam_size * batch_size, axis=0) finished_beam_bias *= array_ops.expand_dims(prev_finished_float, 1) # compute new probs, with finished flags & mask probs = probs * array_ops.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias # [batch_size * beam_size, target_vocab_size] # compute new log_probs log_probs = probs + array_ops.expand_dims(prev_log_probs, 1) # new decoding length: [batch_size * beam_size] lengths = lengths + 1 - math_ops.to_int32(finished) # compute beam score # length_penalty: [batch_size * beam_size,] length_penalty = math_ops.pow( ((5.0 + math_ops.to_float(lengths)) / 6.0), -alpha) scores = log_probs * array_ops.expand_dims(length_penalty, axis=1) # flatten # [batch_size, beam_size * target_vocab_size] scores = array_ops.reshape(array_ops.reshape(scores, [-1]), [batch_size, -1]) ret_log_probs = array_ops.reshape(array_ops.reshape(log_probs, [-1]), [batch_size, -1]) scores_flat = control_flow_ops.cond( ops.convert_to_tensor(time) > 0, lambda: scores, # time > 0: all lambda: array_ops.slice(scores, [0, 0], [-1, vocab_size])) # time = 0: first logits in each batch # [batch_size, beam_size] will restore top live_k sample_scores, sample_ids = nn_ops.top_k(scores_flat, k=beam_size) ret_sample_ids = array_ops.reshape(sample_ids, [-1]) # flatten: [batch_size * beam_size,] sample_ids = array_ops.reshape(sample_ids, [-1]) # because we do topk to scores with dim:[batch, beam * vocab] # we need to cover the true word ids word_ids = math_ops.mod(sample_ids, vocab_size) # beam ids should be adjusted according to batch_size # batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], [batch_size,...] ] batch_pos = compute_batch_indices(batch_size, beam_size) # compute new beam_ids, [batch_size * beam_size, ] beam_ids = math_ops.div(sample_ids, vocab_size) \ + array_ops.reshape(batch_pos * beam_size, [-1]) # we need to recover log_probs from score # flatten sample_scores: [batch_size * beam_size,] sample_scores_flatten = array_ops.reshape(sample_scores, [-1]) # gather each length penalty length_penalty = gather_states(length_penalty, beam_ids) # recover log probabilities next_log_probs = sample_scores_flatten / length_penalty # gather states according to beam_ids next_lengths = gather_states(lengths, beam_ids) # [batch_size * beam_size * vocab_size, ] log_probs_flat = array_ops.reshape(log_probs, [-1]) log_probs_index = array_ops.reshape(batch_pos, [-1]) * beam_size * vocab_size + sample_ids next_log_probs = array_ops.gather(log_probs_flat, log_probs_index) return word_ids, beam_ids, next_log_probs, next_lengths, ret_log_probs, ret_sample_ids, length_penalty