Beispiel #1
0
    def simple_test_mask_refinement_v2(self,
                                       x,
                                       img_meta,
                                       det_bboxes,
                                       det_labels,
                                       rescale=False):
        # image shape of the first image in the batch (only one)
        scale_factor = img_meta[0]['scale_factor']
        ori_shape = (*img_meta[0]['ori_shape'][:2],
                     int(img_meta[0]['img_shape'][3] / scale_factor))
        if det_bboxes.shape[0] == 0:
            segm_result = [
                [] for _ in range(self.refinement_mask_head.num_classes - 1)
            ]
        else:
            # if det_bboxes is rescaled to the original image size, we need to
            # rescale it back to the testing scale to obtain RoIs.

            # original code. remove because in two_heads implementation bboxes should not be scaled
            # _bboxes = (det_bboxes[:, :6] * scale_factor if rescale else det_bboxes)
            _bboxes = det_bboxes[:, :6]

            mask_rois_refinement = bbox2roi3D([_bboxes])
            mask_feats = self.refinement_mask_roi_extractor(
                x[:len(self.refinement_mask_roi_extractor.featmap_strides)],
                mask_rois_refinement)
            if self.with_shared_head:
                mask_feats = self.shared_head(mask_feats)
            mask_pred = self.refinement_mask_head(mask_feats)
            segm_result = self.refinement_mask_head.get_seg_masks(
                mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape,
                scale_factor, rescale)
        return segm_result
Beispiel #2
0
    def simple_test_bbox_refinement_3(self,
                                      x,
                                      img_meta,
                                      proposals,
                                      rcnn_test_cfg,
                                      rescale=False):
        """Test only det bboxes without augmentation."""
        rois = bbox2roi3D(proposals)
        roi_feats = self.bbox_roi_extractor_refinement_3(
            x[:len(self.bbox_roi_extractor_refinement_3.featmap_strides)],
            rois)
        if self.with_shared_head:
            roi_feats = self.shared_head(roi_feats)
        # regression only
        bbox_pred = self.refinement_head_3(roi_feats)
        # class and regression
        # _, bbox_pred = self.refinement_head(roi_feats)

        img_shape = img_meta[0]['img_shape']
        scale_factor = img_meta[0]['scale_factor']
        det_bboxes = self.refinement_head_3.get_det_bboxes(rois,
                                                           bbox_pred,
                                                           img_shape,
                                                           scale_factor,
                                                           rescale=rescale,
                                                           cfg=rcnn_test_cfg)
        return det_bboxes
 def simple_test_bboxes(self,
                        x,
                        img_meta,
                        proposals,
                        rcnn_test_cfg,
                        rescale=False):
     """Test only det bboxes without augmentation."""
     rois = bbox2roi3D(proposals)
     roi_feats = self.bbox_roi_extractor(
         x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
     if self.with_shared_head:
         roi_feats = self.shared_head(roi_feats)
     cls_score, bbox_pred, parcellation_score = self.bbox_head(roi_feats)
     img_shape = img_meta[0]['img_shape']
     scale_factor = img_meta[0]['scale_factor']
     det_bboxes, det_labels, det_parcellations = self.bbox_head.get_det_bboxes(
         rois,
         cls_score,
         bbox_pred,
         parcellation_score,
         img_shape,
         scale_factor,
         rescale=rescale,
         cfg=rcnn_test_cfg)
     return det_bboxes, det_labels, det_parcellations
    def forward_train(self,
                      imgs,
                      img_meta,
                      gt_bboxes,
                      gt_labels,
                      gt_bregions,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None):
        # self.print_iterations()
        assert imgs.shape[1] == 3 # make sure channel size is 3
        x = self.extract_feat(imgs)
        losses = dict()

        # RPN forward and loss
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration)
            losses.update(rpn_losses)

            proposal_inputs = rpn_outs + (img_meta, self.train_cfg.rpn_proposal)
            proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs)
            # self.rpn_head.visualize_anchor_boxes(imgs, rpn_outs[0], img_meta, slice_num=45, shuffle=True) # debug only
            # self.visualize_proposals(imgs, proposal_list, gt_bboxes, img_meta, slice_num=None, isProposal=True) #debug only
            # self.visualize_proposals(imgs, anchors, gt_bboxes, img_meta, slice_num=None, isProposal=False) #debug only
            # self.visualize_gt_bboxes(imgs, gt_bboxes, img_meta) #debug only
            # self.visualize_gt_bboxes_masks(imgs, gt_bboxes, img_meta, gt_masks) # debug only
        else:
            proposal_list = proposals

        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(
                self.train_cfg.rcnn.sampler, context=self)
            num_imgs = imgs.size(0)
            if gt_bboxes_ignore is None:
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []

            for i in range(num_imgs):
                gt_bboxes_cur_pat = gt_bboxes[i]
                gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i]
                gt_labels_cur_pat = gt_labels[i]
                gt_bregions_cur_pat = gt_bregions[i]

                assign_result = bbox_assigner.assign(
                    proposal_list[i], gt_bboxes_cur_pat, 
                    gt_bboxes_ignore_cur_pat, gt_labels_cur_pat, gt_bregions_cur_pat)
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes_cur_pat,
                    gt_labels_cur_pat,
                    gt_bregions_cur_pat,
                    feats=[lvl_feat[i][None] for lvl_feat in x])     
                sampling_results.append(sampling_result)

        # bbox head forward and loss
        if self.with_bbox:
            rois = bbox2roi3D([res.bboxes for res in sampling_results])
            # TODO: a more flexible way to decide which feature maps to use
            bbox_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            if self.with_shared_head:
                bbox_feats = self.shared_head(bbox_feats)
            cls_score, bbox_pred, parcellation_score = self.bbox_head(bbox_feats)
            bbox_targets = self.bbox_head.get_target(
                sampling_results, gt_bboxes, gt_labels, parcellation_score, self.train_cfg.rcnn)
            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, parcellation_score,
                                            *bbox_targets)
            losses.update(loss_bbox)

        # mask head forward and loss
        if self.with_mask:
            if not self.share_roi_extractor:
                pos_rois = bbox2roi3D(
                    [res.pos_bboxes for res in sampling_results])
                mask_feats = self.mask_roi_extractor(
                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
            else:
                pos_inds = []
                device = bbox_feats.device
                for res in sampling_results:
                    pos_inds.append(
                        torch.ones(
                            res.pos_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                    pos_inds.append(
                        torch.zeros(
                            res.neg_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                pos_inds = torch.cat(pos_inds)
                mask_feats = bbox_feats[pos_inds]
            mask_pred = self.mask_head(mask_feats)
            mask_targets = self.mask_head.get_target(
                sampling_results, gt_masks, self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results])
            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                            pos_labels)
            losses.update(loss_mask)

        # self.save_losses_plt(losses) #debug only...
        self.iteration += 1
        return losses
    def forward_train(self,
                      imgs,
                      img_meta,
                      imgs_2,
                      img_meta_2,
                      imgs_3,
                      img_meta_3,
                      gt_bboxes,
                      gt_bboxes_2,
                      gt_bboxes_3,
                      gt_labels,
                      gt_labels_2,
                      gt_labels_3,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      gt_masks_2=None,
                      gt_masks_3=None,
                      pp=None,
                      pp_2=None,
                      proposals=None):
        # self.print_iterations()
        assert imgs.shape[1] == 3 and imgs_2.shape[
            1] == 3 and imgs_3.shape[1] == 3  # make sure channel size is 3

        # Default FPN
        x = self.extract_feat(imgs)
        x_2 = self.extract_feat(imgs_2)
        x_3 = self.extract_feat(imgs_3)

        losses = dict()

        # RPN forward and loss
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_outs_2 = self.rpn_head_2(x_2)
            rpn_outs_3 = self.rpn_head_3(x_3)

            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_loss_inputs_2 = rpn_outs_2 + (gt_bboxes_2, img_meta,
                                              self.train_cfg.rpn)
            rpn_loss_inputs_3 = rpn_outs_3 + (gt_bboxes_3, img_meta,
                                              self.train_cfg.rpn)

            rpn_losses = self.rpn_head.loss(*rpn_loss_inputs,
                                            gt_bboxes_ignore=gt_bboxes_ignore,
                                            iteration=self.iteration)
            rpn_losses_2 = self.rpn_head_2.loss(
                *rpn_loss_inputs_2,
                gt_bboxes_ignore=gt_bboxes_ignore,
                iteration=self.iteration,
                img_meta_2=img_meta_2)
            rpn_losses_3 = self.rpn_head_3.loss(
                *rpn_loss_inputs_3,
                gt_bboxes_ignore=gt_bboxes_ignore,
                iteration=self.iteration,
                img_meta_3=img_meta_3)

            losses.update(rpn_losses)
            losses.update(rpn_losses_2)
            losses.update(rpn_losses_3)

            proposal_inputs = rpn_outs + (img_meta,
                                          self.train_cfg.rpn_proposal)
            proposal_inputs_2 = rpn_outs_2 + (img_meta,
                                              self.train_cfg.rpn_proposal)
            proposal_inputs_3 = rpn_outs_3 + (img_meta,
                                              self.train_cfg.rpn_proposal)

            proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs)
            proposal_list_2, anchors_2 = self.rpn_head_2.get_bboxes(
                *proposal_inputs_2, img_meta_2=img_meta_2)
            proposal_list_3, anchors_3 = self.rpn_head_3.get_bboxes(
                *proposal_inputs_3, img_meta_3=img_meta_3)

            # self.rpn_head.visualize_anchor_boxes(imgs, rpn_outs[0], img_meta, slice_num=45, shuffle=True) # debug only
            # self.visualize_proposals(imgs, proposal_list, gt_bboxes, img_meta, slice_num=None, isProposal=True) #debug only
            # self.visualize_proposals(imgs, anchors, gt_bboxes, img_meta, slice_num=None, isProposal=False) #debug only
            # self.visualize_gt_bboxes(imgs, gt_bboxes, img_meta) #debug only
            # breakpoint()
            # self.visualize_gt_bboxes(imgs_2, gt_bboxes_2, img_meta_2) #debug only
            # breakpoint()
            # self.visualize_gt_bboxes(imgs_3, gt_bboxes_3, img_meta_3) #debug only
            # breakpoint()
            # self.visualize_gt_bboxes_masks(imgs_2, gt_bboxes_2, img_meta_2, gt_masks) # debug only
        else:
            proposal_list = proposals

        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler,
                                         context=self)
            num_imgs = imgs.size(0)
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []

            for i in range(num_imgs):
                gt_bboxes_cur_pat = gt_bboxes[i]
                gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i]
                gt_labels_cur_pat = gt_labels[i]

                assign_result = bbox_assigner.assign(proposal_list[i],
                                                     gt_bboxes_cur_pat,
                                                     gt_bboxes_ignore_cur_pat,
                                                     gt_labels_cur_pat)
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes_cur_pat,
                    gt_labels_cur_pat,
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)

            bbox_assigner_2 = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler_2 = build_sampler(self.train_cfg.rcnn.sampler,
                                           context=self)
            num_imgs_2 = imgs_2.size(0)
            gt_bboxes_ignore_2 = [None for _ in range(num_imgs_2)]
            sampling_results_2 = []

            for i in range(num_imgs_2):
                gt_bboxes_cur_pat_2 = gt_bboxes_2[i]
                gt_bboxes_ignore_cur_pat_2 = gt_bboxes_ignore_2[i]
                gt_labels_cur_pat_2 = gt_labels_2[i]

                assign_result_2 = bbox_assigner_2.assign(
                    proposal_list_2[i], gt_bboxes_cur_pat_2,
                    gt_bboxes_ignore_cur_pat_2, gt_labels_cur_pat_2)
                sampling_result_2 = bbox_sampler_2.sample(
                    assign_result_2,
                    proposal_list_2[i],
                    gt_bboxes_cur_pat_2,
                    gt_labels_cur_pat_2,
                    feats=[lvl_feat[i][None] for lvl_feat in x_2])
                sampling_results_2.append(sampling_result_2)

            bbox_assigner_3 = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler_3 = build_sampler(self.train_cfg.rcnn.sampler,
                                           context=self)
            num_imgs_3 = imgs_3.size(0)
            gt_bboxes_ignore_3 = [None for _ in range(num_imgs_3)]
            sampling_results_3 = []

            for i in range(num_imgs_3):
                gt_bboxes_cur_pat_3 = gt_bboxes_3[i]
                gt_bboxes_ignore_cur_pat_3 = gt_bboxes_ignore_3[i]
                gt_labels_cur_pat_3 = gt_labels_3[i]

                assign_result_3 = bbox_assigner_3.assign(
                    proposal_list_3[i], gt_bboxes_cur_pat_3,
                    gt_bboxes_ignore_cur_pat_3, gt_labels_cur_pat_3)
                sampling_result_3 = bbox_sampler_3.sample(
                    assign_result_3,
                    proposal_list_3[i],
                    gt_bboxes_cur_pat_3,
                    gt_labels_cur_pat_3,
                    feats=[lvl_feat[i][None] for lvl_feat in x_3])
                sampling_results_3.append(sampling_result_3)

        # bbox head forward and loss
        if self.with_bbox:
            rois = bbox2roi3D([res.bboxes for res in sampling_results])
            rois_2 = bbox2roi3D([res.bboxes for res in sampling_results_2])
            rois_3 = bbox2roi3D([res.bboxes for res in sampling_results_3])

            # TODO: a more flexible way to decide which feature maps to use
            bbox_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            bbox_feats_2 = self.bbox_roi_extractor_2(
                x_2[:self.bbox_roi_extractor_2.num_inputs], rois_2)
            bbox_feats_3 = self.bbox_roi_extractor_3(
                x_3[:self.bbox_roi_extractor_3.num_inputs], rois_3)

            cls_score, bbox_pred = self.bbox_head(bbox_feats)
            cls_score_2, bbox_pred_2 = self.bbox_head_2(bbox_feats_2)
            cls_score_3, bbox_pred_3 = self.bbox_head_3(bbox_feats_3)

            bbox_targets = self.bbox_head.get_target(sampling_results,
                                                     gt_bboxes, gt_labels,
                                                     self.train_cfg.rcnn)
            bbox_targets_2 = self.bbox_head_2.get_target(
                sampling_results_2, gt_bboxes_2, gt_labels_2,
                self.train_cfg.rcnn)
            bbox_targets_3 = self.bbox_head_3.get_target(
                sampling_results_3, gt_bboxes_3, gt_labels_3,
                self.train_cfg.rcnn)

            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
                                            *bbox_targets)
            loss_bbox_2 = self.bbox_head_2.loss(cls_score_2,
                                                bbox_pred_2,
                                                *bbox_targets_2,
                                                img_meta_2=img_meta_2)
            loss_bbox_3 = self.bbox_head_3.loss(cls_score_3,
                                                bbox_pred_3,
                                                *bbox_targets_3,
                                                img_meta_3=img_meta_3)

            losses.update(loss_bbox)
            losses.update(loss_bbox_2)
            losses.update(loss_bbox_3)

        if self.refinement_head_2:
            # prepare upscaled data for refinement head
            upscaled_factor_2 = img_meta_2[0]['ori_shape'][0] / img_meta[0][
                'ori_shape'][0]
            # convert parameterized adjustments to actual bounding boxes coordinates
            pred_bboxes_2 = self.bbox_head_2.convert_adjustments_to_bboxes(
                rois_2, bbox_pred_2, img_meta_2[0]['img_shape'])[:, 6:].cpu(
                ).detach().numpy() / upscaled_factor_2

            pred_cls_score_2 = cls_score_2[:, 1, None].cpu().detach().numpy()
            pred_bboxes_2 = np.concatenate((pred_bboxes_2, pred_cls_score_2),
                                           axis=1)
            pred_bboxes_2 = [torch.from_numpy(pred_bboxes_2).cuda()]
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler,
                                         context=self)
            num_imgs = imgs.size(0)
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results_refinement_2 = []
            for i in range(num_imgs):
                gt_bboxes_cur_pat = gt_bboxes[i]
                gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i]
                gt_labels_cur_pat = gt_labels[i]

                assign_result = bbox_assigner.assign(pred_bboxes_2[i],
                                                     gt_bboxes_cur_pat,
                                                     gt_bboxes_ignore_cur_pat,
                                                     gt_labels_cur_pat)
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    pred_bboxes_2[i],
                    gt_bboxes_cur_pat,
                    gt_labels_cur_pat,
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results_refinement_2.append(sampling_result)
            rois_refinement_2 = bbox2roi3D(
                [res.bboxes for res in sampling_results_refinement_2])
            bbox_feats_refinement_2 = self.bbox_roi_extractor_refinement_2(
                x[:self.bbox_roi_extractor_refinement_2.num_inputs],
                rois_refinement_2)
            # training refinement head
            refined_bbox_pred_2 = self.refinement_head_2(
                bbox_feats_refinement_2)
            bbox_targets_refinement_2 = self.refinement_head_2.get_target(
                sampling_results_refinement_2, gt_bboxes, gt_labels,
                self.train_cfg.rcnn)
            loss_refinement_2 = self.refinement_head_2.loss(
                refined_bbox_pred_2,
                *bbox_targets_refinement_2,
                img_meta_2=img_meta_2)
            losses.update(loss_refinement_2)

        if self.refinement_head_3:
            # prepare upscaled data for refinement head
            upscaled_factor_3 = img_meta_3[0]['ori_shape'][0] / img_meta[0][
                'ori_shape'][0]
            # convert parameterized adjustments to actual bounding boxes coordinates
            pred_bboxes_3 = self.bbox_head_3.convert_adjustments_to_bboxes(
                rois_3, bbox_pred_3, img_meta_3[0]['img_shape'])[:, 6:].cpu(
                ).detach().numpy() / upscaled_factor_3

            pred_cls_score_3 = cls_score_3[:, 1, None].cpu().detach().numpy()
            pred_bboxes_3 = np.concatenate((pred_bboxes_3, pred_cls_score_3),
                                           axis=1)
            pred_bboxes_3 = [torch.from_numpy(pred_bboxes_3).cuda()]
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler,
                                         context=self)
            num_imgs = imgs.size(0)
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results_refinement_3 = []
            for i in range(num_imgs):
                gt_bboxes_cur_pat = gt_bboxes[i]
                gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i]
                gt_labels_cur_pat = gt_labels[i]

                assign_result = bbox_assigner.assign(pred_bboxes_3[i],
                                                     gt_bboxes_cur_pat,
                                                     gt_bboxes_ignore_cur_pat,
                                                     gt_labels_cur_pat)
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    pred_bboxes_3[i],
                    gt_bboxes_cur_pat,
                    gt_labels_cur_pat,
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results_refinement_3.append(sampling_result)
            rois_refinement_3 = bbox2roi3D(
                [res.bboxes for res in sampling_results_refinement_3])
            bbox_feats_refinement_3 = self.bbox_roi_extractor_refinement_3(
                x[:self.bbox_roi_extractor_refinement_3.num_inputs],
                rois_refinement_3)
            # training refinement head
            refined_bbox_pred_3 = self.refinement_head_3(
                bbox_feats_refinement_3)
            bbox_targets_refinement_3 = self.refinement_head_3.get_target(
                sampling_results_refinement_3, gt_bboxes, gt_labels,
                self.train_cfg.rcnn)
            loss_refinement_3 = self.refinement_head_3.loss(
                refined_bbox_pred_3,
                *bbox_targets_refinement_3,
                img_meta_3=img_meta_3)
            losses.update(loss_refinement_3)

        # mask head forward and loss
        if self.with_mask:
            # lower resolution mask head
            pos_rois = bbox2roi3D([res.pos_bboxes for res in sampling_results])
            mask_feats = self.mask_roi_extractor(
                x[:self.mask_roi_extractor.num_inputs], pos_rois)
            mask_pred = self.mask_head(mask_feats)
            mask_targets = self.mask_head.get_target(sampling_results,
                                                     gt_masks,
                                                     self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results])
            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                            pos_labels)
            losses.update(loss_mask)

            # refinement head mask head 2
            pos_rois = bbox2roi3D(
                [res.pos_bboxes for res in sampling_results_refinement_2])
            mask_feats = self.mask_roi_extractor_refinement_2(
                x[:self.mask_roi_extractor_refinement_2.num_inputs], pos_rois)
            mask_pred = self.mask_head_refinement_2(mask_feats)
            mask_targets = self.mask_head_refinement_2.get_target(
                sampling_results_refinement_2, gt_masks, self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results_refinement_2])
            loss_mask_refinement_2 = self.mask_head_refinement_2.loss(
                mask_pred, mask_targets, pos_labels, img_meta_2=img_meta_2)
            losses.update(loss_mask_refinement_2)

            # refinement head mask head 3
            pos_rois = bbox2roi3D(
                [res.pos_bboxes for res in sampling_results_refinement_3])
            mask_feats = self.mask_roi_extractor_refinement_3(
                x[:self.mask_roi_extractor_refinement_3.num_inputs], pos_rois)
            mask_pred = self.mask_head_refinement_3(mask_feats)
            mask_targets = self.mask_head_refinement_3.get_target(
                sampling_results_refinement_3, gt_masks, self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results_refinement_3])
            loss_mask_refinement_3 = self.mask_head_refinement_3.loss(
                mask_pred, mask_targets, pos_labels, img_meta_3=img_meta_3)
            losses.update(loss_mask_refinement_3)

        # self.save_losses_plt(losses) #debug only...
        self.iteration += 1
        return losses
Beispiel #6
0
    def forward_train(self,
                      imgs,
                      img_meta,
                      imgs_2,
                      img_meta_2,
                      gt_bboxes,
                      gt_bboxes_2,
                      gt_labels,
                      gt_labels_2,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      gt_masks_2=None,
                      proposals=None):
        assert imgs.shape[
            1] == 3 and imgs_2.shape[1] == 3  # make sure channel size is 3
        x = self.extract_feat(imgs)
        x_2 = self.extract_feat(imgs_2)

        losses = dict()

        # RPN forward and loss
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_outs_2 = self.rpn_head_2(x_2)

            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_loss_inputs_2 = rpn_outs_2 + (gt_bboxes_2, img_meta,
                                              self.train_cfg.rpn)

            rpn_losses = self.rpn_head.loss(*rpn_loss_inputs,
                                            gt_bboxes_ignore=gt_bboxes_ignore,
                                            iteration=self.iteration)
            rpn_losses_2 = self.rpn_head_2.loss(
                *rpn_loss_inputs_2,
                gt_bboxes_ignore=gt_bboxes_ignore,
                iteration=self.iteration,
                img_meta_2=img_meta_2)

            losses.update(rpn_losses)
            losses.update(rpn_losses_2)

            proposal_inputs = rpn_outs + (img_meta,
                                          self.train_cfg.rpn_proposal)
            proposal_inputs_2 = rpn_outs_2 + (img_meta,
                                              self.train_cfg.rpn_proposal)

            proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs)
            proposal_list_2, anchors_2 = self.rpn_head_2.get_bboxes(
                *proposal_inputs_2, img_meta_2=img_meta_2)
        else:
            proposal_list = proposals

        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler,
                                         context=self)
            num_imgs = imgs.size(0)
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []

            for i in range(num_imgs):
                gt_bboxes_cur_pat = gt_bboxes[i]
                gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i]
                gt_labels_cur_pat = gt_labels[i]

                assign_result = bbox_assigner.assign(proposal_list[i],
                                                     gt_bboxes_cur_pat,
                                                     gt_bboxes_ignore_cur_pat,
                                                     gt_labels_cur_pat)
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes_cur_pat,
                    gt_labels_cur_pat,
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)

            bbox_assigner_2 = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler_2 = build_sampler(self.train_cfg.rcnn.sampler,
                                           context=self)
            num_imgs_2 = imgs_2.size(0)
            gt_bboxes_ignore_2 = [None for _ in range(num_imgs_2)]
            sampling_results_2 = []

            for i in range(num_imgs_2):
                gt_bboxes_cur_pat_2 = gt_bboxes_2[i]
                gt_bboxes_ignore_cur_pat_2 = gt_bboxes_ignore_2[i]
                gt_labels_cur_pat_2 = gt_labels_2[i]

                assign_result_2 = bbox_assigner_2.assign(
                    proposal_list_2[i], gt_bboxes_cur_pat_2,
                    gt_bboxes_ignore_cur_pat_2, gt_labels_cur_pat_2)
                sampling_result_2 = bbox_sampler_2.sample(
                    assign_result_2,
                    proposal_list_2[i],
                    gt_bboxes_cur_pat_2,
                    gt_labels_cur_pat_2,
                    feats=[lvl_feat[i][None] for lvl_feat in x_2])
                sampling_results_2.append(sampling_result_2)

        # bbox head forward and loss
        if self.with_bbox:
            rois = bbox2roi3D([res.bboxes for res in sampling_results])
            rois_2 = bbox2roi3D([res.bboxes for res in sampling_results_2])

            # TODO: a more flexible way to decide which feature maps to use
            bbox_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            bbox_feats_2 = self.bbox_roi_extractor(
                x_2[:self.bbox_roi_extractor.num_inputs], rois_2)

            cls_score, bbox_pred = self.bbox_head(bbox_feats)
            cls_score_2, bbox_pred_2 = self.bbox_head(bbox_feats_2)

            cls_score = torch.cat((cls_score, cls_score_2), 0)
            bbox_pred = torch.cat((bbox_pred, bbox_pred_2), 0)

            bbox_targets = self.bbox_head.get_target(sampling_results,
                                                     gt_bboxes, gt_labels,
                                                     self.train_cfg.rcnn)
            bbox_targets_2 = self.bbox_head.get_target(sampling_results_2,
                                                       gt_bboxes_2,
                                                       gt_labels_2,
                                                       self.train_cfg.rcnn)
            bbox_targets_combined = []
            for bbox_target, bbox_target_2 in zip(bbox_targets,
                                                  bbox_targets_2):
                bbox_targets_combined.append(
                    torch.cat((bbox_target, bbox_target_2), 0))
            bbox_targets_combined = tuple(bbox_targets_combined)

            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
                                            *bbox_targets_combined)

            losses.update(loss_bbox)

        if self.refinement_head:
            # prepare upscaled data for refinement head
            upscaled_factor = img_meta_2[0]['ori_shape'][0] / img_meta[0][
                'ori_shape'][0]
            # convert parameterized adjustments to actual bounding boxes coordinates
            pred_bboxes_2 = self.bbox_head.convert_adjustments_to_bboxes(
                rois_2, bbox_pred_2, img_meta_2[0]
                ['img_shape'])[:, 6:].cpu().detach().numpy() / upscaled_factor

            pred_cls_score_2 = cls_score_2[:, 1, None].cpu().detach().numpy()
            pred_bboxes_2 = np.concatenate((pred_bboxes_2, pred_cls_score_2),
                                           axis=1)
            pred_bboxes_2 = [torch.from_numpy(pred_bboxes_2).cuda()]
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler,
                                         context=self)
            num_imgs = imgs.size(0)
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results_refinement = []
            for i in range(num_imgs):
                gt_bboxes_cur_pat = gt_bboxes[i]
                gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i]
                gt_labels_cur_pat = gt_labels[i]

                assign_result = bbox_assigner.assign(pred_bboxes_2[i],
                                                     gt_bboxes_cur_pat,
                                                     gt_bboxes_ignore_cur_pat,
                                                     gt_labels_cur_pat)
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    pred_bboxes_2[i],
                    gt_bboxes_cur_pat,
                    gt_labels_cur_pat,
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results_refinement.append(sampling_result)
            rois_refinement = bbox2roi3D(
                [res.bboxes for res in sampling_results_refinement])
            bbox_feats_refinement = self.bbox_roi_extractor_refinement(
                x[:self.bbox_roi_extractor_refinement.num_inputs],
                rois_refinement)
            # training refinement head
            refined_bbox_pred = self.refinement_head(bbox_feats_refinement)
            bbox_targets_refinement = self.refinement_head.get_target(
                sampling_results_refinement, gt_bboxes, gt_labels,
                self.train_cfg.rcnn)
            loss_refinement = self.refinement_head.loss(
                refined_bbox_pred, *bbox_targets_refinement)
            losses.update(loss_refinement)

        # mask head forward and loss
        if self.with_mask:
            pos_rois = bbox2roi3D([res.pos_bboxes for res in sampling_results])
            mask_feats = self.mask_roi_extractor(
                x[:self.mask_roi_extractor.num_inputs], pos_rois)
            mask_pred = self.mask_head(mask_feats)
            mask_targets = self.mask_head.get_target(sampling_results,
                                                     gt_masks,
                                                     self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results])
            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                            pos_labels)

            losses.update(loss_mask)

        if self.refinement_mask_head:
            pos_rois_refined = bbox2roi3D(
                [res.pos_bboxes for res in sampling_results_refinement])
            mask_feats_refined = self.refinement_mask_roi_extractor(
                x[:self.refinement_mask_roi_extractor.num_inputs],
                pos_rois_refined)
            mask_pred_refined = self.refinement_mask_head(mask_feats_refined)
            mask_targets_refined = self.refinement_mask_head.get_target(
                sampling_results_refinement, gt_masks, self.train_cfg.rcnn)
            pos_labels_refined = torch.cat(
                [res.pos_gt_labels for res in sampling_results_refinement])
            loss_refinement_mask = self.refinement_mask_head.loss(
                mask_pred_refined,
                mask_targets_refined,
                pos_labels_refined,
                img_meta_refinement=True)
            losses.update(loss_refinement_mask)

        self.iteration += 1
        return losses
