예제 #1
0
    def test_paste_mask_in_image(self):
        # disable profiling
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)

        masks = torch.rand(10, 1, 26, 26)
        boxes = torch.rand(10, 4)
        boxes[:, 2:] += torch.rand(10, 2)
        boxes *= 50
        o_im_s = (100, 100)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
        out = paste_masks_in_image(masks, boxes, o_im_s)
        jit_trace = torch.jit.trace(
            paste_masks_in_image,
            (masks, boxes, [torch.tensor(o_im_s[0]),
                            torch.tensor(o_im_s[1])]))
        out_trace = jit_trace(
            masks, boxes, [torch.tensor(o_im_s[0]),
                           torch.tensor(o_im_s[1])])

        assert torch.all(out.eq(out_trace))

        masks2 = torch.rand(20, 1, 26, 26)
        boxes2 = torch.rand(20, 4)
        boxes2[:, 2:] += torch.rand(20, 2)
        boxes2 *= 100
        o_im_s2 = (200, 200)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
        out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
        out_trace2 = jit_trace(
            masks2, boxes2,
            [torch.tensor(o_im_s2[0]),
             torch.tensor(o_im_s2[1])])

        assert torch.all(out2.eq(out_trace2))
예제 #2
0
 def postprocess(self, result, image_shapes, original_image_sizes):
     # type: (List[Dict[str, Tensor]], List[Tuple[int, int]], List[Tuple[int, int]])
     # import pdb; pdb.set_trace()
     for i, (pred, im_s, o_im_s) in enumerate(
             zip(result, image_shapes, original_image_sizes)):
         boxes = pred["boxes"]
         boxes = resize_boxes(boxes, im_s, o_im_s)
         result[i]["boxes"] = boxes
         if "masks" in pred:
             masks = pred["masks"]
             masks = paste_masks_in_image(masks, boxes, o_im_s)
             result[i]["masks"] = masks
     return result
예제 #3
0
    def predict_boxes(self, images, boxes):
        self.eval()
        device = list(self.parameters())[0].device
        images = images.to(device)
        boxes = boxes.to(device)

        targets = None
        original_image_sizes = [img.shape[-2:] for img in images]

        images, targets = self.transform(images, targets)

        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([(0, features)])

        # proposals, proposal_losses = self.rpn(images, features, targets)
        from torchvision.models.detection.transform import resize_boxes

        boxes = resize_boxes(boxes, original_image_sizes[0],
                             images.image_sizes[0])
        proposals = [boxes]

        box_feats = self.roi_heads.box_roi_pool(features, proposals,
                                                images.image_sizes)
        box_features = self.roi_heads.box_head(box_feats)
        class_logits, box_regression = self.roi_heads.box_predictor(
            box_features)

        pred_boxes = self.roi_heads.box_coder.decode(box_regression, proposals)
        pred_scores = F.softmax(class_logits, -1)

        pred_boxes = pred_boxes[:, 1:].squeeze(dim=1).detach()
        pred_boxes = resize_boxes(pred_boxes, images.image_sizes[0],
                                  original_image_sizes[0])
        pred_scores = pred_scores[:, 1:].squeeze(dim=1).detach()

        mask_features = self.roi_heads.mask_roi_pool(features, proposals,
                                                     images.image_sizes)
        cropped_features = self.roi_heads.mask_head(mask_features)
        mask_logits = self.roi_heads.mask_predictor(cropped_features)

        switch_channel_masks = torch.zeros(mask_logits.size())
        switch_channel_masks[:, 0, :, :] = mask_logits[:, 1, :, :]

        # workaround that only works with 2 classes. otherwise try to get maskrcnn_inference running
        # or manually filter out the class with highest score here
        switch_channel_masks = torch.sigmoid(switch_channel_masks)
        pred_masks = paste_masks_in_image(switch_channel_masks, pred_boxes,
                                          original_image_sizes[0]).detach()

        return pred_boxes, pred_scores, pred_masks
예제 #4
0
 def postprocess(self, result, image_shapes, original_image_sizes):
     if self.training:
         return result
     for i, (pred, im_s, o_im_s) in enumerate(
             zip(result, image_shapes, original_image_sizes)):
         boxes = pred["boxes"]
         boxes = resize_boxes(boxes, im_s, o_im_s)
         result[i]["boxes"] = boxes
         if "masks" in pred:
             masks = pred["masks"]
             masks = paste_masks_in_image(masks, boxes, o_im_s)
             result[i]["masks"] = masks
         if "keypoints" in pred:
             keypoints = pred["keypoints"]
             keypoints = resize_keypoints(keypoints, im_s, o_im_s)
             result[i]["keypoints"] = keypoints
     return result
