Пример #1
def _SampleGumbelWithMax(phi, target_max, batch_seed, time_step, src_ids,
    """Samples a set of Gumbel noises with a specified maximum value.

  A set of values are sampled from Gumbel distributions with location parameters
  `phi` under the condition that their maximum is equal to `target_max`.

  The numerical stable implementation from Appendix B.3 of
  https://arxiv.org/pdf/1903.06059.pdf is used.

    phi: A float tensor of shape [tgt_batch, k] thtat represents location
      parameters of Gumbel distributions.
    target_max: A float tensor of shape [tgt_batch, 1] that represents the
      target max values.
    batch_seed: An int tensor of shape [src_batch] that holds a seed value for
      each batch item. src_batch must be equal to tgt_batch / num_hyps_per_beam.
      The same seed is used within each consecutive num_hyps_per_beam items
      along the tgt_batch axis.
    time_step: A float tensor used as a secondary seed.
    src_ids: An int tensor of shape [src_batch, src_seq] that represents source
      IDs. Used for turning the random seed into a function of source IDs.
    src_paddings: A 0/1 float tensor of shape [src_batch, src_seq] where 1 means
      that the corresponding element of src_ids is a padding.

    A float tensor like `phi` where their maximum values along the second axis
    is (almost) equal to `target_max`.
    dtype = phi.dtype
    tgt_batch = tf.shape(phi)[0]
    k = tf.shape(phi)[1]
    src_batch = tf.shape(batch_seed)[0]
    num_hyps_per_beam = tgt_batch // src_batch

    # Sample noises from Gumbel distributions with location parameters `phi`.
    # shape: [src_batch, num_hyps_per_beam, k]
    gumbel_noises = _BatchSampleGumbel(batch_seed, time_step, src_ids,
                                       src_paddings, [num_hyps_per_beam, k],
    # shape: [num_hyps_per_beam, src_batch, k]
    gumbel_noises = tf.transpose(gumbel_noises, perm=[1, 0, 2])
    # shape: [tgt_batch, k]
    gumbel_noises = tf.reshape(gumbel_noises, tf.shape(phi))
    # shape: [tgt_batch, k]
    g_phi = phi + gumbel_noises

    # shape: [tgt_batch, 1]
    z = tf.reduce_max(g_phi, axis=1, keepdims=True)

    # Equation (23).
    # shape: [tgt_batch, k]
    v = target_max - g_phi + tf.math.log1p(
        # Without taking max, sometimes the result of log1p would become NaN on
        # TPU.
        tf.maximum(-tf.exp(g_phi - z), tf.constant(-1., dtype=dtype)))

    # Equation (24).
    return target_max - tf.nn.relu(v) - tf.math.log1p(tf.exp(-tf.abs(v)))
Пример #2
def _KeepTopP(sorted_log_probs, p):
    """Keeps the top-p probability mass of `sorted_log_probs`.

  For each row, elements that are not included in the first `p` probability mass
  are set to `LARGE_NEGATIVE_NUMBER`. The first element is always kept as-is.

    sorted_log_probs: A float tensor of shape [batch, k] that represents
      log-probabilities sorted in descending order. The probabilities do not
      need to sum to 1.
    p: A float tensor of shape [batch] that represents a probability threshold
      for each batch item.

    A tensor like `sorted_log_probs` where elements outside the top-p
    probability mass are set to `LARGE_NEGATIVE_NUMBER`.
    sorted_cum_probs = tf.math.cumsum(tf.exp(sorted_log_probs),
    mask = tf.less(sorted_cum_probs, tf.expand_dims(p, axis=1))
    # Set mask[:, 0] = True to always keep the first element.
    batch_size = tf.shape(mask)[0]
    true = tf.ones([batch_size, 1], dtype=tf.bool)
    mask = tf.concat([true, mask[:, 1:]], axis=1)
    filtered_sorted_log_probs = tf.where(
        mask, sorted_log_probs,
            tf.constant(LARGE_NEGATIVE_NUMBER, dtype=sorted_log_probs.dtype)))
    return filtered_sorted_log_probs
    def GetSequenceInfo(self, ids, enc_out):
        inp_ids = self._AddStartToken(ids)
        dummy_pred = self.decoder.ComputePredictions(
            self.theta.decoder, enc_out,
                tf.ones_like(inp_ids, dtype=tf.float32),
        # What's that? You thought 'softmax_input' in dummy_pred were the logits?
        # Don't be silly.
        # Let's pass what we have through the loss layer to really get the logits.
        # and don't forget this magic line!
        self.decoder.params.per_example_tensors = True
        _, per_example_tensors = self.GetDecoderLoss(self.theta, dummy_pred,

        mask = tf.transpose(1 - self._GetPaddings(ids))
        logits = per_example_tensors["logits"]

        log_p = tf.nn.log_softmax(logits)
        prob = tf.exp(log_p)
        entropy = -tf.reduce_sum(log_p * prob, axis=-1) * mask
        ave_entropy = tf.reduce_sum(entropy, axis=0) / tf.reduce_sum(mask,

        return logits, ave_entropy, dummy_pred.attention["probs"]
Пример #4
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent
    def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w):
        prev_sum_uvw = tf.stop_gradient((u + v + w) / eps)
        sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps)

        # Compute the relative changes on margins of P.
        # This will be used for stopping criteria.
        # Note the last update on w would guarantee the
        # margin constraint c is satisfied, so we don't
        # need to check it here.
        p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw))
        p_a = tf.reduce_sum(p, axis=-1, keepdims=True)
        p_b = tf.reduce_sum(p, axis=-2, keepdims=True)
        delta_a = tf.abs(a - p_a) / (a + 1e-6)
        delta_b = tf.abs(b - p_b) / (b + 1e-6)
        new_delta = tf.reduce_max(delta_a)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b))

        # Compute the relative changes on assignment solution P.
        # This will be used for stopping criteria.
        delta_p = tf.abs(tf.exp(prev_sum_uvw) -
                         tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p))
        return new_delta
