예제 #1
0
 def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name):
   over_capacity = tf.reduce_sum(
       tf.cast(
           tf.greater_equal(mask * position_in_expert, capacity), mask.dtype))
   over_capacity_ratio = over_capacity / tf.reduce_sum(mask)
   py_utils.AddTpuSummaryTensor(name, over_capacity_ratio)
   tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean')
예제 #2
0
        def _MergeCandidates(tokens, candidates):
            """Merge in the reverse binary tree."""
            best_id = tf.argmin(candidates, output_type=tf.int32)
            # Perform the merge at position best_id.
            tokens = tf.concat([
                tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:]
            ],
                               axis=0)
            # Recompute the merge candidates.
            # Only the neighbors of best_id need to be recomputed.
            empty = tf.zeros([0], dtype=candidates.dtype)

            def _MergeLeft():
                return tf.concat([
                    candidates[:best_id - 1],
                    _MergeOneToken(tokens, best_id - 1)
                ],
                                 axis=0)

            left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty,
                                      _MergeLeft)

            def _MergeRight():
                return tf.concat([
                    _MergeOneToken(tokens, best_id), candidates[best_id + 2:]
                ],
                                 axis=0)

            right_candidates = tf.cond(
                tf.greater_equal(best_id,
                                 tf.size(tokens) - 1), lambda: empty,
                _MergeRight)

            candidates = tf.concat([left_candidates, right_candidates], axis=0)
            return tokens, candidates
예제 #3
0
def IsWithinBBox3D(points_3d, bboxes_3d):
    """Checks if points are within a 3-d bbox.

  Args:
    points_3d: [num_points, 3] float32 Tensor specifying points in 3-d space as
      [x, y, z] coordinates.
    bboxes_3d: [num_bboxes, 7] float32 Tensor specifying a 3-d bboxes specified
      as [x, y, z, dx, dy, dz, phi] where x, y and z is the center of the box.

  Returns:
    boolean Tensor of shape [num_points, num_bboxes] indicating whether the
    points belong within each box.
  """
    points_3d = py_utils.HasRank(points_3d, 2)
    points_3d = py_utils.HasShape(points_3d, [-1, 3])
    num_points, _ = py_utils.GetShape(points_3d, 2)

    bboxes_3d = py_utils.HasRank(bboxes_3d, 2)
    bboxes_3d = py_utils.HasShape(bboxes_3d, [-1, 7])
    num_bboxes, _ = py_utils.GetShape(bboxes_3d, 2)

    # Compute the 3-D corners of the bounding boxes.
    bboxes_3d_b = tf.expand_dims(bboxes_3d, 0)
    bbox_corners = BBoxCorners(bboxes_3d_b)
    bbox_corners = py_utils.HasShape(bbox_corners, [1, -1, 8, 3])
    # First four points are the top of the bounding box.
    # Counter-clockwise arrangement of points specifying 2-d Euclidean box.
    #   (x0, y1) <--- (x1, y1)
    #                    ^
    #                    |
    #                    |
    #   (x0, y0) ---> (x1, y0)
    bboxes_2d_corners = bbox_corners[0, :, 0:4, 0:2]
    bboxes_2d_corners = py_utils.HasShape(bboxes_2d_corners, [-1, 4, 2])
    # Determine if points lie within 2-D (x, y) plane for all bounding boxes.
    points_2d = points_3d[:, :2]
    is_inside_2d = IsWithinBBox(points_2d, bboxes_2d_corners)
    is_inside_2d = py_utils.HasShape(is_inside_2d, [num_points, num_bboxes])

    # Determine if points lie with the z-dimension for all bounding boxes.
    [_, _, z, _, _, dz, _] = tf.split(bboxes_3d, 7, axis=-1)

    def _ComputeLimits(center, width):
        left = center - width / 2.0
        right = center + width / 2.0
        return left, right

    z0, z1 = _ComputeLimits(z, dz)
    z_points = tf.expand_dims(points_3d[:, 2], -1)

    def _BroadcastAcrossPoints(z):
        return tf.transpose(tf.tile(z, [1, num_points]))

    is_inside_z = tf.logical_and(
        tf.less_equal(z_points, _BroadcastAcrossPoints(z1)),
        tf.greater_equal(z_points, _BroadcastAcrossPoints(z0)))
    is_inside_z = py_utils.HasShape(is_inside_z, [num_points, num_bboxes])

    return tf.logical_and(is_inside_z, is_inside_2d)
