Exemplo n.º 1
0
    def Apply(self, lr, var_grad):
        p = self.params

        def _Acc(vg):
            """Updating accumulators."""

            v, g = vg
            with tf.variable_scope(v.op.name):
                _, a = py_utils.CreateVariable(
                    'grad_accumulator',
                    py_utils.WeightParams(v.get_shape(),
                                          py_utils.WeightInit.Constant(0.0),
                                          self.params.dtype),
                    trainable=False)
                a = tf.assign_add(a, g)

            return py_utils.VarGrad(v, a)

        var_grad = var_grad.Transform(_Acc)

        def _ApplyAndReset():
            with tf.control_dependencies([
                    self._opt.Apply(
                        lr,
                        py_utils.ApplyGradMultiplier(var_grad,
                                                     1. / p.accum_steps))
            ]):
                return tf.group(*[
                    tf.assign(a, tf.zeros_like(a))
                    for _, a in var_grad.Flatten()
                ])

        return tf.cond(
            tf.equal(tf.mod(self.theta.global_step, p.accum_steps),
                     p.accum_steps - 1), _ApplyAndReset,
            lambda: tf.group(tf.no_op()))
Exemplo n.º 2
0
    def _GetBucketKey(self, features, filtered):
        """Returns a the bucket key for a given input."""
        # The token ids are not truncated if and only if it ends with padding
        # or the last id is EOS.
        src_fits = tf.math.logical_or(
            tf.math.equal(features.src.ids_indicator[-1], 0),
            tf.math.equal(features.src.ids[-1], self._src_tokenizer.eos_id))
        tgt_fits = tf.math.logical_or(
            tf.math.equal(features.tgt.ids_indicator[-1], 0),
            tf.math.equal(features.tgt.labels[-1], self._tgt_tokenizer.eos_id))

        # We return the max of sourcec or target sequence length if and only if both
        # src and tgt fit. Otherwise we return a key of -1 to filter out this input.
        def _MaxLen():
            src_len = tf.cast(tf.math.reduce_sum(features.src.ids_indicator),
                              dtype=tf.int32)
            tgt_len = tf.cast(tf.math.reduce_sum(features.tgt.ids_indicator),
                              dtype=tf.int32)
            return tf.math.maximum(src_len, tgt_len)

        filtered = tf.math.logical_or(
            filtered,
            tf.math.logical_not(tf.math.logical_and(src_fits, tgt_fits)))
        return tf.cond(filtered, lambda: -1, _MaxLen)
Exemplo n.º 3
0
    def conditional_mask_update_op(self):
        def maybe_update_masks():
            with tf.name_scope(self._spec.name):
                is_step_within_pruning_range = tf.logical_and(
                    tf.greater_equal(self._global_step,
                                     self._spec.begin_pruning_step),
                    # If end_pruning_step is negative, keep pruning forever!
                    tf.logical_or(
                        tf.less_equal(self._global_step,
                                      self._spec.end_pruning_step),
                        tf.less(self._spec.end_pruning_step, 0)))
                is_pruning_step = tf.less_equal(
                    tf.add(self._last_update_step,
                           self._spec.pruning_frequency), self._global_step)
                return tf.logical_and(is_step_within_pruning_range,
                                      is_pruning_step)

        def mask_update_op():
            return self.mask_update_op()

        def no_update_op():
            return tf.no_op()

        return tf.cond(maybe_update_masks(), mask_update_op, no_update_op)
