def my_gather(tensor): return mtf.gather(tensor, top_beam_index, beam_dim, output_shape=mtf.Shape([ double_beam if d == beam_dim else d for d in tensor.shape.dims ]))
def gather(tensor, name): with tf.name_scope(prefix + name): output_shape = mtf.Shape([ beam_dim if d == old_beam_dim else d for d in tensor.shape.dims ]) return mtf.gather(tensor, topk_indices, old_beam_dim, output_shape=output_shape)
def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores, finished_in_finished, *unused_states): """Checking termination condition. We terminate when we decoded up to decode_length or the lowest scoring item in finished has a greater score that the highest prob item in alive divided by the max length penalty Args: i: loop index alive_log_probs: probabilities of the beams. [batch_size, beam_size] finished_scores: scores for each of these sequences. [batch_size, beam_size] finished_in_finished: finished bools for each of these sequences. [batch_size, beam_size] Returns: Bool. """ # TODO(noam): support a different decode length... # decode_length = mtf.constant(mesh, length_dim.size, dtype=tf.int32) # del alive_log_probs, finished_scores, finished_in_finished # return mtf.less(i, length_dim.size) if not stop_early: return mtf.less(i, decode_length) max_length_penalty = mtf.pow( ((5. + mtf.cast(decode_length, finished_scores.dtype)) / 6.), alpha) # The best possible score of the most likely alive sequence. lower_bound_alive_scores = mtf.gather( alive_log_probs, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim) / max_length_penalty # Now to compute the lowest score of a finished sequence in finished # If the sequence isn't finished, we multiply it's score by 0. since # scores are all -ve, taking the min will give us the score of the lowest # finished item. lowest_score_of_finished_in_finished = mtf.reduce_min( finished_scores * mtf.cast(finished_in_finished, finished_scores.dtype), reduced_dim=beam_dim) # If none of the sequences have finished, then the min will be 0 and # we have to replace it by -ve INF if it is. The score of any seq in alive # will be much higher than -ve INF and the termination condition will not # be met. lowest_score_of_finished_in_finished += ((1. - mtf.cast( mtf.reduce_any(finished_in_finished, reduced_dim=beam_dim), finished_scores.dtype)) * -INF) bound_is_met = mtf.reduce_all( mtf.greater(lowest_score_of_finished_in_finished, lower_bound_alive_scores)) return mtf.logical_and(mtf.less(i, decode_length), mtf.logical_not(bound_is_met))
def body_fn(step_num, ids, *states): """Body function for greedy decoding. Args: step_num: a mtf.Tensor ids: a mtf.Tensor *states: additional mtf.Tensors Returns: new_step_num, new_ids, *new_states """ logits, new_states = logits_fn(step_num, ids, states) vocab_dim = logits.shape.dims[-1] new_ids = mtf.sample_with_temperature(logits, vocab_dim, temperature) if forced_ids is not None: # force the new ids to equal the partial targets where specified # (positions where partial_targets contain nonzero values) forced = mtf.gather(forced_ids, step_num, length_dim) new_ids = forced + new_ids * mtf.to_int32(mtf.equal(forced, 0)) ids += new_ids * mtf.one_hot(step_num, length_dim, dtype=tf.int32) new_step_num = step_num + 1 return [new_step_num, ids] + new_states