예제 #4
0
파일: pruning.py 프로젝트: snsun/lingvo
 def maybe_update_masks():
     with tf.name_scope(self._spec.name):
         is_step_within_pruning_range = tf.logical_and(
             tf.greater_equal(self._global_step,
                              self._spec.begin_pruning_step),
             # If end_pruning_step is negative, keep pruning forever!
             tf.logical_or(
                 tf.less_equal(self._global_step,
                               self._spec.end_pruning_step),
                 tf.less(self._spec.end_pruning_step, 0)))
         is_pruning_step = tf.less_equal(
             tf.add(self._last_update_step,
                    self._spec.pruning_frequency), self._global_step)
         return tf.logical_and(is_step_within_pruning_range,
                               is_pruning_step)
예제 #5
0
def KnnIndices(points, query_points, k, valid_num=None, max_distance=None):
    """k-nearest neighbors of query_points in points.

  The caller should ensure that points[i, :valid_num[i], :] are the non-padding
  points.

  Padding is returned alongside indices. Non-padded points are guaranteed to
  be unique (non-repeated) points from original non-padded points.

  Padded points arise due to either a lack of points (k exceeds valid_num)
  or points are too far away (exceeds max distance).

  TODO(weihan,jngiam): For backwards compatibility with PointCNN, if there are
  fewer than k points to select (possibly because of valid_num), the points
  selected will first consist of those in the non-padded points, and
  then those from the padded points. This assumes that the padded points are
  duplications of the original points. PointCNN should be updated to respect
  padding.

  The auxiliary input 'valid_num' marks the number of non-padding points in each
  sample. This is needed because we randomly duplicated points to make the input
  fix-sized, we want search for k-NN in non-padding points first otherwise the
  result may degenerate to be k-duplications of the query point itself.

  Args:
    points: tensor of shape [N, P1, dims].
    query_points: tensor of shape [N, P2, dims]
    k: Integer.
    valid_num: tensor of shape [N,]
    max_distance: float representing the maximum distance that each neighbor can
      be. If there are no points within the distance, then the closest point is
      returned (regardless of distance). If this is set to None, then
      max_distance is not used.

  Returns:
    A pair of tensors:

    - indices: tensor of shape [N, P2, k].
    - padding: tensor of shape [N, P2 ,k] where 1 represents a padded point, and
      0 represents an unpadded (real) point.

  """
    p1 = tf.shape(points)[1]
    padding = None
    if valid_num is not None:
        padding = tf.greater_equal(tf.range(p1), tf.expand_dims(
            valid_num, -1))  # [N, P1], False/True padding
    return NeighborhoodIndices(points, query_points, k, padding, max_distance)
예제 #6
0
파일: pruning.py 프로젝트: snsun/lingvo
    def _update_mask(self, weights, threshold):
        """Updates the mask for a given weight tensor.

    This functions first computes the cdf of the weight tensor, and estimates
    the threshold value such that 'desired_sparsity' fraction of weights
    have magnitude less than the threshold.

    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if sparsity is not defined
    """
        if self._sparsity is None:
            raise ValueError('Sparsity variable undefined')

        sparsity = self._get_sparsity(weights.op.name)
        with tf.name_scope(weights.op.name + '_pruning_ops'):
            abs_weights = tf.abs(weights)
            k = tf.cast(
                tf.round(
                    tf.cast(tf.size(abs_weights), tf.float32) *
                    (1 - sparsity)), tf.int32)
            # Sort the entire array
            values, _ = tf.nn.top_k(tf.reshape(abs_weights, [-1]),
                                    k=tf.size(abs_weights))
            # Grab the (k-1) th value
            current_threshold = tf.gather(values, k - 1)
            smoothed_threshold = tf.add_n([
                tf.multiply(current_threshold, 1 - self._spec.threshold_decay),
                tf.multiply(threshold, self._spec.threshold_decay)
            ])

            new_mask = tf.cast(
                tf.greater_equal(abs_weights, smoothed_threshold), tf.float32)

        return smoothed_threshold, new_mask
