def _forward_train(self, image, im_info, gt_boxes): loss_dict = {} # stride: 64,32,16,8,4, p6->p2 fpn_fms = self.backbone(image) rpn_rois, loss_dict_rpn = \ self.RPN(fpn_fms, im_info, gt_boxes) rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(rpn_rois, im_info, gt_boxes, top_k=1) proposals, loss_dict_cascade_0 = self.Cascade_0( fpn_fms, rcnn_rois, rcnn_labels, rcnn_bbox_targets) cascade_rois, cascade_labels, cascade_bbox_targets = cascade_roi_target( proposals, im_info, gt_boxes, pos_threshold=0.6, top_k=1) #proposals, loss_dict_cascade_1 = self.Cascade_1( # fpn_fms, cascade_rois, cascade_labels, cascade_bbox_targets) #cascade_rois, cascade_labels, cascade_bbox_targets = cascade_roi_target( # proposals, im_info, gt_boxes, pos_threshold=0.7, top_k=1) loss_dict_rcnn = self.RCNN(fpn_fms, cascade_rois, cascade_labels, cascade_bbox_targets) loss_dict.update(loss_dict_rpn) loss_dict.update(loss_dict_cascade_0) #loss_dict.update(loss_dict_cascade_1) loss_dict.update(loss_dict_rcnn) return loss_dict
def _forward_train(self, image, im_info, gt_boxes): loss_dict = {} # stride: 64,32,16,8,4, p6->p2 fpn_fms = self.backbone(image) rpn_rois, loss_dict_rpn = \ self.RPN(fpn_fms, im_info, gt_boxes) rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target( rpn_rois, im_info, gt_boxes, top_k=1) loss_dict_rcnn = self.RCNN( fpn_fms, rcnn_rois, rcnn_labels, rcnn_bbox_targets) loss_dict.update(loss_dict_rpn) loss_dict.update(loss_dict_rcnn) return loss_dict
def _forward_train(self, image, im_info, gt_boxes): loss_dict = {} fpn_fms = self.FPN(image) ## stride: 64,32,16,8,4 rpn_rois, rpn_rois_inds, loss_dict_rpn = \ self.RPN(fpn_fms, im_info, gt_boxes) with torch.no_grad(): rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target( rpn_rois, rpn_rois_inds, im_info, gt_boxes, top_k=2) loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois, rcnn_labels, rcnn_bbox_targets) loss_dict.update(loss_dict_rpn) loss_dict.update(loss_dict_rcnn) return loss_dict
def forward(self, fpn_fms, rois, gtboxes=None, im_info = None): rpn_fms = fpn_fms[1:] rpn_fms.reverse() rcnn_rois = rois stride = [4, 8, 16, 32] if self.training: rcnn_rois, labels, bbox_targets = fpn_roi_target(rois, im_info, gtboxes, self.iou_thresh, top_k=self.nheads) pool5, rcnn_rois, labels, bbox_targets = roi_pooler( rpn_fms, rcnn_rois, stride, (7, 7), 'roi_align', \ labels, bbox_targets) else: pool5, rcnn_rois, _, _ = roi_pooler(rpn_fms, rcnn_rois, stride, (7, 7), 'roi_align') pool5 = F.flatten(pool5, start_axis=1) fc1 = self.relu(self.fc1(pool5)) fc2 = self.relu(self.fc2(fc1)) prob = self.p(fc2) if self.refinement: final_pred = self.refinement_module(prob, fc2) loss = {} if self.training: # compute the loss function and then return bbox_targets = bbox_targets.reshape(-1, 4) if self.nheads > 1 else bbox_targets labels = labels.reshape(-1) loss = self.compute_regular_loss(prob, bbox_targets, labels) if self.nheads < 2 else \ self.compute_gemini_loss_opr(prob, bbox_targets, labels) pred_bboxes = self.recover_pred_boxes(rcnn_rois, prob, self.nheads) if self.refinement: auxi_loss = self.compute_gemini_loss_opr(final_pred, bbox_targets, labels) pred_boxes = self.recover_pred_boxes(rcnn_rois, final_pred, self.nheads) loss.update(auxi_loss) return loss, pred_bboxes else: # return the detection boxes and their scores pred_boxes = self.recover_pred_boxes(rcnn_rois, prob, self.nheads) if self.refinement: pred_boxes = self.recover_pred_boxes(rcnn_rois, final_pred, self.nheads) return pred_boxes
def _forward_train(self, image, im_info, gt_boxes, isEval = False, extra = {}): loss_dict = {} fpn_fms = self.FPN(image) ## stride: 64,32,16,8,4 rpn_rois, rpn_rois_inds, loss_dict_rpn = \ self.RPN(fpn_fms, im_info, gt_boxes , isEval) if self.args.flip_JSD_0g: with torch.no_grad(): flip_fms = [x.flip(-1) for x in self.FPN(image.flip(-1))] if self.args.flip_JSD: if self.args.flip_aug: flip_fms = [x.flip(-1) for x in self.FPN(extra['aug_data'].flip(-1))] else: flip_fms = [x.flip(-1) for x in self.FPN(image.flip(-1))] if self.args.recursive: loss_dict_rcnn = self.RCNN( fpn_fms, rpn_rois, rpn_rois_inds, im_info, gt_boxes = gt_boxes, isEval = isEval, flip_fms = None if not self.args.flip_JSD else flip_fms, extra=extra) else: with torch.no_grad(): rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target( rpn_rois, rpn_rois_inds, im_info, gt_boxes, top_k=2) loss_dict_rcnn = self.RCNN( fpn_fms, rcnn_rois, rcnn_labels, rcnn_bbox_targets, isEval, flip_fms = None if not self.args.flip_JSD else flip_fms, extra=extra) loss_dict.update(loss_dict_rpn) loss_dict.update(loss_dict_rcnn) return loss_dict
def forward(self, fpn_fms,rpn_rois, rpn_rois_inds = None, im_info = None, gt_boxes=None, isEval = False, flip_fms = None, extra = {}): if self.training or isEval: with torch.no_grad(): rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target( rpn_rois, rpn_rois_inds, im_info, gt_boxes, top_k=2) else: rcnn_rois = rpn_rois pred_ref_pred_cls, pred_ref_pred_delta, pred_cls_unrefined, pred_delta_unrefined, \ pool_features = self._recursive_forward(fpn_fms, rcnn_rois, keep_pool_feature = True) if self.training or isEval: #loss_rcnn = emd_loss_multi(pred_delta_unrefined, pred_cls_unrefined,rcnn_bbox_targets,rcnn_labels,top_k=2) #loss_ref = emd_loss_multi(pred_ref_pred_delta, pred_ref_pred_cls,rcnn_bbox_targets,rcnn_labels,top_k=2) loss0 = emd_loss( pred_delta_unrefined[0], pred_cls_unrefined[0], pred_delta_unrefined[1], pred_cls_unrefined[1], rcnn_bbox_targets, rcnn_labels) loss1 = emd_loss( pred_delta_unrefined[1], pred_cls_unrefined[1], pred_delta_unrefined[0], pred_cls_unrefined[0], rcnn_bbox_targets, rcnn_labels) loss2 = emd_loss( pred_ref_pred_delta[0], pred_ref_pred_cls[0], pred_ref_pred_delta[1], pred_ref_pred_cls[1], rcnn_bbox_targets, rcnn_labels) loss3 = emd_loss( pred_ref_pred_delta[1], pred_ref_pred_cls[1], pred_ref_pred_delta[0], pred_ref_pred_cls[0], rcnn_bbox_targets, rcnn_labels) loss_rcnn = torch.cat([loss0, loss1], axis=1) loss_ref = torch.cat([loss2, loss3], axis=1) loss_rcnn = torch.cat([loss0, loss1], axis=1) loss_ref = torch.cat([loss2, loss3], axis=1) with torch.no_grad(): _, min_indices_rcnn = loss_rcnn.min(axis=1) _, min_indices_ref = loss_ref.min(axis=1) loss_rcnn = loss_rcnn[torch.arange(loss_rcnn.shape[0]), min_indices_rcnn] loss_rcnn = loss_rcnn.sum()/loss_rcnn.shape[0] loss_ref = loss_ref[torch.arange(loss_ref.shape[0]), min_indices_ref] loss_ref = loss_ref.sum()/loss_ref.shape[0] loss_dict = {} loss_dict['loss_rcnn_emd'] = loss_rcnn loss_dict['loss_ref_emd'] = loss_ref if self.args.flip_JSD: if self.args.flip_JSD_0g: with torch.no_grad(): f_pred_ref_pred_cls, _, pred_cls_unrefined, _ = self._recursive_forward(flip_fms, rcnn_rois) else: f_pred_ref_pred_cls, _ = self._recursive_forward(flip_fms, rcnn_rois) loss_flip_JSD = _flip_loss_JSD(F.softmax(pred_ref_pred_cls[0], dim=-1),F.softmax(f_pred_ref_pred_cls[0], dim=-1)) loss_flip_JSD += _flip_loss_JSD(F.softmax(pred_ref_pred_cls[1], dim=-1),F.softmax(f_pred_ref_pred_cls[1], dim=-1)) loss_dict['loss_flip_JSD'] = loss_flip_JSD return loss_dict else: pred_bboxes = None for p_cls,p_delta in zip(pred_ref_pred_cls,pred_ref_pred_delta): pred_ref_scores = F.softmax(p_cls, dim=-1) pred_bbox = restore_bbox(rcnn_rois[:, 1:5], p_delta, True) if pred_bboxes is None: pred_bboxes = torch.cat([pred_bbox, pred_ref_scores[:, 1].reshape(-1,1)], dim=1) else: pred_bbox = torch.cat([pred_bbox, pred_ref_scores[:, 1].reshape(-1,1)], dim=1) pred_bboxes = torch.cat([pred_bboxes, pred_bbox], dim=0) #pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), dim=1).reshape(-1,5) return pred_bboxes