예제 #5
0
 def postprocess(self,
                 result,               # type: List[Dict[str, Tensor]]
                 image_shapes,         # type: List[Tuple[int, int]]
                 original_image_sizes  # type: List[Tuple[int, int]]
                 ):
     # type: (...) -> List[Dict[str, Tensor]]
     if self.training:
         return result
     for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
         boxes = pred["boxes"]
         boxes = resize_boxes(boxes, im_s, o_im_s)
         result[i]["boxes"] = boxes
         if "masks" in pred:
             masks = pred["masks"]
             masks = paste_masks_in_image(masks, boxes, o_im_s)
             result[i]["masks"] = masks
         if "keypoints" in pred:
             keypoints = pred["keypoints"]
             keypoints = resize_keypoints(keypoints, im_s, o_im_s)
             result[i]["keypoints"] = keypoints
     return result
예제 #6
0
    def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
                      ori_shape, scale_factor, rescale):
        """Get segmentation masks from mask_pred and bboxes.

        Args:
            mask_pred (Tensor or ndarray): shape (n, #class, h, w).
                For single-scale testing, mask_pred is the direct output of
                model, whose type is Tensor, while for multi-scale testing,
                it will be converted to numpy array outside of this method.
            det_bboxes (Tensor): shape (n, 4/5)
            det_labels (Tensor): shape (n, )
            img_shape (Tensor): shape (3, )
            rcnn_test_cfg (dict): rcnn testing config
            ori_shape: original image size

        Returns:
            list[list]: encoded masks
        """
        if isinstance(mask_pred, torch.Tensor):
            mask_pred = mask_pred.sigmoid()
        else:
            mask_pred = det_bboxes.new_tensor(mask_pred)

        device = mask_pred.device
        cls_segms = [[] for _ in range(self.num_classes)
                     ]  # BG is not included in num_classes
        bboxes = det_bboxes[:, :4]
        labels = det_labels

        if rescale:
            img_h, img_w = ori_shape[:2]
        else:
            if isinstance(scale_factor, float):
                img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
                img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
            else:
                w_scale, h_scale = scale_factor[0], scale_factor[1]
                img_h = np.round(ori_shape[0] * h_scale.item()).astype(
                    np.int32)
                img_w = np.round(ori_shape[1] * w_scale.item()).astype(
                    np.int32)
            scale_factor = 1.0

        if not isinstance(scale_factor, (float, torch.Tensor)):
            scale_factor = bboxes.new_tensor(scale_factor)
        bboxes = bboxes / scale_factor

        if torch.onnx.is_in_onnx_export():
            # TODO: Remove after F.grid_sample is supported.
            from torchvision.models.detection.roi_heads \
                import paste_masks_in_image
            masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
            thr = rcnn_test_cfg.get('mask_thr_binary', 0)
            if thr > 0:
                masks = masks >= thr
            return masks

        N = len(mask_pred)
        # The actual implementation split the input into chunks,
        # and paste them chunk by chunk.
        if device.type == 'cpu':
            # CPU is most efficient when they are pasted one by one with
            # skip_empty=True, so that it performs minimal number of
            # operations.
            num_chunks = N
        else:
            # GPU benefits from parallelism for larger chunks,
            # but may have memory issue
            num_chunks = int(
                np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
            assert (num_chunks <=
                    N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
        chunks = torch.chunk(torch.arange(N, device=device), num_chunks)

        threshold = rcnn_test_cfg.mask_thr_binary
        im_mask = torch.zeros(
            N,
            img_h,
            img_w,
            device=device,
            dtype=torch.bool if threshold >= 0 else torch.uint8)

        if not self.class_agnostic:
            mask_pred = mask_pred[range(N), labels][:, None]

        for inds in chunks:
            masks_chunk, spatial_inds = _do_paste_mask(
                mask_pred[inds],
                bboxes[inds],
                img_h,
                img_w,
                skip_empty=device.type == 'cpu')

            if threshold >= 0:
                masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
            else:
                # for visualization and debugging
                masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

            im_mask[(inds, ) + spatial_inds] = masks_chunk

        for i in range(N):
            cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
        return cls_segms
예제 #7
0
    def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
                      ori_shape, scale_factor, rescale):
        """Get segmentation masks from mask_pred and bboxes.

        Args:
            mask_pred (Tensor or ndarray): shape (n, #class, h, w).
                For single-scale testing, mask_pred is the direct output of
                model, whose type is Tensor, while for multi-scale testing,
                it will be converted to numpy array outside of this method.
            det_bboxes (Tensor): shape (n, 4/5)
            det_labels (Tensor): shape (n, )
            rcnn_test_cfg (dict): rcnn testing config
            ori_shape (Tuple): original image height and width, shape (2,)
            scale_factor(float | Tensor): If ``rescale is True``, box
                coordinates are divided by this scale factor to fit
                ``ori_shape``.
            rescale (bool): If True, the resulting masks will be rescaled to
                ``ori_shape``.

        Returns:
            list[list]: encoded masks. The c-th item in the outer list
                corresponds to the c-th class. Given the c-th outer list, the
                i-th item in that inner list is the mask for the i-th box with
                class label c.

        Example:
            >>> import mmcv
            >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import *  # NOQA
            >>> N = 7  # N = number of extracted ROIs
            >>> C, H, W = 11, 32, 32
            >>> # Create example instance of FCN Mask Head.
            >>> self = FCNMaskHead(num_classes=C, num_convs=0)
            >>> inputs = torch.rand(N, self.in_channels, H, W)
            >>> mask_pred = self.forward(inputs)
            >>> # Each input is associated with some bounding box
            >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
            >>> det_labels = torch.randint(0, C, size=(N,))
            >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })
            >>> ori_shape = (H * 4, W * 4)
            >>> scale_factor = torch.FloatTensor((1, 1))
            >>> rescale = False
            >>> # Encoded masks are a list for each category.
            >>> encoded_masks = self.get_seg_masks(
            >>>     mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,
            >>>     scale_factor, rescale
            >>> )
            >>> assert len(encoded_masks) == C
            >>> assert sum(list(map(len, encoded_masks))) == N
        """
        if isinstance(mask_pred, torch.Tensor):
            mask_pred = mask_pred.sigmoid()
        else:
            mask_pred = det_bboxes.new_tensor(mask_pred)

        device = mask_pred.device
        cls_segms = [[] for _ in range(self.num_classes)
                     ]  # BG is not included in num_classes
        bboxes = det_bboxes[:, :4]
        labels = det_labels

        if rescale:
            img_h, img_w = ori_shape[:2]
        else:
            if isinstance(scale_factor, float):
                img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
                img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
            else:
                w_scale, h_scale = scale_factor[0], scale_factor[1]
                img_h = np.round(ori_shape[0] * h_scale.item()).astype(
                    np.int32)
                img_w = np.round(ori_shape[1] * w_scale.item()).astype(
                    np.int32)
            scale_factor = 1.0

        if not isinstance(scale_factor, (float, torch.Tensor)):
            scale_factor = bboxes.new_tensor(scale_factor)
        bboxes = bboxes / scale_factor

        if torch.onnx.is_in_onnx_export():
            # TODO: Remove after F.grid_sample is supported.
            from torchvision.models.detection.roi_heads \
                import paste_masks_in_image
            masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
            thr = rcnn_test_cfg.get('mask_thr_binary', 0)
            if thr > 0:
                masks = masks >= thr
            return masks

        N = len(mask_pred)
        # The actual implementation split the input into chunks,
        # and paste them chunk by chunk.
        if device.type == 'cpu':
            # CPU is most efficient when they are pasted one by one with
            # skip_empty=True, so that it performs minimal number of
            # operations.
            num_chunks = N
        else:
            # GPU benefits from parallelism for larger chunks,
            # but may have memory issue
            num_chunks = int(
                np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
            assert (num_chunks <=
                    N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
        chunks = torch.chunk(torch.arange(N, device=device), num_chunks)

        threshold = rcnn_test_cfg.mask_thr_binary
        im_mask = torch.zeros(
            N,
            img_h,
            img_w,
            device=device,
            dtype=torch.bool if threshold >= 0 else torch.uint8)

        if not self.class_agnostic:
            mask_pred = mask_pred[range(N), labels][:, None]

        for inds in chunks:
            masks_chunk, spatial_inds = _do_paste_mask(
                mask_pred[inds],
                bboxes[inds],
                img_h,
                img_w,
                skip_empty=device.type == 'cpu')

            if threshold >= 0:
                masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
            else:
                # for visualization and debugging
                masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

            im_mask[(inds, ) + spatial_inds] = masks_chunk

        for i in range(N):
            cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
        return cls_segms
예제 #8
0
def post_process_proposals(proposals,
                           img_depths,
                           img_shape=(960, 1280),
                           K=None,
                           score_thresh=SCORE_THRESH,
                           RGB_only=False,
                           camera_mat=config.Unreal_camera_mat):
    res = []
    for p in proposals:
        scores = p['scores'].cpu()
        masks = p['masks'].cpu()
        boxes = p['boxes'].cpu()
        labels = p['labels'].cpu()
        if not RGB_only:
            keypoints_offset = None
            if 'keypoints_offset' in p.keys():
                keypoints_offset = p['keypoints_offset']
            axis_keypoints_offset = None
            if 'axis_keypoint_offsets' in p.keys():
                axis_keypoints_offset = p['axis_keypoint_offsets']
            axis_offset = None
            if 'axis_offsets' in p.keys():
                axis_offset = p['axis_offsets']

            pafs = None
            if 'pafs' in p.keys():
                pafs = p['pafs'].cpu()

            norm_vectors = None
            if 'norm_vector' in p.keys():
                norm_vectors = p['norm_vector']

        if K is not None:
            scores = scores[:K]
            masks = masks[:K]
            boxes = boxes[:K]
            labels = labels[:K]
            if not RGB_only:
                if keypoints_offset is not None:
                    keypoints_offset = keypoints_offset[:K]
                if axis_keypoints_offset is not None:
                    axis_keypoints_offset = axis_keypoints_offset[:K]
                if pafs is not None:
                    pafs = pafs[:K]

                if axis_offset is not None:
                    axis_offset = axis_offset[:K]
                if norm_vectors is not None:
                    norm_vectors = norm_vectors[:K]

        accepted = torch.nonzero(scores > score_thresh).squeeze()
        if accepted.nelement() == 0:
            res_dict = dict(scores=torch.empty(0),
                            masks=torch.empty(0, *img_shape),
                            boxes=torch.empty(0, 4),
                            labels=torch.empty(0))
            res.append(res_dict)
            continue
        if accepted.dim() == 0:
            accepted = accepted.unsqueeze(0)
        masks = masks[accepted]
        scores = scores[accepted]
        labels = labels[accepted]
        boxes = boxes[accepted]
        if not RGB_only:
            keypoints_3d = None
            center_voters = None
            if keypoints_offset is not None:
                keypoints_offset = keypoints_offset[accepted]
                img_depths = TF.to_tensor(img_depths / 10).type(
                    torch.float32).cuda()
                keypoints_3d, center_voters = vote_keypoint_inference(
                    img_depths, keypoints_offset.cuda(), boxes.cuda(),
                    masks.cuda(), camera_mat)
                keypoints_3d = keypoints_3d.cpu()

            if pafs is not None:
                pafs = pafs[accepted]
                pafs = pafs * masks
                pafs = pafs.sum((2, 3)) / masks.sum((2, 3))
                pafs = F.normalize(pafs, dim=1)

            axis = None
            voters = None
            if axis_keypoints_offset is not None:
                axis_keypoints_offset = axis_keypoints_offset[accepted]
                axis_keypoint1, voters1 = vote_keypoint_inference(
                    img_depths, axis_keypoints_offset[:, 0], boxes.cuda(),
                    masks.cuda(), camera_mat)
                axis_keypoint2, voters2 = vote_keypoint_inference(
                    img_depths, axis_keypoints_offset[:, 1], boxes.cuda(),
                    masks.cuda(), camera_mat)
                voters = [voters1, voters2]
                axis_keypoint = torch.stack((axis_keypoint1, axis_keypoint2),
                                            dim=1).cpu()
                axis = F.normalize(axis_keypoint[:, 1] - axis_keypoint[:, 0],
                                   dim=1)
            elif axis_offset is not None:
                axis_offset = axis_offset[accepted]
                axis, voters, ori_voters = vote_axis_inference(
                    img_depths, axis_offset.cuda(), boxes.cuda(), masks.cuda(),
                    labels)
            elif norm_vectors is not None:
                norm_vectors = norm_vectors[accepted]
                N = masks.shape[0]
                masks_for_vote = F.interpolate(masks,
                                               scale_factor=2,
                                               mode="bilinear",
                                               align_corners=False)
                _, index = torch.sort(masks_for_vote.view(N, -1), dim=-1)
                index = index[:, -1000:]
                estimate_vectors = []
                for i in range(N):
                    norm_v = norm_vectors[i].view(3, -1)[:, index[i]]
                    estimate_vectors.append(norm_v)
                axis = torch.stack(estimate_vectors, dim=0)
                axis = torch.mean(axis, dim=2)
                axis = F.normalize(axis, dim=1)
        masks = paste_masks_in_image(masks, boxes, img_shape, padding=0)

        res.append(
            dict(scores=scores,
                 masks=masks,
                 boxes=boxes,
                 labels=labels,
                 keypoints_3d=keypoints_3d,
                 axis=axis,
                 axis_voters=voters,
                 center_voters=center_voters,
                 paf_vectors=pafs))
    return res