예제 #7
0
  def Decode(self, input_batch):
    """Decode an input batch, computing predicted bboxes from residuals."""
    p = self.params

    predictions = self.ComputePredictions(self.theta, input_batch)
    bboxes_and_logits = self._BBoxesAndLogits(input_batch, predictions)
    predicted_bboxes = bboxes_and_logits.predicted_bboxes
    batch_size, num_bboxes, _ = py_utils.GetShape(predicted_bboxes, 3)
    classification_logits = bboxes_and_logits.classification_logits
    classification_logits = py_utils.HasShape(
        classification_logits, [batch_size, num_bboxes, p.num_classes])

    classification_scores = tf.sigmoid(classification_logits)

    _, per_example_dict = self.ComputeLoss(self.theta, predictions, input_batch)
    if 'score_scaler' in per_example_dict:
      classification_scores *= per_example_dict['score_scaler']

    with tf.device('/cpu:0'):
      # Decode the predicted bboxes, performing NMS.
      per_cls_idxs, per_cls_bboxes, per_cls_bbox_scores, per_cls_valid_mask = (
          detection_decoder.DecodeWithNMS(
              predicted_bboxes,
              classification_scores,
              nms_iou_threshold=p.nms_iou_threshold,
              score_threshold=p.nms_score_threshold,
              max_boxes_per_class=p.max_nms_boxes,
              use_oriented_per_class_nms=p.use_oriented_per_class_nms))

      # per_cls_valid_mask is [batch, num_classes, num_boxes] Tensor that
      # indicates which boxes were selected by NMS. Each example will have a
      # different number of chosen bboxes, so the mask is present to allow us
      # to keep the boxes as a batched dense Tensor.
      #
      # We mask the scores by the per_cls_valid_mask so that none of these boxes
      # will be interpreted as valid.
      per_cls_bbox_scores *= per_cls_valid_mask
      visualization_weights = py_utils.HasShape(
          per_cls_bbox_scores, [batch_size, p.num_classes, p.max_nms_boxes])

      # For top down visualization, filter boxes whose scores are not above the
      # visualization threshold.
      visualization_weights = tf.where(
          tf.greater_equal(visualization_weights,
                           p.visualization_classification_threshold),
          visualization_weights, tf.zeros_like(visualization_weights))

    model_outputs = py_utils.NestedMap()
    model_outputs.per_class_predicted_bboxes = per_cls_bboxes
    model_outputs.per_class_predicted_bbox_scores = per_cls_bbox_scores
    model_outputs.per_class_valid_mask = per_cls_valid_mask

    decoder_outputs = py_utils.NestedMap({
        'per_class_predicted_bboxes': per_cls_bboxes,
        'per_class_predicted_bbox_scores': per_cls_bbox_scores,
        'per_class_valid_mask': per_cls_valid_mask,
        'visualization_weights': visualization_weights,
    })

    if p.decode_include_residuals:
      # Including the residuals in the decoder output makes it possible to save
      # the outputs for further analysis. Note that we ensure that the outputs
      # match the per-class NMS output format of [batch, num_classes, ...].
      def _ReshapeGather(tensor):
        """Reshapes tensor and then gathers using the nms indices."""
        tensor = tf.gather(
            tf.reshape(tensor, [batch_size, num_bboxes, -1]),
            per_cls_idxs,
            batch_dims=1)
        if not p.use_oriented_per_class_nms:
          # Tile so that the data fits the expected per class shape of
          # [batch_size, num_classes, ...]. When *not* using oriented NMS, the
          # num_classes dimension will be missing since the indices will not
          # have it.
          tensor = tf.tile(tensor[:, tf.newaxis, :, :],
                           [1, p.num_classes, 1, 1])
        return tensor

      decoder_outputs.update({
          'per_class_gt_residuals':
              _ReshapeGather(input_batch.anchor_localization_residuals),
          'per_class_gt_labels':
              _ReshapeGather(input_batch.assigned_gt_labels),
          'per_class_residuals':
              _ReshapeGather(predictions.residuals),
          'per_class_logits':
              _ReshapeGather(predictions.classification_logits),
          'per_class_anchor_boxes':
              _ReshapeGather(input_batch.anchor_bboxes),
      })

    decoder_outputs.update(
        self.output_decoder.ProcessOutputs(input_batch, model_outputs))

    # Produce global step as an output (which is the step
    # of the checkpoint being decoded.)
    decoder_outputs.global_step = py_utils.GetGlobalStep()

    return decoder_outputs
예제 #8
0
def NeighborhoodIndices(points,
                        query_points,
                        k,
                        points_padding=None,
                        max_distance=None,
                        sample_neighbors_uniformly=False):
    """Get indices to k-neighbors of query_points in points.

  Padding is returned along-side indices. Non-padded points are guaranteed to
  be unique (non-repeated) points from original non-padded points.

  Padded points arise due to either a lack of points (k exceeds the number
  of original non-padded points) or points are too far away (exceeds max
  distance).

  Note: Padded point indices may refer to padded points from the original, or
  may be duplicates of the closest point.

  TODO(weihan,jngiam): PointCNN implementation makes an assumption that padded
  points are repeated points from the original points. This behavior is
  maintained here, but we should update PointCNN to respect indices paddings.

  Args:
    points: tensor of shape [N, P1, dims].
    query_points: tensor of shape [N, P2, dims]
    k: Integer.
    points_padding: optional tensor of shape [N, P1] containing True/1.0 iff the
      point is a padded point. if None, then all points are considered real
      points.
    max_distance: float representing the maximum distance that each neighbor can
      be. If there are no points within the distance, then the closest point is
      returned (regardless of distance). If this is set to None, then no
      filtering by distance is performed.
    sample_neighbors_uniformly: boolean specifying whether to sample neighbors
      uniformly if they are within max distance.

  Returns:
    A pair of tensors:

    - indices: tensor of shape [N, P2, k].
    - padding: tensor of shape [N, P2, k] where 1 represents a padded point, and
      0 represents an unpadded (real) point.

  """
    n, p1 = py_utils.GetShape(points, 2)
    query_points = py_utils.HasShape(query_points, [n, -1, -1])
    _, p2 = py_utils.GetShape(query_points, 2)

    # Compute pair-wise squared distances.
    # Note that dist_mat contains the squared distance (without sqrt). Thus, when
    # using max_distance, we will need to square max_distance to make sure it's
    # in the same units.
    dist_mat = SquaredDistanceMatrix(query_points, points)
    dist_mat = py_utils.HasShape(dist_mat, [n, p2, p1])

    # Add a large scalar to the distances for padded points.
    # dist_mat[i, j, k] will be:
    #   if k < valid_num[i]: distance between points[i, k] and query_points[i, j]
    #   otherwise:           a large scalar added to dist_mat[i, j, k]
    if points_padding is not None:
        points_padding = tf.cast(tf.expand_dims(points_padding, 1), tf.float32)
        points_padding = py_utils.HasShape(points_padding, [n, 1, p1])
        large_scalar = tf.reduce_max(dist_mat) + 1
        dist_mat += points_padding * large_scalar

    # To perform sampling neighbors uniformly efficiently, we set all neighbors
    # that are within the distance threshold to have distances be drawn uniformly
    # at random. Using top_k with this enables selecting a random set quickly
    # without replacement.
    if sample_neighbors_uniformly:
        if max_distance is not None:
            mask_by_distance = tf.less_equal(dist_mat, max_distance**2)
            dist_mat = tf.where(
                mask_by_distance,
                tf.square(max_distance) *
                tf.random_uniform(tf.shape(dist_mat)), dist_mat)
        else:
            raise ValueError(
                'Uniform sampling requires specifying max_distance.')

    top_k_dist, indices = tf.nn.top_k(-dist_mat, k=k,
                                      sorted=True)  # N x P2 x K

    # Set padding using top_k_dist; padded points will have distance exceeding
    # the large_scalar.
    if points_padding is not None:
        paddings = tf.greater_equal(-top_k_dist, large_scalar)
    else:
        paddings = tf.zeros_like(top_k_dist, dtype=tf.bool)

    # Filter by max_distances by setting all indices that exceed the max_distance
    # to the closest point.
    if max_distance is not None:
        # Mask is true for points that are further than max_distance.
        mask_by_distance = tf.greater(-top_k_dist, tf.square(max_distance))
        closest_idx = tf.tile(indices[:, :, :1], [1, 1, k])
        indices = tf.where(mask_by_distance, closest_idx, indices)
        paddings |= mask_by_distance

    indices = tf.reshape(indices, [n, p2, k])
    paddings = tf.cast(paddings, tf.float32)

    return indices, paddings
