def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None): # stride: 64,32,16,8,4 -> 4, 8, 16, 32 fpn_fms = fpn_fms[1:][::-1] stride = [4, 8, 16, 32] pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2") flatten_feature = torch.flatten(pool_features, start_dim=1) flatten_feature = F.relu_(self.fc1(flatten_feature)) flatten_feature = F.relu_(self.fc2(flatten_feature)) pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature) pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature) pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature) pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature) if self.training: loss0 = emd_loss_softmax(pred_emd_delta_0, pred_emd_cls_0, pred_emd_delta_1, pred_emd_cls_1, bbox_targets, labels) loss1 = emd_loss_softmax(pred_emd_delta_1, pred_emd_cls_1, pred_emd_delta_0, pred_emd_cls_0, bbox_targets, labels) loss = torch.cat([loss0, loss1], axis=1) # requires_grad = False _, min_indices = loss.min(axis=1) loss_emd = loss[torch.arange(loss.shape[0]), min_indices] loss_emd = loss_emd.mean() loss_dict = {} loss_dict['loss_rcnn_emd'] = loss_emd return loss_dict else: class_num = pred_emd_cls_0.shape[-1] - 1 tag = torch.arange(class_num).type_as(pred_emd_cls_0) + 1 tag = tag.repeat(pred_emd_cls_0.shape[0], 1).reshape(-1, 1) pred_scores_0 = F.softmax(pred_emd_cls_0, dim=-1)[:, 1:].reshape(-1, 1) pred_scores_1 = F.softmax(pred_emd_cls_1, dim=-1)[:, 1:].reshape(-1, 1) pred_delta_0 = pred_emd_delta_0[:, 4:].reshape(-1, 4) pred_delta_1 = pred_emd_delta_1[:, 4:].reshape(-1, 4) base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4) pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True) pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True) pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1) pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1) pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1) return pred_bbox
def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None): # stride: 64,32,16,8,4 -> 4, 8, 16, 32 fpn_fms = fpn_fms[1:][::-1] stride = [4, 8, 16, 32] pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2") flatten_feature = torch.flatten(pool_features, start_axis=1) flatten_feature = F.relu_(self.fc1(flatten_feature)) flatten_feature = F.relu_(self.fc2(flatten_feature)) pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature) pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature) pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature) pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature) pred_emd_scores_0 = F.softmax(pred_emd_cls_0, axis=-1) pred_emd_scores_1 = F.softmax(pred_emd_cls_1, axis=-1) # cons refine feature boxes_feature_0 = cat( (pred_emd_delta_0[:, 4:], pred_emd_scores_0[:, 1][:, None]), axis=1).repeat(1, 4) boxes_feature_1 = cat( (pred_emd_delta_1[:, 4:], pred_emd_scores_1[:, 1][:, None]), axis=1).repeat(1, 4) boxes_feature_0 = cat((flatten_feature, boxes_feature_0), axis=1) boxes_feature_1 = cat((flatten_feature, boxes_feature_1), axis=1) refine_feature_0 = F.relu_(self.fc3(boxes_feature_0)) refine_feature_1 = F.relu_(self.fc3(boxes_feature_1)) # refine pred_ref_cls_0 = self.ref_pred_cls_0(refine_feature_0) pred_ref_delta_0 = self.ref_pred_delta_0(refine_feature_0) pred_ref_cls_1 = self.ref_pred_cls_1(refine_feature_1) pred_ref_delta_1 = self.ref_pred_delta_1(refine_feature_1) if self.training: loss0 = emd_loss_softmax(pred_emd_delta_0, pred_emd_cls_0, pred_emd_delta_1, pred_emd_cls_1, bbox_targets, labels) loss1 = emd_loss_softmax(pred_emd_delta_1, pred_emd_cls_1, pred_emd_delta_0, pred_emd_cls_0, bbox_targets, labels) loss2 = emd_loss_softmax(pred_ref_delta_0, pred_ref_cls_0, pred_ref_delta_1, pred_ref_cls_1, bbox_targets, labels) loss3 = emd_loss_softmax(pred_ref_delta_1, pred_ref_cls_1, pred_ref_delta_0, pred_ref_cls_0, bbox_targets, labels) loss_rcnn = cat([loss0, loss1], axis=1) loss_ref = cat([loss2, loss3], axis=1) # requires_grad = False _, min_indices_rcnn = loss_rcnn.min(dim=1) _, min_indices_ref = loss_ref.min(dim=1) loss_rcnn = loss_rcnn[torch.arange(loss_rcnn.shape[0]), min_indices_rcnn] loss_rcnn = loss_rcnn.mean() loss_ref = loss_ref[torch.arange(loss_ref.shape[0]), min_indices_ref] loss_ref = loss_ref.mean() loss_dict = {} loss_dict['loss_rcnn_emd'] = loss_rcnn loss_dict['loss_ref_emd'] = loss_ref return loss_dict else: class_num = pred_ref_cls_0.shape[-1] - 1 tag = torch.arange(class_num).type_as(pred_ref_cls_0) + 1 tag = tag.repeat(pred_ref_cls_0.shape[0], 1).reshape(-1, 1) pred_scores_0 = F.softmax(pred_ref_cls_0, axis=-1)[:, 1:].reshape(-1, 1) pred_scores_1 = F.softmax(pred_ref_cls_1, axis=-1)[:, 1:].reshape(-1, 1) pred_delta_0 = pred_ref_delta_0[:, 4:].reshape(-1, 4) pred_delta_1 = pred_ref_delta_1[:, 4:].reshape(-1, 4) base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4) pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True) pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True) pred_bbox_0 = cat([pred_bbox_0, pred_scores_0, tag], axis=1) pred_bbox_1 = cat([pred_bbox_1, pred_scores_1, tag], axis=1) pred_bbox = cat((pred_bbox_0, pred_bbox_1), axis=1) return pred_bbox