Ejemplo n.º 1
0
 def __init__(self, params):
     super(TestInputGenerator, self).__init__(params)
     p = self.params
     self._bprop_variable_filters = ['']
     self._bprop_onehot = tf.constant([1], dtype=tf.float32)
     if p.target_key and not p.target_key_target_shape:
         raise ValueError('target_key_target_shape must be set when '
                          'target_key (%s) is not empty.' % p.target_key)
     if (p.set_tgt_and_additional_tgts
             and (p.target_key_target_shape[0] != p.target_shape[0])):
         raise ValueError(
             'The first dimension of target_key_target_shape (%d) '
             'should match the first dimension of target_shape '
             '(%d) when both have to be set.' %
             (p.target_key_target_shape[0], p.target_shape[0]))
     self._cur_iter = 0
     if p.bprop_filters and p.number_sources:
         raise ValueError(
             'Number of sources will be set to length of bprop_filters, the param'
             'number_sources should not be used when bprop_filters is set.')
     number_sources = p.number_sources
     if p.bprop_filters:
         self._bprop_variable_filters = p.bprop_filters
         number_sources = len(p.bprop_filters)
     if number_sources and number_sources > 1:
         self._bprop_onehot = tf.one_hot(p.source_selected,
                                         number_sources,
                                         dtype=tf.float32)
Ejemplo n.º 2
0
    def ComputeLoss(self, theta, predictions, input_batch):
        p = self.params
        batch = tf.shape(input_batch.data)[0]
        act = predictions.act
        with tf.colocate_with(act):
            tf.logging.info("{}'s device: {}".format(act, act.device))
            # Softmax
            labels = tf.to_int64(input_batch.label)
            onehot_labels = tf.one_hot(labels, p.softmax.num_classes)
            if p.label_smoothing > 0:
                smooth_positives = 1.0 - p.label_smoothing
                smooth_negatives = p.label_smoothing / p.softmax.num_classes
                onehot_labels = onehot_labels * smooth_positives + smooth_negatives

            xent = self.softmax.FProp(theta=theta.softmax,
                                      inputs=act,
                                      class_weights=input_batch.weight,
                                      class_probabilities=onehot_labels)

        self._AddSummary(input_batch, xent.per_example_argmax)

        rets = {
            'loss': (xent.avg_xent, batch),
            'log_pplx': (xent.avg_xent, batch),
            'num_preds': (batch, 1),
        }
        if p.is_eval or p.compute_accuracy_for_training:
            acc1 = self._Accuracy(1, xent.logits, labels, input_batch.weight)
            acc5 = self._Accuracy(5, xent.logits, labels, input_batch.weight)
            rets.update(accuracy=(acc1, batch), acc5=(acc5, batch))
        return rets, {}
Ejemplo n.º 3
0
 def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                               unused_step_ids, states,
                               unused_num_hyps_per_beam):
     # Same probs for each id.
     logits = tf.zeros([tgt_batch_size, vocab_size])
     # Except eoc has slightly lower score.
     logits = logits - 1.0 * tf.expand_dims(
         tf.one_hot(p.target_eoc_id, vocab_size), 0)
     # eos has very low score (can not terminate by eos)
     logits = logits + eos_score * tf.expand_dims(
         tf.one_hot(p.target_eos_id, vocab_size), 0)
     return py_utils.NestedMap(
         atten_probs=tf.zeros([tgt_batch_size, 0]),
         log_probs=logits,
         is_last_chunk=tf.fill([tgt_batch_size],
                               value=is_last_chunk)), states
Ejemplo n.º 4
0
    def _ComputeClassificationLoss(self, predictions, input_batch,
                                   class_weights):
        """Compute classification loss for the given predictions.

    Args:
      predictions: The output of `ComputePredictions`, contains: logits - [b,
        nx, ny, nz, na, 7 + num_classes]. na is the number of anchor
        boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt).
      input_batch: The input batch from which we accesses the groundtruth.
      class_weights: Per-class weights to use in loss computation.

    Returns:
      Classification loss.

    """
        p = self.params
        predicted_class_logits = py_utils.HasShape(
            predictions.classification_logits,
            [-1, -1, -1, -1, p.num_anchors, p.num_classes])
        bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6)
        assigned_gt_labels = py_utils.HasShape(input_batch.assigned_gt_labels,
                                               [bs, nx, ny, nz, na])
        class_loss = py_utils.SigmoidCrossEntropyFocalLoss(
            logits=predicted_class_logits,
            labels=tf.one_hot(assigned_gt_labels, p.num_classes),
            alpha=p.focal_loss_alpha,
            gamma=p.focal_loss_gamma)
        class_loss *= class_weights[..., tf.newaxis]
        class_loss_sum = tf.reduce_sum(class_loss)
        return class_loss_sum
Ejemplo n.º 5
0
 def _ConcatOnehotFn(input_data):
   """Concat the input features with a onehot version of the label ids."""
   features = input_data.features
   label = input_data.label
   num_pts = tf.shape(features)[1]
   label_one_hot = tf.one_hot(tf.cast(label, tf.int32), depth=16)
   label_one_hot = tf.tile(tf.expand_dims(label_one_hot, 1), [1, num_pts, 1])
   input_data.features = tf.concat([features, label_one_hot], axis=-1)
   return input_data