예제 #9
0
    def Decode(self, input_batch):
        """Decode an input batch, computing predicted bboxes from residuals."""
        p = self.params

        bboxes_and_logits = self._BBoxesAndLogits(input_batch)
        predicted_bboxes = bboxes_and_logits.predicted_bboxes
        batch_size, num_bboxes, _ = py_utils.GetShape(predicted_bboxes, 3)
        classification_logits = bboxes_and_logits.classification_logits
        classification_logits = py_utils.HasShape(
            classification_logits, [batch_size, num_bboxes, p.num_classes])

        classification_scores = tf.sigmoid(classification_logits)

        # Score scaler.
        if 'score_scaler' in bboxes_and_logits:
            classification_scores *= bboxes_and_logits.score_scaler

        with tf.device('/cpu:0'):
            # Decode the predicted bboxes, performing NMS.
            per_cls_bboxes, per_cls_bbox_scores, per_cls_valid_mask = (
                detection_decoder.DecodeWithNMS(
                    predicted_bboxes,
                    classification_scores,
                    nms_iou_threshold=p.nms_iou_threshold,
                    score_threshold=p.nms_score_threshold,
                    max_boxes_per_class=p.max_nms_boxes,
                    use_oriented_per_class_nms=p.use_oriented_per_class_nms))

            # per_cls_valid_mask is [batch, num_classes, num_boxes] Tensor that
            # indicates which boxes were selected by NMS. Each example will have a
            # different number of chosen bboxes, so the mask is present to allow us
            # to keep the boxes as a batched dense Tensor.
            #
            # We mask the scores by the per_cls_valid_mask so that none of these boxes
            # will be interpreted as valid.
            per_cls_bbox_scores *= per_cls_valid_mask
            visualization_weights = py_utils.HasShape(
                per_cls_bbox_scores,
                [batch_size, p.num_classes, p.max_nms_boxes])

            # For top down visualization, filter boxes whose scores are not above the
            # visualization threshold.
            visualization_weights = tf.where(
                tf.greater_equal(visualization_weights,
                                 p.visualization_classification_threshold),
                visualization_weights, tf.zeros_like(visualization_weights))

        model_outputs = py_utils.NestedMap()
        model_outputs.per_class_predicted_bboxes = per_cls_bboxes
        model_outputs.per_class_predicted_bbox_scores = per_cls_bbox_scores
        model_outputs.per_class_valid_mask = per_cls_valid_mask

        decoder_outputs = py_utils.NestedMap({
            'per_class_predicted_bboxes':
            per_cls_bboxes,
            'per_class_predicted_bbox_scores':
            per_cls_bbox_scores,
            'per_class_valid_mask':
            per_cls_valid_mask,
            'visualization_weights':
            visualization_weights,
        })

        decoder_outputs.update(
            self.output_decoder.ProcessOutputs(input_batch, model_outputs))

        # Produce global step as an output (which is the step
        # of the checkpoint being decoded.)
        decoder_outputs.global_step = py_utils.GetGlobalStep()

        return decoder_outputs
