def iou(box_mask_list1, box_mask_list2):
    """Computes pairwise intersection-over-union between box and mask collections.

  Args:
    box_mask_list1: BoxMaskList holding N boxes and masks
    box_mask_list2: BoxMaskList holding M boxes and masks

  Returns:
    a numpy array with shape [N, M] representing pairwise iou scores.
  """
    return np_mask_ops.iou(box_mask_list1.get_masks(),
                           box_mask_list2.get_masks())
def non_max_suppression(box_mask_list,
                        max_output_size=10000,
                        iou_threshold=1.0,
                        score_threshold=-10.0):
    """Non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes. In each iteration, the detected bounding box with
  highest score in the available pool is selected.

  Args:
    box_mask_list: np_box_mask_list.BoxMaskList holding N boxes.  Must contain
      a 'scores' field representing detection scores. All scores belong to the
      same class.
    max_output_size: maximum number of retained boxes
    iou_threshold: intersection over union threshold.
    score_threshold: minimum score threshold. Remove the boxes with scores
                     less than this value. Default value is set to -10. A very
                     low threshold to pass pretty much all the boxes, unless
                     the user sets a different score threshold.

  Returns:
    an np_box_mask_list.BoxMaskList holding M boxes where M <= max_output_size

  Raises:
    ValueError: if 'scores' field does not exist
    ValueError: if threshold is not in [0, 1]
    ValueError: if max_output_size < 0
  """
    if not box_mask_list.has_field('scores'):
        raise ValueError('Field scores does not exist')
    if iou_threshold < 0. or iou_threshold > 1.0:
        raise ValueError('IOU threshold must be in [0, 1]')
    if max_output_size < 0:
        raise ValueError('max_output_size must be bigger than 0.')

    box_mask_list = filter_scores_greater_than(box_mask_list, score_threshold)
    if box_mask_list.num_boxes() == 0:
        return box_mask_list

    box_mask_list = sort_by_field(box_mask_list, 'scores')

    # Prevent further computation if NMS is disabled.
    if iou_threshold == 1.0:
        if box_mask_list.num_boxes() > max_output_size:
            selected_indices = np.arange(max_output_size)
            return gather(box_mask_list, selected_indices)
        else:
            return box_mask_list

    masks = box_mask_list.get_masks()
    num_masks = box_mask_list.num_boxes()

    # is_index_valid is True only for all remaining valid boxes,
    is_index_valid = np.full(num_masks, 1, dtype=bool)
    selected_indices = []
    num_output = 0
    for i in range(num_masks):
        if num_output < max_output_size:
            if is_index_valid[i]:
                num_output += 1
                selected_indices.append(i)
                is_index_valid[i] = False
                valid_indices = np.where(is_index_valid)[0]
                if valid_indices.size == 0:
                    break

                intersect_over_union = np_mask_ops.iou(
                    np.expand_dims(masks[i], axis=0), masks[valid_indices])
                intersect_over_union = np.squeeze(intersect_over_union, axis=0)
                is_index_valid[valid_indices] = np.logical_and(
                    is_index_valid[valid_indices],
                    intersect_over_union <= iou_threshold)
    return gather(box_mask_list, np.array(selected_indices))
Ejemplo n.º 3
0
 def testIOU(self):
     iou = np_mask_ops.iou(self.masks1, self.masks2)
     expected_iou = np.array(
         [[1.0, 0.0, 8.0 / 25.0], [0.0, 9.0 / 16.0, 7.0 / 28.0]],
         dtype=np.float32)
     self.assertAllClose(iou, expected_iou)