Beispiel #7
0
    def forward_train(self,
                      imgs,
                      img_meta,
                      imgs_2,
                      img_meta_2,
                      gt_bboxes,
                      gt_bboxes_2,
                      gt_labels,
                      gt_labels_2,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      gt_masks_2=None,
                      pp=None,
                      pp_2=None,
                      proposals=None):
        # self.print_iterations()
        assert imgs.shape[
            1] == 3 and imgs_2.shape[1] == 3  # make sure channel size is 3
        # Default FPN
        x = self.extract_feat(imgs)
        x_2 = self.extract_feat(imgs_2)

        ##### WORSE PERFORMANCE
        # Better FPN for 2 scales v1
        # x, x_2 = self.fuse_feature_maps(x, x_2)
        # Better FPN for 2 scales v2
        # x, x_2 = self.extract_feat_fusion(imgs, imgs_2)

        losses = dict()

        # RPN forward and loss
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_outs_2 = self.rpn_head_2(x_2)

            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_loss_inputs_2 = rpn_outs_2 + (gt_bboxes_2, img_meta,
                                              self.train_cfg.rpn)

            rpn_losses = self.rpn_head.loss(*rpn_loss_inputs,
                                            gt_bboxes_ignore=gt_bboxes_ignore,
                                            iteration=self.iteration)
            rpn_losses_2 = self.rpn_head_2.loss(
                *rpn_loss_inputs_2,
                gt_bboxes_ignore=gt_bboxes_ignore,
                iteration=self.iteration,
                img_meta_2=img_meta_2)

            losses.update(rpn_losses)
            losses.update(rpn_losses_2)

            proposal_inputs = rpn_outs + (img_meta,
                                          self.train_cfg.rpn_proposal)
            proposal_inputs_2 = rpn_outs_2 + (img_meta,
                                              self.train_cfg.rpn_proposal)

            proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs)
            proposal_list_2, anchors_2 = self.rpn_head_2.get_bboxes(
                *proposal_inputs_2, img_meta_2=img_meta_2)

            if pp is not None and pp_2 is not None:
                proposal_list = torch.cat((proposal_list[0], pp[0]), 0)
                proposal_list = [proposal_list]
                proposal_list_2 = torch.cat((proposal_list_2[0], pp_2[0]), 0)
                proposal_list_2 = [proposal_list_2]

            # self.rpn_head.visualize_anchor_boxes(imgs, rpn_outs[0], img_meta, slice_num=45, shuffle=True) # debug only
            # self.visualize_proposals(imgs, proposal_list, gt_bboxes, img_meta, slice_num=None, isProposal=True) #debug only
            # self.visualize_proposals(imgs, anchors, gt_bboxes, img_meta, slice_num=None, isProposal=False) #debug only
            # self.visualize_gt_bboxes(imgs, gt_bboxes, img_meta) #debug only
            # breakpoint()
            # self.visualize_gt_bboxes(imgs_2, gt_bboxes_2, img_meta_2) #debug only
            # breakpoint()
            # self.visualize_gt_bboxes_masks(imgs_2, gt_bboxes_2, img_meta_2, gt_masks) # debug only
        else:
            proposal_list = proposals

        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler,
                                         context=self)
            num_imgs = imgs.size(0)
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []

            for i in range(num_imgs):
                gt_bboxes_cur_pat = gt_bboxes[i]
                gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i]
                gt_labels_cur_pat = gt_labels[i]

                assign_result = bbox_assigner.assign(proposal_list[i],
                                                     gt_bboxes_cur_pat,
                                                     gt_bboxes_ignore_cur_pat,
                                                     gt_labels_cur_pat)
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes_cur_pat,
                    gt_labels_cur_pat,
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)

            bbox_assigner_2 = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler_2 = build_sampler(self.train_cfg.rcnn.sampler,
                                           context=self)
            num_imgs_2 = imgs_2.size(0)
            gt_bboxes_ignore_2 = [None for _ in range(num_imgs_2)]
            sampling_results_2 = []

            for i in range(num_imgs_2):
                gt_bboxes_cur_pat_2 = gt_bboxes_2[i]
                gt_bboxes_ignore_cur_pat_2 = gt_bboxes_ignore_2[i]
                gt_labels_cur_pat_2 = gt_labels_2[i]

                assign_result_2 = bbox_assigner_2.assign(
                    proposal_list_2[i], gt_bboxes_cur_pat_2,
                    gt_bboxes_ignore_cur_pat_2, gt_labels_cur_pat_2)
                sampling_result_2 = bbox_sampler_2.sample(
                    assign_result_2,
                    proposal_list_2[i],
                    gt_bboxes_cur_pat_2,
                    gt_labels_cur_pat_2,
                    feats=[lvl_feat[i][None] for lvl_feat in x_2])
                sampling_results_2.append(sampling_result_2)

        # bbox head forward and loss
        if self.with_bbox:
            rois = bbox2roi3D([res.bboxes for res in sampling_results])
            rois_2 = bbox2roi3D([res.bboxes for res in sampling_results_2])

            # TODO: a more flexible way to decide which feature maps to use
            bbox_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            bbox_feats_2 = self.bbox_roi_extractor_2(
                x_2[:self.bbox_roi_extractor_2.num_inputs], rois_2)

            cls_score, bbox_pred = self.bbox_head(bbox_feats)
            cls_score_2, bbox_pred_2 = self.bbox_head_2(bbox_feats_2)

            bbox_targets = self.bbox_head.get_target(sampling_results,
                                                     gt_bboxes, gt_labels,
                                                     self.train_cfg.rcnn)
            bbox_targets_2 = self.bbox_head_2.get_target(
                sampling_results_2, gt_bboxes_2, gt_labels_2,
                self.train_cfg.rcnn)

            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
                                            *bbox_targets)
            loss_bbox_2 = self.bbox_head_2.loss(cls_score_2,
                                                bbox_pred_2,
                                                *bbox_targets_2,
                                                img_meta_2=img_meta_2)

            losses.update(loss_bbox)
            losses.update(loss_bbox_2)

        # mask head forward and loss
        if self.with_mask:
            # # implementation #1
            # # only utilize one mask head for higher resolution feature maps
            # pos_rois = bbox2roi3D(
            #     [res.pos_bboxes for res in sampling_results_2])
            # mask_feats = self.mask_roi_extractor(
            #     x_2[:self.mask_roi_extractor.num_inputs], pos_rois)
            # mask_pred = self.mask_head(mask_feats)
            # mask_targets = self.mask_head.get_target(
            #     sampling_results_2, gt_masks, self.train_cfg.rcnn)
            # pos_labels = torch.cat(
            #     [res.pos_gt_labels for res in sampling_results_2])
            # loss_mask = self.mask_head.loss(mask_pred, mask_targets,
            #                                 pos_labels)
            # losses.update(loss_mask)

            # implementation #2
            # lower resolution mask head
            pos_rois = bbox2roi3D([res.pos_bboxes for res in sampling_results])
            mask_feats = self.mask_roi_extractor(
                x[:self.mask_roi_extractor.num_inputs], pos_rois)
            mask_pred = self.mask_head(mask_feats)
            mask_targets = self.mask_head.get_target(sampling_results,
                                                     gt_masks,
                                                     self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results])
            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                            pos_labels)
            losses.update(loss_mask)

            # higher resolution mask head
            pos_rois = bbox2roi3D(
                [res.pos_bboxes for res in sampling_results_2])
            mask_feats = self.mask_roi_extractor_2(
                x_2[:self.mask_roi_extractor_2.num_inputs], pos_rois)
            mask_pred = self.mask_head_2(mask_feats)
            mask_targets = self.mask_head_2.get_target(sampling_results_2,
                                                       gt_masks_2,
                                                       self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results_2])
            loss_mask_2 = self.mask_head_2.loss(mask_pred,
                                                mask_targets,
                                                pos_labels,
                                                img_meta_2=img_meta_2)
            losses.update(loss_mask_2)

        # self.save_losses_plt(losses) #debug only...
        self.iteration += 1
        return losses