def cond_fn(step_num, prev_ids, *unused_states): """Should we run another loop iteration.""" overflow = mtf.equal(step_num, num_steps) has_eos = mtf.reduce_any(mtf.equal(prev_ids, eos_id), reduced_dim=length_dim) all_has_eos = mtf.reduce_all(has_eos) return mtf.logical_not(mtf.logical_or(overflow, all_has_eos))
def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores, finished_in_finished, *unused_states): """Checking termination condition. We terminate when we decoded up to decode_length or the lowest scoring item in finished has a greater score that the highest prob item in alive divided by the max length penalty Args: i: loop index alive_log_probs: probabilities of the beams. [batch_size, beam_size] finished_scores: scores for each of these sequences. [batch_size, beam_size] finished_in_finished: finished bools for each of these sequences. [batch_size, beam_size] Returns: Bool. """ # TODO(noam): support a different decode length... # decode_length = mtf.constant(mesh, length_dim.size, dtype=tf.int32) # del alive_log_probs, finished_scores, finished_in_finished # return mtf.less(i, length_dim.size) if not stop_early: return mtf.less(i, decode_length) max_length_penalty = mtf.pow( ((5. + mtf.cast(decode_length, finished_scores.dtype)) / 6.), alpha) # The best possible score of the most likely alive sequence. lower_bound_alive_scores = mtf.gather( alive_log_probs, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim) / max_length_penalty # Now to compute the lowest score of a finished sequence in finished # If the sequence isn't finished, we multiply it's score by 0. since # scores are all -ve, taking the min will give us the score of the lowest # finished item. lowest_score_of_finished_in_finished = mtf.reduce_min( finished_scores * mtf.cast(finished_in_finished, finished_scores.dtype), reduced_dim=beam_dim) # If none of the sequences have finished, then the min will be 0 and # we have to replace it by -ve INF if it is. The score of any seq in alive # will be much higher than -ve INF and the termination condition will not # be met. lowest_score_of_finished_in_finished += ((1. - mtf.cast( mtf.reduce_any(finished_in_finished, reduced_dim=beam_dim), finished_scores.dtype)) * -INF) bound_is_met = mtf.reduce_all( mtf.greater(lowest_score_of_finished_in_finished, lower_bound_alive_scores)) return mtf.logical_and(mtf.less(i, decode_length), mtf.logical_not(bound_is_met))