def ExtractBlockContext(x,
                        block_size,
                        left_context,
                        right_context,
                        padding_val=0.0):
    """Extracts temporal context for every block.

  Args:
    x: a tensor of [batch, time, ...].
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.
    padding_val: float. value on the padded frames.

  Returns:
    A tensor of [batch, num_blocks, context_size, ...], with necessary paddings,
    where context_size = block_size + (left_context - 1) + right_context,
    and output[:, i, ...] are x[:, start-left_context+1:end+right_context, ...],
    start = i * block_size, end = (i + 1) * block_size.
  """
    if block_size < 1:
        raise ValueError(
            'block_size must be at least 1, got {}'.format(block_size))
    if left_context < 1 or left_context > block_size + 1:
        raise ValueError(
            'left_context must be at least 1 and at most block_size + 1 = {}, '
            'got {}'.format(block_size + 1, left_context))
    if right_context < 0 or right_context > block_size:
        raise ValueError(
            'right_context must be at least 0 and at most block_size = {}, '
            'got {}'.format(block_size, right_context))

    block = ConvertToBlocks(x, block_size, padding_val)
    concat_list = [block]

    if left_context > 1:
        if block_size == left_context - 1:
            left_block = tf.roll(block, shift=1, axis=1)
        else:
            x_shift = tf.roll(x, shift=left_context - 1, axis=1)
            x_shift_block = ConvertToBlocks(x_shift, block_size, padding_val)
            left_block = x_shift_block[:, :, :left_context - 1:, ...]
        concat_list = [left_block] + concat_list

    if right_context > 0:
        if block_size == right_context:
            right_block = tf.roll(block, shift=-1, axis=1)
        else:
            x_shift = tf.roll(x, shift=-right_context, axis=1)
            x_shift_block = ConvertToBlocks(x_shift, block_size, padding_val)
            right_block = x_shift_block[:, :, -right_context:, ...]
        concat_list += [right_block]

    return tf.concat(concat_list, axis=2)
 def _InputBatch(self):
     targets = tf.ones([self.params.batch_size, 1024], dtype=tf.int32)
     input_batch = py_utils.NestedMap()
     input_batch.tgt = py_utils.NestedMap()
     input_batch.tgt.ids = tf.roll(targets, 1, axis=1)
     input_batch.tgt.labels = targets
     input_batch.tgt.segment_ids = tf.minimum(targets, 1)
     input_batch.tgt.segment_pos = targets
     input_batch = input_batch.Transform(
         lambda t: tf.ensure_shape(t, (self.params.batch_size, 1024)))
     return input_batch
Exemple #3
0
def _BBoxArea(bbox):
    """Computes the area of a 2-d bbox.

  Vertices must be ordered clockwise or counter-clockwise. This function can
  technically handle any kind of convex polygons.

  Args:
    bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates
      are the four corners of the bbox and (x, y). The corners must be given in
      counter-clockwise order.

  Returns:
    Area of the bbox. Tensor of shape [..., 1].
  """
    bbox_roll = tf.roll(bbox, shift=1, axis=-2)
    det = tf.reduce_sum(
        bbox[..., 0] * bbox_roll[..., 1] - bbox[..., 1] * bbox_roll[..., 0],
        axis=-1,
        keepdims=True) / 2.0
    return tf.abs(det)