Exemplo n.º 4
0
    def Processor(source_id, record):
      """Parses a record, which is a line of text."""

      task_id = self._GetTaskIds(source_id)

      if self.params.input_file_type == 'tsv':

        def _ApplyMass(task_id):
          mass_task_ids = tf.constant(self.params.mass_task_ids, dtype=tf.int32)
          return tf.reduce_any(tf.equal(task_id, mass_task_ids))

        def _MASSInput():
          src, filtered = self._ReadRecordTsvSingleColumn(record)
          return self._ProcessMASSInput(source_id, src), filtered

        def _SingleInput():
          src, tgt, filtered = self._ReadRecordTsv(record)
          return self._ProcessSingleInput(source_id, src, tgt), filtered

        if self.params.single_column_input:
          # For monolingual input, MASS is applied by default.
          # If mass_task_ids is specified, only apply MASS to specified tasks.
          if self.params.mass_task_ids is not None:
            cond = _ApplyMass(task_id)
            features, filtered = tf.cond(cond, _MASSInput, _SingleInput)
          else:
            features, filtered = _MASSInput()
        else:
          features, filtered = _SingleInput()

      else:
        src, tgt = self._ReadRecordSentencePairProto(record)
        filtered = tf.constant(False, dtype=tf.bool)
        features = self._ProcessSingleInput(source_id, src, tgt)

      return features, self._GetBucketKey(features, filtered)
Exemplo n.º 5
0
 def bucket_fn(num):
     # Drops record if num[0] is odd.
     return tf.cond(tf.equal(tf.math.floormod(num[0], 2), 0), lambda: 1,
                    lambda: -tf.cast(num[0], tf.int32))
Exemplo n.º 6
0
        def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                      num_hyps_per_beam, *args, **kwargs):
            """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

      Args:
        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

      Returns:
        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          are:
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)
            labels = encoder_outputs.targets.labels
            weights = encoder_outputs.targets.weights

            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent

            def NoApplyBias():
                """No-op. Return original log_probs and consistent."""
                return bs_results.log_probs, states.consistent

            log_probs, consistent = tf.cond(
                tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias, ApplyBias)
            bs_results.log_probs = log_probs
            out_states.consistent = consistent

            return bs_results, out_states
Exemplo n.º 7
0
 def _Real():
     return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                    _GetFurthestPoint)
Exemplo n.º 8
0
 def _Seeded():
     return tf.cond(tf.less(curr_idx, num_seeded_points),
                    _GetSeededPoint, _GetFurthestPoint)
Exemplo n.º 9
0
    def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
        """Loop body for farthest point sampler."""
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)

        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)

        def _GetSeededPoint():
            """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
            return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx

        # Select indices for this loop iteration.
        def _Seeded():
            return tf.cond(tf.less(curr_idx, num_seeded_points),
                           _GetSeededPoint, _GetFurthestPoint)

        def _Real():
            return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                           _GetFurthestPoint)

        new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded,
                               _Real)
        sampled_idx = sampled_idx.write(curr_idx, new_selected)

        # Extract the distance to the latest point selected to update
        # distance_to_selected.
        new_selected_gather_idx = tf.stack(
            [tf.range(batch_size), new_selected], axis=1)
        if precomputed_squared_distance is not None:
            new_distance = tf.gather_nd(precomputed_squared_distance,
                                        new_selected_gather_idx)
        else:
            new_points = tf.reshape(
                tf.gather_nd(points, new_selected_gather_idx),
                [batch_size, 1, dims])
            new_distance = tf.reshape(
                SquaredDistanceMatrix(points, new_points),
                [batch_size, num_points])

        is_newly_closest = tf.less(new_distance, distance_to_selected)
        distance_to_selected = tf.minimum(distance_to_selected, new_distance)

        # Track the index to the closest selected point.
        new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
        closest_idx = tf.cond(
            tf.equal(curr_idx, 0),
            # At the first loop iteration, the init points are the closest.
            lambda: new_selected_tiled,
            # Otherwise, update with the new points based on the distances.
            lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)
        )
        return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
