def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx): batch_size = tf.shape(boxes)[0] new_slice = tf.slice(boxes, [0, inner_idx * NMS_TILE_SIZE, 0], [batch_size, NMS_TILE_SIZE, 4]) iou = box_ops.bbox_overlap(new_slice, box_slice) ret_slice = tf.expand_dims( tf.cast(tf.reduce_all(iou < iou_threshold, [1]), box_slice.dtype), 2) * box_slice return boxes, ret_slice, iou_threshold, inner_idx + 1
def _suppression_loop_body(boxes, iou_threshold, output_size, idx): """Process boxes in the range [idx*NMS_TILE_SIZE, (idx+1)*NMS_TILE_SIZE). Args: boxes: a tensor with a shape of [batch_size, anchors, 4]. iou_threshold: a float representing the threshold for deciding whether boxes overlap too much with respect to IOU. output_size: an int32 tensor of size [batch_size]. Representing the number of selected boxes for each batch. idx: an integer scalar representing induction variable. Returns: boxes: updated boxes. iou_threshold: pass down iou_threshold to the next iteration. output_size: the updated output_size. idx: the updated induction variable. """ num_tiles = tf.shape(boxes)[1] // NMS_TILE_SIZE batch_size = tf.shape(boxes)[0] # Iterates over tiles that can possibly suppress the current tile. box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0], [batch_size, NMS_TILE_SIZE, 4]) _, box_slice, _, _ = tf.while_loop( lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, _cross_suppression, [boxes, box_slice, iou_threshold, tf.constant(0)]) # Iterates over the current tile to compute self-suppression. iou = box_ops.bbox_overlap(box_slice, box_slice) mask = tf.expand_dims( tf.reshape(tf.range(NMS_TILE_SIZE), [1, -1]) > tf.reshape( tf.range(NMS_TILE_SIZE), [-1, 1]), 0) iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype) suppressed_iou, _, _ = tf.while_loop( lambda _iou, loop_condition, _iou_sum: loop_condition, _self_suppression, [iou, tf.constant(True), tf.reduce_sum(iou, [1, 2])]) suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0 box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype), 2) # Uses box_slice to update the input boxes. mask = tf.reshape(tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype), [1, -1, 1, 1]) boxes = tf.tile(tf.expand_dims( box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape( boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask) boxes = tf.reshape(boxes, [batch_size, -1, 4]) # Updates output_size. output_size += tf.reduce_sum( tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1]) return boxes, iou_threshold, output_size, idx + 1
def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale, aspect_ratio_range, min_overlap_params, max_retry): """Crops a random slice from the input image. The function will correspondingly recompute the bounding boxes and filter out outside boxes and their labels. References: [1] End-to-End Object Detection with Transformers https://arxiv.org/abs/2005.12872 The preprocessing steps: 1. Sample a minimum IoU overlap. 2. For each trial, sample the new image width, height, and top-left corner. 3. Compute the IoUs of bounding boxes with the cropped image and retry if the maximum IoU is below the sampled threshold. 4. Find boxes whose centers are in the cropped image. 5. Compute new bounding boxes in the cropped region and only select those boxes' labels. Args: img: a 'Tensor' of shape [height, width, 3] representing the input image. boxes: a 'Tensor' of shape [N, 4] representing the ground-truth bounding boxes with (ymin, xmin, ymax, xmax). labels: a 'Tensor' of shape [N,] representing the class labels of the boxes. min_scale: a 'float' in [0.0, 1.0) indicating the lower bound of the random scale variable. aspect_ratio_range: a list of two 'float' that specifies the lower and upper bound of the random aspect ratio. min_overlap_params: a list of four 'float' representing the min value, max value, step size, and offset for the minimum overlap sample. max_retry: an 'int' representing the number of trials for cropping. If it is exhausted, no cropping will be performed. Returns: img: a Tensor representing the random cropped image. Can be the original image if max_retry is exhausted. boxes: a Tensor representing the bounding boxes in the cropped image. labels: a Tensor representing the new bounding boxes' labels. """ shape = tf.shape(img) original_h = shape[0] original_w = shape[1] minval, maxval, step, offset = min_overlap_params min_overlap = tf.math.floordiv( tf.random.uniform([], minval=minval, maxval=maxval), step) * step - offset min_overlap = tf.clip_by_value(min_overlap, 0.0, 1.1) if min_overlap > 1.0: return img, boxes, labels aspect_ratio_low = aspect_ratio_range[0] aspect_ratio_high = aspect_ratio_range[1] for _ in tf.range(max_retry): scale_h = tf.random.uniform([], min_scale, 1.0) scale_w = tf.random.uniform([], min_scale, 1.0) new_h = tf.cast( scale_h * tf.cast(original_h, dtype=tf.float32), dtype=tf.int32) new_w = tf.cast( scale_w * tf.cast(original_w, dtype=tf.float32), dtype=tf.int32) # Aspect ratio has to be in the prespecified range aspect_ratio = new_h / new_w if aspect_ratio_low > aspect_ratio or aspect_ratio > aspect_ratio_high: continue left = tf.random.uniform([], 0, original_w - new_w, dtype=tf.int32) right = left + new_w top = tf.random.uniform([], 0, original_h - new_h, dtype=tf.int32) bottom = top + new_h normalized_left = tf.cast( left, dtype=tf.float32) / tf.cast( original_w, dtype=tf.float32) normalized_right = tf.cast( right, dtype=tf.float32) / tf.cast( original_w, dtype=tf.float32) normalized_top = tf.cast( top, dtype=tf.float32) / tf.cast( original_h, dtype=tf.float32) normalized_bottom = tf.cast( bottom, dtype=tf.float32) / tf.cast( original_h, dtype=tf.float32) cropped_box = tf.expand_dims( tf.stack([ normalized_top, normalized_left, normalized_bottom, normalized_right, ]), axis=0) iou = box_ops.bbox_overlap( tf.expand_dims(cropped_box, axis=0), tf.expand_dims(boxes, axis=0)) # (1, 1, n_ground_truth) iou = tf.squeeze(iou, axis=[0, 1]) # If not a single bounding box has a Jaccard overlap of greater than # the minimum, try again if tf.reduce_max(iou) < min_overlap: continue centroids = box_ops.yxyx_to_cycxhw(boxes) mask = tf.math.logical_and( tf.math.logical_and(centroids[:, 0] > normalized_top, centroids[:, 0] < normalized_bottom), tf.math.logical_and(centroids[:, 1] > normalized_left, centroids[:, 1] < normalized_right)) # If not a single bounding box has its center in the crop, try again. if tf.reduce_sum(tf.cast(mask, dtype=tf.int32)) > 0: indices = tf.squeeze(tf.where(mask), axis=1) filtered_boxes = tf.gather(boxes, indices) boxes = tf.clip_by_value( (filtered_boxes[..., :] * tf.cast( tf.stack([original_h, original_w, original_h, original_w]), dtype=tf.float32) - tf.cast(tf.stack([top, left, top, left]), dtype=tf.float32)) / tf.cast(tf.stack([new_h, new_w, new_h, new_w]), dtype=tf.float32), 0.0, 1.0) img = tf.image.crop_to_bounding_box(img, top, left, bottom - top, right - left) labels = tf.gather(labels, indices) break return img, boxes, labels