Exemple #4
0
    def ComputeLoss(self, theta, predictions, input_batch):
        p = self.params
        # Computes the loss for input_batch.
        with self._DecoderDevice():
            result = self.dec.ComputeLoss(theta.dec, predictions,
                                          input_batch.tgt)
            if self.do_eval:
                return result

        probs = result[1]['reshape_probs']
        probs_hard = result[1]['target_hard_probs']
        atten_probs = predictions.attention.probs

        if 'other_src' in input_batch and 'other_tgt' in input_batch:
            other_batch = py_utils.NestedMap()
            other_batch.src = input_batch.other_src.DeepCopy()
            other_batch.tgt = input_batch.other_tgt.DeepCopy()
        else:
            other_batch = py_utils.NestedMap()
            other_batch.src = input_batch.src.DeepCopy()
            other_batch.tgt = input_batch.tgt.DeepCopy()
            other_batch = other_batch.Transform(lambda x: tf.roll(x, 1, 0))
            other_atten_probs = tf.roll(atten_probs, 1, 0)
            other_probs = tf.roll(probs, 1, 0)
            other_probs_hard = tf.roll(probs_hard, 1, 0)
            other_predictions = py_utils.NestedMap()
            other_predictions.source_embs = tf.roll(predictions.source_embs, 1,
                                                    0)
            other_predictions.target_embs = tf.roll(predictions.target_embs, 1,
                                                    0)

        # Computes the loss for other_batch.
        if p.loss_mono_weight > 0:
            other_predictions = self.ComputePredictions(theta, other_batch)
            with self._DecoderDevice():
                other_result = self.dec.ComputeLoss(theta.dec,
                                                    other_predictions,
                                                    other_batch.tgt)
                other_atten_probs = other_predictions.attention.probs
                other_probs = other_result[1]['reshape_probs']
                other_probs_hard = other_result[1]['target_hard_probs']

        # Computes the xendec loss.
        if p.loss_mix_weight > 0:
            if p.atten_drop > 0:
                atten_probs = tf.nn.dropout(atten_probs, p.atten_drop)
                if other_atten_probs is not None:
                    other_atten_probs = tf.nn.dropout(other_atten_probs,
                                                      p.atten_drop)
            if other_atten_probs is not None:
                if p.use_prob_cl:
                    cur_step = py_utils.GetGlobalStep()
                    cur_ratio = tf.minimum(
                        tf.cast(cur_step, py_utils.FPropDtype(p)) / 20000, 1.0)
                    probs_hard = tf.cast(probs_hard, py_utils.FPropDtype(p))
                    other_probs_hard = tf.cast(other_probs_hard,
                                               py_utils.FPropDtype(p))
                    prob_ratio = tf.expand_dims(input_batch.tgt.weights,
                                                -1) * cur_ratio
                    probs = probs_hard * (1.0 -
                                          prob_ratio) + probs * prob_ratio
                    other_prob_ratio = tf.expand_dims(other_batch.tgt.weights,
                                                      -1) * cur_ratio
                    other_probs = other_probs_hard * (
                        1.0 -
                        other_prob_ratio) + other_probs * other_prob_ratio
                else:
                    probs = tf.cast(probs_hard, py_utils.FPropDtype(p))
                    other_probs = tf.cast(other_probs_hard,
                                          py_utils.FPropDtype(p))

            source_paddings_pair = [
                input_batch.src.paddings, other_batch.src.paddings
            ]
            target_paddings_pair = [
                input_batch.tgt.paddings, other_batch.tgt.paddings
            ]

            source_mask = input_batch.src.source_mask
            other_lambdas = source_mask * (1. - source_paddings_pair[1])
            source_lambdas = (1. - other_lambdas) * (1. -
                                                     source_paddings_pair[0])
            source_lambdas = [source_lambdas, other_lambdas]

            source_lambdas, input_lambdas, label_lambdas = self._CreateTargetLambdas(
                [atten_probs, other_atten_probs],
                source_lambdas,
                source_paddings_pair,
                target_paddings_pair,
                smooth=0.)

            mix_tgt = input_batch.tgt
            target_weights = input_batch.tgt.weights + other_batch.tgt.weights
            target_weights = tf.clip_by_value(target_weights, 0.0, 1.0)
            mix_tgt.weights = target_weights

            input_batch.src.embs = predictions.source_embs
            input_batch.tgt.embs = predictions.target_embs
            other_batch.src.embs = other_predictions.source_embs
            other_batch.tgt.embs = other_predictions.target_embs

            mix_predictions = self.ComputePredictions(theta, input_batch,
                                                      other_batch,
                                                      source_lambdas,
                                                      input_lambdas)

            target_probs = probs * tf.expand_dims(
                label_lambdas[0], -1) + other_probs * tf.expand_dims(
                    label_lambdas[1], -1)

            target_probs = target_probs + 1e-9
            target_probs = target_probs / tf.reduce_sum(
                target_probs, -1, keepdims=True)

            with self._DecoderDevice():
                mix_result = self.dec.ComputeLoss(theta.dec, mix_predictions,
                                                  mix_tgt, target_probs)

        losses = []
        loss_names = []
        loss_weights = []
        new_metrics = {}

        if p.loss_clean_weight > 0:
            losses.append(result)
            loss_weights.append(p.loss_clean_weight)
            loss_names.append('clean_loss')

        if p.loss_mono_weight > 0:
            losses.append(other_result)
            loss_weights.append(p.loss_mono_weight)
            loss_names.append('other_loss')

        if p.loss_mix_weight > 0.0:
            losses.append(mix_result)
            loss_weights.append(p.loss_mix_weight)
            loss_names.append('mix_loss')

        combined_loss = 0
        num_predictions = 1.
        # Combines three losses.
        for i in range(len(loss_names)):
            combined_loss += losses[i][0]['loss'][0] * loss_weights[i]
            if loss_names[i] == 'clean_loss':
                num_predictions = losses[i][0]['loss'][1]
            new_metrics[loss_names[i]] = (losses[i][0]['loss'][0] *
                                          loss_weights[i],
                                          losses[i][0]['loss'][1])
        new_metrics['loss'] = (combined_loss, num_predictions)
        return new_metrics, losses[0][1]
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        # Create prefix of start tokens.
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(beam_size - 1, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ - - - - - - P P P P P P P* - - - ]   ^
        # [ - - P P P P P P P P P P P* - - - ]   | batch
        # [ - - - - - - - - - - - P P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_indexes = pfx_time - pfx_max + tf.expand_dims(pfx_len - 1, 1)
        pfx_pad = tf.cast(tf.greater_equal(pfx_indexes, 0),
                          tf.int32)  # Exclude final pfx token.
        pfx_id = tf.roll(pfx, shift=1, axis=-1) * pfx_pad
        pfx_last = pfx[:, -1]

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        # Remove padding.
        assert buf_size > pfx_max
        pfx_pad_long = tf.pad(pfx_pad, [(0, 0), (0, buf_size - pfx_max)],
                              constant_values=1)
        pfx_mask *= tf.cast(tf.expand_dims(pfx_pad_long, axis=1), tf.float32)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_indexes * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest