def _forward_train(self, image, im_info, gt_boxes):
     loss_dict = {}
     # stride: 64,32,16,8,4, p6->p2
     fpn_fms = self.backbone(image)
     rpn_rois, loss_dict_rpn = \
         self.RPN(fpn_fms, im_info, gt_boxes)
     rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(rpn_rois,
                                                                im_info,
                                                                gt_boxes,
                                                                top_k=1)
     proposals, loss_dict_cascade_0 = self.Cascade_0(
         fpn_fms, rcnn_rois, rcnn_labels, rcnn_bbox_targets)
     cascade_rois, cascade_labels, cascade_bbox_targets = cascade_roi_target(
         proposals, im_info, gt_boxes, pos_threshold=0.6, top_k=1)
     #proposals, loss_dict_cascade_1 = self.Cascade_1(
     #    fpn_fms, cascade_rois, cascade_labels, cascade_bbox_targets)
     #cascade_rois, cascade_labels, cascade_bbox_targets = cascade_roi_target(
     #    proposals, im_info, gt_boxes, pos_threshold=0.7, top_k=1)
     loss_dict_rcnn = self.RCNN(fpn_fms, cascade_rois, cascade_labels,
                                cascade_bbox_targets)
     loss_dict.update(loss_dict_rpn)
     loss_dict.update(loss_dict_cascade_0)
     #loss_dict.update(loss_dict_cascade_1)
     loss_dict.update(loss_dict_rcnn)
     return loss_dict
Exemple #2
0
 def _forward_train(self, image, im_info, gt_boxes):
     loss_dict = {}
     # stride: 64,32,16,8,4, p6->p2
     fpn_fms = self.backbone(image)
     rpn_rois, loss_dict_rpn = \
         self.RPN(fpn_fms, im_info, gt_boxes)
     rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(
         rpn_rois, im_info, gt_boxes, top_k=1)
     loss_dict_rcnn = self.RCNN(
             fpn_fms, rcnn_rois, rcnn_labels, rcnn_bbox_targets)
     loss_dict.update(loss_dict_rpn)
     loss_dict.update(loss_dict_rcnn)
     return loss_dict
Exemple #3
0
 def _forward_train(self, image, im_info, gt_boxes):
     loss_dict = {}
     fpn_fms = self.FPN(image)
     ## stride: 64,32,16,8,4
     rpn_rois, rpn_rois_inds, loss_dict_rpn = \
         self.RPN(fpn_fms, im_info, gt_boxes)
     with torch.no_grad():
         rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(
             rpn_rois, rpn_rois_inds, im_info, gt_boxes, top_k=2)
     loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois, rcnn_labels,
                                rcnn_bbox_targets)
     loss_dict.update(loss_dict_rpn)
     loss_dict.update(loss_dict_rcnn)
     return loss_dict
Exemple #4
0
    def forward(self, fpn_fms, rois, gtboxes=None, im_info = None):

        rpn_fms = fpn_fms[1:]
        rpn_fms.reverse()
        rcnn_rois = rois
        stride = [4, 8, 16, 32]
        if self.training:
            rcnn_rois, labels, bbox_targets = fpn_roi_target(rois, im_info, gtboxes, 
                self.iou_thresh, top_k=self.nheads)

            pool5, rcnn_rois, labels, bbox_targets = roi_pooler(
                rpn_fms, rcnn_rois, stride, (7, 7), 'roi_align',  \
                labels, bbox_targets)
        else:
            pool5, rcnn_rois, _, _ = roi_pooler(rpn_fms, rcnn_rois, stride, (7, 7),
                'roi_align')
        
        pool5 = F.flatten(pool5, start_axis=1)
        fc1 = self.relu(self.fc1(pool5))
        fc2 = self.relu(self.fc2(fc1))
        prob = self.p(fc2)

        if self.refinement:
            final_pred = self.refinement_module(prob, fc2)
            
        loss = {}
        if self.training:
            
            # compute the loss function and then return 
            bbox_targets = bbox_targets.reshape(-1, 4) if self.nheads > 1 else bbox_targets
            labels = labels.reshape(-1)
            loss = self.compute_regular_loss(prob, bbox_targets, labels) if self.nheads < 2 else \
                self.compute_gemini_loss_opr(prob, bbox_targets, labels)
            pred_bboxes = self.recover_pred_boxes(rcnn_rois, prob, self.nheads)

            if self.refinement:
                auxi_loss = self.compute_gemini_loss_opr(final_pred, bbox_targets, labels)
                pred_boxes = self.recover_pred_boxes(rcnn_rois, final_pred, self.nheads)
                loss.update(auxi_loss)

            return loss, pred_bboxes
        
        else:

            # return the detection boxes and their scores
            pred_boxes = self.recover_pred_boxes(rcnn_rois, prob, self.nheads)
            if self.refinement:
                pred_boxes = self.recover_pred_boxes(rcnn_rois, final_pred, self.nheads)
            
            return pred_boxes
Exemple #5
0
    def _forward_train(self, image, im_info, gt_boxes,  isEval = False, extra = {}):
        loss_dict = {}
        fpn_fms = self.FPN(image)
        

        ## stride: 64,32,16,8,4
        rpn_rois, rpn_rois_inds, loss_dict_rpn = \
            self.RPN(fpn_fms, im_info, gt_boxes , isEval)

            

        if self.args.flip_JSD_0g:
             with torch.no_grad():
                 flip_fms = [x.flip(-1) for x in self.FPN(image.flip(-1))]
        if self.args.flip_JSD:
            if self.args.flip_aug:
                flip_fms = [x.flip(-1) for x in self.FPN(extra['aug_data'].flip(-1))]
            else:
                flip_fms = [x.flip(-1) for x in self.FPN(image.flip(-1))]    
        if self.args.recursive:
            loss_dict_rcnn = self.RCNN(
                    fpn_fms,
                    rpn_rois, rpn_rois_inds, im_info,
                    gt_boxes = gt_boxes,
                    isEval = isEval, 
                    flip_fms = None if not self.args.flip_JSD else flip_fms, 
                    extra=extra)
        else:
            with torch.no_grad():
                rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(
                    rpn_rois, rpn_rois_inds, im_info, gt_boxes, top_k=2)

            loss_dict_rcnn = self.RCNN(
                    fpn_fms, rcnn_rois, rcnn_labels, rcnn_bbox_targets,
                    isEval, 
                    flip_fms = None if not self.args.flip_JSD else flip_fms, 
                    extra=extra)
        

        loss_dict.update(loss_dict_rpn)
        loss_dict.update(loss_dict_rcnn)
        return loss_dict
Exemple #6
0
    def forward(self, fpn_fms,rpn_rois, rpn_rois_inds = None, im_info = None, gt_boxes=None, isEval = False, flip_fms = None, extra = {}):
        
        if self.training or isEval:
            with torch.no_grad():
                rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(
                    rpn_rois, rpn_rois_inds, im_info, gt_boxes, top_k=2)
        else:
            rcnn_rois = rpn_rois

        pred_ref_pred_cls, pred_ref_pred_delta, pred_cls_unrefined, pred_delta_unrefined, \
        pool_features = self._recursive_forward(fpn_fms, rcnn_rois, keep_pool_feature = True)
        

        if self.training or isEval:

            #loss_rcnn = emd_loss_multi(pred_delta_unrefined, pred_cls_unrefined,rcnn_bbox_targets,rcnn_labels,top_k=2)
            #loss_ref = emd_loss_multi(pred_ref_pred_delta, pred_ref_pred_cls,rcnn_bbox_targets,rcnn_labels,top_k=2)

            loss0 = emd_loss(
                        pred_delta_unrefined[0], pred_cls_unrefined[0],
                        pred_delta_unrefined[1], pred_cls_unrefined[1],
                        rcnn_bbox_targets, rcnn_labels)
            loss1 = emd_loss(
                        pred_delta_unrefined[1], pred_cls_unrefined[1],
                        pred_delta_unrefined[0], pred_cls_unrefined[0],
                        rcnn_bbox_targets, rcnn_labels)

            loss2 = emd_loss(
                        pred_ref_pred_delta[0], pred_ref_pred_cls[0],
                        pred_ref_pred_delta[1], pred_ref_pred_cls[1],
                        rcnn_bbox_targets, rcnn_labels)
            loss3 = emd_loss(
                        pred_ref_pred_delta[1], pred_ref_pred_cls[1],
                        pred_ref_pred_delta[0], pred_ref_pred_cls[0],
                        rcnn_bbox_targets, rcnn_labels)
                        
            loss_rcnn = torch.cat([loss0, loss1], axis=1)
            loss_ref = torch.cat([loss2, loss3], axis=1)
            loss_rcnn = torch.cat([loss0, loss1], axis=1)
            loss_ref = torch.cat([loss2, loss3], axis=1)


            with torch.no_grad():
                _, min_indices_rcnn = loss_rcnn.min(axis=1)
                _, min_indices_ref = loss_ref.min(axis=1)
            loss_rcnn = loss_rcnn[torch.arange(loss_rcnn.shape[0]), min_indices_rcnn]
            loss_rcnn = loss_rcnn.sum()/loss_rcnn.shape[0]
            loss_ref = loss_ref[torch.arange(loss_ref.shape[0]), min_indices_ref]
            loss_ref = loss_ref.sum()/loss_ref.shape[0]


            loss_dict = {}
            loss_dict['loss_rcnn_emd'] = loss_rcnn
            loss_dict['loss_ref_emd'] = loss_ref

            if self.args.flip_JSD:
                if self.args.flip_JSD_0g:
                    with torch.no_grad():
                        f_pred_ref_pred_cls, _, pred_cls_unrefined, _ = self._recursive_forward(flip_fms, rcnn_rois)
                else:
                    f_pred_ref_pred_cls, _ = self._recursive_forward(flip_fms, rcnn_rois)
                loss_flip_JSD = _flip_loss_JSD(F.softmax(pred_ref_pred_cls[0], dim=-1),F.softmax(f_pred_ref_pred_cls[0], dim=-1))
                loss_flip_JSD += _flip_loss_JSD(F.softmax(pred_ref_pred_cls[1], dim=-1),F.softmax(f_pred_ref_pred_cls[1], dim=-1))
                loss_dict['loss_flip_JSD'] = loss_flip_JSD
                
               
            return loss_dict
        else:
            pred_bboxes = None
            for p_cls,p_delta in zip(pred_ref_pred_cls,pred_ref_pred_delta):
                pred_ref_scores = F.softmax(p_cls, dim=-1)
                pred_bbox = restore_bbox(rcnn_rois[:, 1:5], p_delta, True)
                if pred_bboxes is None:
                    pred_bboxes = torch.cat([pred_bbox, pred_ref_scores[:, 1].reshape(-1,1)], dim=1)
                else:
                    pred_bbox = torch.cat([pred_bbox, pred_ref_scores[:, 1].reshape(-1,1)], dim=1)
                    
                    pred_bboxes = torch.cat([pred_bboxes, pred_bbox], dim=0)
            
            #pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), dim=1).reshape(-1,5)
            return pred_bboxes