Exemplo n.º 10
0
  def ProcessFeatures(self, features):
    """Process extracted features.

    Args:
      features: A dict of extracted Tensors from the records.

    Returns:
      A tuple of tensors:

      - bucket_id: A scalar int Tensor.
      - extracted: a NestedMap of Tensors extracted.
    """
    def ExtractAndFilter(e):
      with tf.name_scope(e.params.name):
        with tf.name_scope('extract'):
          # Filter out extracted features from other extractors.
          filtered_features = {}
          if self.params.record_type == 'TEXT':
            # Text extractors only produce {'line': record} and their
            # FeatureMap() is empty, so don't do any filtering.
            filtered_features = features
          else:
            filtered_keys = e.FeatureMap().keys() | e.ContextMap().keys()
            filtered_features = {
                k: v for k, v in features.items() if k in filtered_keys
            }
          try:
            if self.params.batched_input:
              extracted = e.ExtractBatch(filtered_features)
            else:
              extracted = e.Extract(filtered_features)
          except Exception as exc:  # pylint:disable=bare-except
            # Raise exception with context about which extractor failed.
            raise RuntimeError('Failed running extractor '
                               f'{e.params.name}. '
                               'See above exception for details.') from exc
        with tf.name_scope('filter'):
          if self.params.batched_input:
            bucket = e.FilterBatch(extracted)
          else:
            bucket = e.Filter(extracted)
      return bucket, extracted

    bucket_extracted = self._extractors.Transform(ExtractAndFilter)
    buckets = bucket_extracted.Transform(lambda x: x[0])
    extracted = bucket_extracted.Transform(lambda x: x[1])

    # Return the maximum bucket id so that any extractor can decide whether
    # to filter the entire example.
    max_bucket = tf.reduce_max(buckets.Flatten())

    def NullLike():
      """A function to return the same Tensor signature as Preprocess.

      This is necessary for the tf.cond() to avoid executing the preprocessor
      for examples that are going to be dropped because it exceeds the bucket
      limit; tf.cond() requires that the output of both branches yields the same
      structure.

      Returns:
        A structure with the same Tensor dtype as the output of
        Preprocess.
      """
      shapes = self.Shape()
      rets = []
      for dtype, shape in zip(self.DType().Flatten(), shapes.Flatten()):
        if shape.is_fully_defined():
          rets += [tf.zeros(dtype=dtype, shape=shape)]
        else:
          rets += [tf.zeros(dtype=dtype, shape=[])]  # Our best guess.
      return shapes.Pack(rets)

    def Preprocess(extracted):
      for key, preprocessor in zip(self.params.preprocessors_order,
                                   self.preprocessors):
        with tf.name_scope(key), tf.name_scope(preprocessor.params.name):
          if self.params.batched_input:
            extracted = preprocessor.TransformBatchedFeatures(extracted)
          else:
            extracted = preprocessor.TransformFeatures(extracted)
      return extracted

    # If the extractor wants to filter the example, don't run the preprocessor.
    #
    # Preprocessors can then assume that only examples that pass filtering will
    # be executed.
    #
    # Note that the NullLike branch may return tensors with shapes different
    # from self.Shape().
    final_output = tf.cond(
        tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted),
        NullLike)

    return max_bucket, final_output