예제 #10
0
    def AssignAnchors(self,
                      anchor_bboxes,
                      gt_bboxes,
                      gt_bboxes_labels,
                      gt_bboxes_mask,
                      foreground_assignment_threshold=0.5,
                      background_assignment_threshold=0.35,
                      background_class_id=0,
                      force_match=True,
                      similarity_fn=None):
        """Assigns anchors to bboxes using a similarity function (SSD-based).

    Each anchor box is assigned to the top matching ground truth box.
    Ground truth boxes can be assigned to multiple anchor boxes.

    Assignments can result in 3 outcomes:

      - Positive assignment (if score >= foreground_assignment_threshold):
        assigned_gt_labels will reflect the assigned box label and
        assigned_cls_mask will be set to 1.0
      - Background assignment (if score <= background_assignment_threshold):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 1.0
      - Ignore assignment (otherwise):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 0.0

    The detection loss function would usually:

      - Use assigned_cls_mask for weighting the classification loss. The mask
        is set such that the loss applies to foreground and background
        assignments only - ignored anchors will be set to 0.
      - Use assigned_reg_mask for weighting the regression loss. The mask is set
        such that the loss applies to foreground assignments only.

    The thresholds (foreground_assignment_threshold and
    background_assignment_threshold) should be tuned per dataset.

    TODO(jngiam): Consider having a separate threshold for regression boxes; a
    separate threshold is used in PointRCNN.

    Args:
      anchor_bboxes: tf.float32. [A, 7], where [..., :] corresponds to box
        parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes: tf.float32. [G, 7], where [..., :] corresponds to ground truth
        box parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes_labels: tensor with shape [G]. Ground truth labels for each
        bounding box.
      gt_bboxes_mask: tensor with shape [G]. Mask for ground truth boxes, 1 iff
        the gt_bbox is a real bbox.
      foreground_assignment_threshold: Similarity score threshold for assigning
        foreground bounding boxes; scores need to be >=
        foreground_assignment_threshold to be assigned to foreground.
      background_assignment_threshold: Similarity score threshold for assigning
        background bounding boxes; scores need to be <=
        background_assignment_threshold to be assigned to background.
      background_class_id: class id to be assigned to anchors_gt_class if no
        anchor boxes match.
      force_match: Boolean specifying if force matching is enabled. If
        force matching is enabled, then matched anchors which are also the
        highest scoring with a ground-truth box are considered foreground
        matches as long as their similarity score > 0.
      similarity_fn: Function that computes the a similarity score (e.g., IOU)
        between pairs of bounding boxes. This function should take in two
        tensors corresponding to anchor and ground-truth bboxes, and return a
        matrix [A, G] with the similarity score between each pair of bboxes. The
        score must be non-negative, with greater scores representing more
        similar. The fore/background_assignment_thresholds will be applied to
        this score to determine if the an anchor is foreground, background or
        ignored. If set to None, the function will default to IOU2DRotatedBoxes.

    Returns:
      NestedMap with the following keys

      - assigned_gt_idx: shape [A] index corresponding to the index of the
        assigned ground truth box. Anchors not assigned to a ground truth box
        will have the index set to -1.
      - assigned_gt_bbox: shape [A, 7] bbox parameters assigned to each anchor.
      - assigned_gt_similarity_score: shape [A] (iou) score between the anchor
        and the gt bbox.
      - assigned_gt_labels: shape [A] label assigned to bbox.
      - assigned_cls_mask: shape [A] mask for classification loss per anchor.
        This should be 1.0 if the anchor has a foreground or background
        assignment; otherwise, it will be assigned to 0.0.
      - assigned_reg_mask: shape [A] mask for regression loss per anchor.
        This should be 1.0 if the anchor has a foreground assignment;
        otherwise, it will be assigned to 0.0.
        Note: background anchors do not have regression targets.
    """
        if similarity_fn is None:
            similarity_fn = self.IOU2DRotatedBoxes

        # Shape validation.
        anchor_bboxes = py_utils.HasShape(anchor_bboxes, [-1, 7])
        num_anchor_bboxes, _ = py_utils.GetShape(anchor_bboxes, 2)
        gt_bboxes = py_utils.HasShape(gt_bboxes, [-1, 7])
        num_gt_bboxes, _ = py_utils.GetShape(gt_bboxes, 2)

        # Compute similarity score and reduce max by anchors and by ground-truth.
        similarity_score = similarity_fn(anchor_bboxes, gt_bboxes)
        similarity_score = py_utils.HasShape(
            similarity_score, [num_anchor_bboxes, num_gt_bboxes])

        # Reduce over ground-truth boxes, so we have the max score per anchor.
        anchor_max_score = tf.reduce_max(similarity_score, axis=1)
        anchor_max_idx = tf.argmax(similarity_score, axis=1)

        if force_match:
            # Reduce over anchors, so we have the max score per ground truth box.
            gt_max_score = tf.reduce_max(similarity_score,
                                         axis=0,
                                         keepdims=True)

            # Force matches occur when the top matching gt bbox for an anchor is the
            # top matching anchor for the gt bbox. When force matching, we match
            # these boxes as long as their similarity score exceeds 0.
            force_matches = (
                tf.equal(similarity_score, gt_max_score)
                & tf.equal(similarity_score, anchor_max_score[..., tf.newaxis])
                & tf.greater(similarity_score, 0.)
                & tf.cast(gt_bboxes_mask[tf.newaxis, ...], tf.bool))
            force_match_indicator = tf.reduce_any(force_matches, axis=1)
            force_match_idx = tf.argmax(tf.cast(force_matches, tf.int32),
                                        axis=1)

            # In assigning foreground/background anchors later, force_match_indicator
            # is used to determine which anchors are force foreground, and the index
            # assigned will be taken from anchor_max_idx.

            # Force matchers must also be the max scoring gt bbox per anchor.
            # We overwrite anchor_max_idx to ensure that the right match is done.
            anchor_max_idx = tf.where(force_match_indicator, force_match_idx,
                                      anchor_max_idx)

        # Ensure that max score boxes are not padded boxes by setting score to 0
        # for boxes that are padded.
        gathered_mask = tf.batch_gather(gt_bboxes_mask, anchor_max_idx)
        anchor_max_score = tf.where(tf.equal(gathered_mask, 1),
                                    anchor_max_score,
                                    tf.zeros_like(anchor_max_score))

        # Boolean tensors corresponding to whether an anchor is background or
        # foreground based on thresholding.
        background_anchors = tf.less_equal(anchor_max_score,
                                           background_assignment_threshold)
        foreground_anchors = tf.greater_equal(anchor_max_score,
                                              foreground_assignment_threshold)
        if force_match:
            # Background anchors are below threshold and not force matches.
            background_anchors &= ~force_match_indicator
            # Foreground anchors are above thresholds or force matches.
            foreground_anchors |= force_match_indicator

        # Add dummy background bbox to gt_boxes to facilitate batch gather.
        dummy_bbox = tf.constant([[0, 0, 0, 1, 1, 1, 0]], dtype=tf.float32)

        # Since we are concatenating the dummy bbox, the index corresponds to the
        # number of boxes.
        dummy_bbox_idx = py_utils.GetShape(gt_bboxes, 1)[0]

        gt_bboxes = tf.concat([gt_bboxes, dummy_bbox], axis=0)
        gt_bboxes_labels = tf.concat([gt_bboxes_labels, [background_class_id]],
                                     axis=0)

        # Gather indices so that all foreground boxes are gathered from gt_bboxes,
        # while all background and ignore boxes gather the dummy_bbox.
        anchor_gather_idx = tf.where(
            foreground_anchors, anchor_max_idx,
            tf.constant(dummy_bbox_idx,
                        shape=py_utils.GetShape(anchor_max_idx),
                        dtype=anchor_max_idx.dtype))

        # Gather the bboxes and weights.
        assigned_gt_bbox = tf.batch_gather(gt_bboxes, anchor_gather_idx)
        assigned_gt_labels = tf.batch_gather(gt_bboxes_labels,
                                             anchor_gather_idx)

        # Set masks for classification and regression losses.
        assigned_cls_mask = tf.cast(background_anchors | foreground_anchors,
                                    tf.float32)
        assigned_reg_mask = tf.cast(foreground_anchors, tf.float32)

        # Set assigned_gt_idx such that dummy boxes have idx = -1.
        assigned_gt_idx = tf.where(tf.equal(anchor_gather_idx, dummy_bbox_idx),
                                   tf.ones_like(anchor_gather_idx) * -1,
                                   anchor_gather_idx)
        assigned_gt_idx = tf.cast(assigned_gt_idx, tf.int32)

        return py_utils.NestedMap(
            assigned_gt_idx=assigned_gt_idx,
            assigned_gt_bbox=assigned_gt_bbox,
            assigned_gt_similarity_score=anchor_max_score,
            assigned_gt_labels=assigned_gt_labels,
            assigned_cls_mask=assigned_cls_mask,
            assigned_reg_mask=assigned_reg_mask)