Ejemplo n.º 6
0
    def _PostprocessSample(self, sample, is_tpu):
        """Add topk_hyps, topk_ids, topk_lens, topk_scores tensors to `sample`.

    These features are required by `.BeamSearchDecodeOutput`.

    Args:
      sample: a NestedMap with `id`, `paddings`, and `logits` fields.
      is_tpu: whether inference is being run on TPU.

    Returns:
      sample with additional feature that matches `.BeamSearchDecodeOutput`
      requirements. `topk_hyps` is empty.
    """
        p = self.params
        bs = tf.shape(sample.ids)[0]
        num_hyps_per_beam = p.target_sequence_sampler.num_hyps_per_beam
        vocab_size = tf.shape(sample.logits)[2]

        # tf.string is not supported on tpu.
        sample.topk_hyps = tf.zeros([bs],
                                    dtype=tf.int32 if is_tpu else tf.string)
        sample.topk_hyps = tf.reshape(sample.topk_hyps,
                                      [-1, num_hyps_per_beam])

        sample.topk_ids = sample.ids
        weights = 1 - sample.paddings
        sample.topk_lens = tf.cast(tf.reduce_sum(weights, axis=1),
                                   dtype=tf.int32)
        # Computing the hypothesis scores based on the returned ids
        mask = tf.one_hot(sample.topk_ids,
                          depth=vocab_size,
                          axis=-1,
                          dtype=sample.logits.dtype)
        token_log_probs = tf.einsum('ijk,ijk->ij',
                                    tf.nn.log_softmax(sample.logits), mask)
        sample.topk_scores = tf.reduce_sum(token_log_probs * weights, axis=1)
        # At this point batch dimension is (batch_size*num_hyps_per_beam),
        # interleaved as [num_hyps_per_beam, batch_size].
        # This does not match the order expected by beam search post-processing.
        # Must transpose to [batch_size, num_hyps_per_beam] and flatten back.
        max_len = tf.shape(sample.topk_ids)[1]
        sample.topk_ids = tf.reshape(sample.topk_ids,
                                     [num_hyps_per_beam, -1, max_len])
        sample.topk_ids = tf.transpose(sample.topk_ids, perm=[1, 0, 2])
        sample.topk_ids = tf.reshape(sample.topk_ids, [bs, max_len])

        # The same for topk_lens and topk_scores
        sample.topk_lens = tf.reshape(sample.topk_lens,
                                      [num_hyps_per_beam, -1])
        sample.topk_lens = tf.transpose(sample.topk_lens, [1, 0])
        sample.topk_lens = tf.reshape(sample.topk_lens, [-1])

        sample.topk_scores = tf.reshape(sample.topk_scores,
                                        [num_hyps_per_beam, -1])
        sample.topk_scores = tf.transpose(sample.topk_scores, [1, 0])
        sample.topk_scores = tf.reshape(sample.topk_scores, [-1])
        return sample
Ejemplo n.º 7
0
 def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                               unused_step_ids, states,
                               unused_num_hyps_per_beam):
   # Same probs for each id.
   logits = tf.zeros([tgt_batch_size, vocab_size])
   # Except eos is slightly lower prob.
   logits = logits - 1.0 * tf.expand_dims(
       tf.one_hot(p.target_eos_id, vocab_size), 0)
   return py_utils.NestedMap(
       atten_probs=tf.zeros([tgt_batch_size, 0]), log_probs=logits), states
Ejemplo n.º 8
0
  def FProp(self, theta, ids, segment_pos):
    p = self.params
    fprop_dtype = py_utils.FPropDtype(p)

    ids = self._MaybeSplit(ids)
    segment_pos = self._MaybeSplit(segment_pos)

    one_hot_ids = tf.one_hot(ids, p.vocab_size, dtype=fprop_dtype)
    one_hot_ids = self._MaybeSplit(one_hot_ids)

    one_hot_pos = tf.one_hot(segment_pos, p.max_len, dtype=fprop_dtype)
    one_hot_pos = self._MaybeSplit(one_hot_pos)

    token_emb = tf.einsum('VH,BLV->BLH', theta.embedding, one_hot_ids)
    token_emb = self._MaybeSplit(token_emb)

    pos_emb = tf.einsum('VH,BLV->BLH', theta.pos_emb, one_hot_pos)
    pos_emb = self._MaybeSplit(pos_emb)
    return self._MaybeSplit(token_emb + pos_emb)
Ejemplo n.º 9
0
 def test_entmax_loss_generate_right_loss(self):
     inputs = tf.constant([[[0.5, 1.0, 2.0]] * 3], dtype='bfloat16')
     labels = tf.constant([[0, 1, 2]])
     # Convert to the matrix with given depth, e.g. the vocabulary size.
     labels = tf.one_hot(labels, depth=3)
     expected_loss = tf.constant([[1.5642307, 1.0642307, 0.06423065]],
                                 dtype='bfloat16')
     entmax_loss_val = entmax.entmax_loss(labels, inputs, alpha=1.5)
     with self.session(use_gpu=False) as sess:
         output = sess.run(entmax_loss_val)
         self.assertAllClose(expected_loss, output)
Ejemplo n.º 10
0
 def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                               unused_step_ids, states, num_hyps_per_beam):
   self.assertEqual(1, num_hyps_per_beam)
   logits = tf.random.stateless_normal([batch_size, vocab_size],
                                       seed=[8273747, 9])
   # Make it never predict <eos>.
   logits -= tf.one_hot([p.target_eos_id], vocab_size, 1e30)
   is_last_chunk = tf.equal(states.src_step, src_len - 1)
   result = py_utils.NestedMap(
       log_probs=logits, is_last_chunk=is_last_chunk)
   return result, states
Ejemplo n.º 11
0
 def unmask(h, m):
     with tf.name_scope('unmask'):
         tpu_summary.tensor('unmask_h', h)
         tpu_summary.tensor('unmask_m', m)
         t = tf.cumsum(m, -1) * m - 1
         mh = einsum_i32('bkt,bt->bkt', m, h)
         t2 = tf.one_hot(tf.cast(t, tf.int32),
                         output_len,
                         dtype=fprop_dtype)
         x = einsum_i32('bkt,bktT->bkT', mh, t2)
         return tf.cast(x, h.dtype)
Ejemplo n.º 12
0
  def testOneHotLabels(self):
    """Tests that the loss equals softmax CE when the labels are one hot."""
    num_classes = 400
    batch_size = 7
    label_indices = np.random.randint(0, num_classes, size=(batch_size, 3))
    labels = tf.one_hot(label_indices, depth=num_classes, dtype=tf.float32)
    logits = np.random.uniform(size=(batch_size, 3, num_classes)) * 10 + 1e7
    logits_tensor = tf.convert_to_tensor(logits, dtype=tf.float32)

    losses = label_lib.MultiLabelContrastiveLoss(labels, logits_tensor)
    expected = tf.nn.softmax_cross_entropy_with_logits(
        labels=labels, logits=logits_tensor)
    self.assertAllClose(expected, losses)