Exemplo n.º 11
0
    def ExtractUsingExtractors(self, record):
        """Extracts Tensors from a tf.Example record using self.extractors.

    Args:
      record: A tf.Example input to pass to tf.parse_single_example.

    Returns:
      A tuple of tensors:

      - bucket_id: A scalar int Tensor.
      - extracted: a NestedMap of Tensors extracted.
    """
        feature_map = {}
        context_map = {}
        self._extractors.Transform(
            lambda e: feature_map.update(e.FeatureMap()))
        if self.params.record_type == 'SEQUENCE_EXAMPLE':
            self._extractors.Transform(
                lambda e: context_map.update(e.ContextMap()))

        if self.params.record_type not in _PARSING_FUNCTIONS:
            raise ValueError('Invalid record_type: {}'.format(
                self.params.record_type))
        parsing_fn = _PARSING_FUNCTIONS[self.params.record_type]
        if self.params.record_type == 'SEQUENCE_EXAMPLE':
            features = parsing_fn(record, feature_map, context_map)
        else:
            features = parsing_fn(record, feature_map)

        def ExtractAndFilter(e):
            with tf.name_scope(e.params.name):
                with tf.name_scope('extract'):
                    extracted = e.Extract(features)
                with tf.name_scope('filter'):
                    bucket = e.Filter(extracted)
            return bucket, extracted

        bucket_extracted = self._extractors.Transform(ExtractAndFilter)
        buckets = bucket_extracted.Transform(lambda x: x[0])
        extracted = bucket_extracted.Transform(lambda x: x[1])

        # Return the maximum bucket id so that any extractor can decide whether
        # to filter the entire example.
        max_bucket = tf.reduce_max(buckets.Flatten())

        def NullLike():
            """A function to return the same Tensor signature as Preprocess.

      This is necessary for the tf.cond() to avoid executing the preprocessor
      for examples that are going to be dropped because it exceeds the bucket
      limit; tf.cond() requires that the output of both branches yields the same
      structure.

      Returns:
        A structure with the same Tensor dtype and shape as the output of
        Preprocess.
      """
            shapes = self.Shape()
            rets = [
                tf.zeros(dtype=dtype, shape=shape)
                for (dtype,
                     shape) in zip(self.DType().Flatten(), shapes.Flatten())
            ]
            return shapes.Pack(rets)

        def Preprocess(extracted):
            for key, preprocessor in zip(self.params.preprocessors_order,
                                         self.preprocessors):
                with tf.name_scope(key), tf.name_scope(
                        preprocessor.params.name):
                    extracted = preprocessor.TransformFeatures(extracted)
            return extracted

        # If the extractor wants to filter the example, don't run the preprocessor.
        #
        # Preprocessors can then assume that only examples that pass filtering will
        # be executed.
        final_output = tf.cond(tf.less(max_bucket, BUCKET_UPPER_BOUND),
                               lambda: Preprocess(extracted), NullLike)

        return max_bucket, final_output
Exemplo n.º 12
0
    def _InputBatch(self):
        np.random.seed(1)
        bs, sl = 10, 7
        src_ids = tf.constant(
            np.random.randint(low=0,
                              high=8192 - 1,
                              size=[bs, sl],
                              dtype=np.int32))
        tgt_ids = tf.constant(
            np.random.randint(low=0,
                              high=8192 - 1,
                              size=[bs, sl],
                              dtype=np.int32))
        tgt_labels = tf.constant(
            np.random.randint(low=0,
                              high=8192 - 1,
                              size=[bs, sl],
                              dtype=np.int32))
        tgt_weights = tf.constant(np.ones(shape=[bs, sl], dtype=np.float32))

        src_paddings = tf.zeros([bs, sl])
        tgt_paddings = tf.zeros([bs, sl])

        ret = py_utils.NestedMap()
        ret.src = py_utils.NestedMap()
        ret.tgt = py_utils.NestedMap()

        if self.params.split:
            src_ids = tf.split(src_ids, 2, 0)
            src_paddings = tf.split(src_paddings, 2, 0)
            tgt_ids = tf.split(tgt_ids, 2, 0)
            tgt_labels = tf.split(tgt_labels, 2, 0)
            tgt_paddings = tf.split(tgt_paddings, 2, 0)
            tgt_weights = tf.split(tgt_weights, 2, 0)

            ret.src.ids = tf.cond(
                tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0),
                lambda: src_ids[0], lambda: src_ids[1])
            ret.src.paddings = tf.cond(
                tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0),
                lambda: src_paddings[0], lambda: src_paddings[1])
            ret.tgt.ids = tf.cond(
                tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0),
                lambda: tgt_ids[0], lambda: tgt_ids[1])
            ret.tgt.labels = tf.cond(
                tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0),
                lambda: tgt_labels[0], lambda: tgt_labels[1])
            ret.tgt.paddings = tf.cond(
                tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0),
                lambda: tgt_paddings[0], lambda: tgt_paddings[1])
            ret.tgt.weights = tf.cond(
                tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0),
                lambda: tgt_weights[0], lambda: tgt_weights[1])
        else:
            ret.src.ids = src_ids
            ret.src.paddings = src_paddings
            ret.tgt.ids = tgt_ids
            ret.tgt.labels = tgt_labels
            ret.tgt.paddings = tgt_paddings
            ret.tgt.weights = tgt_weights

        return ret