Пример #6
                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                    consistent = tf.math.logical_and(states.consistent,

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent
Пример #7
 def Exp(x):
   return tf.exp(self.linear.Value(x))
Пример #8
    def ResidualsToBBoxes(self,
        r"""Converts anchor_boxes and residuals to predicted bboxes.

    This converts predicted residuals into bboxes using the following formulae::

      x_predicted = x_a + x_residual * diagonal_xy
      y_predicted = y_a + y_residual * diagonal_xy
      z_predicted = z_a + z_residual * dz_a

      dx_predicted = dx_a * exp(dx_residual)
      dy_predicted = dy_a * exp(dy_residual)
      dz_predicted = dz_a * exp(dz_residual)

      # Adding the residual, and bounding it between
      # [min_angle_rad, max_angle_rad]
      phi_predicted = NormalizeAngleRad(phi_a + phi_residual,
                                        min_angle_rad, max_angle_rad)

    These equations follow from those in LocalizationResiduals, where we solve
    for the \*_gt variables.

      anchor_bboxes: tf.float32. where [..., :7] contains (x, y, z, dx, dy, dz,
        phi), corresponding to each anchor bbox parameters.
      residuals: tf.float32 of the same shape as anchor_bboxes containing
        predicted residuals at each anchor location.
      min_angle_rad: Scalar with the minimum angle allowed (before wrapping)
        in radians.
      max_angle_rad: Scalar with the maximum angle allowed (before wrapping)
        in radians. This value usually should be pi.

      A tf.float32 tensor of the same shape as anchor_bboxes with predicted
        anchor_bboxes_shape = py_utils.GetShape(anchor_bboxes)
        anchor_bboxes = py_utils.with_dependencies(
            [py_utils.assert_equal(anchor_bboxes_shape[-1], 7)], anchor_bboxes)
        residuals = py_utils.HasShape(residuals, anchor_bboxes_shape)

        x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a = tf.unstack(anchor_bboxes,
        (x_residual, y_residual, z_residual, dx_residual, dy_residual,
         dz_residual, phi_residual) = tf.unstack(residuals, num=7, axis=-1)

        diagonal_xy = tf.sqrt(tf.square(dx_a) + tf.square(dy_a))

        x_predicted = x_a + x_residual * diagonal_xy
        y_predicted = y_a + y_residual * diagonal_xy
        z_predicted = z_a + z_residual * dz_a

        dx_predicted = dx_a * tf.exp(dx_residual)
        dy_predicted = dy_a * tf.exp(dy_residual)
        dz_predicted = dz_a * tf.exp(dz_residual)

        # We bound the angle between [min_angle_rad, max_angle_rad], which should
        # be passed in depending on the heading handling in the calling model.
        # If the model uses a sine(delta_phi) transformation in the loss, then it
        # cannot distinguish direction and a [0, np.pi]
        # [min_angle_rad, max_angle_rad] should be used.
        # If there is a heading encoding that is directional, most likely you
        # should use a [-np.pi, np.pi] [min_angle_rad, max_angle_rad].
        phi_predicted = phi_a + phi_residual
        phi_predicted = geometry.WrapAngleRad(phi_predicted, min_angle_rad,

        return tf.stack([
                        axis=-1)  # pyformat: disable
Пример #9
 def Value(self):
     return tf.exp(self.linear.Value())
Пример #10
def max_assignment(score: tf.Tensor,
                   elementwise_upper_bound: tf.Tensor,
                   row_sums: tf.Tensor,
                   col_sums: tf.Tensor,
                   epsilon: float = 0.1,
                   num_iterations: int = 50,
                   use_epsilon_scaling: bool = True):
    """Differentiable max assignment with margin and upper bound constraints.

    score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k]
      denotes the weight if the assignment on this entry is non-zero.
    elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows,
      n_columns]. Each entry denotes the maximum value assignment[i, j, k] can
      take and must be a non-negative value. For example, upper_bound[i, j,
      k]=1.0 for binary assignment problem.
    row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint.
      The output assignment p[i, j, :] must sum to row_sums[i, j].
    col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum
      constraint. The output assignment p[i, :, k] must sum to col_sums[i, k].
    epsilon: the epsilon coefficient of entropy regularization. The value should
      be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may
      not make the assignment close enough to 0 or 1.
    num_iterations: the maximum number of iterations to perform.
    use_epsilon_scaling: whether to use epsilon scaling. In practice, the
      convergence of the iterative algorithm is much better if we start by
      solving the optimization with a larger epsilon value and re-use the
      solution (i.e. dual variables) for the instance with a smaller epsilon.
      This is called the epsilon scaling trick. See [Schmitzer 2019]
      (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if
      use_epsilon_scaling=True, after each iteration we decrease the running
      epsilon by a constant factor until it reaches the target epsilon
      value. We found this to work well for gradient backward propagation,
      while the original scaling trick doesn't.

    A tuple with the following values.
      - assignment: a 3D tensor of size [batch_size, n_rows, n_columns].
        The output assignment.
      - used_iter: a scalar tensor indicating the number of iterations used.
      - eps: a scalar tensor indicating the stopping epsilon value.
      - delta: a scalar tensor indicating the stopping delta value (the relative
        change on the margins of assignment p in the last iteration).

    # Check if all shapes are correct
    score_shape = score.shape
    bsz = score_shape[0]
    n = score_shape[1]
    m = score_shape[2]
    score = tf.ensure_shape(score, [bsz, n, m])
    elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound,
                                              [bsz, n, m])
    row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1])
    col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m])

    # the total sum of row sums must be equal to total sum of column sums
    sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums,
    sum_diff = tf.abs(sum_diff)
    tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff])

    # Convert upper_bound constraint into another margin constraint
    # by adding auxiliary variables & scores. Tensor `a`, `b` and `c`
    # represent the margins (i.e. reduced sum) of 3 axes respectively.
    max_row_sums = tf.reduce_sum(elementwise_upper_bound,
    max_col_sums = tf.reduce_sum(elementwise_upper_bound,
    score_ = tf.stack([score, tf.zeros_like(score)], axis=1)  # (bsz, 2, n, m)
    a = tf.stack([row_sums, max_row_sums - row_sums], axis=1)  # (bsz, 2, n, 1)
    b = tf.stack([col_sums, max_col_sums - col_sums], axis=1)  # (bsz, 2, 1, m)
    c = tf.expand_dims(elementwise_upper_bound, axis=1)  # (bsz, 1, n, m)

    # Clip log(0) to a large negative values -1e+36 to avoid
    # getting inf or NaN values in computation. Cannot use larger
    # values because float32 would use `-inf` automatically.
    tf.Assert(tf.reduce_all(a >= 0), [a])
    tf.Assert(tf.reduce_all(b >= 0), [b])
    tf.Assert(tf.reduce_all(c >= 0), [c])
    log_a = tf.maximum(tf.math.log(a), -1e+36)
    log_b = tf.maximum(tf.math.log(b), -1e+36)
    log_c = tf.maximum(tf.math.log(c), -1e+36)

    # Initialize the dual variables of margin constraints
    u = tf.zeros_like(a)
    v = tf.zeros_like(b)
    w = tf.zeros_like(c)

    eps = tf.constant(1.0 if use_epsilon_scaling else epsilon,
    epsilon = tf.constant(epsilon, dtype=score.dtype)

    def do_updates(cur_iter, eps, u, v, w):  # pylint: disable=unused-argument
        # Epsilon scaling, i.e. gradually decreasing `eps` until it
        # reaches the target `epsilon` value
        cur_iter = tf.cast(cur_iter, u.dtype)
        scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85)
        eps = tf.maximum(epsilon, eps * scaling)
        score_div_eps = score_ / eps

        # Update u
        log_q_1 = score_div_eps + (w + v) / eps
        log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True)
        new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps

        # Update v
        log_q_2 = score_div_eps + (w + new_u) / eps
        log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True)
        new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps

        # Update w
        log_q_3 = score_div_eps + (new_u + new_v) / eps
        log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True)
        new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps
        return eps, new_u, new_v, new_w

    def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w):
        prev_sum_uvw = tf.stop_gradient((u + v + w) / eps)
        sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps)

        # Compute the relative changes on margins of P.
        # This will be used for stopping criteria.
        # Note the last update on w would guarantee the
        # margin constraint c is satisfied, so we don't
        # need to check it here.
        p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw))
        p_a = tf.reduce_sum(p, axis=-1, keepdims=True)
        p_b = tf.reduce_sum(p, axis=-2, keepdims=True)
        delta_a = tf.abs(a - p_a) / (a + 1e-6)
        delta_b = tf.abs(b - p_b) / (b + 1e-6)
        new_delta = tf.reduce_max(delta_a)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b))

        # Compute the relative changes on assignment solution P.
        # This will be used for stopping criteria.
        delta_p = tf.abs(tf.exp(prev_sum_uvw) -
                         tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p))
        return new_delta

    for cur_iter in tf.range(num_iterations):
        prev_eps, prev_u, prev_v, prev_w = eps, u, v, w
        eps, u, v, w = do_updates(cur_iter, eps, u, v, w)
    delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u,
                                     v, w)
    cur_iter = num_iterations
    assignment = tf.exp((score_ + u + v + w) / eps)
    assignment = assignment[:, 0]
    return assignment, cur_iter, eps, delta
