def detector_postprocess(results, output_height, output_width, mask_threshold=0.5): """ Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs. This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution. Args: results (Instances): the raw outputs from the detector. `results.image_size` contains the input image resolution the detector sees. This object might be modified in-place. output_height, output_width: the desired output resolution. Returns: Instances: the resized output from the model, based on the output resolution """ scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) results = Instances((output_height, output_width), **results.get_fields()) if results.has("pred_boxes"): output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) results = results[output_boxes.nonempty()] if results.has("pred_masks"): results.pred_masks = paste_masks_in_image( results.pred_masks[:, 0, :, :], # N, 1, M, M results.pred_boxes, results.image_size, threshold=mask_threshold, ) if results.has("pred_keypoints"): results.pred_keypoints[:, :, 0] *= scale_x results.pred_keypoints[:, :, 1] *= scale_y return results
def _paste_mask_lists_in_image(masks, boxes, image_shape, threshold=0.5): """ Paste a list of masks that are of various resolutions (e.g., 28 x 28) into an image. The location, height, and width for pasting each mask is determined by their corresponding bounding boxes in boxes. Args: masks (list(Tensor)): A list of Tensor of shape (1, Hmask_i, Wmask_i). Values are in [0, 1]. The list length, Bimg, is the number of detected object instances in the image. boxes (Boxes): A Boxes of length Bimg. boxes.tensor[i] and masks[i] correspond to the same object instance. image_shape (tuple): height, width threshold (float): A threshold in [0, 1] for converting the (soft) masks to binary masks. Returns: img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the number of detected object instances and Himage, Wimage are the image width and height. img_masks[i] is a binary mask for object instance i. """ if len(masks) == 0: return torch.empty((0, 1) + image_shape, dtype=torch.uint8) # Loop over masks groups. Each group has the same mask prediction size. img_masks = [] ind_masks = [] mask_sizes = torch.tensor([m.shape[-1] for m in masks]) unique_sizes = torch.unique(mask_sizes) for msize in unique_sizes.tolist(): cur_ind = torch.where(mask_sizes == msize)[0] ind_masks.append(cur_ind) cur_masks = cat([masks[i] for i in cur_ind]) cur_boxes = boxes[cur_ind] img_masks.append( paste_masks_in_image(cur_masks, cur_boxes, image_shape, threshold) ) img_masks = cat(img_masks) ind_masks = cat(ind_masks) img_masks_out = torch.empty_like(img_masks) img_masks_out[ind_masks, :, :] = img_masks return img_masks_out