Exemplo n.º 13
0
 def _Wrap(fn, x, y):
     if not self._cond_is_finite:
         return fn(x, y)
     return tf.cond(cond, lambda: fn(x, y), lambda: x)
Exemplo n.º 14
0
  def ProcessFeatures(self, features):
    """Process extracted features.

    Args:
      features: A dict of extracted Tensors from the records.

    Returns:
      A tuple of tensors:

      - bucket_id: A scalar int Tensor.
      - extracted: a NestedMap of Tensors extracted.
    """
    def ExtractAndFilter(e):
      with tf.name_scope(e.params.name):
        with tf.name_scope('extract'):
          extracted = e.Extract(features)
        with tf.name_scope('filter'):
          bucket = e.Filter(extracted)
      return bucket, extracted

    bucket_extracted = self._extractors.Transform(ExtractAndFilter)
    buckets = bucket_extracted.Transform(lambda x: x[0])
    extracted = bucket_extracted.Transform(lambda x: x[1])

    # Return the maximum bucket id so that any extractor can decide whether
    # to filter the entire example.
    max_bucket = tf.reduce_max(buckets.Flatten())

    def NullLike():
      """A function to return the same Tensor signature as Preprocess.

      This is necessary for the tf.cond() to avoid executing the preprocessor
      for examples that are going to be dropped because it exceeds the bucket
      limit; tf.cond() requires that the output of both branches yields the same
      structure.

      Returns:
        A structure with the same Tensor dtype as the output of
        Preprocess.
      """
      shapes = self.Shape()
      rets = []
      for dtype, shape in zip(self.DType().Flatten(), shapes.Flatten()):
        if shape.is_fully_defined():
          rets += [tf.zeros(dtype=dtype, shape=shape)]
        else:
          rets += [tf.zeros(dtype=dtype, shape=[])]  # Our best guess.
      return shapes.Pack(rets)

    def Preprocess(extracted):
      for key, preprocessor in zip(self.params.preprocessors_order,
                                   self.preprocessors):
        with tf.name_scope(key), tf.name_scope(preprocessor.params.name):
          extracted = preprocessor.TransformFeatures(extracted)
      return extracted

    # If the extractor wants to filter the example, don't run the preprocessor.
    #
    # Preprocessors can then assume that only examples that pass filtering will
    # be executed.
    #
    # Note that the NullLike branch may return tensors with shapes different
    # from self.Shape().
    final_output = tf.cond(
        tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted),
        NullLike)

    return max_bucket, final_output