Ejemplo n.º 13
0
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent
Ejemplo n.º 14
0
    def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask,
                     dec_state, t):
        del tgt_pos, tgt_segment_id

        [buf] = dec_state
        if tgt_id.shape == (self.batch_size, self.beam_size):
            buf = inplace_ops.alias_inplace_update(buf, t, tgt_id)
        else:
            div = int(tgt_id.shape[1] // self.beam_size)
            for i, x_i in enumerate(tf.split(tgt_id, div, 1)):
                buf = inplace_ops.alias_inplace_update(buf, t + i, x_i)

        buf1 = tf.transpose(buf, [1, 0, 2])
        buf1 = tf.reshape(buf1,
                          [self.batch_size, self.max_steps * self.beam_size])

        # select next_tgt_id as a function of previous target tokens
        if self.rule == '+1':
            next_tgt_id = (tgt_id + 1)
            next_tgt_id %= self.vocab_size
        elif self.rule == 'sum':
            # sum over all previous tokens in tgt_mask
            next_tgt_id = tf.einsum('BT,BKT->BK', buf1,
                                    tf.cast(tgt_mask, tf.int32))
            next_tgt_id %= self.vocab_size
        elif self.rule == 'fib':
            # select last token according to tgt_mask
            m = tgt_mask
            m *= tf.cast(
                tf.equal(tf.cumsum(m, -1),
                         tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype)
            last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32))
            next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size

        # with a lower probably add extra +1 to the correct next_tgt_id
        n = self.vocab_size
        logits = 5 * tf.one_hot(next_tgt_id % n, n)
        logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n)
        logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n)
        logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n)
        logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n)

        # increase eos_score if current tgt_id contains 9
        eos_id = 0
        tgt_id_contains_9 = tf.logical_or(tf.equal(tgt_id % 10, 9),
                                          tf.equal((tgt_id // 10) % 10, 9))
        logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot(
            eos_id, self.vocab_size), tf.cast(tgt_id_contains_9, tf.float32))

        # tie-breaking -- lower token id wins a little bit
        tie = np.arange(0., 1., 1. / n)
        tie /= tie.sum()
        logits -= tie

        logits = tf.nn.log_softmax(logits)

        dec_state = [buf]
        return logits, dec_state
Ejemplo n.º 15
0
def update_nbest(nbest_hyps, cur_hyps):
    """Updates nbest hyps from cur_hyps. Returns new values for nbest_hyps."""
    with tf.name_scope('update_nbest'):
        (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
        (cur_mask, cur_score, cur_score_norm) = cur_hyps
        k = int(nbest_mask.shape[1])
        m = int(cur_mask.shape[1])
        mask = tf.concat([nbest_mask, cur_mask], 1)
        score = tf.concat([nbest_score, cur_score], 1)
        score_norm = tf.concat([nbest_score_norm, cur_score_norm], 1)
        nbest_score_norm, i = tf.math.top_k(score_norm, k)
        i_one_hot = tf.one_hot(i, k + m, dtype=mask.dtype)
        nbest_mask = tf.einsum('bkt,bjk->bjt', mask, i_one_hot)
        nbest_score = tf.einsum('bk,bjk->bj', score, i_one_hot)
        return (nbest_mask, nbest_score, nbest_score_norm)
Ejemplo n.º 16
0
 def _Slice(tensor):
   """Return a slice of this tensor at time=state0.t."""
   shape = py_utils.GetShape(tensor)
   # All zeros except for t in the time dimension.
   # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...]
   begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t)
   # Same as shape, but with a 1 in the time dimension.
   # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...]
   size = tf.concat([
       shape[0:self.params.axis],
       tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:]
   ],
                    axis=0)
   # Make a slice where the time dimension is fixed at state0.t.
   time_slice = tf.slice(tensor, begin, size)
   # Remove the time dimension.
   return tf.squeeze(time_slice, axis=self.params.axis)
Ejemplo n.º 17
0
                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                                  axis=1))
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                                                                  1)))
                    consistent = tf.math.logical_and(states.consistent,
                                                     local_consistence)

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,
                                                  py_utils.FPropDtype(p))

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                             vocab_size,
                                             dtype=py_utils.FPropDtype(
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent
Ejemplo n.º 18
0
    def test_entmax_loss_generate_right_gradient(self):
        inputs = tf.constant([[0.5, 1.0, 2.0]] * 3)
        labels = tf.constant([0, 1, 2])
        expected_loss_gradient = tf.constant(
            [[[-0.97671956, 0.16207013, 0.8146494],
              [0.02328045, -0.83792984, 0.8146494],
              [0.02328045, 0.16207013, -0.1853506]]])
        # Convert to the matrix with given depth, e.g. the vocabulary size.
        labels = tf.one_hot(labels, depth=3)
        expected_loss = tf.constant(2.692692)
        entmax_loss_val = tf.reduce_sum(entmax.entmax_loss(
            labels, inputs, 1.5))
        entmax_loss_gradient_val = tf.gradients(entmax_loss_val, inputs)

        with self.session(use_gpu=False) as sess:
            loss_output = sess.run(entmax_loss_val)
            gradient_output = sess.run(entmax_loss_gradient_val)
            self.assertAllClose(expected_loss, loss_output)
            self.assertAllClose(expected_loss_gradient, gradient_output)
Ejemplo n.º 19
0
  def ComputeLoss(self, theta, predictions, input_batch):
    p = self.params
    batch = tf.shape(input_batch.data)[0]
    act = predictions.act
    with tf.ops.colocate_with(act):
      tf.logging.info("{}'s device: {}".format(act, act.device))
      # Softmax
      if py_utils.GetRank(input_batch.label) == 1:
        # Create one_hot labels if rank is 1.
        labels = tf.cast(input_batch.label, tf.int64)
        onehot_labels = tf.one_hot(labels, p.softmax.num_classes)
      else:
        onehot_labels = input_batch.label
        labels = tf.math.argmax(onehot_labels, axis=-1)
      if p.label_smoothing > 0:
        smooth_positives = 1.0 - p.label_smoothing
        smooth_negatives = p.label_smoothing / p.softmax.num_classes
        onehot_labels = onehot_labels * smooth_positives + smooth_negatives

      xent = self.softmax.FProp(
          theta=theta.softmax,
          inputs=act,
          class_weights=input_batch.weight,
          class_probabilities=onehot_labels)

    self._AddSummary(input_batch, xent.per_example_argmax)

    rets = {
        'loss': (xent.avg_xent, batch),
        'log_pplx': (xent.avg_xent, batch),
        'num_preds': (batch, 1),
    }
    if self.do_eval or p.compute_accuracy_for_training:
      acc1 = self._Accuracy(1, xent.logits, labels, input_batch.weight)
      acc5 = self._Accuracy(5, xent.logits, labels, input_batch.weight)
      rets.update(
          accuracy=(acc1, batch),
          acc5=(acc5, batch),
          error=(1. - acc1, batch),
          error5=(1. - acc5, batch))
    return rets, {'loss': xent.per_example_xent}
Ejemplo n.º 20
0
  def _ComputeGradientMask(self, bprop_variable_filters):
    """Compute gradient mask for each variable and bprop_variable_filters.

    Note that per_input_gradient_mask[var][i] will be 1 if var matches
    bprop_variable_filter[i], 0 otherwise.

    Args:
      bprop_variable_filters: A list of regex bprop_variable_filters for each
        file pattern.
    """
    self._per_input_gradient_mask = py_utils.NestedMap()
    all_vars = set(self.vars.Flatten())
    for var in all_vars:
      self._per_input_gradient_mask[var.name] = (
          tf.zeros(len(bprop_variable_filters), dtype=tf.float32))
      for i in range(len(bprop_variable_filters)):
        if re.search(bprop_variable_filters[i], var.name):
          tf.logging.info('Keep gradient after filtering, regex: %s var: %s' %
                          (bprop_variable_filters[i], var.name))
          self._per_input_gradient_mask[var.name] += (
              tf.one_hot(i, len(bprop_variable_filters), dtype=tf.float32))
Ejemplo n.º 21
0
  def FProp(self, theta, x):
    p = self.params
    if not hasattr(theta, 't'):
      return x
    t = theta.t
    if t is None:
      return x
    assert hasattr(theta, 'state')
    state = theta.state

    tf.logging.info('p.name=%r', p.name)
    tf.logging.info('state=%r', state)
    tf.logging.info('x=%r', x)
    tf.logging.info('t=%r', t)

    with tf.name_scope(p.name):
      if not self._for_flat_beam_search:
        # For tpu_beam_search_helper
        z = tf.one_hot(t, tf.shape(state)[1])
        z = tf.expand_dims(z, 0)
        while len(z.shape) < len(x.shape):
          z = tf.expand_dims(z, -1)
        y = state = (1 - z) * state + z * x
      if self._for_flat_beam_search:
        # For flat beam search
        state = tf.InplaceUpdate(state, t, x)
        # [T,B,L,...]
        y = state
        # [T, B, L, ...] -> [B, T, L, ...]
        perm = list(range(len(y.shape)))
        perm[:2] = [1, 0]
        y = tf.transpose(y, perm)
        # [B, T, L, ...] -> [B, T*L, ...]
        y_shape = list(y.shape)
        y_shape[1:3] = [int(y_shape[1]) * int(y_shape[2])]
        y = tf.reshape(y, y_shape)
    theta.state = state

    tf.logging.info('y=%r', y)
    return y
Ejemplo n.º 22
0
def MatmulGather(source, indices):
    """Drop in replacement for tf.gather_nd() optimized for speed on TPU.

  TODO(weihan): tf.gather_nd() is supposed to be implemented in the same way
  on TPU. Investigate why it's much slower.

  Args:
    source: tensor of shape [N, P1, C]
    indices: tensor of shape [N, P2, K]

  Returns:
    tensor of shape [N, P2, K, C]
  """
    source = py_utils.HasRank(source, 3)
    n, p1, c = py_utils.GetShape(source)
    indices = py_utils.HasShape(indices, [n, -1, -1])
    _, p2, k = py_utils.GetShape(indices)

    onehot = tf.one_hot(indices, depth=p1)  # N x P2 x K x P1
    reshaped = tf.reshape(onehot, [n, -1, p1])  # N x (P2 x K) x P1
    target = tf.matmul(reshaped, source)  # N x (P2 x K) x C
    return tf.reshape(target, [n, p2, k, c])
Ejemplo n.º 23
0
    def ComputeLoss(self, theta, predictions, input_batch):
        """Compute loss for the sparse detector model v1.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      predictions: A `.NestedMap` object containing residuals and
        classification_logits.
      input_batch: A `.NestedMap` expected to contain cell_center_xyz,
        cell_points_xyz, cell_feature, anchor_bboxes,
        anchor_localization_residuals, assigned_gt_labels, and
        assigned_cls_mask. See class doc string for details.

    Returns:
      Two dicts:

      - A dict containing str keys and (metric, weight) pairs as values, where
        one of the keys is expected to be 'loss'.
      - A dict containing arbitrary tensors describing something about each
        training example, where the first dimension of each tensor is the batch
        index.
    """
        p = self.params

        batch_size, num_centers = py_utils.GetShape(
            input_batch.cell_center_xyz, 2)

        # Assert shapes of inputs.
        anchor_bboxes = py_utils.HasShape(
            input_batch.anchor_bboxes,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
        anchor_localization_residuals = py_utils.HasShape(
            input_batch.anchor_localization_residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
        predicted_residuals = py_utils.HasShape(
            predictions.residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])

        assigned_gt_labels = py_utils.HasShape(
            input_batch.assigned_gt_labels,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        predicted_classification_logits = py_utils.HasShape(
            predictions.classification_logits, [
                batch_size, num_centers, p.num_anchor_bboxes_per_center,
                p.num_classes
            ])

        # assigned_cls_mask is for weighting the classification loss.
        # Ignored targets will have their mask = 0; this happens when their IOU is
        # not high enough to be a foreground object and not low enough to be
        # background.
        class_weights = py_utils.HasShape(
            input_batch.assigned_cls_mask,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        class_weights = tf.reshape(
            class_weights,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1])

        # Broadcast per class loss weights. For each anchor, there are num_classes
        # prediction heads, we weight the outputs of these heads by the per class
        # loss weights.
        per_class_loss_weight = tf.constant([[[p.per_class_loss_weight]]],
                                            dtype=tf.float32)
        per_class_loss_weight = py_utils.HasShape(per_class_loss_weight,
                                                  [1, 1, 1, p.num_classes])
        class_weights *= per_class_loss_weight
        class_weights = py_utils.HasShape(class_weights, [
            batch_size, num_centers, p.num_anchor_bboxes_per_center,
            p.num_classes
        ])

        # We use assigned_reg_mask for masking the regression loss.
        # Only foreground objects will have assigned_reg_mask = 1.
        reg_weights = py_utils.HasShape(
            input_batch.assigned_reg_mask,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        reg_weights = tf.reshape(
            reg_weights,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1])

        if p.loss_norm_type == LossNormType.NORM_BY_NUM_POS_PER_CENTER:
            # Compute number of positive anchors per example.
            foreground_mask = py_utils.HasShape(
                input_batch.assigned_reg_mask,
                [batch_size, num_centers, p.num_anchor_bboxes_per_center])

            # Sum to get the number of foreground anchors for each example.
            loss_normalization = tf.reduce_sum(foreground_mask, axis=2)
            loss_normalization = tf.maximum(loss_normalization,
                                            tf.ones_like(loss_normalization))

            # Reshape for broadcasting.
            loss_normalization = tf.reshape(loss_normalization,
                                            [batch_size, num_centers, 1, 1])

            # Normalize so that the loss is independent of # centers.
            loss_normalization *= num_centers
            class_weights /= loss_normalization
            reg_weights /= loss_normalization

        classification_loss = py_utils.SigmoidCrossEntropyFocalLoss(
            logits=predicted_classification_logits,
            labels=tf.one_hot(assigned_gt_labels, p.num_classes),
            alpha=p.focal_loss_alpha,
            gamma=p.focal_loss_gamma)

        # Apply mask.
        classification_loss *= class_weights

        # TODO(jngiam): Consider normalizing by num_foreground_anchors for each
        # example instead. This would match the 1/N_positive normalization in
        # point pillars.

        # Reduce sum over centers, boxes and classes.
        classification_loss = tf.reduce_sum(classification_loss,
                                            axis=[1, 2, 3])

        # Reduce mean over batch.
        classification_loss = tf.reduce_mean(classification_loss)

        # Localization regression loss with Huber loss (SmoothL1).
        regression_loc_and_dims_loss = self._utils_3d.ScaledHuberLoss(
            labels=anchor_localization_residuals[..., :6],
            predictions=predicted_residuals[..., :6],
            delta=p.huber_loss_delta)

        # Rotation loss is computed on a transform on rotation_delta. For a
        # direction aware loss, we simply wrap the angles to -pi to pi; for a loss
        # that is symmetric to direction (i.e., rotating by pi), we use a sin
        # transform.
        rotation_delta_transform = tf.sin
        if p.direction_aware_rot_loss:
            rotation_delta_transform = functools.partial(geometry.WrapAngleRad,
                                                         min_val=-np.pi,
                                                         max_val=np.pi)
        rotation_delta = (predicted_residuals[..., 6:] -
                          anchor_localization_residuals[..., 6:])
        regression_rotation_loss = self._utils_3d.ScaledHuberLoss(
            labels=tf.zeros_like(rotation_delta),
            predictions=rotation_delta_transform(rotation_delta),
            delta=p.huber_loss_delta)

        reg_loc_loss = regression_loc_and_dims_loss[..., :3]
        reg_dim_loss = regression_loc_and_dims_loss[..., 3:6]

        gt_bboxes = self._utils_3d.ResidualsToBBoxes(
            anchor_bboxes,
            anchor_localization_residuals,
            min_angle_rad=-np.pi,
            max_angle_rad=np.pi)
        predicted_bboxes = self._utils_3d.ResidualsToBBoxes(
            anchor_bboxes,
            predicted_residuals,
            min_angle_rad=-np.pi,
            max_angle_rad=np.pi)

        # Apply mask to individual losses.
        #
        # And then reduce sum over centers, boxes, residuals, and batch
        # and divide by the batch_size.
        regression_rotation_loss *= reg_weights
        reg_rot_loss = tf.reduce_sum(regression_rotation_loss) / batch_size

        reg_loc_loss *= reg_weights
        reg_loc_loss = tf.reduce_sum(reg_loc_loss) / batch_size

        reg_dim_loss *= reg_weights
        reg_dim_loss = tf.reduce_sum(reg_dim_loss) / batch_size

        # Do not create corner loss graph if weight is 0.0
        # TODO(bcyang): Remove condition after fixing corner loss NaN issue
        if p.corner_loss_weight != 0.0:
            reg_corner_loss = self._utils_3d.CornerLoss(
                gt_bboxes=gt_bboxes, predicted_bboxes=predicted_bboxes)
            reg_corner_loss = tf.expand_dims(reg_corner_loss, axis=-1)

            reg_corner_loss *= reg_weights
            reg_corner_loss = tf.reduce_sum(reg_corner_loss) / batch_size
        else:
            reg_corner_loss = 0.0

        # Sum components of regression loss.
        regression_loss = (p.location_loss_weight * reg_loc_loss +
                           p.dimension_loss_weight * reg_dim_loss +
                           p.rotation_loss_weight * reg_rot_loss +
                           p.corner_loss_weight * reg_corner_loss)

        # Compute total loss.
        total_loss = (p.loss_weight_localization * regression_loss +
                      p.loss_weight_classification * classification_loss)

        metrics_dict = py_utils.NestedMap({
            'loss': (total_loss, batch_size),
            'loss/regression': (regression_loss, batch_size),
            'loss/regression/loc': (reg_loc_loss, batch_size),
            'loss/regression/dim': (reg_dim_loss, batch_size),
            'loss/regression/rot': (reg_rot_loss, batch_size),
            'loss/regression/corner': (reg_corner_loss, batch_size),
            'loss/classification': (classification_loss, batch_size),
        })

        # Calculate dimension errors
        dimension_errors_dict = self._BBoxDimensionErrors(
            gt_bboxes, predicted_bboxes, reg_weights)
        metrics_dict.update(dimension_errors_dict)

        per_example_dict = py_utils.NestedMap({
            'residuals': predicted_residuals,
            'classification_logits': predicted_classification_logits,
            'predicted_bboxes': predicted_bboxes,
            'gt_bboxes': gt_bboxes,
            'reg_weights': reg_weights,
        })

        return metrics_dict, per_example_dict
Ejemplo n.º 24
0
 def Sum(theta, state, inputs):
     next_state = py_utils.NestedMap()
     v = tf.reduce_sum(tf.one_hot(inputs.one_hot, depth=2) * theta.x,
                       axis=0)
     next_state.sum = state.sum + v
     return next_state, py_utils.NestedMap()
Ejemplo n.º 25
0
    def ComputePredictions(self, theta, input_batch):
        """Computes predictions for `input_batch`.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` expected to contain lasers.points_xyz,
        lasers.points_feature, lasers.points_padding, cell_center_xyz,
        cell_points_xyz, cell_feature, anchor_bboxes,
        anchor_localization_residuals, assigned_gt_labels, and
        assigned_cls_mask. See class doc string for details.

    Returns:
      A `.NestedMap` object containing residuals and classification_logits.
    """
        p = self.params
        input_batch.Transform(lambda x:
                              (x.shape, x.shape.num_elements())).VLog(
                                  1, 'input_batch shapes: ')
        cell_feature = py_utils.HasRank(input_batch.cell_feature, 4)
        batch_size, num_centers = py_utils.GetShape(cell_feature, 2)

        featurized_cell = self._CellFeaturizer(theta, input_batch)

        # Project each featurized_cell features to each bbox per center.
        featurized_anchors = self.cell_feature_projector.FProp(
            theta.cell_feature_projector, featurized_cell)

        # Reshape output so that we have features per offset.
        featurized_anchors = tf.reshape(
            featurized_anchors,
            [batch_size, num_centers, p.num_anchor_bboxes_offsets, -1])

        # Predict localization residuals.
        predicted_residuals = self.localization_regressor.FProp(
            theta.localization_regressor, featurized_anchors)
        predicted_residuals = tf.reshape(
            predicted_residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])

        if any([p.oracle_location, p.oracle_dimension, p.oracle_rotation]):
            gt_residuals = py_utils.HasShape(
                input_batch.anchor_localization_residuals,
                [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
            residuals = []
            if p.oracle_location:
                residuals.append(gt_residuals[..., 0:3])
            else:
                residuals.append(predicted_residuals[..., 0:3])

            if p.oracle_dimension:
                residuals.append(gt_residuals[..., 3:6])
            else:
                residuals.append(predicted_residuals[..., 3:6])

            if p.oracle_rotation:
                residuals.append(gt_residuals[..., 6:])
            else:
                residuals.append(predicted_residuals[..., 6:])
            predicted_residuals = tf.concat(residuals, axis=-1)

        if p.squash_rotation_predictions:
            predicted_rotations = predicted_residuals[..., 6:]
            predicted_rotations = np.pi * tf.tanh(predicted_rotations)
            predicted_residuals = tf.concat(
                [predicted_residuals[..., :6], predicted_rotations], axis=-1)

        # Predict object classification at each bbox.
        predicted_classification_logits = self.classifier.FProp(
            theta.classifier, featurized_anchors)
        predicted_classification_logits = tf.reshape(
            predicted_classification_logits, [
                batch_size, num_centers, p.num_anchor_bboxes_per_center,
                p.num_classes
            ])

        if p.oracle_classification:
            assigned_gt_labels = py_utils.HasShape(
                input_batch.assigned_gt_labels,
                [batch_size, num_centers, p.num_anchor_bboxes_per_center])
            predicted_classification_logits = tf.one_hot(
                assigned_gt_labels, p.num_classes)

        return py_utils.NestedMap({
            'residuals':
            predicted_residuals,
            'classification_logits':
            predicted_classification_logits,
        })
Ejemplo n.º 26
0
def MaxPool3D(points, point_features, pooling_idx, closest_idx):
    """Apply max pooling to a point cloud with computed sampling indices.

  sampled_idx and closest_idx are the outputs of a sampler such as
  FurthestPointSampler.

  The pooling operation results in a point cloud with fewer points, where the
  pooled points are specified by pooling_idx. Each element of pooling_idx
  contains an integer in the range [0, P1) containing the index of the point in
  points/points_features.

  Max pooling is performed by assigning each point to its closest pooled point,
  and then taking a max over the features of points assigned. We assume that
  this mapping is provided by closest_idx, where each element should contain
  an integer in the range [0, P2) containing the index of the pooled point that
  each point is assigned to.

  Note: This logic for pooling assumes that there will be at least
  one value > 0 per sampled region for each feature, otherwise it will return 0.
  Additionally, it does a reduce over a masked version of the features, so
  mean and min would not work without a change in the logic.

  Args:
    points: a floating point tf.Tensor with shape [N, P1, 3]
    point_features: a floating point tf.Tensor with shape [N, P1, C]
    pooling_idx: A tf.int32 tf.Tensor of shape [N, P2] with the index of which
      points we want to keep. Each value should be in the range [0, P1].
    closest_idx: A tf.int32 tf.Tensor of shape [N, P1] representing which
      sampled point is closest to each original point. Each value should be in
      the range of [0, P2].

  Returns:
    A tuple of tf.Tensors (pooled_points, pooled_features).

    pooled_points has shape [N, P2, 3] representing the locations of each
    selected point. P2 corresponds to num_pooled_points.

    pooled_features has shape [N, P2, C] representing the pooled features at
    each point.
  """
    batch_size, num_points = py_utils.GetShape(points, 2)
    point_features = py_utils.HasShape(point_features,
                                       [batch_size, num_points, -1])
    pooling_idx = py_utils.HasShape(pooling_idx, [batch_size, -1])
    _, num_output_points = py_utils.GetShape(pooling_idx)
    _, _, feature_dims = py_utils.GetShape(point_features, 3)

    # Gather new point locations.
    pooled_points = tf.batch_gather(points, pooling_idx)

    mask = tf.one_hot(closest_idx, num_output_points)  # [N, P1, P2]
    mask = tf.transpose(mask, [2, 0, 1])  # [P2, N, P1]

    def _PartialPoolFeaturesFn(partial_mask):
        partial_mask = tf.tile(
            tf.reshape(partial_mask, [batch_size, num_points, 1]),
            [1, 1, feature_dims])
        # Note: This method of pooling assumes there will be a value > 0
        # And will only work with max under this condition.
        return tf.reduce_max(partial_mask * point_features, axis=1)

    # Performing a map_fn over the pooled points is more memory efficient.
    pooled_point_features = tf.map_fn(_PartialPoolFeaturesFn,
                                      mask)  # [P2, N, P1]
    pooled_point_features = tf.transpose(pooled_point_features, [1, 0, 2])

    return pooled_points, pooled_point_features
Ejemplo n.º 27
0
  def FProp(self, theta, x, paddings=None, update=False):
    """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
    p = self.params
    if paddings is None:
      paddings = tf.zeros_like(x[:, :, 0, 0])
    # Shape [B, L, 1, 1]
    paddings_4d = paddings[:, :, None, None]

    if p.apply_layer_norm:
      x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

    # 'x' is normalized (but theta.means is not), we use negative dot product to
    # approximate the Euclidean distance here.
    dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means)

    # For padded positions we update the distances to very large numbers.
    very_large_dists = tf.ones_like(dists) * tf.constant(
        0.1, dtype=dists.dtype) * dists.dtype.max
    paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
    dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

    # Shape [B, L, N, K], the same as 'dists' above.
    nearest_one_hot = tf.one_hot(
        tf.math.argmin(dists, axis=-1),
        p.num_clusters,
        dtype=py_utils.FPropDtype(p))
    # Same shape as the input 'x'.
    nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                 theta.means)
    diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid))
    diff = py_utils.ApplyPadding(paddings_4d, diff)
    diff = tf.math.reduce_mean(diff, axis=2)

    # The commitment loss which when back proped against encourages the 'x'
    # values to commit to their chosen centroids.
    k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings)
    summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

    # TODO(zhouwk): investigate normalizing theta.means after each update.
    means_norm = tf.norm(theta.means)
    summary_utils.scalar('k_means/centroid_l2_norm/min',
                         tf.math.reduce_min(means_norm))
    summary_utils.scalar('k_means/centroid_l2_norm/mean',
                         tf.math.reduce_mean(means_norm))

    if not update:
      return dists, k_means_loss

    # To update the centroids (self.vars.means), we apply gradient descent on
    # the mini-batch of input 'x', which yields the following:
    #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
    # where x_mean is the average over all the input vectors closest to this
    # centroid.
    #
    # Note that this approach is equivalent with backprop via
    #    loss = tf.math.reduce_mean(
    #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
    # , except that here the learning rate is independently set via 'decay'.

    # Ensure that the padded positions are not used to update the centroids.
    nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

    # Sum away batch and sequence length dimensions to get per cluster count.
    # Shape: [N, K]
    per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
    summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count)

    # Sum of the input 'x' per each closest centroid.
    sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

    if py_utils.use_tpu():
      per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
      sum_x = tf.tpu.cross_replica_sum(sum_x)

    # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
    # cluster's position will always be 0, hence 'sum_x' in that dimension will
    # be 0.
    new_means = sum_x / tf.maximum(
        tf.constant(1.0, dtype=per_cluster_count.dtype),
        tf.expand_dims(per_cluster_count, axis=-1))

    # We use exponential moving average. TODO(zhouwk): investigate smooth this
    # over an exponentially moving averaged per cluster count.
    #
    # Note that we intentionally do not normalize the means after this update
    # as empirically this works better.
    update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means),
                                self.vars.means.dtype)
    return py_utils.with_dependencies(
        [tf.assign_add(self.vars.means, update_means_diff)],
        dists), k_means_loss
Ejemplo n.º 28
0
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ P P P P P P - - - - - - P* - - - ]   ^
        # [ P P P P P P P P P P - - P* - - - ]   | batch
        # [ P - - - - - - - - - - - P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_pad = tf.cast(
            tf.less(tf.expand_dims(pfx_time, 0),
                    tf.expand_dims(pfx_len - 1, 1)), tf.int32)
        pfx_id = pfx * pfx_pad
        pfx_last = einsum_i32(
            'BT,BT->B', pfx, tf.one_hot(pfx_len - 1,
                                        pfx_max,
                                        dtype=fprop_dtype))

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_time * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest
Ejemplo n.º 29
0
    def ComputeLoss(self, theta, predictions, input_batch):
        """Computes loss and other metrics for the given predictions.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      predictions: The output of `ComputePredictions`, contains: logits - [b,
        nx, ny, nz, na, 7 + num_classes]. na is the number of anchor
        boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt).
      input_batch: The input batch from which we accesses the groundtruth.

    Returns:
      Two dicts defined as BaseTask.ComputeLoss.
    """
        p = self.params
        predicted_residuals = py_utils.HasShape(
            predictions.residuals, [-1, -1, -1, -1, p.num_anchors, 7])
        predicted_class_logits = py_utils.HasShape(
            predictions.classification_logits,
            [-1, -1, -1, -1, p.num_anchors, p.num_classes])
        bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6)

        # Compute class and regression weights.
        class_weights = input_batch.assigned_cls_mask
        class_weights = py_utils.HasShape(class_weights, [bs, nx, ny, nz, na])
        reg_weights = input_batch.assigned_reg_mask
        reg_weights = py_utils.HasShape(reg_weights, [bs, nx, ny, nz, na])
        reg_weights = tf.expand_dims(reg_weights, -1)

        if p.loss_norm_type == LossNormType.NORM_BY_NUM_POSITIVES:
            # Compute number of positive anchors per example.
            foreground_mask = py_utils.HasShape(input_batch.assigned_reg_mask,
                                                [bs, nx, ny, nz, na])
            # Sum to get the number of foreground anchors for each example.
            loss_normalization = tf.reduce_sum(foreground_mask,
                                               axis=[1, 2, 3, 4])
            loss_normalization = tf.maximum(loss_normalization,
                                            tf.ones_like(loss_normalization))
            # Reshape for broadcasting.
            loss_normalization = tf.reshape(loss_normalization,
                                            [bs, 1, 1, 1, 1, 1])

            class_weights /= loss_normalization
            reg_weights /= loss_normalization

        # Classification loss.
        assigned_gt_labels = py_utils.HasShape(input_batch.assigned_gt_labels,
                                               [bs, nx, ny, nz, na])
        class_loss = py_utils.SigmoidCrossEntropyFocalLoss(
            logits=predicted_class_logits,
            labels=tf.one_hot(assigned_gt_labels, p.num_classes),
            alpha=p.focal_loss_alpha,
            gamma=p.focal_loss_gamma)
        class_loss *= class_weights[..., tf.newaxis]
        class_loss_sum = tf.reduce_sum(class_loss)

        # Regression loss.
        anchor_localization_residuals = py_utils.HasShape(
            input_batch.anchor_localization_residuals, [bs, nx, ny, nz, na, 7])

        # Location and dimensions loss.
        reg_loc_and_dims_loss = self._utils.ScaledHuberLoss(
            predictions=py_utils.HasShape(predicted_residuals[..., :6],
                                          [bs, nx, ny, nz, na, 6]),
            labels=anchor_localization_residuals[..., :6],
            delta=1 / (3.**2))

        # Rotation loss with SmoothL1(sin(delta)).
        rot_delta = (predicted_residuals[..., 6:] -
                     input_batch.anchor_localization_residuals[..., 6:])

        if p.use_atan2_heading_loss:
            atan2_of_delta = tf.atan2(tf.sin(rot_delta), tf.cos(rot_delta))
            reg_rot_loss = self._utils.ScaledHuberLoss(
                predictions=atan2_of_delta,
                labels=tf.zeros_like(atan2_of_delta),
                delta=1 / (3.**2))
        else:
            # Rotation loss with SmoothL1(sin(delta)).
            reg_rot_loss = self._utils.ScaledHuberLoss(
                predictions=tf.sin(rot_delta),
                labels=tf.zeros_like(rot_delta),
                delta=1 / (3.**2))

        # Direction loss
        if p.direction_classifier_weight > 0.0:
            # The target rotations are in the assigned_gt_bbox tensor,
            # which already has assigned a gt bounding box to every anchor.
            rot_target = input_batch.assigned_gt_bbox[..., 6]
            # If rotation is > 0, the class is 1, else it is 0.
            rot_dir = tf.cast(rot_target > 0., tf.int32)

            # Compute one-hot labels as a target.
            rot_dir_onehot = tf.one_hot(rot_dir, 2)

            # Manually handle loss reduction.
            dir_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=rot_dir_onehot,
                logits=predictions.predicted_dir,
                weights=tf.squeeze(reg_weights, axis=-1),
                reduction=tf.losses.Reduction.NONE)
            # Reduce across all dimensions (we'll divide by the batch size below).
            dir_loss_sum = tf.reduce_sum(dir_loss)
        else:
            dir_loss_sum = 0.0

        # Compute loss contribution from location and dimension separately.
        reg_loc_loss = reg_loc_and_dims_loss[..., :3] * reg_weights
        reg_loc_loss_sum = tf.reduce_sum(reg_loc_loss)

        reg_dim_loss = reg_loc_and_dims_loss[..., 3:6] * reg_weights
        reg_dim_loss_sum = tf.reduce_sum(reg_dim_loss)

        # Compute rotation loss contribution.
        reg_rot_loss *= reg_weights
        reg_rot_loss_sum = tf.reduce_sum(reg_rot_loss)

        # Num. predictions.
        # TODO(zhifengc): Consider other normalization factors. E.g., # of bboxes.
        preds = tf.cast(bs, class_loss_sum.dtype)

        # Normalize all of the components by batch size.
        reg_loc_loss = reg_loc_loss_sum / preds
        reg_dim_loss = reg_dim_loss_sum / preds
        reg_rot_loss = reg_rot_loss_sum / preds
        class_loss = class_loss_sum / preds
        dir_loss = dir_loss_sum / preds

        # Compute total localization regression loss.
        reg_loss = (p.location_loss_weight * reg_loc_loss +
                    p.dimension_loss_weight * reg_dim_loss +
                    p.rotation_loss_weight * reg_rot_loss)

        # Apply weights to normalized class losses.
        loss = (class_loss * p.classification_loss_weight +
                reg_loss * p.localization_loss_weight +
                dir_loss * p.direction_classifier_weight)

        metrics_dict = {
            'loss': (loss, preds),
            'loss/class': (class_loss, preds),
            'loss/reg': (reg_loss, preds),
            'loss/reg/rot': (reg_rot_loss, preds),
            'loss/reg/loc': (reg_loc_loss, preds),
            'loss/reg/dim': (reg_dim_loss, preds),
            'loss/dir': (dir_loss, preds),
        }

        # Calculate dimension errors
        min_angle_rad = -np.pi if p.use_atan2_heading_loss else 0
        gt_bboxes = self._utils_3d.ResidualsToBBoxes(
            input_batch.anchor_bboxes,
            anchor_localization_residuals,
            min_angle_rad=min_angle_rad,
            max_angle_rad=np.pi)
        predicted_bboxes = self._utils_3d.ResidualsToBBoxes(
            input_batch.anchor_bboxes,
            predicted_residuals,
            min_angle_rad=min_angle_rad,
            max_angle_rad=np.pi)
        dimension_errors_dict = self._BBoxDimensionErrors(
            gt_bboxes, predicted_bboxes, reg_weights)
        metrics_dict.update(dimension_errors_dict)

        per_example_dict = {
            'residuals': predicted_residuals,
            'classification_logits': predicted_class_logits,
        }

        return metrics_dict, per_example_dict
Ejemplo n.º 30
0
    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state