Пример #11
 def Value(self, step=None):
   return tf.exp(self.linear.Value(step))
Пример #12
  def ResidualsToBBoxes(self, anchor_bboxes, residuals):
    r"""Converts anchor_boxes and residuals to predicted bboxes.

    This converts predicted residuals into bboxes using the following formulae:

      x_predicted = x_a + x_residual \* diagonal_xy
      y_predicted = y_a + y_residual \* diagonal_xy
      z_predicted = z_a + z_residual \* dz_a

      dx_predicted = dx_a \* exp(dx_residual)
      dy_predicted = dy_a \* exp(dy_residual)
      dz_predicted = dz_a \* exp(dz_residual)

      phi_predicted = phi_a + phi_residual

    These equations follow from those in LocalizationResiduals, where we solve
    for the \*_gt variables.

      anchor_bboxes: tf.float32. where [..., :7] contains (x, y, z, dx, dy, dz,
        phi), corresponding to each anchor bbox parameters.
      residuals: tf.float32 of the same shape as anchor_bboxes containing
        predicted residuals at each anchor location.

      A tf.float32 tensor of the same shape as anchor_bboxes with predicted
    anchor_bboxes_shape = py_utils.GetShape(anchor_bboxes)
    anchor_bboxes = py_utils.with_dependencies(
        [py_utils.assert_equal(anchor_bboxes_shape[-1], 7)], anchor_bboxes)
    residuals = py_utils.HasShape(residuals, anchor_bboxes_shape)

    x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a = tf.unstack(
        anchor_bboxes, num=7, axis=-1)
    (x_residual, y_residual, z_residual, dx_residual, dy_residual, dz_residual,
     phi_residual) = tf.unstack(
         residuals, num=7, axis=-1)

    diagonal_xy = tf.sqrt(tf.square(dx_a) + tf.square(dy_a))

    x_predicted = x_a + x_residual * diagonal_xy
    y_predicted = y_a + y_residual * diagonal_xy
    z_predicted = z_a + z_residual * dz_a

    dx_predicted = dx_a * tf.exp(dx_residual)
    dy_predicted = dy_a * tf.exp(dy_residual)
    dz_predicted = dz_a * tf.exp(dz_residual)

    # Assuming a sine(delta_phi) transformation is used in the loss, then, it
    # is not possible to distinguish direction, hence, we use floormod here to
    # ensure that the predicted_phi is always in [0, np.pi) for consistency.
    # A separate direction classifier should be added the model if needed.
    phi_predicted = phi_a + phi_residual
    phi_predicted = tf.floormod(phi_predicted, np.pi)

    return tf.stack([
        x_predicted, y_predicted, z_predicted,
        dx_predicted, dy_predicted, dz_predicted,
    ], axis=-1)  # pyformat: disable
Пример #13
        def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                      num_hyps_per_beam, *args, **kwargs):
            """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
            The updated attention probs, of shape [tgt_batch, src_len].
            Log prob for each of the tokens in the target vocab. This is of
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              so far.
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)

            labels = encoder_outputs.targets.labels
            weights = encoder_outputs.targets.weights

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            # Consistent if step_ids == labels from previous step
            # TODO(navari): Consider updating consistent only if weights > 0. Then
            # re-evaluate the need for bias_only_if_consistent=True.
            # Note that prev_label is incorrrect for step 0 but is overridden later
            prev_label = TileForBeamAndFlatten(
                tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
            is_step0 = tf.equal(time_step, 0)
            local_consistence = tf.logical_or(
                is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
            out_states.consistent = tf.logical_and(states.consistent,

            # get label, weight slices corresponding to current time_step
            label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
            weight = TileForBeamAndFlatten(
                tf.gather(weights, time_step, axis=1))
            if p.bias_only_if_consistent:
                weight = weight * tf.cast(out_states.consistent, p.dtype)

            # convert from dense label to sparse label probs
            vocab_size = tf.shape(bs_results.log_probs)[1]
            uncertainty = tf.constant(
                p.dtype)  # avoid 0 probs which may cause issues with log
            label_probs = tf.one_hot(label,
                                     on_value=1 - uncertainty,
                                     off_value=uncertainty /
                                     tf.cast(vocab_size - 1, p.dtype),
                                     dtype=p.dtype)  # [tgt_batch, vocab_size]
            pred_probs = tf.exp(bs_results.log_probs)

            # interpolate predicted probs and label probs
            weight = tf.expand_dims(weight, 1)
            probs = py_utils.with_dependencies([
                py_utils.assert_less_equal(weight, 1.),
                py_utils.assert_greater_equal(weight, 0.)
            ], (1.0 - weight) * pred_probs + weight * label_probs)

            bs_results.log_probs = tf.log(probs)

            return bs_results, out_states