Exemplo n.º 15
0
        def Callback(theta, encoder_outputs, step_ids, states,
                     num_hyps_per_beam, *args, **kwargs):
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    tensor,
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            if biased:
                labels = encoder_outputs.targets.labels
                weights = encoder_outputs.targets.weights

                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                                  axis=1))
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                                                                  1)))
                    consistent = tf.math.logical_and(states.consistent,
                                                     local_consistence)

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,
                                                  py_utils.FPropDtype(p))

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                             vocab_size,
                                             dtype=py_utils.FPropDtype(
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent

                def NoApplyBias():
                    """No-op. Return original log_probs and consistent."""
                    return bs_results.log_probs, states.consistent

                log_probs, consistent = tf.cond(
                    tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias,
                    ApplyBias)
                bs_results.log_probs = log_probs
                out_states.consistent = consistent

            if stochastic:
                log_probs = bs_results.log_probs

                def PerturbedLogProbs():
                    # STEP 1: Perform top-k filtering. This is done as a performance
                    # optimization of avoiding sorting the entire `log_probs`, which is
                    # prohibitively slow.
                    top_k = tf.math.top_k(log_probs, k, sorted=True)
                    # shape: [tgt_batch, k]
                    top_k_log_probs = top_k.values
                    # shape: [tgt_batch, k]
                    top_k_ids = top_k.indices

                    # STEP 2: Perform top-p filtering.
                    # shape: [tgt_batch]
                    top_p_threshold = encoder_outputs.stochastic_beam_search.top_p_threshold
                    top_p_threshold = tf.clip_by_value(top_p_threshold, 0., 1.)
                    top_p_threshold = TileForBeamAndFlatten(top_p_threshold)
                    # shape: [tgt_batch, k]
                    filtered_top_k_log_probs = _KeepTopP(
                        top_k_log_probs, top_p_threshold)

                    # STEP 3: Perturb cumulative log-probs.
                    # shape: [tgt_batch, 1]
                    last_cumulative_log_probs = states.cumulative_log_probs
                    # shape: [tgt_batch, 1]
                    last_perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs
                    # Compute cumulative log-probs of the current step.
                    # shape: [tgt_batch, k]
                    cumulative_log_probs = (last_cumulative_log_probs +
                                            filtered_top_k_log_probs)
                    # Perturb cumulative log-probs by Gumbel noises under the condition
                    # that the max of the new perturbed log-probs is equal to
                    # perturbed_cumulative_log_probs of the previous step.
                    # shape: [tgt_batch, k]
                    new_perturbed_cumulative_log_probs = _SampleGumbelWithMax(
                        cumulative_log_probs,
                        last_perturbed_cumulative_log_probs,
                        encoder_outputs.stochastic_beam_search.seed, time_step,
                        encoder_outputs.stochastic_beam_search.src_ids,
                        encoder_outputs.stochastic_beam_search.src_paddings)

                    # STEP 4: Compute updated log_probs. This step is necessary because
                    # the output of PreBeamSearchStepCallback must be "per-step"
                    # log-probs, whereas so far "cumulative" log-probs have been computed.
                    # shape: [tgt_batch, k]
                    updated_top_k_log_probs = (
                        new_perturbed_cumulative_log_probs -
                        last_perturbed_cumulative_log_probs)
                    # Convert to the shape [tgt_batch, vocab_size].
                    updated_log_probs = tf.fill(
                        tf.shape(log_probs),
                        tf.constant(LARGE_NEGATIVE_NUMBER,
                                    dtype=log_probs.dtype))
                    updated_log_probs = _BatchScatter(updated_log_probs,
                                                      top_k_ids,
                                                      updated_top_k_log_probs)

                    return (updated_log_probs,
                            py_utils.NestedMap(
                                new_perturbed_cumulative_log_probs=
                                new_perturbed_cumulative_log_probs,
                                top_k_log_probs=top_k_log_probs,
                                top_k_ids=top_k_ids,
                            ))

                (bs_results.log_probs, out_states.tmp_states) = tf.cond(
                    encoder_outputs.stochastic_beam_search.enable,
                    PerturbedLogProbs,
                    # No-op.
                    lambda: (bs_results.log_probs, states.tmp_states))
                # These states are not updated here but will be updated in
                # PostBeamSearchStepCallback since doing so requires the knowledge of
                # the next step IDs.
                out_states.cumulative_log_probs = states.cumulative_log_probs
                out_states.perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs

            return bs_results, out_states
 def bucket_fn(num):
     # Drops record if num[0] is odd.
     return tf.cond(tf.equal(tf.mod(num[0], 2), 0), lambda: 1,
                    lambda: -tf.to_int32(num[0]))