def simple_test_mask_refinement_v2(self, x, img_meta, det_bboxes, det_labels, rescale=False): # image shape of the first image in the batch (only one) scale_factor = img_meta[0]['scale_factor'] ori_shape = (*img_meta[0]['ori_shape'][:2], int(img_meta[0]['img_shape'][3] / scale_factor)) if det_bboxes.shape[0] == 0: segm_result = [ [] for _ in range(self.refinement_mask_head.num_classes - 1) ] else: # if det_bboxes is rescaled to the original image size, we need to # rescale it back to the testing scale to obtain RoIs. # original code. remove because in two_heads implementation bboxes should not be scaled # _bboxes = (det_bboxes[:, :6] * scale_factor if rescale else det_bboxes) _bboxes = det_bboxes[:, :6] mask_rois_refinement = bbox2roi3D([_bboxes]) mask_feats = self.refinement_mask_roi_extractor( x[:len(self.refinement_mask_roi_extractor.featmap_strides)], mask_rois_refinement) if self.with_shared_head: mask_feats = self.shared_head(mask_feats) mask_pred = self.refinement_mask_head(mask_feats) segm_result = self.refinement_mask_head.get_seg_masks( mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape, scale_factor, rescale) return segm_result
def simple_test_bbox_refinement_3(self, x, img_meta, proposals, rcnn_test_cfg, rescale=False): """Test only det bboxes without augmentation.""" rois = bbox2roi3D(proposals) roi_feats = self.bbox_roi_extractor_refinement_3( x[:len(self.bbox_roi_extractor_refinement_3.featmap_strides)], rois) if self.with_shared_head: roi_feats = self.shared_head(roi_feats) # regression only bbox_pred = self.refinement_head_3(roi_feats) # class and regression # _, bbox_pred = self.refinement_head(roi_feats) img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] det_bboxes = self.refinement_head_3.get_det_bboxes(rois, bbox_pred, img_shape, scale_factor, rescale=rescale, cfg=rcnn_test_cfg) return det_bboxes
def simple_test_bboxes(self, x, img_meta, proposals, rcnn_test_cfg, rescale=False): """Test only det bboxes without augmentation.""" rois = bbox2roi3D(proposals) roi_feats = self.bbox_roi_extractor( x[:len(self.bbox_roi_extractor.featmap_strides)], rois) if self.with_shared_head: roi_feats = self.shared_head(roi_feats) cls_score, bbox_pred, parcellation_score = self.bbox_head(roi_feats) img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] det_bboxes, det_labels, det_parcellations = self.bbox_head.get_det_bboxes( rois, cls_score, bbox_pred, parcellation_score, img_shape, scale_factor, rescale=rescale, cfg=rcnn_test_cfg) return det_bboxes, det_labels, det_parcellations
def forward_train(self, imgs, img_meta, gt_bboxes, gt_labels, gt_bregions, gt_bboxes_ignore=None, gt_masks=None, proposals=None): # self.print_iterations() assert imgs.shape[1] == 3 # make sure channel size is 3 x = self.extract_feat(imgs) losses = dict() # RPN forward and loss if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) rpn_losses = self.rpn_head.loss( *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration) losses.update(rpn_losses) proposal_inputs = rpn_outs + (img_meta, self.train_cfg.rpn_proposal) proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs) # self.rpn_head.visualize_anchor_boxes(imgs, rpn_outs[0], img_meta, slice_num=45, shuffle=True) # debug only # self.visualize_proposals(imgs, proposal_list, gt_bboxes, img_meta, slice_num=None, isProposal=True) #debug only # self.visualize_proposals(imgs, anchors, gt_bboxes, img_meta, slice_num=None, isProposal=False) #debug only # self.visualize_gt_bboxes(imgs, gt_bboxes, img_meta) #debug only # self.visualize_gt_bboxes_masks(imgs, gt_bboxes, img_meta, gt_masks) # debug only else: proposal_list = proposals # assign gts and sample proposals if self.with_bbox or self.with_mask: bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler( self.train_cfg.rcnn.sampler, context=self) num_imgs = imgs.size(0) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): gt_bboxes_cur_pat = gt_bboxes[i] gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i] gt_labels_cur_pat = gt_labels[i] gt_bregions_cur_pat = gt_bregions[i] assign_result = bbox_assigner.assign( proposal_list[i], gt_bboxes_cur_pat, gt_bboxes_ignore_cur_pat, gt_labels_cur_pat, gt_bregions_cur_pat) sampling_result = bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes_cur_pat, gt_labels_cur_pat, gt_bregions_cur_pat, feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) # bbox head forward and loss if self.with_bbox: rois = bbox2roi3D([res.bboxes for res in sampling_results]) # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred, parcellation_score = self.bbox_head(bbox_feats) bbox_targets = self.bbox_head.get_target( sampling_results, gt_bboxes, gt_labels, parcellation_score, self.train_cfg.rcnn) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, parcellation_score, *bbox_targets) losses.update(loss_bbox) # mask head forward and loss if self.with_mask: if not self.share_roi_extractor: pos_rois = bbox2roi3D( [res.pos_bboxes for res in sampling_results]) mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], pos_rois) if self.with_shared_head: mask_feats = self.shared_head(mask_feats) else: pos_inds = [] device = bbox_feats.device for res in sampling_results: pos_inds.append( torch.ones( res.pos_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds.append( torch.zeros( res.neg_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds = torch.cat(pos_inds) mask_feats = bbox_feats[pos_inds] mask_pred = self.mask_head(mask_feats) mask_targets = self.mask_head.get_target( sampling_results, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_pred, mask_targets, pos_labels) losses.update(loss_mask) # self.save_losses_plt(losses) #debug only... self.iteration += 1 return losses
def forward_train(self, imgs, img_meta, imgs_2, img_meta_2, imgs_3, img_meta_3, gt_bboxes, gt_bboxes_2, gt_bboxes_3, gt_labels, gt_labels_2, gt_labels_3, gt_bboxes_ignore=None, gt_masks=None, gt_masks_2=None, gt_masks_3=None, pp=None, pp_2=None, proposals=None): # self.print_iterations() assert imgs.shape[1] == 3 and imgs_2.shape[ 1] == 3 and imgs_3.shape[1] == 3 # make sure channel size is 3 # Default FPN x = self.extract_feat(imgs) x_2 = self.extract_feat(imgs_2) x_3 = self.extract_feat(imgs_3) losses = dict() # RPN forward and loss if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_outs_2 = self.rpn_head_2(x_2) rpn_outs_3 = self.rpn_head_3(x_3) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) rpn_loss_inputs_2 = rpn_outs_2 + (gt_bboxes_2, img_meta, self.train_cfg.rpn) rpn_loss_inputs_3 = rpn_outs_3 + (gt_bboxes_3, img_meta, self.train_cfg.rpn) rpn_losses = self.rpn_head.loss(*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration) rpn_losses_2 = self.rpn_head_2.loss( *rpn_loss_inputs_2, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration, img_meta_2=img_meta_2) rpn_losses_3 = self.rpn_head_3.loss( *rpn_loss_inputs_3, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration, img_meta_3=img_meta_3) losses.update(rpn_losses) losses.update(rpn_losses_2) losses.update(rpn_losses_3) proposal_inputs = rpn_outs + (img_meta, self.train_cfg.rpn_proposal) proposal_inputs_2 = rpn_outs_2 + (img_meta, self.train_cfg.rpn_proposal) proposal_inputs_3 = rpn_outs_3 + (img_meta, self.train_cfg.rpn_proposal) proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs) proposal_list_2, anchors_2 = self.rpn_head_2.get_bboxes( *proposal_inputs_2, img_meta_2=img_meta_2) proposal_list_3, anchors_3 = self.rpn_head_3.get_bboxes( *proposal_inputs_3, img_meta_3=img_meta_3) # self.rpn_head.visualize_anchor_boxes(imgs, rpn_outs[0], img_meta, slice_num=45, shuffle=True) # debug only # self.visualize_proposals(imgs, proposal_list, gt_bboxes, img_meta, slice_num=None, isProposal=True) #debug only # self.visualize_proposals(imgs, anchors, gt_bboxes, img_meta, slice_num=None, isProposal=False) #debug only # self.visualize_gt_bboxes(imgs, gt_bboxes, img_meta) #debug only # breakpoint() # self.visualize_gt_bboxes(imgs_2, gt_bboxes_2, img_meta_2) #debug only # breakpoint() # self.visualize_gt_bboxes(imgs_3, gt_bboxes_3, img_meta_3) #debug only # breakpoint() # self.visualize_gt_bboxes_masks(imgs_2, gt_bboxes_2, img_meta_2, gt_masks) # debug only else: proposal_list = proposals # assign gts and sample proposals if self.with_bbox or self.with_mask: bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = imgs.size(0) gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): gt_bboxes_cur_pat = gt_bboxes[i] gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i] gt_labels_cur_pat = gt_labels[i] assign_result = bbox_assigner.assign(proposal_list[i], gt_bboxes_cur_pat, gt_bboxes_ignore_cur_pat, gt_labels_cur_pat) sampling_result = bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes_cur_pat, gt_labels_cur_pat, feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) bbox_assigner_2 = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler_2 = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs_2 = imgs_2.size(0) gt_bboxes_ignore_2 = [None for _ in range(num_imgs_2)] sampling_results_2 = [] for i in range(num_imgs_2): gt_bboxes_cur_pat_2 = gt_bboxes_2[i] gt_bboxes_ignore_cur_pat_2 = gt_bboxes_ignore_2[i] gt_labels_cur_pat_2 = gt_labels_2[i] assign_result_2 = bbox_assigner_2.assign( proposal_list_2[i], gt_bboxes_cur_pat_2, gt_bboxes_ignore_cur_pat_2, gt_labels_cur_pat_2) sampling_result_2 = bbox_sampler_2.sample( assign_result_2, proposal_list_2[i], gt_bboxes_cur_pat_2, gt_labels_cur_pat_2, feats=[lvl_feat[i][None] for lvl_feat in x_2]) sampling_results_2.append(sampling_result_2) bbox_assigner_3 = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler_3 = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs_3 = imgs_3.size(0) gt_bboxes_ignore_3 = [None for _ in range(num_imgs_3)] sampling_results_3 = [] for i in range(num_imgs_3): gt_bboxes_cur_pat_3 = gt_bboxes_3[i] gt_bboxes_ignore_cur_pat_3 = gt_bboxes_ignore_3[i] gt_labels_cur_pat_3 = gt_labels_3[i] assign_result_3 = bbox_assigner_3.assign( proposal_list_3[i], gt_bboxes_cur_pat_3, gt_bboxes_ignore_cur_pat_3, gt_labels_cur_pat_3) sampling_result_3 = bbox_sampler_3.sample( assign_result_3, proposal_list_3[i], gt_bboxes_cur_pat_3, gt_labels_cur_pat_3, feats=[lvl_feat[i][None] for lvl_feat in x_3]) sampling_results_3.append(sampling_result_3) # bbox head forward and loss if self.with_bbox: rois = bbox2roi3D([res.bboxes for res in sampling_results]) rois_2 = bbox2roi3D([res.bboxes for res in sampling_results_2]) rois_3 = bbox2roi3D([res.bboxes for res in sampling_results_3]) # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) bbox_feats_2 = self.bbox_roi_extractor_2( x_2[:self.bbox_roi_extractor_2.num_inputs], rois_2) bbox_feats_3 = self.bbox_roi_extractor_3( x_3[:self.bbox_roi_extractor_3.num_inputs], rois_3) cls_score, bbox_pred = self.bbox_head(bbox_feats) cls_score_2, bbox_pred_2 = self.bbox_head_2(bbox_feats_2) cls_score_3, bbox_pred_3 = self.bbox_head_3(bbox_feats_3) bbox_targets = self.bbox_head.get_target(sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn) bbox_targets_2 = self.bbox_head_2.get_target( sampling_results_2, gt_bboxes_2, gt_labels_2, self.train_cfg.rcnn) bbox_targets_3 = self.bbox_head_3.get_target( sampling_results_3, gt_bboxes_3, gt_labels_3, self.train_cfg.rcnn) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, *bbox_targets) loss_bbox_2 = self.bbox_head_2.loss(cls_score_2, bbox_pred_2, *bbox_targets_2, img_meta_2=img_meta_2) loss_bbox_3 = self.bbox_head_3.loss(cls_score_3, bbox_pred_3, *bbox_targets_3, img_meta_3=img_meta_3) losses.update(loss_bbox) losses.update(loss_bbox_2) losses.update(loss_bbox_3) if self.refinement_head_2: # prepare upscaled data for refinement head upscaled_factor_2 = img_meta_2[0]['ori_shape'][0] / img_meta[0][ 'ori_shape'][0] # convert parameterized adjustments to actual bounding boxes coordinates pred_bboxes_2 = self.bbox_head_2.convert_adjustments_to_bboxes( rois_2, bbox_pred_2, img_meta_2[0]['img_shape'])[:, 6:].cpu( ).detach().numpy() / upscaled_factor_2 pred_cls_score_2 = cls_score_2[:, 1, None].cpu().detach().numpy() pred_bboxes_2 = np.concatenate((pred_bboxes_2, pred_cls_score_2), axis=1) pred_bboxes_2 = [torch.from_numpy(pred_bboxes_2).cuda()] bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = imgs.size(0) gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results_refinement_2 = [] for i in range(num_imgs): gt_bboxes_cur_pat = gt_bboxes[i] gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i] gt_labels_cur_pat = gt_labels[i] assign_result = bbox_assigner.assign(pred_bboxes_2[i], gt_bboxes_cur_pat, gt_bboxes_ignore_cur_pat, gt_labels_cur_pat) sampling_result = bbox_sampler.sample( assign_result, pred_bboxes_2[i], gt_bboxes_cur_pat, gt_labels_cur_pat, feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results_refinement_2.append(sampling_result) rois_refinement_2 = bbox2roi3D( [res.bboxes for res in sampling_results_refinement_2]) bbox_feats_refinement_2 = self.bbox_roi_extractor_refinement_2( x[:self.bbox_roi_extractor_refinement_2.num_inputs], rois_refinement_2) # training refinement head refined_bbox_pred_2 = self.refinement_head_2( bbox_feats_refinement_2) bbox_targets_refinement_2 = self.refinement_head_2.get_target( sampling_results_refinement_2, gt_bboxes, gt_labels, self.train_cfg.rcnn) loss_refinement_2 = self.refinement_head_2.loss( refined_bbox_pred_2, *bbox_targets_refinement_2, img_meta_2=img_meta_2) losses.update(loss_refinement_2) if self.refinement_head_3: # prepare upscaled data for refinement head upscaled_factor_3 = img_meta_3[0]['ori_shape'][0] / img_meta[0][ 'ori_shape'][0] # convert parameterized adjustments to actual bounding boxes coordinates pred_bboxes_3 = self.bbox_head_3.convert_adjustments_to_bboxes( rois_3, bbox_pred_3, img_meta_3[0]['img_shape'])[:, 6:].cpu( ).detach().numpy() / upscaled_factor_3 pred_cls_score_3 = cls_score_3[:, 1, None].cpu().detach().numpy() pred_bboxes_3 = np.concatenate((pred_bboxes_3, pred_cls_score_3), axis=1) pred_bboxes_3 = [torch.from_numpy(pred_bboxes_3).cuda()] bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = imgs.size(0) gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results_refinement_3 = [] for i in range(num_imgs): gt_bboxes_cur_pat = gt_bboxes[i] gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i] gt_labels_cur_pat = gt_labels[i] assign_result = bbox_assigner.assign(pred_bboxes_3[i], gt_bboxes_cur_pat, gt_bboxes_ignore_cur_pat, gt_labels_cur_pat) sampling_result = bbox_sampler.sample( assign_result, pred_bboxes_3[i], gt_bboxes_cur_pat, gt_labels_cur_pat, feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results_refinement_3.append(sampling_result) rois_refinement_3 = bbox2roi3D( [res.bboxes for res in sampling_results_refinement_3]) bbox_feats_refinement_3 = self.bbox_roi_extractor_refinement_3( x[:self.bbox_roi_extractor_refinement_3.num_inputs], rois_refinement_3) # training refinement head refined_bbox_pred_3 = self.refinement_head_3( bbox_feats_refinement_3) bbox_targets_refinement_3 = self.refinement_head_3.get_target( sampling_results_refinement_3, gt_bboxes, gt_labels, self.train_cfg.rcnn) loss_refinement_3 = self.refinement_head_3.loss( refined_bbox_pred_3, *bbox_targets_refinement_3, img_meta_3=img_meta_3) losses.update(loss_refinement_3) # mask head forward and loss if self.with_mask: # lower resolution mask head pos_rois = bbox2roi3D([res.pos_bboxes for res in sampling_results]) mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], pos_rois) mask_pred = self.mask_head(mask_feats) mask_targets = self.mask_head.get_target(sampling_results, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_pred, mask_targets, pos_labels) losses.update(loss_mask) # refinement head mask head 2 pos_rois = bbox2roi3D( [res.pos_bboxes for res in sampling_results_refinement_2]) mask_feats = self.mask_roi_extractor_refinement_2( x[:self.mask_roi_extractor_refinement_2.num_inputs], pos_rois) mask_pred = self.mask_head_refinement_2(mask_feats) mask_targets = self.mask_head_refinement_2.get_target( sampling_results_refinement_2, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results_refinement_2]) loss_mask_refinement_2 = self.mask_head_refinement_2.loss( mask_pred, mask_targets, pos_labels, img_meta_2=img_meta_2) losses.update(loss_mask_refinement_2) # refinement head mask head 3 pos_rois = bbox2roi3D( [res.pos_bboxes for res in sampling_results_refinement_3]) mask_feats = self.mask_roi_extractor_refinement_3( x[:self.mask_roi_extractor_refinement_3.num_inputs], pos_rois) mask_pred = self.mask_head_refinement_3(mask_feats) mask_targets = self.mask_head_refinement_3.get_target( sampling_results_refinement_3, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results_refinement_3]) loss_mask_refinement_3 = self.mask_head_refinement_3.loss( mask_pred, mask_targets, pos_labels, img_meta_3=img_meta_3) losses.update(loss_mask_refinement_3) # self.save_losses_plt(losses) #debug only... self.iteration += 1 return losses
def forward_train(self, imgs, img_meta, imgs_2, img_meta_2, gt_bboxes, gt_bboxes_2, gt_labels, gt_labels_2, gt_bboxes_ignore=None, gt_masks=None, gt_masks_2=None, proposals=None): assert imgs.shape[ 1] == 3 and imgs_2.shape[1] == 3 # make sure channel size is 3 x = self.extract_feat(imgs) x_2 = self.extract_feat(imgs_2) losses = dict() # RPN forward and loss if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_outs_2 = self.rpn_head_2(x_2) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) rpn_loss_inputs_2 = rpn_outs_2 + (gt_bboxes_2, img_meta, self.train_cfg.rpn) rpn_losses = self.rpn_head.loss(*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration) rpn_losses_2 = self.rpn_head_2.loss( *rpn_loss_inputs_2, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration, img_meta_2=img_meta_2) losses.update(rpn_losses) losses.update(rpn_losses_2) proposal_inputs = rpn_outs + (img_meta, self.train_cfg.rpn_proposal) proposal_inputs_2 = rpn_outs_2 + (img_meta, self.train_cfg.rpn_proposal) proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs) proposal_list_2, anchors_2 = self.rpn_head_2.get_bboxes( *proposal_inputs_2, img_meta_2=img_meta_2) else: proposal_list = proposals # assign gts and sample proposals if self.with_bbox or self.with_mask: bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = imgs.size(0) gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): gt_bboxes_cur_pat = gt_bboxes[i] gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i] gt_labels_cur_pat = gt_labels[i] assign_result = bbox_assigner.assign(proposal_list[i], gt_bboxes_cur_pat, gt_bboxes_ignore_cur_pat, gt_labels_cur_pat) sampling_result = bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes_cur_pat, gt_labels_cur_pat, feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) bbox_assigner_2 = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler_2 = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs_2 = imgs_2.size(0) gt_bboxes_ignore_2 = [None for _ in range(num_imgs_2)] sampling_results_2 = [] for i in range(num_imgs_2): gt_bboxes_cur_pat_2 = gt_bboxes_2[i] gt_bboxes_ignore_cur_pat_2 = gt_bboxes_ignore_2[i] gt_labels_cur_pat_2 = gt_labels_2[i] assign_result_2 = bbox_assigner_2.assign( proposal_list_2[i], gt_bboxes_cur_pat_2, gt_bboxes_ignore_cur_pat_2, gt_labels_cur_pat_2) sampling_result_2 = bbox_sampler_2.sample( assign_result_2, proposal_list_2[i], gt_bboxes_cur_pat_2, gt_labels_cur_pat_2, feats=[lvl_feat[i][None] for lvl_feat in x_2]) sampling_results_2.append(sampling_result_2) # bbox head forward and loss if self.with_bbox: rois = bbox2roi3D([res.bboxes for res in sampling_results]) rois_2 = bbox2roi3D([res.bboxes for res in sampling_results_2]) # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) bbox_feats_2 = self.bbox_roi_extractor( x_2[:self.bbox_roi_extractor.num_inputs], rois_2) cls_score, bbox_pred = self.bbox_head(bbox_feats) cls_score_2, bbox_pred_2 = self.bbox_head(bbox_feats_2) cls_score = torch.cat((cls_score, cls_score_2), 0) bbox_pred = torch.cat((bbox_pred, bbox_pred_2), 0) bbox_targets = self.bbox_head.get_target(sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn) bbox_targets_2 = self.bbox_head.get_target(sampling_results_2, gt_bboxes_2, gt_labels_2, self.train_cfg.rcnn) bbox_targets_combined = [] for bbox_target, bbox_target_2 in zip(bbox_targets, bbox_targets_2): bbox_targets_combined.append( torch.cat((bbox_target, bbox_target_2), 0)) bbox_targets_combined = tuple(bbox_targets_combined) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, *bbox_targets_combined) losses.update(loss_bbox) if self.refinement_head: # prepare upscaled data for refinement head upscaled_factor = img_meta_2[0]['ori_shape'][0] / img_meta[0][ 'ori_shape'][0] # convert parameterized adjustments to actual bounding boxes coordinates pred_bboxes_2 = self.bbox_head.convert_adjustments_to_bboxes( rois_2, bbox_pred_2, img_meta_2[0] ['img_shape'])[:, 6:].cpu().detach().numpy() / upscaled_factor pred_cls_score_2 = cls_score_2[:, 1, None].cpu().detach().numpy() pred_bboxes_2 = np.concatenate((pred_bboxes_2, pred_cls_score_2), axis=1) pred_bboxes_2 = [torch.from_numpy(pred_bboxes_2).cuda()] bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = imgs.size(0) gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results_refinement = [] for i in range(num_imgs): gt_bboxes_cur_pat = gt_bboxes[i] gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i] gt_labels_cur_pat = gt_labels[i] assign_result = bbox_assigner.assign(pred_bboxes_2[i], gt_bboxes_cur_pat, gt_bboxes_ignore_cur_pat, gt_labels_cur_pat) sampling_result = bbox_sampler.sample( assign_result, pred_bboxes_2[i], gt_bboxes_cur_pat, gt_labels_cur_pat, feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results_refinement.append(sampling_result) rois_refinement = bbox2roi3D( [res.bboxes for res in sampling_results_refinement]) bbox_feats_refinement = self.bbox_roi_extractor_refinement( x[:self.bbox_roi_extractor_refinement.num_inputs], rois_refinement) # training refinement head refined_bbox_pred = self.refinement_head(bbox_feats_refinement) bbox_targets_refinement = self.refinement_head.get_target( sampling_results_refinement, gt_bboxes, gt_labels, self.train_cfg.rcnn) loss_refinement = self.refinement_head.loss( refined_bbox_pred, *bbox_targets_refinement) losses.update(loss_refinement) # mask head forward and loss if self.with_mask: pos_rois = bbox2roi3D([res.pos_bboxes for res in sampling_results]) mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], pos_rois) mask_pred = self.mask_head(mask_feats) mask_targets = self.mask_head.get_target(sampling_results, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_pred, mask_targets, pos_labels) losses.update(loss_mask) if self.refinement_mask_head: pos_rois_refined = bbox2roi3D( [res.pos_bboxes for res in sampling_results_refinement]) mask_feats_refined = self.refinement_mask_roi_extractor( x[:self.refinement_mask_roi_extractor.num_inputs], pos_rois_refined) mask_pred_refined = self.refinement_mask_head(mask_feats_refined) mask_targets_refined = self.refinement_mask_head.get_target( sampling_results_refinement, gt_masks, self.train_cfg.rcnn) pos_labels_refined = torch.cat( [res.pos_gt_labels for res in sampling_results_refinement]) loss_refinement_mask = self.refinement_mask_head.loss( mask_pred_refined, mask_targets_refined, pos_labels_refined, img_meta_refinement=True) losses.update(loss_refinement_mask) self.iteration += 1 return losses
def forward_train(self, imgs, img_meta, imgs_2, img_meta_2, gt_bboxes, gt_bboxes_2, gt_labels, gt_labels_2, gt_bboxes_ignore=None, gt_masks=None, gt_masks_2=None, pp=None, pp_2=None, proposals=None): # self.print_iterations() assert imgs.shape[ 1] == 3 and imgs_2.shape[1] == 3 # make sure channel size is 3 # Default FPN x = self.extract_feat(imgs) x_2 = self.extract_feat(imgs_2) ##### WORSE PERFORMANCE # Better FPN for 2 scales v1 # x, x_2 = self.fuse_feature_maps(x, x_2) # Better FPN for 2 scales v2 # x, x_2 = self.extract_feat_fusion(imgs, imgs_2) losses = dict() # RPN forward and loss if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_outs_2 = self.rpn_head_2(x_2) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) rpn_loss_inputs_2 = rpn_outs_2 + (gt_bboxes_2, img_meta, self.train_cfg.rpn) rpn_losses = self.rpn_head.loss(*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration) rpn_losses_2 = self.rpn_head_2.loss( *rpn_loss_inputs_2, gt_bboxes_ignore=gt_bboxes_ignore, iteration=self.iteration, img_meta_2=img_meta_2) losses.update(rpn_losses) losses.update(rpn_losses_2) proposal_inputs = rpn_outs + (img_meta, self.train_cfg.rpn_proposal) proposal_inputs_2 = rpn_outs_2 + (img_meta, self.train_cfg.rpn_proposal) proposal_list, anchors = self.rpn_head.get_bboxes(*proposal_inputs) proposal_list_2, anchors_2 = self.rpn_head_2.get_bboxes( *proposal_inputs_2, img_meta_2=img_meta_2) if pp is not None and pp_2 is not None: proposal_list = torch.cat((proposal_list[0], pp[0]), 0) proposal_list = [proposal_list] proposal_list_2 = torch.cat((proposal_list_2[0], pp_2[0]), 0) proposal_list_2 = [proposal_list_2] # self.rpn_head.visualize_anchor_boxes(imgs, rpn_outs[0], img_meta, slice_num=45, shuffle=True) # debug only # self.visualize_proposals(imgs, proposal_list, gt_bboxes, img_meta, slice_num=None, isProposal=True) #debug only # self.visualize_proposals(imgs, anchors, gt_bboxes, img_meta, slice_num=None, isProposal=False) #debug only # self.visualize_gt_bboxes(imgs, gt_bboxes, img_meta) #debug only # breakpoint() # self.visualize_gt_bboxes(imgs_2, gt_bboxes_2, img_meta_2) #debug only # breakpoint() # self.visualize_gt_bboxes_masks(imgs_2, gt_bboxes_2, img_meta_2, gt_masks) # debug only else: proposal_list = proposals # assign gts and sample proposals if self.with_bbox or self.with_mask: bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = imgs.size(0) gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): gt_bboxes_cur_pat = gt_bboxes[i] gt_bboxes_ignore_cur_pat = gt_bboxes_ignore[i] gt_labels_cur_pat = gt_labels[i] assign_result = bbox_assigner.assign(proposal_list[i], gt_bboxes_cur_pat, gt_bboxes_ignore_cur_pat, gt_labels_cur_pat) sampling_result = bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes_cur_pat, gt_labels_cur_pat, feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) bbox_assigner_2 = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler_2 = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs_2 = imgs_2.size(0) gt_bboxes_ignore_2 = [None for _ in range(num_imgs_2)] sampling_results_2 = [] for i in range(num_imgs_2): gt_bboxes_cur_pat_2 = gt_bboxes_2[i] gt_bboxes_ignore_cur_pat_2 = gt_bboxes_ignore_2[i] gt_labels_cur_pat_2 = gt_labels_2[i] assign_result_2 = bbox_assigner_2.assign( proposal_list_2[i], gt_bboxes_cur_pat_2, gt_bboxes_ignore_cur_pat_2, gt_labels_cur_pat_2) sampling_result_2 = bbox_sampler_2.sample( assign_result_2, proposal_list_2[i], gt_bboxes_cur_pat_2, gt_labels_cur_pat_2, feats=[lvl_feat[i][None] for lvl_feat in x_2]) sampling_results_2.append(sampling_result_2) # bbox head forward and loss if self.with_bbox: rois = bbox2roi3D([res.bboxes for res in sampling_results]) rois_2 = bbox2roi3D([res.bboxes for res in sampling_results_2]) # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) bbox_feats_2 = self.bbox_roi_extractor_2( x_2[:self.bbox_roi_extractor_2.num_inputs], rois_2) cls_score, bbox_pred = self.bbox_head(bbox_feats) cls_score_2, bbox_pred_2 = self.bbox_head_2(bbox_feats_2) bbox_targets = self.bbox_head.get_target(sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn) bbox_targets_2 = self.bbox_head_2.get_target( sampling_results_2, gt_bboxes_2, gt_labels_2, self.train_cfg.rcnn) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, *bbox_targets) loss_bbox_2 = self.bbox_head_2.loss(cls_score_2, bbox_pred_2, *bbox_targets_2, img_meta_2=img_meta_2) losses.update(loss_bbox) losses.update(loss_bbox_2) # mask head forward and loss if self.with_mask: # # implementation #1 # # only utilize one mask head for higher resolution feature maps # pos_rois = bbox2roi3D( # [res.pos_bboxes for res in sampling_results_2]) # mask_feats = self.mask_roi_extractor( # x_2[:self.mask_roi_extractor.num_inputs], pos_rois) # mask_pred = self.mask_head(mask_feats) # mask_targets = self.mask_head.get_target( # sampling_results_2, gt_masks, self.train_cfg.rcnn) # pos_labels = torch.cat( # [res.pos_gt_labels for res in sampling_results_2]) # loss_mask = self.mask_head.loss(mask_pred, mask_targets, # pos_labels) # losses.update(loss_mask) # implementation #2 # lower resolution mask head pos_rois = bbox2roi3D([res.pos_bboxes for res in sampling_results]) mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], pos_rois) mask_pred = self.mask_head(mask_feats) mask_targets = self.mask_head.get_target(sampling_results, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_pred, mask_targets, pos_labels) losses.update(loss_mask) # higher resolution mask head pos_rois = bbox2roi3D( [res.pos_bboxes for res in sampling_results_2]) mask_feats = self.mask_roi_extractor_2( x_2[:self.mask_roi_extractor_2.num_inputs], pos_rois) mask_pred = self.mask_head_2(mask_feats) mask_targets = self.mask_head_2.get_target(sampling_results_2, gt_masks_2, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results_2]) loss_mask_2 = self.mask_head_2.loss(mask_pred, mask_targets, pos_labels, img_meta_2=img_meta_2) losses.update(loss_mask_2) # self.save_losses_plt(losses) #debug only... self.iteration += 1 return losses