def Filter(self, outputs):
        """Optionally filters the data based on context info."""
        p = self.params
        if p.equality_filters is None:
            return 1

        allowed_example = tf.convert_to_tensor(True)
        for filter_key, filter_values in p.equality_filters:
            if filter_key not in outputs:
                raise ValueError(
                    'Filter key `{}` not found in extracted data.'.format(
                        filter_key))
            has_allowed_data = tf.reduce_any(
                tf.equal(outputs[filter_key], filter_values))
            allowed_example = tf.logical_and(allowed_example, has_allowed_data)

        not_allowed_example = 1 - tf.cast(allowed_example, tf.int32)
        return 1 + (not_allowed_example * input_extractor.BUCKET_UPPER_BOUND)
    def _check_paddings(self, paddings):
        with tf.name_scope('check_paddings'):
            unpacked_paddings = tf.unstack(paddings)

            non_decr = []
            for t in unpacked_paddings:
                non_d = tf.is_non_decreasing(t)
                non_decr.append(non_d)
            all_non_decr = tf.stack(non_decr)

            paddings = py_utils.with_dependencies([
                tf.assert_equal(tf.reduce_any(tf.equal(paddings, 0.0)),
                                True,
                                message='must have at least one zero value.'),
                tf.assert_equal(
                    all_non_decr, True, message='must be non-decreasing')
            ], paddings)
            return paddings
Beispiel #3
0
  def __call__(self, inputs: ExamplePairs) -> tf.Tensor:
    drop_masks = []
    for feature_name in self._drop_on_match:
      query_ids = inputs.query_examples.get(feature_name)
      result_ids = inputs.result_examples.get(feature_name)
      if query_ids is None:
        raise ValueError('No feature {} in query batch'.format(feature_name))
      if result_ids is None:
        raise ValueError('No feature {} in result batch'.format(feature_name))
      query_ids.shape.assert_is_compatible_with([inputs.query_batch_size])
      result_ids.shape.assert_is_compatible_with([inputs.result_batch_size])
      drop_masks.append(tf.equal(query_ids[:, None], result_ids[None, :]))

    labels = tf.cast(inputs.correspondences, dtype=tf.int32)
    if drop_masks:
      any_drop_condition_met = tf.reduce_any(drop_masks, axis=0)
      labels = _IgnorePairsWhere(
          ~inputs.correspondences & any_drop_condition_met, labels)
    return labels
Beispiel #4
0
    def IsSpecialExample(task_ids, special_task_ids):
      """A utility function indicates whether inputs belong to specific tasks.

      Args:
        task_ids: Task ids for the input batch. Tensor of shape [batch].
        special_task_ids: A list of specified task ids.

      Returns:
        A tensor indicating whether each sample in the batch belong to the
        specified task. Return a tensor of size [batch].
      """
      batch_size = py_utils.GetShape(task_ids)[0]
      return tf.reduce_any(
          tf.equal(
              tf.expand_dims(task_ids, -1),
              tf.cast(
                  tf.broadcast_to(
                      special_task_ids,
                      [batch_size, len(special_task_ids)]), tf.int32)), -1)
Beispiel #5
0
 def _ShouldMerge(unused_tokens, candidates):
     """Merge until not possible, or we abort early according to merge_prob."""
     return tf.logical_and(
         tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)),
         tf.random.uniform([]) < self._merge_prob)
