Esempio n. 1
0
    def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
        # stride: 64,32,16,8,4 -> 4, 8, 16, 32
        fpn_fms = fpn_fms[1:][::-1]
        stride = [4, 8, 16, 32]
        pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7),
                                   "ROIAlignV2")

        flatten_feature = torch.flatten(pool_features, start_dim=1)

        flatten_feature = F.relu_(self.fc1(flatten_feature))
        flatten_feature = F.relu_(self.fc2(flatten_feature))
        pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature)
        pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature)
        pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature)
        pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature)

        if self.training:
            loss0 = emd_loss_softmax(pred_emd_delta_0, pred_emd_cls_0,
                                     pred_emd_delta_1, pred_emd_cls_1,
                                     bbox_targets, labels)
            loss1 = emd_loss_softmax(pred_emd_delta_1, pred_emd_cls_1,
                                     pred_emd_delta_0, pred_emd_cls_0,
                                     bbox_targets, labels)
            loss = torch.cat([loss0, loss1], axis=1)
            # requires_grad = False
            _, min_indices = loss.min(axis=1)
            loss_emd = loss[torch.arange(loss.shape[0]), min_indices]
            loss_emd = loss_emd.mean()
            loss_dict = {}
            loss_dict['loss_rcnn_emd'] = loss_emd
            return loss_dict
        else:
            class_num = pred_emd_cls_0.shape[-1] - 1
            tag = torch.arange(class_num).type_as(pred_emd_cls_0) + 1
            tag = tag.repeat(pred_emd_cls_0.shape[0], 1).reshape(-1, 1)
            pred_scores_0 = F.softmax(pred_emd_cls_0,
                                      dim=-1)[:, 1:].reshape(-1, 1)
            pred_scores_1 = F.softmax(pred_emd_cls_1,
                                      dim=-1)[:, 1:].reshape(-1, 1)
            pred_delta_0 = pred_emd_delta_0[:, 4:].reshape(-1, 4)
            pred_delta_1 = pred_emd_delta_1[:, 4:].reshape(-1, 4)
            base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4)
            pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True)
            pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True)

            pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1)

            pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1)

            pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1)
            return pred_bbox
Esempio n. 2
0
 def forward(self, fpn_fms, proposals, labels=None, bbox_targets=None):
     # input p2-p5
     fpn_fms = fpn_fms[1:][::-1]
     stride = [4, 8, 16, 32]
     pool_features = roi_pooler(fpn_fms, proposals, stride, (7, 7),
                                "ROIAlignV2")
     flatten_feature = torch.flatten(pool_features, start_dim=1)
     flatten_feature = F.relu_(self.fc1(flatten_feature))
     flatten_feature = F.relu_(self.fc2(flatten_feature))
     pred_cls = self.pred_cls(flatten_feature)
     pred_delta = self.pred_delta(flatten_feature)
     if self.training:
         labels = labels.long().flatten()
         fg_masks = labels > 0
         valid_masks = labels >= 0
         # loss for regression
         localization_loss = smooth_l1_loss(pred_delta[fg_masks],
                                            bbox_targets[fg_masks],
                                            config.rcnn_smooth_l1_beta)
         # loss for classification
         objectness_loss = softmax_loss(pred_cls, labels)
         objectness_loss = objectness_loss * valid_masks
         normalizer = 1.0 / valid_masks.sum().item()
         loss_rcnn_loc = localization_loss.sum() * normalizer
         loss_rcnn_cls = objectness_loss.sum() * normalizer
         loss_dict = {}
         loss_dict[self.stage_name + '_loc'] = loss_rcnn_loc
         loss_dict[self.stage_name + '_cls'] = loss_rcnn_cls
         # proposals
         with torch.no_grad():
             pred_bbox = restore_bbox(proposals[:, 1:5], pred_delta,
                                      True).detach()
             pred_proposals = torch.cat(
                 [proposals[:, 0].reshape(-1, 1), pred_bbox], axis=1)
             #pred_proposals = batch_clip_boxes_opr(pred_proposals, im_info)
         return pred_proposals, loss_dict
     else:
         pred_bbox = restore_bbox(proposals[:, 1:5], pred_delta,
                                  True).detach()
         pred_proposals = torch.cat(
             [proposals[:, 0].reshape(-1, 1), pred_bbox], axis=1)
         pred_scores = F.softmax(pred_cls, dim=-1)[:, 1].reshape(-1, 1)
         return pred_proposals, pred_scores
Esempio n. 3
0
 def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
     # input p2-p5
     fpn_fms = fpn_fms[1:][::-1]
     stride = [4, 8, 16, 32]
     pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7),
                                "ROIAlignV2")
     flatten_feature = torch.flatten(pool_features, start_dim=1)
     flatten_feature = F.relu_(self.fc1(flatten_feature))
     flatten_feature = F.relu_(self.fc2(flatten_feature))
     pred_cls = self.pred_cls(flatten_feature)
     pred_delta = self.pred_delta(flatten_feature)
     if self.training:
         # loss for regression
         labels = labels.long().flatten()
         fg_masks = labels > 0
         valid_masks = labels >= 0
         # multi class
         pred_delta = pred_delta.reshape(-1, config.num_classes, 4)
         fg_gt_classes = labels[fg_masks]
         pred_delta = pred_delta[fg_masks, fg_gt_classes, :]
         localization_loss = smooth_l1_loss(pred_delta,
                                            bbox_targets[fg_masks],
                                            config.rcnn_smooth_l1_beta)
         # loss for classification
         objectness_loss = softmax_loss(pred_cls, labels)
         objectness_loss = objectness_loss * valid_masks
         normalizer = 1.0 / valid_masks.sum().item()
         loss_rcnn_loc = localization_loss.sum() * normalizer
         loss_rcnn_cls = objectness_loss.sum() * normalizer
         loss_dict = {}
         loss_dict['loss_rcnn_loc'] = loss_rcnn_loc
         loss_dict['loss_rcnn_cls'] = loss_rcnn_cls
         return loss_dict
     else:
         class_num = pred_cls.shape[-1] - 1
         tag = torch.arange(class_num).type_as(pred_cls) + 1
         tag = tag.repeat(pred_cls.shape[0], 1).reshape(-1, 1)
         pred_scores = F.softmax(pred_cls, dim=-1)[:, 1:].reshape(-1, 1)
         pred_delta = pred_delta[:, 4:].reshape(-1, 4)
         base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4)
         pred_bbox = restore_bbox(base_rois, pred_delta, True)
         pred_bbox = torch.cat([pred_bbox, pred_scores, tag], axis=1)
         return pred_bbox
