def detector_postprocess(results, output_height, output_width, mask_threshold=0.5): """ Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs. This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution. Args: results (Instances): the raw outputs from the detector. `results.image_size` contains the input image resolution the detector sees. This object might be modified in-place. output_height, output_width: the desired output resolution. Returns: Instances: the resized output from the model, based on the output resolution """ scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) results = Instances((output_height, output_width), **results.get_fields()) if results.has("pred_boxes"): output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) results = results[output_boxes.nonempty()] if results.has("pred_masks"): results.pred_masks = paste_masks_in_image( results.pred_masks[:, 0, :, :], # N, 1, M, M results.pred_boxes, results.image_size, threshold=mask_threshold, ) if results.has("pred_masks_soft"): results.pred_masks_soft = paste_masks_in_image( results.pred_masks_sot[:, 0, :, :], # N, 1, M, M results.pred_boxes, results.image_size, threshold=mask_threshold, ) if results.has("pred_keypoints"): results.pred_keypoints[:, :, 0] *= scale_x results.pred_keypoints[:, :, 1] *= scale_y return results
def generate_parts(id_map, parts): new_parts = [] keep_areas = [] for i in range(len(parts)): part = parts[i] new_part = Part(part.image_size) if len(part) > 0: pred_masks = paste_masks_in_image( part.pred_masks[:, 0, :, :], # N, 1, M, M part.pred_boxes, part.image_size) # upsample_pred_masks = F.upsample(pred_masks.float().unsqueeze(0), size=(128, 128), mode='bilinear').squeeze(0) # areas=upsample_pred_masks.sum(dim=1).sum(dim=1) areas = pred_masks.sum(dim=1).sum(dim=1) area_keep = torch.where(areas > 0) keep_areas.append(area_keep) pred_keep_masks = pred_masks[area_keep] new_part.pred_classes = part.pred_classes[area_keep] new_part.pred_boxes = Boxes(extract_bboxes(pred_keep_masks)).to( new_part.pred_classes.device) # new_thing_instance.pred_masks=upsample_pred_masks[area_keep] # new_thing_instance.pred_masks = pred_keep_masks pred_classes = [] for pred_class in new_part.pred_classes: pred_classes.append(id_map[pred_class.item()]) pred_classes = torch.IntTensor(pred_classes).to( new_part.pred_classes.device) new_part.pred_classes = pred_classes else: keep_areas.append([]) new_parts.append(new_part) del parts return new_parts, keep_areas
def _paste_mask_lists_in_image(masks, boxes, image_shape, threshold=0.5): """ Paste a list of masks that are of various resolutions (e.g., 28 x 28) into an image. The location, height, and width for pasting each mask is determined by their corresponding bounding boxes in boxes. Args: masks (list(Tensor)): A list of Tensor of shape (1, Hmask_i, Wmask_i). Values are in [0, 1]. The list length, Bimg, is the number of detected object instances in the image. boxes (Boxes): A Boxes of length Bimg. boxes.tensor[i] and masks[i] correspond to the same object instance. image_shape (tuple): height, width threshold (float): A threshold in [0, 1] for converting the (soft) masks to binary masks. Returns: img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the number of detected object instances and Himage, Wimage are the image width and height. img_masks[i] is a binary mask for object instance i. """ if len(masks) == 0: return torch.empty((0, 1) + image_shape, dtype=torch.uint8) # Loop over masks groups. Each group has the same mask prediction size. img_masks = [] ind_masks = [] mask_sizes = torch.tensor([m.shape[-1] for m in masks]) unique_sizes = torch.unique(mask_sizes) for msize in unique_sizes.tolist(): cur_ind = torch.where(mask_sizes == msize)[0] ind_masks.append(cur_ind) cur_masks = cat([masks[i] for i in cur_ind]) cur_boxes = boxes[cur_ind] img_masks.append( paste_masks_in_image(cur_masks, cur_boxes, image_shape, threshold)) img_masks = cat(img_masks) ind_masks = cat(ind_masks) img_masks_out = torch.empty_like(img_masks) img_masks_out[ind_masks, :, :] = img_masks return img_masks_out