Пример #1
0
def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx):
    batch_size = tf.shape(boxes)[0]
    new_slice = tf.slice(boxes, [0, inner_idx * NMS_TILE_SIZE, 0],
                         [batch_size, NMS_TILE_SIZE, 4])
    iou = box_ops.bbox_overlap(new_slice, box_slice)
    ret_slice = tf.expand_dims(
        tf.cast(tf.reduce_all(iou < iou_threshold, [1]), box_slice.dtype),
        2) * box_slice
    return boxes, ret_slice, iou_threshold, inner_idx + 1
Пример #2
0
def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
    """Process boxes in the range [idx*NMS_TILE_SIZE, (idx+1)*NMS_TILE_SIZE).

  Args:
    boxes: a tensor with a shape of [batch_size, anchors, 4].
    iou_threshold: a float representing the threshold for deciding whether boxes
      overlap too much with respect to IOU.
    output_size: an int32 tensor of size [batch_size]. Representing the number
      of selected boxes for each batch.
    idx: an integer scalar representing induction variable.

  Returns:
    boxes: updated boxes.
    iou_threshold: pass down iou_threshold to the next iteration.
    output_size: the updated output_size.
    idx: the updated induction variable.
  """
    num_tiles = tf.shape(boxes)[1] // NMS_TILE_SIZE
    batch_size = tf.shape(boxes)[0]

    # Iterates over tiles that can possibly suppress the current tile.
    box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0],
                         [batch_size, NMS_TILE_SIZE, 4])
    _, box_slice, _, _ = tf.while_loop(
        lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
        _cross_suppression, [boxes, box_slice, iou_threshold,
                             tf.constant(0)])

    # Iterates over the current tile to compute self-suppression.
    iou = box_ops.bbox_overlap(box_slice, box_slice)
    mask = tf.expand_dims(
        tf.reshape(tf.range(NMS_TILE_SIZE), [1, -1]) > tf.reshape(
            tf.range(NMS_TILE_SIZE), [-1, 1]), 0)
    iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype)
    suppressed_iou, _, _ = tf.while_loop(
        lambda _iou, loop_condition, _iou_sum: loop_condition,
        _self_suppression,
        [iou, tf.constant(True),
         tf.reduce_sum(iou, [1, 2])])
    suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0
    box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype),
                                2)

    # Uses box_slice to update the input boxes.
    mask = tf.reshape(tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype),
                      [1, -1, 1, 1])
    boxes = tf.tile(tf.expand_dims(
        box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape(
            boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask)
    boxes = tf.reshape(boxes, [batch_size, -1, 4])

    # Updates output_size.
    output_size += tf.reduce_sum(
        tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1])
    return boxes, iou_threshold, output_size, idx + 1
Пример #3
0
def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
                                            aspect_ratio_range,
                                            min_overlap_params, max_retry):
  """Crops a random slice from the input image.

  The function will correspondingly recompute the bounding boxes and filter out
  outside boxes and their labels.

  References:
  [1] End-to-End Object Detection with Transformers
  https://arxiv.org/abs/2005.12872

  The preprocessing steps:
  1. Sample a minimum IoU overlap.
  2. For each trial, sample the new image width, height, and top-left corner.
  3. Compute the IoUs of bounding boxes with the cropped image and retry if
    the maximum IoU is below the sampled threshold.
  4. Find boxes whose centers are in the cropped image.
  5. Compute new bounding boxes in the cropped region and only select those
    boxes' labels.

  Args:
    img: a 'Tensor' of shape [height, width, 3] representing the input image.
    boxes: a 'Tensor' of shape [N, 4] representing the ground-truth bounding
      boxes with (ymin, xmin, ymax, xmax).
    labels: a 'Tensor' of shape [N,] representing the class labels of the boxes.
    min_scale: a 'float' in [0.0, 1.0) indicating the lower bound of the random
      scale variable.
    aspect_ratio_range: a list of two 'float' that specifies the lower and upper
      bound of the random aspect ratio.
    min_overlap_params: a list of four 'float' representing the min value, max
      value, step size, and offset for the minimum overlap sample.
    max_retry: an 'int' representing the number of trials for cropping. If it is
      exhausted, no cropping will be performed.

  Returns:
    img: a Tensor representing the random cropped image. Can be the
      original image if max_retry is exhausted.
    boxes: a Tensor representing the bounding boxes in the cropped image.
    labels: a Tensor representing the new bounding boxes' labels.
  """

  shape = tf.shape(img)
  original_h = shape[0]
  original_w = shape[1]

  minval, maxval, step, offset = min_overlap_params

  min_overlap = tf.math.floordiv(
      tf.random.uniform([], minval=minval, maxval=maxval), step) * step - offset

  min_overlap = tf.clip_by_value(min_overlap, 0.0, 1.1)

  if min_overlap > 1.0:
    return img, boxes, labels

  aspect_ratio_low = aspect_ratio_range[0]
  aspect_ratio_high = aspect_ratio_range[1]

  for _ in tf.range(max_retry):
    scale_h = tf.random.uniform([], min_scale, 1.0)
    scale_w = tf.random.uniform([], min_scale, 1.0)
    new_h = tf.cast(
        scale_h * tf.cast(original_h, dtype=tf.float32), dtype=tf.int32)
    new_w = tf.cast(
        scale_w * tf.cast(original_w, dtype=tf.float32), dtype=tf.int32)

    # Aspect ratio has to be in the prespecified range
    aspect_ratio = new_h / new_w
    if aspect_ratio_low > aspect_ratio or aspect_ratio > aspect_ratio_high:
      continue

    left = tf.random.uniform([], 0, original_w - new_w, dtype=tf.int32)
    right = left + new_w
    top = tf.random.uniform([], 0, original_h - new_h, dtype=tf.int32)
    bottom = top + new_h

    normalized_left = tf.cast(
        left, dtype=tf.float32) / tf.cast(
            original_w, dtype=tf.float32)
    normalized_right = tf.cast(
        right, dtype=tf.float32) / tf.cast(
            original_w, dtype=tf.float32)
    normalized_top = tf.cast(
        top, dtype=tf.float32) / tf.cast(
            original_h, dtype=tf.float32)
    normalized_bottom = tf.cast(
        bottom, dtype=tf.float32) / tf.cast(
            original_h, dtype=tf.float32)

    cropped_box = tf.expand_dims(
        tf.stack([
            normalized_top,
            normalized_left,
            normalized_bottom,
            normalized_right,
        ]),
        axis=0)
    iou = box_ops.bbox_overlap(
        tf.expand_dims(cropped_box, axis=0),
        tf.expand_dims(boxes, axis=0))  # (1, 1, n_ground_truth)
    iou = tf.squeeze(iou, axis=[0, 1])

    # If not a single bounding box has a Jaccard overlap of greater than
    # the minimum, try again
    if tf.reduce_max(iou) < min_overlap:
      continue

    centroids = box_ops.yxyx_to_cycxhw(boxes)
    mask = tf.math.logical_and(
        tf.math.logical_and(centroids[:, 0] > normalized_top,
                            centroids[:, 0] < normalized_bottom),
        tf.math.logical_and(centroids[:, 1] > normalized_left,
                            centroids[:, 1] < normalized_right))
    # If not a single bounding box has its center in the crop, try again.
    if tf.reduce_sum(tf.cast(mask, dtype=tf.int32)) > 0:
      indices = tf.squeeze(tf.where(mask), axis=1)

      filtered_boxes = tf.gather(boxes, indices)

      boxes = tf.clip_by_value(
          (filtered_boxes[..., :] * tf.cast(
              tf.stack([original_h, original_w, original_h, original_w]),
              dtype=tf.float32) -
           tf.cast(tf.stack([top, left, top, left]), dtype=tf.float32)) /
          tf.cast(tf.stack([new_h, new_w, new_h, new_w]), dtype=tf.float32),
          0.0, 1.0)

      img = tf.image.crop_to_bounding_box(img, top, left, bottom - top,
                                          right - left)

      labels = tf.gather(labels, indices)
      break

  return img, boxes, labels