예제 #11
0
    def __init__(self,
                 learning_rate,
                 momentum=0.0,
                 initial_accumulator_value=0.0,
                 start_preconditioning_steps=1000,
                 statistics_computation_frequency=1,
                 matrix_epsilon=1e-6,
                 synchronous_preconditioning=False,
                 second_moment_averaging=1.0,
                 fallback_to_diagonal_dim=4096,
                 max_any_dim=6656,
                 block_size=4096,
                 block_partition_threshold_size=1000000,
                 global_step=None,
                 exponent_multiplier=1.0,
                 name="DistributedShampoo"):
        """Construct a DistributedShampoo optimizer.

    Args:
      learning_rate: A `Tensor` or a floating point value.  The learning rate.
      momentum: A `Tensor` or a floating point value. Momentum is not applied to
        sparse updates.
      initial_accumulator_value: A floating point value.
      start_preconditioning_steps: A int32 value which indicates when to start
        preconditioning.
      statistics_computation_frequency: A int32 step value which indicates how
        often to compute statistics for preconditioning.
      matrix_epsilon: An epsilon regularizer to make the matrices positive
        definite.
      synchronous_preconditioning: Whether to run preconditioning synchronously.
      second_moment_averaging: 1.0 means sum of gradients squares, while less
        than 1.0 switches to RMSProp style exponential moving averages of the
        second moments.
      fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any
        of the dimension is larger than fallback_to_diagonal_dim.
      max_any_dim: If maximum value for any dimension is greater than this value
        we skip preconditioning and fall back to the diagonal.
      block_size: Dimension of the partitioned tensors.
      block_partition_threshold_size: Partitions diemnsions beyond this size.
      global_step: Global step for training.
      exponent_multiplier: A multiplier 'e` for the exponent for the inverse
        calculation. e * -1/(2*rank). Only applies when calculating inverses
        through svd.
      name: Optional name prefix for the operations created when applying
        gradients.
    """
        super().__init__(False, name)
        self._learning_rate = learning_rate
        self._momentum = momentum
        self._initial_accumulator_value = initial_accumulator_value
        self._start_preconditioning_steps = start_preconditioning_steps
        self._matrix_epsilon = matrix_epsilon
        self._synchronous_preconditioning = synchronous_preconditioning
        self._second_moment_averaging = second_moment_averaging
        self._fallback_to_diagonal_dim = fallback_to_diagonal_dim
        self._max_any_dim = max_any_dim
        self._block_size = block_size
        # NOTE: On XLA - int64 is not handled properly.
        if global_step is not None:
            self._global_step = tf.cast(tf.identity(global_step), tf.int32)
        else:
            self._global_step = tf.cast(
                tf.identity(tf.train.get_or_create_global_step()), tf.int32)
        self._run_nondiagonal_update = tf.greater_equal(
            self._global_step, self._start_preconditioning_steps)
        start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32)
        global_step_f = tf.cast(self._global_step, tf.float32)
        self._run_nondiagonal_update_warmup = tf.minimum(
            1.0,
            tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0))
        # Computes statistics every K steps.
        self._statistics_computation_frequency = statistics_computation_frequency
        self._run_statistics_computation = tf.equal(
            tf.math.floormod(self._global_step,
                             self._statistics_computation_frequency), 0)
        # All vars that are preconditioned.
        self._all_vars_for_preconditioning = []
        self._exponent_multiplier = exponent_multiplier
        self._partition_info = PartitionConfig(block_partition_threshold_size,
                                               block_size)
        self._partitioner_metadata = {}
