Ejemplo n.º 1
0
    def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))
        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        if bbox_pred is not None:
            bboxes = delta2bbox3D(rois[:, 1:], bbox_pred, self.target_means,
                                  self.target_stds, img_shape)
        else:
            bboxes = rois[:, 1:]
            # TODO: add clip here

        if rescale:
            bboxes /= scale_factor
            # TODO: we need to change this.... only scale x,y but not z
            # boxes_tmp = bboxes[:, :4] / scale_factor
            # boxes_tmp_2 = bboxes[:, 6:10] / scale_factor
            # bboxes = torch.cat((boxes_tmp, bboxes[:, 4:6], boxes_tmp_2, bboxes[:, 10:12]), 1)

        if cfg is None:
            return bboxes, scores
        else:
            det_bboxes, det_labels = multiclass_nms_3d(bboxes, scores,
                                                       cfg.score_thr, cfg.nms,
                                                       cfg.max_per_img)

            return det_bboxes, det_labels
    def get_det_bboxes(self,
                       rois,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if bbox_pred is not None:
            bboxes = delta2bbox3D(rois[:, 1:], bbox_pred, self.target_means,
                                  self.target_stds, img_shape)
        else:
            bboxes = rois[:, 1:]

        if rescale:
            bboxes /= scale_factor

        if cfg is None:
            return bboxes
        else:
            det_bboxes, det_labels = multiclass_nms_3d(bboxes, scores,
                                                       cfg.score_thr, cfg.nms,
                                                       cfg.max_per_img)

            return det_bboxes, det_labels
    def simple_test(self,
                    imgs,
                    img_metas,
                    imgs_2,
                    img_metas_2,
                    imgs_3,
                    img_metas_3,
                    pp=None,
                    pp_2=None,
                    proposals=None,
                    rescale=False,
                    test_cfg2=None,
                    test_cfg3=None):
        """Test without augmentation."""
        assert self.with_bbox, "Bbox head must be implemented."

        if test_cfg3 is not None:
            test_cfg = test_cfg3
        elif test_cfg2 is not None:
            test_cfg = test_cfg2
        else:
            test_cfg = self.test_cfg

        img_metas = img_metas[0]
        img_metas_2 = img_metas_2[0]
        img_metas_3 = img_metas_3[0]

        # Default FPN
        x = self.extract_feat(imgs)
        x_2 = self.extract_feat(imgs_2)
        x_3 = self.extract_feat(imgs_3)

        # dataset 1
        proposal_list = self.simple_test_rpn(
            x, img_metas, test_cfg.rpn) if proposals is None else proposals
        bboxes, scores = self.simple_test_bboxes(x,
                                                 img_metas,
                                                 proposal_list,
                                                 None,
                                                 rescale=rescale)
        del proposal_list

        # dataset 2
        proposal_list = self.simple_test_rpn_2(
            x_2, img_metas_2, test_cfg.rpn) if proposals is None else proposals
        bboxes_2, scores_2 = self.simple_test_bboxes_2(x_2,
                                                       img_metas_2,
                                                       proposal_list,
                                                       None,
                                                       rescale=rescale)
        del proposal_list

        # dataset 3
        proposal_list = self.simple_test_rpn_3(
            x_3, img_metas_3, test_cfg.rpn) if proposals is None else proposals
        bboxes_3, scores_3 = self.simple_test_bboxes_3(x_3,
                                                       img_metas_3,
                                                       proposal_list,
                                                       None,
                                                       rescale=rescale)
        del proposal_list

        # refinement head
        if self.refinement_head_2 and self.refinement_head_3:
            bboxes_2_refinement = bboxes_2[:, 6:]
            bboxes_2_refinement = [
                torch.cat((bboxes_2_refinement, scores_2[:, 1, None]), dim=1)
            ]
            bboxes_2_refinement = self.simple_test_bbox_refinement_2(
                x, img_metas, bboxes_2_refinement, None, rescale=rescale)

            bboxes_3_refinement = bboxes_3[:, 6:]
            bboxes_3_refinement = [
                torch.cat((bboxes_3_refinement, scores_3[:, 1, None]), dim=1)
            ]
            bboxes_3_refinement = self.simple_test_bbox_refinement_3(
                x, img_metas, bboxes_3_refinement, None, rescale=rescale)

            bboxes_final = torch.cat(
                (bboxes, bboxes_2_refinement, bboxes_3_refinement), 0)
            scores_final = torch.cat((scores, scores_2, scores_3), 0)
        else:
            bboxes_final = torch.cat((bboxes, bboxes_2, bboxes_3), 0)
            scores_final = torch.cat((scores, scores_2, scores_3), 0)
        det_bboxes, det_labels = multiclass_nms_3d(bboxes_final, scores_final,
                                                   test_cfg.rcnn.score_thr,
                                                   test_cfg.rcnn.nms,
                                                   test_cfg.rcnn.max_per_img)

        bbox_results = bbox2result3D(det_bboxes, det_labels,
                                     self.bbox_head.num_classes)

        # free CUDA memory otherwise it would run out of memory
        del x
        del x_2
        del x_3
        del bboxes
        del bboxes_2
        del bboxes_3
        del scores
        del scores_2
        del scores_3
        del bboxes_final
        del scores_final
        del det_bboxes
        del det_labels
        if self.refinement_head_2 and self.refinement_head_3:
            del bboxes_2_refinement
            del bboxes_3_refinement
        torch.cuda.empty_cache()

        # ############ test RPN's performance ############
        # proposal_list = proposal_list[0].cpu().numpy()
        # return [proposal_list]

        # ############ only return bbox ############
        return bbox_results
    def simple_test(self,
                    imgs,
                    img_metas,
                    imgs_2,
                    img_metas_2,
                    imgs_3,
                    img_metas_3,
                    pp=None,
                    pp_2=None,
                    proposals=None,
                    rescale=False,
                    test_cfg2=None,
                    test_cfg3=None):
        """Test without augmentation."""
        assert self.with_bbox, "Bbox head must be implemented."

        if test_cfg3 is not None:
            test_cfg = test_cfg3
        elif test_cfg2 is not None:
            test_cfg = test_cfg2
        else:
            test_cfg = self.test_cfg

        img_metas = img_metas[0]
        img_metas_2 = img_metas_2[0]
        img_metas_3 = img_metas_3[0]

        # Default FPN
        x = self.extract_feat(imgs)
        x_2 = self.extract_feat(imgs_2)
        x_3 = self.extract_feat(imgs_3)

        proposal_list = self.simple_test_rpn(
            x, img_metas, test_cfg.rpn) if proposals is None else proposals
        proposal_list_2 = self.simple_test_rpn_2(
            x_2, img_metas_2, test_cfg.rpn) if proposals is None else proposals
        proposal_list_3 = self.simple_test_rpn_3(
            x_3, img_metas_3, test_cfg.rpn) if proposals is None else proposals

        bboxes, scores = self.simple_test_bboxes(x,
                                                 img_metas,
                                                 proposal_list,
                                                 None,
                                                 rescale=rescale)
        bboxes_2, scores_2 = self.simple_test_bboxes(x_2,
                                                     img_metas_2,
                                                     proposal_list_2,
                                                     None,
                                                     rescale=rescale)
        bboxes_3, scores_3 = self.simple_test_bboxes(x_3,
                                                     img_metas_3,
                                                     proposal_list_3,
                                                     None,
                                                     rescale=rescale)

        # refinement head
        bboxes_2_refinement = bboxes_2[:, 6:]
        bboxes_2_refinement = [
            torch.cat((bboxes_2_refinement, scores_2[:, 1, None]), dim=1)
        ]
        bboxes_2_refinement = self.simple_test_bbox_refinement(
            x, img_metas, bboxes_2_refinement, None, rescale=rescale)

        bboxes_3_refinement = bboxes_3[:, 6:]
        bboxes_3_refinement = [
            torch.cat((bboxes_3_refinement, scores_3[:, 1, None]), dim=1)
        ]
        bboxes_3_refinement = self.simple_test_bbox_refinement(
            x, img_metas, bboxes_3_refinement, None, rescale=rescale)

        # combine non-scaled and upscaled bboxes and scores
        bboxes = torch.cat((bboxes, bboxes_2_refinement, bboxes_3_refinement),
                           0)
        scores = torch.cat((scores, scores_2, scores_3), 0)
        det_bboxes, det_labels = multiclass_nms_3d(bboxes, scores,
                                                   test_cfg.rcnn.score_thr,
                                                   test_cfg.rcnn.nms,
                                                   test_cfg.rcnn.max_per_img)

        # return bboxes only
        bbox_results = bbox2result3D(det_bboxes, det_labels,
                                     self.bbox_head.num_classes)

        # bboxes = torch.cat((bboxes, bboxes_2), 0)
        # scores = torch.cat((scores, scores_2), 0)
        # det_bboxes, det_labels = multiclass_nms_3d(bboxes, scores, test_cfg.rcnn.score_thr, test_cfg.rcnn.nms, test_cfg.rcnn.max_per_img)

        # bbox_results = bbox2result3D(det_bboxes, det_labels,
        #                                 self.bbox_head.num_classes)

        # ############ test RPN's performance ############
        # proposal_list = proposal_list[0].cpu().numpy()
        # return [proposal_list]

        # ############ only return bbox ############
        return bbox_results
Ejemplo n.º 5
0
    def simple_test(self,
                    imgs,
                    img_metas,
                    imgs_2,
                    img_metas_2,
                    proposals=None,
                    rescale=False,
                    test_cfg2=None):
        """Test without augmentation."""
        assert self.with_bbox, "Bbox head must be implemented."
        if test_cfg2 is not None:
            test_cfg = test_cfg2
        else:
            test_cfg = self.test_cfg

        img_metas = img_metas[0]
        img_metas_2 = img_metas_2[0]
        x = self.extract_feat(imgs)
        x_2 = self.extract_feat(imgs_2)

        proposal_list = self.simple_test_rpn(
            x, img_metas, test_cfg.rpn) if proposals is None else proposals
        proposal_list_2 = self.simple_test_rpn_2(
            x_2, img_metas_2, test_cfg.rpn) if proposals is None else proposals

        bboxes, scores = self.simple_test_bboxes(x,
                                                 img_metas,
                                                 proposal_list,
                                                 None,
                                                 rescale=rescale)
        bboxes_2, scores_2 = self.simple_test_bboxes(x_2,
                                                     img_metas_2,
                                                     proposal_list_2,
                                                     None,
                                                     rescale=rescale)

        if self.refinement_head:
            # refinement head
            bboxes_2_refinement = bboxes_2[:, 6:]
            bboxes_2_refinement = [
                torch.cat((bboxes_2_refinement, scores_2[:, 1, None]), dim=1)
            ]
            bboxes_2_refinement = self.simple_test_bbox_refinement(
                x, img_metas, bboxes_2_refinement, None, rescale=rescale)

            # combine non-scaled and upscaled bboxes and scores
            bboxes_combined = torch.cat((bboxes, bboxes_2_refinement), 0)
            scores_combined = torch.cat((scores, scores_2), 0)
        else:
            bboxes_combined = torch.cat((bboxes, bboxes_2), 0)
            scores_combined = torch.cat((scores, scores_2), 0)

        det_bboxes, det_labels = multiclass_nms_3d(bboxes_combined,
                                                   scores_combined,
                                                   test_cfg.rcnn.score_thr,
                                                   test_cfg.rcnn.nms,
                                                   test_cfg.rcnn.max_per_img)

        bbox_results = bbox2result3D(det_bboxes, det_labels,
                                     self.bbox_head.num_classes)

        # return bbox or (bbox and segm)
        if test_cfg.return_bbox_only:
            return bbox_results
        else:
            if self.refinement_mask_head:
                # find out which detection box belongs to which resolution
                det_bboxes_np = det_bboxes.cpu().numpy()
                det_labels_np = det_labels.cpu().numpy()
                bboxes_np = bboxes_combined.cpu().numpy()
                cutoff_between_res1_res2 = len(bboxes)
                nonscaled_bboxes = []
                nonscaled_labels = []
                upscaled_bboxes = []
                upscaled_labels = []
                for det_bbox, det_label in zip(det_bboxes_np, det_labels_np):
                    for index, bbox in enumerate(bboxes_np):
                        if np.all(det_bbox[:6] == bbox[6:]):
                            if index >= cutoff_between_res1_res2:
                                #  upscaled bboxes
                                upscaled_bboxes.append(det_bbox)
                                upscaled_labels.append(det_label)
                            else:
                                # original-scaled bboxes
                                nonscaled_bboxes.append(det_bbox)
                                nonscaled_labels.append(det_label)

                nonscaled_bboxes_gpu = torch.from_numpy(
                    np.array(nonscaled_bboxes)).cuda()
                nonscaled_labels_gpu = torch.from_numpy(
                    np.array(nonscaled_labels)).cuda()
                upscaled_bboxes_gpu = torch.from_numpy(
                    np.array(upscaled_bboxes)).cuda()
                upscaled_labels_gpu = torch.from_numpy(
                    np.array(upscaled_labels)).cuda()

                segm_results_nonscaled = self.simple_test_mask(
                    x,
                    img_metas,
                    nonscaled_bboxes_gpu,
                    nonscaled_labels_gpu,
                    rescale=rescale)

                segm_results_refinement = self.simple_test_mask_refinement_v2(
                    x,
                    img_metas,
                    upscaled_bboxes_gpu,
                    upscaled_labels_gpu,
                    rescale=rescale)

                if len(nonscaled_bboxes_gpu) == 0:
                    det_bboxes = upscaled_bboxes_gpu
                    det_labels = upscaled_labels_gpu
                elif len(upscaled_bboxes_gpu) == 0:
                    det_bboxes = nonscaled_bboxes_gpu
                    det_labels = nonscaled_labels_gpu
                else:
                    det_bboxes = torch.cat(
                        (nonscaled_bboxes_gpu, upscaled_bboxes_gpu), 0)
                    det_labels = torch.cat(
                        (nonscaled_labels_gpu, upscaled_labels_gpu), 0)
                bbox_results = bbox2result3D(det_bboxes, det_labels,
                                             self.bbox_head.num_classes)
                # after this for loop, segm_results_nonscaled contains non-scaled and upscaled segmentation results
                for segm_results in segm_results_refinement[0]:
                    segm_results_nonscaled[0].append(segm_results)

                return bbox_results, segm_results_nonscaled
            else:
                segm_results = self.simple_test_mask(x,
                                                     img_metas,
                                                     det_bboxes,
                                                     det_labels,
                                                     rescale=rescale)
                return bbox_results, segm_results
Ejemplo n.º 6
0
    def simple_test(self,
                    imgs,
                    img_metas,
                    imgs_2,
                    img_metas_2,
                    pp=None,
                    pp_2=None,
                    proposals=None,
                    rescale=False,
                    test_cfg2=None):
        """Test without augmentation."""
        assert self.with_bbox, "Bbox head must be implemented."

        if test_cfg2 is not None:
            test_cfg = test_cfg2
        else:
            test_cfg = self.test_cfg

        img_metas = img_metas[0]
        img_metas_2 = img_metas_2[0]

        # Default FPN
        x = self.extract_feat(imgs)
        x_2 = self.extract_feat(imgs_2)

        #### WORSE PERFORMANCE:
        # Better FPN for 2 scales v1
        # x, x_2 = self.fuse_feature_maps(x, x_2)
        # Better FPN for 2 scales v2
        # x, x_2 = self.extract_feat_fusion(imgs, imgs_2, is_test=True)

        # dataset 1
        proposal_list = self.simple_test_rpn(
            x, img_metas, test_cfg.rpn) if proposals is None else proposals
        # use unsupervised learning's proposals
        if pp is not None:
            proposal_list = torch.cat((proposal_list[0], pp[0]), 0)
            proposal_list = [proposal_list]
        bboxes, scores = self.simple_test_bboxes(x,
                                                 img_metas,
                                                 proposal_list,
                                                 None,
                                                 rescale=rescale)

        # dataset 2
        proposal_list = self.simple_test_rpn_2(
            x_2, img_metas_2, test_cfg.rpn) if proposals is None else proposals
        # use unsupervised learning's proposals
        if pp_2 is not None:
            proposal_list = torch.cat((proposal_list[0], pp_2[0]), 0)
            proposal_list = [proposal_list]
        bboxes_2, scores_2 = self.simple_test_bboxes_2(x_2,
                                                       img_metas_2,
                                                       proposal_list,
                                                       None,
                                                       rescale=rescale)

        bboxes = torch.cat((bboxes, bboxes_2), 0)
        scores = torch.cat((scores, scores_2), 0)
        det_bboxes, det_labels = multiclass_nms_3d(bboxes, scores,
                                                   test_cfg.rcnn.score_thr,
                                                   test_cfg.rcnn.nms,
                                                   test_cfg.rcnn.max_per_img)

        bbox_results = bbox2result3D(det_bboxes, det_labels,
                                     self.bbox_head.num_classes)
        return bbox_results

        # segm_results = self.simple_test_mask(
        #         x, img_metas, det_bboxes, det_labels, rescale=rescale)

        # return bbox_results, segm_results
        '''
        bboxes from non-scaled pathway are fed into non-scaled mask branch, while bboxes from up-scaled pathway are 
        fed into upscaled mask branch.
        '''
        # find out which detection box belongs to which resolution
        downscaled_factor = 1.5
        det_bboxes_np = det_bboxes.cpu().numpy()
        det_labels_np = det_labels.cpu().numpy()
        bboxes_np = bboxes.cpu().numpy()
        cutoff_between_res1_res2 = len(bboxes) - len(bboxes_2)
        nonscaled_bboxes = []
        nonscaled_labels = []
        upscaled_bboxes = []
        upscaled_labels = []
        for det_bbox, det_label in zip(det_bboxes_np, det_labels_np):
            for index, bbox in enumerate(bboxes_np):
                if np.all(det_bbox[:6] == bbox[6:]):
                    if index < cutoff_between_res1_res2:
                        # 1x
                        nonscaled_bboxes.append(det_bbox)
                        nonscaled_labels.append(det_label)
                    else:
                        # 1.5x
                        det_bbox_upscaled = det_bbox[:6] / (1 /
                                                            downscaled_factor)
                        det_bbox_upscaled = np.append(det_bbox_upscaled,
                                                      det_bbox[6])
                        upscaled_bboxes.append(det_bbox_upscaled)
                        upscaled_labels.append(det_label)

        nonscaled_bboxes_gpu = torch.from_numpy(
            np.array(nonscaled_bboxes)).cuda()
        nonscaled_labels_gpu = torch.from_numpy(
            np.array(nonscaled_labels)).cuda()
        upscaled_bboxes_gpu = torch.from_numpy(
            np.array(upscaled_bboxes)).cuda()
        upscaled_labels_gpu = torch.from_numpy(
            np.array(upscaled_labels)).cuda()

        # replace original scale's ori_shape with upscaled's ori_shape so that mask size is upscaled and correct
        img_metas_2[0]['ori_shape'] = (512, 512, 3)
        img_metas_2[0]['img_shape'] = (512, 512, 3, 240)  # full volume only
        # img_metas_2[0]['img_shape'] = (128, 128, 3, 240) # patches only

        segm_results_nonscaled = self.simple_test_mask(x,
                                                       img_metas,
                                                       nonscaled_bboxes_gpu,
                                                       nonscaled_labels_gpu,
                                                       rescale=rescale)

        segm_results_upscaled = self.simple_test_mask_2(x_2,
                                                        img_metas_2,
                                                        upscaled_bboxes_gpu,
                                                        upscaled_labels_gpu,
                                                        rescale=rescale)

        upscaled_bboxes_gpu_downscaled = upscaled_bboxes_gpu[:, :
                                                             6] / downscaled_factor
        upscaled_bboxes_gpu_downscaled = torch.cat(
            (upscaled_bboxes_gpu_downscaled, upscaled_bboxes_gpu[:, 6, None]),
            dim=1)

        det_bboxes = torch.cat(
            (nonscaled_bboxes_gpu, upscaled_bboxes_gpu_downscaled), 0)
        det_labels = torch.cat((nonscaled_labels_gpu, upscaled_labels_gpu), 0)
        bbox_results = bbox2result3D(det_bboxes, det_labels,
                                     self.bbox_head.num_classes)
        # after this for loop, segm_results_nonscaled contains non-scaled and upscaled segmentation results
        for segm_results in segm_results_upscaled[0]:
            segm_results_nonscaled[0].append(segm_results)

        # for processing full volume:
        return bbox_results, segm_results_nonscaled

        # for processing patch:
        # segm_out_filepath = 'in_progress/segm_results_{}.npz'.format(self.iteration)
        # if not path.exists(segm_out_filepath):
        #     np.savez_compressed(segm_out_filepath, data=segm_results_nonscaled)
        # self.iteration += 1
        # return bbox_results, segm_out_filepath
        '''
        bboxes from non-scaled pathway are fed into non-scaled mask branch, while bboxes from up-scaled pathway are 
        fed into upscaled mask branch.
        '''
        # # find out which detection box belongs to which resolution
        # upscaled_factor = 1.5
        # det_bboxes_np = det_bboxes.cpu().numpy()
        # det_labels_np = det_labels.cpu().numpy()
        # bboxes_np = bboxes.cpu().numpy()
        # cutoff_between_res1_res2 = len(bboxes) - len(bboxes_2)
        # nonscaled_bboxes = []
        # nonscaled_labels = []
        # upscaled_bboxes = []
        # upscaled_labels = []
        # for det_bbox, det_label in zip(det_bboxes_np, det_labels_np):
        #     for index, bbox in enumerate(bboxes_np):
        #         if np.all(det_bbox[:6] == bbox[6:]):
        #             if index < cutoff_between_res1_res2:
        #                 # 1x
        #                 det_bbox_downscaled = det_bbox[:6] / upscaled_factor
        #                 det_bbox_downscaled = np.append(det_bbox_downscaled, det_bbox[6])
        #                 nonscaled_bboxes.append(det_bbox_downscaled)
        #                 nonscaled_labels.append(det_label)
        #             else:
        #                 # 1.5x
        #                 upscaled_bboxes.append(det_bbox)
        #                 upscaled_labels.append(det_label)

        # nonscaled_bboxes_gpu = torch.from_numpy(np.array(nonscaled_bboxes)).cuda()
        # nonscaled_labels_gpu = torch.from_numpy(np.array(nonscaled_labels)).cuda()
        # upscaled_bboxes_gpu = torch.from_numpy(np.array(upscaled_bboxes)).cuda()
        # upscaled_labels_gpu = torch.from_numpy(np.array(upscaled_labels)).cuda()

        # # replace original scale's ori_shape with upscaled's ori_shape so that mask size is upscaled and correct
        # # img_metas[0]['img_shape'] = (768, 768, 3, 160) # full volume only
        # img_metas[0]['img_shape'] = (192, 192, 3, 160) # patches only

        # segm_results_nonscaled = self.simple_test_mask(
        #         x, img_metas, nonscaled_bboxes_gpu, nonscaled_labels_gpu, rescale=rescale)

        # segm_results_upscaled = self.simple_test_mask_2(
        #         x_2, img_metas_2, upscaled_bboxes_gpu, upscaled_labels_gpu, rescale=rescale)

        # non_scaled_bboxes_gpu_upscaled = nonscaled_bboxes_gpu[:,:6] / (1/upscaled_factor)
        # non_scaled_bboxes_gpu_upscaled = torch.cat((non_scaled_bboxes_gpu_upscaled, nonscaled_bboxes_gpu[:,6, None]), dim=1)

        # det_bboxes = torch.cat((non_scaled_bboxes_gpu_upscaled, upscaled_bboxes_gpu), 0)
        # det_labels = torch.cat((nonscaled_labels_gpu, upscaled_labels_gpu), 0)
        # bbox_results = bbox2result3D(det_bboxes, det_labels,
        #                                 self.bbox_head.num_classes)
        # # after this for loop, segm_results_nonscaled contains non-scaled and upscaled segmentation results
        # for segm_results in segm_results_upscaled[0]:
        #     segm_results_nonscaled[0].append(segm_results)

        # # for processing full volume:
        # return bbox_results, segm_results_nonscaled

        # # for processing patch:
        # segm_out_filepath = 'in_progress/segm_results_{}.npz'.format(self.iteration)
        # if not path.exists(segm_out_filepath):
        #     np.savez_compressed(segm_out_filepath, data=segm_results_nonscaled)
        # self.iteration += 1
        # return bbox_results, segm_out_filepath
        '''