Beispiel #6
0
    def AssignAnchors(self,
                      anchor_bboxes,
                      gt_bboxes,
                      gt_bboxes_labels,
                      gt_bboxes_mask,
                      foreground_assignment_threshold=0.5,
                      background_assignment_threshold=0.35,
                      background_class_id=0,
                      force_match=True,
                      similarity_fn=None):
        """Assigns anchors to bboxes using a similarity function (SSD-based).

    Each anchor box is assigned to the top matching ground truth box.
    Ground truth boxes can be assigned to multiple anchor boxes.

    Assignments can result in 3 outcomes:

      - Positive assignment (if score >= foreground_assignment_threshold):
        assigned_gt_labels will reflect the assigned box label and
        assigned_cls_mask will be set to 1.0
      - Background assignment (if score <= background_assignment_threshold):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 1.0
      - Ignore assignment (otherwise):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 0.0

    The detection loss function would usually:

      - Use assigned_cls_mask for weighting the classification loss. The mask
        is set such that the loss applies to foreground and background
        assignments only - ignored anchors will be set to 0.
      - Use assigned_reg_mask for weighting the regression loss. The mask is set
        such that the loss applies to foreground assignments only.

    The thresholds (foreground_assignment_threshold and
    background_assignment_threshold) should be tuned per dataset.

    TODO(jngiam): Consider having a separate threshold for regression boxes; a
    separate threshold is used in PointRCNN.

    Args:
      anchor_bboxes: tf.float32. [A, 7], where [..., :] corresponds to box
        parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes: tf.float32. [G, 7], where [..., :] corresponds to ground truth
        box parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes_labels: tensor with shape [G]. Ground truth labels for each
        bounding box.
      gt_bboxes_mask: tensor with shape [G]. Mask for ground truth boxes, 1 iff
        the gt_bbox is a real bbox.
      foreground_assignment_threshold: Similarity score threshold for assigning
        foreground bounding boxes; scores need to be >=
        foreground_assignment_threshold to be assigned to foreground.
      background_assignment_threshold: Similarity score threshold for assigning
        background bounding boxes; scores need to be <=
        background_assignment_threshold to be assigned to background.
      background_class_id: class id to be assigned to anchors_gt_class if no
        anchor boxes match.
      force_match: Boolean specifying if force matching is enabled. If
        force matching is enabled, then matched anchors which are also the
        highest scoring with a ground-truth box are considered foreground
        matches as long as their similarity score > 0.
      similarity_fn: Function that computes the a similarity score (e.g., IOU)
        between pairs of bounding boxes. This function should take in two
        tensors corresponding to anchor and ground-truth bboxes, and return a
        matrix [A, G] with the similarity score between each pair of bboxes. The
        score must be non-negative, with greater scores representing more
        similar. The fore/background_assignment_thresholds will be applied to
        this score to determine if the an anchor is foreground, background or
        ignored. If set to None, the function will default to IOU2DRotatedBoxes.

    Returns:
      NestedMap with the following keys

      - assigned_gt_idx: shape [A] index corresponding to the index of the
        assigned ground truth box. Anchors not assigned to a ground truth box
        will have the index set to -1.
      - assigned_gt_bbox: shape [A, 7] bbox parameters assigned to each anchor.
      - assigned_gt_similarity_score: shape [A] (iou) score between the anchor
        and the gt bbox.
      - assigned_gt_labels: shape [A] label assigned to bbox.
      - assigned_cls_mask: shape [A] mask for classification loss per anchor.
        This should be 1.0 if the anchor has a foreground or background
        assignment; otherwise, it will be assigned to 0.0.
      - assigned_reg_mask: shape [A] mask for regression loss per anchor.
        This should be 1.0 if the anchor has a foreground assignment;
        otherwise, it will be assigned to 0.0.
        Note: background anchors do not have regression targets.
    """
        if similarity_fn is None:
            similarity_fn = self.IOU2DRotatedBoxes

        # Shape validation.
        anchor_bboxes = py_utils.HasShape(anchor_bboxes, [-1, 7])
        num_anchor_bboxes, _ = py_utils.GetShape(anchor_bboxes, 2)
        gt_bboxes = py_utils.HasShape(gt_bboxes, [-1, 7])
        num_gt_bboxes, _ = py_utils.GetShape(gt_bboxes, 2)

        # Compute similarity score and reduce max by anchors and by ground-truth.
        similarity_score = similarity_fn(anchor_bboxes, gt_bboxes)
        similarity_score = py_utils.HasShape(
            similarity_score, [num_anchor_bboxes, num_gt_bboxes])

        # Reduce over ground-truth boxes, so we have the max score per anchor.
        anchor_max_score = tf.reduce_max(similarity_score, axis=1)
        anchor_max_idx = tf.argmax(similarity_score, axis=1)

        if force_match:
            # Reduce over anchors, so we have the max score per ground truth box.
            gt_max_score = tf.reduce_max(similarity_score,
                                         axis=0,
                                         keepdims=True)

            # Force matches occur when the top matching gt bbox for an anchor is the
            # top matching anchor for the gt bbox. When force matching, we match
            # these boxes as long as their similarity score exceeds 0.
            force_matches = (
                tf.equal(similarity_score, gt_max_score)
                & tf.equal(similarity_score, anchor_max_score[..., tf.newaxis])
                & tf.greater(similarity_score, 0.)
                & tf.cast(gt_bboxes_mask[tf.newaxis, ...], tf.bool))
            force_match_indicator = tf.reduce_any(force_matches, axis=1)
            force_match_idx = tf.argmax(tf.cast(force_matches, tf.int32),
                                        axis=1)

            # In assigning foreground/background anchors later, force_match_indicator
            # is used to determine which anchors are force foreground, and the index
            # assigned will be taken from anchor_max_idx.

            # Force matchers must also be the max scoring gt bbox per anchor.
            # We overwrite anchor_max_idx to ensure that the right match is done.
            anchor_max_idx = tf.where(force_match_indicator, force_match_idx,
                                      anchor_max_idx)

        # Ensure that max score boxes are not padded boxes by setting score to 0
        # for boxes that are padded.
        gathered_mask = tf.batch_gather(gt_bboxes_mask, anchor_max_idx)
        anchor_max_score = tf.where(tf.equal(gathered_mask, 1),
                                    anchor_max_score,
                                    tf.zeros_like(anchor_max_score))

        # Boolean tensors corresponding to whether an anchor is background or
        # foreground based on thresholding.
        background_anchors = tf.less_equal(anchor_max_score,
                                           background_assignment_threshold)
        foreground_anchors = tf.greater_equal(anchor_max_score,
                                              foreground_assignment_threshold)
        if force_match:
            # Background anchors are below threshold and not force matches.
            background_anchors &= ~force_match_indicator
            # Foreground anchors are above thresholds or force matches.
            foreground_anchors |= force_match_indicator

        # Add dummy background bbox to gt_boxes to facilitate batch gather.
        dummy_bbox = tf.constant([[0, 0, 0, 1, 1, 1, 0]], dtype=tf.float32)

        # Since we are concatenating the dummy bbox, the index corresponds to the
        # number of boxes.
        dummy_bbox_idx = py_utils.GetShape(gt_bboxes, 1)[0]

        gt_bboxes = tf.concat([gt_bboxes, dummy_bbox], axis=0)
        gt_bboxes_labels = tf.concat([gt_bboxes_labels, [background_class_id]],
                                     axis=0)

        # Gather indices so that all foreground boxes are gathered from gt_bboxes,
        # while all background and ignore boxes gather the dummy_bbox.
        anchor_gather_idx = tf.where(
            foreground_anchors, anchor_max_idx,
            tf.constant(dummy_bbox_idx,
                        shape=py_utils.GetShape(anchor_max_idx),
                        dtype=anchor_max_idx.dtype))

        # Gather the bboxes and weights.
        assigned_gt_bbox = tf.batch_gather(gt_bboxes, anchor_gather_idx)
        assigned_gt_labels = tf.batch_gather(gt_bboxes_labels,
                                             anchor_gather_idx)

        # Set masks for classification and regression losses.
        assigned_cls_mask = tf.cast(background_anchors | foreground_anchors,
                                    tf.float32)
        assigned_reg_mask = tf.cast(foreground_anchors, tf.float32)

        # Set assigned_gt_idx such that dummy boxes have idx = -1.
        assigned_gt_idx = tf.where(tf.equal(anchor_gather_idx, dummy_bbox_idx),
                                   tf.ones_like(anchor_gather_idx) * -1,
                                   anchor_gather_idx)
        assigned_gt_idx = tf.cast(assigned_gt_idx, tf.int32)

        return py_utils.NestedMap(
            assigned_gt_idx=assigned_gt_idx,
            assigned_gt_bbox=assigned_gt_bbox,
            assigned_gt_similarity_score=anchor_max_score,
            assigned_gt_labels=assigned_gt_labels,
            assigned_cls_mask=assigned_cls_mask,
            assigned_reg_mask=assigned_reg_mask)