예제 #12
0
def IsWithinBBox3D(points_3d, bboxes_3d):
    """Checks if points are within a 3-d bbox.

  Args:
    points_3d: [..., num_points, 3] float32 Tensor specifying points in 3-d
      space as [x, y, z] coordinates.
    bboxes_3d: [..., num_bboxes, 7] float32 Tensor specifying a 3-d bboxes
      specified as [x, y, z, dx, dy, dz, phi] where x, y and z is the center of
      the box.

  Returns:
    boolean Tensor of shape [..., num_points, num_bboxes] indicating whether the
    points belong within each box.
  """
    # Check that points_3d and bboxes_3d have the same rank.
    bboxes_rank = py_utils.GetRank(bboxes_3d)
    points_3d = py_utils.HasRank(points_3d, bboxes_rank)
    leading_shape = py_utils.GetShape(bboxes_3d)[:-2]

    # Check that both points_3d and bboxes_3d have the same leading shape.
    points_3d = py_utils.HasShape(points_3d, leading_shape + [-1, 3])
    bboxes_3d = py_utils.HasShape(bboxes_3d, leading_shape + [-1, 7])

    num_points = py_utils.GetShape(points_3d)[-2]
    num_bboxes = py_utils.GetShape(bboxes_3d)[-2]

    bbox_corners = BBoxCorners(bboxes_3d)
    bbox_corners = py_utils.HasShape(bbox_corners,
                                     leading_shape + [num_bboxes, 8, 3])
    # First four points are the top of the bounding box.
    # Counter-clockwise arrangement of points specifying 2-d Euclidean box.
    #   (x0, y1) <--- (x1, y1)
    #                    ^
    #                    |
    #                    |
    #   (x0, y0) ---> (x1, y0)
    bboxes_2d_corners = bbox_corners[..., 0:4, 0:2]
    # Determine if points lie within 2-D (x, y) plane for all bounding boxes.
    points_2d = points_3d[..., :2]
    is_inside_2d = IsWithinBBox(points_2d, bboxes_2d_corners)

    is_inside_2d = py_utils.HasShape(is_inside_2d,
                                     leading_shape + [num_points, num_bboxes])

    # Determine if points lie with the z-dimension for all bounding boxes.
    [_, _, z, _, _, dz, _] = tf.split(bboxes_3d, 7, axis=-1)

    def _ComputeLimits(center, width):
        left = center - width / 2.0
        right = center + width / 2.0
        return left, right

    z0, z1 = _ComputeLimits(z[..., 0], dz[..., 0])
    z_points = points_3d[..., 2:]

    is_inside_z = tf.math.logical_and(
        tf.less_equal(z_points, z1[..., tf.newaxis, :]),
        tf.greater_equal(z_points, z0[..., tf.newaxis, :]))
    is_inside_z = py_utils.HasShape(is_inside_z,
                                    leading_shape + [num_points, num_bboxes])

    return tf.math.logical_and(is_inside_z, is_inside_2d)
예제 #13
0
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        # Create prefix of start tokens.
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(beam_size - 1, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ - - - - - - P P P P P P P* - - - ]   ^
        # [ - - P P P P P P P P P P P* - - - ]   | batch
        # [ - - - - - - - - - - - P P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_indexes = pfx_time - pfx_max + tf.expand_dims(pfx_len - 1, 1)
        pfx_pad = tf.cast(tf.greater_equal(pfx_indexes, 0),
                          tf.int32)  # Exclude final pfx token.
        pfx_id = tf.roll(pfx, shift=1, axis=-1) * pfx_pad
        pfx_last = pfx[:, -1]

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        # Remove padding.
        assert buf_size > pfx_max
        pfx_pad_long = tf.pad(pfx_pad, [(0, 0), (0, buf_size - pfx_max)],
                              constant_values=1)
        pfx_mask *= tf.cast(tf.expand_dims(pfx_pad_long, axis=1), tf.float32)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_indexes * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest