コード例 #1
0
ファイル: beam_search.py プロジェクト: trantorznh/mesh
 def my_gather(tensor):
     return mtf.gather(tensor,
                       top_beam_index,
                       beam_dim,
                       output_shape=mtf.Shape([
                           double_beam if d == beam_dim else d
                           for d in tensor.shape.dims
                       ]))
コード例 #2
0
ファイル: beam_search.py プロジェクト: trantorznh/mesh
 def gather(tensor, name):
     with tf.name_scope(prefix + name):
         output_shape = mtf.Shape([
             beam_dim if d == old_beam_dim else d for d in tensor.shape.dims
         ])
         return mtf.gather(tensor,
                           topk_indices,
                           old_beam_dim,
                           output_shape=output_shape)
コード例 #3
0
ファイル: beam_search.py プロジェクト: trantorznh/mesh
    def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
                     finished_scores, finished_in_finished, *unused_states):
        """Checking termination condition.

    We terminate when we decoded up to decode_length or the lowest scoring item
    in finished has a greater score that the highest prob item in alive divided
    by the max length penalty

    Args:
      i: loop index
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_in_finished: finished bools for each of these sequences.
        [batch_size, beam_size]

    Returns:
      Bool.
    """
        # TODO(noam): support a different decode length...
        # decode_length = mtf.constant(mesh, length_dim.size, dtype=tf.int32)

        # del alive_log_probs, finished_scores, finished_in_finished
        # return mtf.less(i, length_dim.size)
        if not stop_early:
            return mtf.less(i, decode_length)
        max_length_penalty = mtf.pow(
            ((5. + mtf.cast(decode_length, finished_scores.dtype)) / 6.),
            alpha)
        # The best possible score of the most likely alive sequence.
        lower_bound_alive_scores = mtf.gather(
            alive_log_probs, mtf.constant(mesh, 0, dtype=tf.int32),
            beam_dim) / max_length_penalty

        # Now to compute the lowest score of a finished sequence in finished
        # If the sequence isn't finished, we multiply it's score by 0. since
        # scores are all -ve, taking the min will give us the score of the lowest
        # finished item.
        lowest_score_of_finished_in_finished = mtf.reduce_min(
            finished_scores *
            mtf.cast(finished_in_finished, finished_scores.dtype),
            reduced_dim=beam_dim)

        # If none of the sequences have finished, then the min will be 0 and
        # we have to replace it by -ve INF if it is. The score of any seq in alive
        # will be much higher than -ve INF and the termination condition will not
        # be met.
        lowest_score_of_finished_in_finished += ((1. - mtf.cast(
            mtf.reduce_any(finished_in_finished, reduced_dim=beam_dim),
            finished_scores.dtype)) * -INF)

        bound_is_met = mtf.reduce_all(
            mtf.greater(lowest_score_of_finished_in_finished,
                        lower_bound_alive_scores))
        return mtf.logical_and(mtf.less(i, decode_length),
                               mtf.logical_not(bound_is_met))
コード例 #4
0
ファイル: beam_search.py プロジェクト: trantorznh/mesh
    def body_fn(step_num, ids, *states):
        """Body function for greedy decoding.

    Args:
      step_num: a mtf.Tensor
      ids: a mtf.Tensor
      *states: additional mtf.Tensors
    Returns:
      new_step_num, new_ids, *new_states
    """
        logits, new_states = logits_fn(step_num, ids, states)
        vocab_dim = logits.shape.dims[-1]
        new_ids = mtf.sample_with_temperature(logits, vocab_dim, temperature)
        if forced_ids is not None:
            # force the new ids to equal the partial targets where specified
            # (positions where partial_targets contain nonzero values)
            forced = mtf.gather(forced_ids, step_num, length_dim)
            new_ids = forced + new_ids * mtf.to_int32(mtf.equal(forced, 0))
        ids += new_ids * mtf.one_hot(step_num, length_dim, dtype=tf.int32)
        new_step_num = step_num + 1
        return [new_step_num, ids] + new_states