Beispiel #7
0
    def _Extract(self, features):
        p = self.params

        source_id = py_utils.HasShape(features['image/source_id'], [])
        xmin = _Dense(features['object/image/bbox/xmin'])
        xmax = _Dense(features['object/image/bbox/xmax'])
        ymin = _Dense(features['object/image/bbox/ymin'])
        ymax = _Dense(features['object/image/bbox/ymax'])

        # 2d bounding box in image coordinates.
        bboxes = tf.stack([ymin, xmin, ymax, xmax], axis=1)
        bboxes_count = tf.shape(bboxes)[0]
        bboxes = py_utils.PadOrTrimTo(bboxes, [p.max_num_objects, 4])

        bboxes_padding = 1.0 - py_utils.PadOrTrimTo(tf.ones([bboxes_count]),
                                                    [p.max_num_objects])

        dim_xyz = tf.reshape(_Dense(features['object/velo/bbox/dim_xyz']),
                             [-1, 3])
        loc_xyz = tf.reshape(_Dense(features['object/velo/bbox/xyz']), [-1, 3])
        phi = tf.reshape(_Dense(features['object/velo/bbox/phi']), [-1, 1])
        # bboxes_3d is in [x, y, z, dx, dy, dz, phi].
        bboxes_3d = tf.concat([loc_xyz, dim_xyz, phi], axis=1)

        cx, cy, _, dx, dy, _, _ = tf.unstack(bboxes_3d, num=7, axis=-1)
        bboxes_td = tf.stack([
            cy - dy / 2,
            cx - dx / 2,
            cy + dy / 2,
            cx + dx / 2,
        ],
                             axis=-1)  # pyformat: disable
        bboxes_td = py_utils.PadOrTrimTo(bboxes_td, [p.max_num_objects, 4])

        has_3d_info = tf.cast(_Dense(features['object/has_3d_info']),
                              tf.float32)
        bboxes_3d_mask = py_utils.PadOrTrimTo(has_3d_info, [p.max_num_objects])
        bboxes_td_mask = bboxes_3d_mask

        # Fill in difficulties from bounding box height, truncation and occlusion.
        bb_height = ymax - ymin
        box_image_height = py_utils.PadOrTrimTo(bb_height, [p.max_num_objects])
        box_image_height *= bboxes_3d_mask

        # 0 to 3 indicating occlusion level. 0 means fully visible, 1 means partly,
        occlusion = tf.reshape(_Dense(features['object/occlusion']), [-1])
        occlusion = tf.cast(occlusion, tf.float32)
        occlusion = py_utils.PadOrTrimTo(occlusion, [p.max_num_objects])
        occlusion *= bboxes_3d_mask

        # Truncation: 0 -> not truncated, 1.0 -> truncated
        truncation = tf.reshape(_Dense(features['object/truncation']), [-1])
        truncation = py_utils.PadOrTrimTo(truncation, [p.max_num_objects])
        truncation *= bboxes_3d_mask

        difficulties = ComputeKITTIDifficulties(box_image_height, occlusion,
                                                truncation)
        difficulties = py_utils.PadOrTrimTo(difficulties, [p.max_num_objects])

        # Make a batch axis to call BBoxCorners, and take the first result back.
        bbox3d_corners = geometry.BBoxCorners(bboxes_3d[tf.newaxis, ...])[0]

        # Project the 3D bbox to the image plane.
        velo_to_image_plane = features['transform/velo_to_image_plane']
        bboxes3d_proj_to_image_plane = geometry.PointsToImagePlane(
            tf.reshape(bbox3d_corners, [-1, 3]), velo_to_image_plane)

        # Output is [num_objects, 8 corners per object, (x, y)].
        bboxes3d_proj_to_image_plane = tf.reshape(bboxes3d_proj_to_image_plane,
                                                  [-1, 8, 2])
        bboxes3d_proj_to_image_plane = py_utils.PadOrTrimTo(
            bboxes3d_proj_to_image_plane, [p.max_num_objects, 8, 2])

        texts = features['object/label'].values
        labels = ops.static_map_string_int(x=texts,
                                           keys=self.KITTI_CLASS_NAMES)

        labels = py_utils.PadOrTrimTo(labels, [p.max_num_objects])
        texts = py_utils.PadOrTrimTo(texts, [p.max_num_objects])

        # Filter labels by setting bboxes_padding, bboxes_3d_mask, and
        # bboxes_td_mask appropriately.
        if p.filter_labels is not None:
            valid_labels = tf.constant([p.filter_labels])
            bbox_mask = tf.reduce_any(tf.equal(tf.expand_dims(labels, 1),
                                               valid_labels),
                                      axis=1)
            bbox_mask = tf.cast(bbox_mask, tf.float32)
            bboxes_padding = 1 - bbox_mask * (1 - bboxes_padding)
            filtered_bboxes_3d_mask = bboxes_3d_mask * bbox_mask
            bboxes_td_mask *= bbox_mask
        else:
            filtered_bboxes_3d_mask = bboxes_3d_mask

        # Placeholder for counting the number of laser points that reside within
        # each 3-d bounding box. This must be filled in outside of this function
        # based on the loaded 3-d laser points.
        bboxes_3d_num_points = tf.zeros([p.max_num_objects], dtype=tf.int32)
        bboxes_3d_num_points = py_utils.PadOrTrimTo(bboxes_3d_num_points,
                                                    [p.max_num_objects])

        # Pad bboxes_3d.
        bboxes_3d = py_utils.PadOrTrimTo(bboxes_3d, [p.max_num_objects, 7])

        return py_utils.NestedMap(
            source_id=source_id,
            bboxes_count=bboxes_count,
            bboxes=bboxes,
            bboxes_padding=bboxes_padding,
            bboxes_3d=bboxes_3d,
            bboxes_3d_mask=filtered_bboxes_3d_mask,
            unfiltered_bboxes_3d_mask=bboxes_3d_mask,
            bboxes3d_proj_to_image_plane=bboxes3d_proj_to_image_plane,
            bboxes_td=bboxes_td,
            bboxes_td_mask=bboxes_td_mask,
            bboxes_3d_num_points=bboxes_3d_num_points,
            labels=labels,
            texts=texts,
            box_image_height=box_image_height,
            occlusion=occlusion,
            truncation=truncation,
            difficulties=difficulties)
  def _Extract(self, features):
    p = self.params
    # Label values match the proto enum car.open_dataset.Label.Type. The value
    # range is [1..4] for non-background labels.
    labels = tf.cast(_Dense(features['labels']), tf.int32)
    labels = py_utils.PadOrTrimTo(labels, [p.max_num_objects])
    label_ids = tf.reshape(_Dense(features['label_ids'], ''), [-1])
    label_ids = py_utils.PadOrTrimTo(label_ids, [p.max_num_objects], '')
    bboxes_3d = tf.reshape(_Dense(features['bboxes_3d']), [-1, 7])
    bboxes_3d_mask = tf.ones([tf.shape(bboxes_3d)[0]])
    bboxes_3d_num_points = tf.cast(
        _Dense(features['bboxes_3d_num_points']), tf.int32)
    bboxes_3d = py_utils.PadOrTrimTo(bboxes_3d, [p.max_num_objects, 7])
    bboxes_3d_mask = py_utils.PadOrTrimTo(bboxes_3d_mask, [p.max_num_objects])
    bboxes_3d_num_points = py_utils.PadOrTrimTo(bboxes_3d_num_points,
                                                [p.max_num_objects])
    label_metadata = tf.reshape(_Dense(features['label_metadata']), [-1, 4])
    label_metadata = py_utils.PadOrTrimTo(label_metadata,
                                          [p.max_num_objects, 4])

    detection_difficulties = py_utils.PadOrTrimTo(
        tf.cast(_Dense(features['detection_difficulties']), tf.int32),
        [p.max_num_objects])
    single_frame_detection_difficulties = py_utils.PadOrTrimTo(
        tf.cast(
            _Dense(features['single_frame_detection_difficulties']), tf.int32),
        [p.max_num_objects])
    tracking_difficulties = py_utils.PadOrTrimTo(
        tf.cast(_Dense(features['tracking_difficulties']), tf.int32),
        [p.max_num_objects])
    unfiltered_bboxes_3d_mask = bboxes_3d_mask

    if p.filter_labels:
      valid_labels = tf.constant([p.filter_labels])
      bbox_mask = tf.reduce_any(
          tf.equal(tf.expand_dims(labels, 1), valid_labels), axis=1)
      bboxes_3d_mask *= tf.cast(bbox_mask, tf.float32)

    outputs = {
        'labels':
            labels,
        'label_ids':
            label_ids,
        'detection_difficulties':
            detection_difficulties,
        'single_frame_detection_difficulties':
            single_frame_detection_difficulties,
        'tracking_difficulties':
            tracking_difficulties,
        'bboxes_3d':
            bboxes_3d,
        'bboxes_3d_mask':
            bboxes_3d_mask,
        'bboxes_3d_num_points':
            bboxes_3d_num_points,
        'unfiltered_bboxes_3d_mask':
            unfiltered_bboxes_3d_mask,
        'speed':
            label_metadata[:, :2],
        'acceleration':
            label_metadata[:, 2:],
    }

    return py_utils.NestedMap(outputs)
Beispiel #9
0
 def _ApplyMass(task_id):
     mass_task_ids = tf.constant(self.params.mass_task_ids,
                                 dtype=tf.int32)
     return tf.reduce_any(tf.equal(task_id, mass_task_ids))
Beispiel #10
0
    def StochasticBeamSearchDecodeBiased(self,
                                         encoder_outputs,
                                         biased,
                                         stochastic,
                                         num_hyps_per_beam_override=0):
        """Performs beam search based decoding with optional advanced features.

    If `biased` is true, the target biasing feature is added. `encoder_outputs`
    must include the following auxiliary inputs:

    - targets.labels: An int tensor of shape [batch, seq] that represents target
      labels to bias beam search towards.
    - targets.paddings: A 0/1 float tensor of shape [batch, seq] where 1 means
      that the corresponding element of targets.labels is a padding.
    - targets.weights: A float tensor of shape [batch, seq] that represents
      biasing weights. 1.0 means forced-decoding.

    If `stochastic` is true, the stochastic beam search feature
    (https://arxiv.org/pdf/1903.06059.pdf) is added. Also, top-p filtering (i.e.
    sampling only from the top-p probability mass of the token distribution) is
    performed to ensure the quality of samples. Note that there are slight
    differences from the implementation in the original paper, e.g., length
    normalization and coverage penalty are applied to the perturbed
    probabilities. `encoder_outputs` must include the following auxiliary
    inputs:

    - stochastic_beam_search.top_p_threshold: A float tensor of shape [batch]
      that represents the thresholds of top-p filtering. Must satisfy
      0 < top_p_threshold <= 1. If the value is low, the quality of samples will
      be high but the diversity will be low. If the value is high, the quality
      of samples will be low but the diversity will be high. Stochastic beam
      search is performed only if top_p_threshold > 0 for some batch items.
    - stochastic_beam_search.seed: An int tensor of shape [batch] the represents
      the random seeds. If the seeds are the same, the same samples are drawn.
    - stochastic_beam_search.src_ids: An int tensor of shape [batch, src_seq]
      that represents source IDs. Used for turning the random seed into a
      function of source IDs.
    - stochastic_beam_search.src_paddings: A 0/1 float tensor of shape [batch,
      src_seq] where 1 means that the corresponding element of
      stochastic_beam_search.src_ids is a padding.

    Args:
      encoder_outputs: a NestedMap computed by encoder.
      biased: If true, add the target decoding feature.
      stochastic: If true, add the stochastic beam search feature.
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.

    Returns:
      BeamSearchDecodeOutput, a namedtuple containing the decode results.
    """
        p = self.params

        if biased:
            targets = encoder_outputs.targets
            targets.weights *= (1.0 - targets.paddings)

            def PadToTargetSeqLen(tensor, constant):
                length = tf.shape(tensor)[1]
                pad = tf.maximum(0, p.beam_search.target_seq_len - length)
                return tf.pad(tensor, [[0, 0], [0, pad]],
                              constant_values=constant)

            targets.labels = PadToTargetSeqLen(targets.labels, 0)
            targets.weights = PadToTargetSeqLen(targets.weights, 0)

        if stochastic:
            # Determine whether to perform stochastic beam search.
            stochastic_beam_search = encoder_outputs.stochastic_beam_search
            stochastic_beam_search.enable = tf.reduce_any(
                tf.greater(stochastic_beam_search.top_p_threshold, 0.0))

        return self.beam_search.BeamSearchDecode(
            self.theta, encoder_outputs, num_hyps_per_beam_override,
            self._WrapInitBeamSearchStateCallback(biased, stochastic),
            self._WrapPreBeamSearchStepCallback(biased, stochastic),
            self._WrapPostBeamSearchStepCallback(stochastic))