Esempio n. 4
0
 def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
     # stride: 64,32,16,8,4 -> 4, 8, 16, 32
     fpn_fms = fpn_fms[1:][::-1]
     stride = [4, 8, 16, 32]
     pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7),
                                "ROIAlignV2")
     flatten_feature = torch.flatten(pool_features, start_axis=1)
     flatten_feature = F.relu_(self.fc1(flatten_feature))
     flatten_feature = F.relu_(self.fc2(flatten_feature))
     pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature)
     pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature)
     pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature)
     pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature)
     pred_emd_scores_0 = F.softmax(pred_emd_cls_0, axis=-1)
     pred_emd_scores_1 = F.softmax(pred_emd_cls_1, axis=-1)
     # cons refine feature
     boxes_feature_0 = cat(
         (pred_emd_delta_0[:, 4:], pred_emd_scores_0[:, 1][:, None]),
         axis=1).repeat(1, 4)
     boxes_feature_1 = cat(
         (pred_emd_delta_1[:, 4:], pred_emd_scores_1[:, 1][:, None]),
         axis=1).repeat(1, 4)
     boxes_feature_0 = cat((flatten_feature, boxes_feature_0), axis=1)
     boxes_feature_1 = cat((flatten_feature, boxes_feature_1), axis=1)
     refine_feature_0 = F.relu_(self.fc3(boxes_feature_0))
     refine_feature_1 = F.relu_(self.fc3(boxes_feature_1))
     # refine
     pred_ref_cls_0 = self.ref_pred_cls_0(refine_feature_0)
     pred_ref_delta_0 = self.ref_pred_delta_0(refine_feature_0)
     pred_ref_cls_1 = self.ref_pred_cls_1(refine_feature_1)
     pred_ref_delta_1 = self.ref_pred_delta_1(refine_feature_1)
     if self.training:
         loss0 = emd_loss_softmax(pred_emd_delta_0, pred_emd_cls_0,
                                  pred_emd_delta_1, pred_emd_cls_1,
                                  bbox_targets, labels)
         loss1 = emd_loss_softmax(pred_emd_delta_1, pred_emd_cls_1,
                                  pred_emd_delta_0, pred_emd_cls_0,
                                  bbox_targets, labels)
         loss2 = emd_loss_softmax(pred_ref_delta_0, pred_ref_cls_0,
                                  pred_ref_delta_1, pred_ref_cls_1,
                                  bbox_targets, labels)
         loss3 = emd_loss_softmax(pred_ref_delta_1, pred_ref_cls_1,
                                  pred_ref_delta_0, pred_ref_cls_0,
                                  bbox_targets, labels)
         loss_rcnn = cat([loss0, loss1], axis=1)
         loss_ref = cat([loss2, loss3], axis=1)
         # requires_grad = False
         _, min_indices_rcnn = loss_rcnn.min(dim=1)
         _, min_indices_ref = loss_ref.min(dim=1)
         loss_rcnn = loss_rcnn[torch.arange(loss_rcnn.shape[0]),
                               min_indices_rcnn]
         loss_rcnn = loss_rcnn.mean()
         loss_ref = loss_ref[torch.arange(loss_ref.shape[0]),
                             min_indices_ref]
         loss_ref = loss_ref.mean()
         loss_dict = {}
         loss_dict['loss_rcnn_emd'] = loss_rcnn
         loss_dict['loss_ref_emd'] = loss_ref
         return loss_dict
     else:
         class_num = pred_ref_cls_0.shape[-1] - 1
         tag = torch.arange(class_num).type_as(pred_ref_cls_0) + 1
         tag = tag.repeat(pred_ref_cls_0.shape[0], 1).reshape(-1, 1)
         pred_scores_0 = F.softmax(pred_ref_cls_0,
                                   axis=-1)[:, 1:].reshape(-1, 1)
         pred_scores_1 = F.softmax(pred_ref_cls_1,
                                   axis=-1)[:, 1:].reshape(-1, 1)
         pred_delta_0 = pred_ref_delta_0[:, 4:].reshape(-1, 4)
         pred_delta_1 = pred_ref_delta_1[:, 4:].reshape(-1, 4)
         base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4)
         pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True)
         pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True)
         pred_bbox_0 = cat([pred_bbox_0, pred_scores_0, tag], axis=1)
         pred_bbox_1 = cat([pred_bbox_1, pred_scores_1, tag], axis=1)
         pred_bbox = cat((pred_bbox_0, pred_bbox_1), axis=1)
         return pred_bbox