def compute_bbox_per_image(self, flatten_bbox_targets_reshape, flatten_labels_targets_reshape, flatten_bbox_preds_reshape, flatten_points_reshape, flatten_conv_reshape, bg_class_ind): #select bbox per image level based on labels, and decode distance bbox bbox_targets_moc = [] labels_targets_moc = [] bbox_preds_moc = [] conv_moc = [] for bbox_targets, labels, bbox_preds, points, conv in zip( flatten_bbox_targets_reshape, flatten_labels_targets_reshape, flatten_bbox_preds_reshape, flatten_points_reshape, flatten_conv_reshape): pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().reshape(-1) #print(pos_inds) pos_bbox_preds = bbox_preds[pos_inds] pos_bbox_targets = bbox_targets[pos_inds] pos_points = points[pos_inds] pos_conv = conv[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) bbox_targets_moc.append(pos_decoded_target_preds) bbox_preds_moc.append(pos_decoded_bbox_preds) conv_moc.append(pos_conv) #print(pos_decoded_target_preds[:5,3]-pos_decoded_target_preds[:5,1],pos_decoded_target_preds[:5]) return bbox_preds_moc, bbox_targets_moc, conv_moc
def loss(self, gt_bboxes, cls_scores, bbox_preds, bbox_iou, labels, label_weight, bbox_targets, bbox_weights, points, reduction_override=None): assert len(cls_scores) == len(bbox_preds) num_imgs = cls_scores.size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = cls_scores.permute(0, 2, 3, 1).reshape( -1, self.cls_out_channels) flatten_bbox_preds = bbox_preds.permute(0, 2, 3, 1).reshape(-1, 4) flatten_bbox_iou = bbox_iou.permute(0, 2, 3, 1).reshape(-1, 1) flatten_labels = labels.reshape(-1) flatten_bbox_targets = bbox_targets.reshape(-1, 4) # repeat points to align with bbox_preds flatten_points = points.reshape(-1, 2) pos_inds = flatten_labels.nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_bbox_iou = flatten_bbox_iou[pos_inds] if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] # pos_gt_bbox = gt_bboxes[0][pos_inds] # a = pos_gt_bbox.cpu().detach().numpy() pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_bbox = self.loss_bbox(pos_decoded_bbox_preds, pos_decoded_target_preds, avg_factor=num_pos) bbox_iou_targets = bbox_goverlaps( pos_decoded_bbox_preds, pos_decoded_target_preds, is_aligned=True).clamp(min=1e-6)[:, None] loss_bbox_iou = self.loss_iou(pos_bbox_iou, bbox_iou_targets) else: loss_bbox = pos_bbox_preds.sum() loss_bbox_iou = pos_bbox_iou.sum() return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_iou=loss_bbox_iou)
def get_bboxes_single(self, cls_scores, bbox_preds, centernesses, mlvl_points, img_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] if cfg.stat_2d: stat = cv2.imread('kitti_tools/stat/stat_2d.png').astype( np.float32) stat = cv2.resize(stat, (106, 32)).astype(np.float32) # std = np.std(stat, axis=(0, 1)) # stat /= std w_stat = torch.from_numpy(stat[:, :, 1]).float().cuda().unsqueeze(0) h_stat = torch.from_numpy(stat[:, :, 2]).float().cuda().unsqueeze(0) stat_all = torch.cat([w_stat, h_stat, w_stat, h_stat], dim=0) bbox_preds = [bbox_pred * stat_all for bbox_pred in bbox_preds] # torch.exp(bbox_pred) for cls_score, bbox_pred, centerness, points in zip( cls_scores, bbox_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) mlvl_centerness = torch.cat(mlvl_centerness) det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness) return det_bboxes, det_labels
def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, bbox_targets, stride, num_total_samples, cfg): anchors = anchors.reshape(-1, 4) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1)) bbox_targets = bbox_targets.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) pos_inds = torch.nonzero(labels).squeeze(1) score = label_weights.new_zeros(labels.shape) if len(pos_inds) > 0: pos_bbox_targets = bbox_targets[pos_inds] pos_bbox_pred = bbox_pred[pos_inds] # (n, 4 * (reg_max + 1)) pos_anchors = anchors[pos_inds] norm_anchor_center = self.anchor_center(pos_anchors) / stride pos_bbox_pred_distance = self.distribution_project(pos_bbox_pred) pos_decode_bbox_pred = distance2bbox(norm_anchor_center, pos_bbox_pred_distance) pos_decode_bbox_targets = pos_bbox_targets / stride target_ltrb = bbox2distance(norm_anchor_center, pos_decode_bbox_targets, self.reg_max).reshape(-1) score[pos_inds] = self.iou_target(pos_decode_bbox_pred.detach(), pos_decode_bbox_targets) weight_targets = \ cls_score.detach().sigmoid().max(dim=1)[0][pos_inds] # regression loss loss_bbox = self.loss_bbox( pos_decode_bbox_pred, pos_decode_bbox_targets, weight=weight_targets, avg_factor=1.0) pred_ltrb = pos_bbox_pred.reshape(-1, self.reg_max + 1) # dfl loss TODO loss_dfl = self.loss_dfl( pred_ltrb, target_ltrb, weight=weight_targets[:, None].expand(-1, 4).reshape(-1), avg_factor=4.0) else: loss_bbox = bbox_pred.sum() * 0 loss_dfl = bbox_pred.sum() * 0 weight_targets = torch.tensor(0).cuda() # qfl loss TODO loss_qfl = self.loss_qfl(cls_score, labels, score, avg_factor=num_total_samples) return loss_qfl, loss_bbox, loss_dfl, weight_targets.sum()
def get_bboxes_single(self, cls_scores, bbox_preds, coef_preds, centernesses, mlvl_points, img_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] mlvl_coefs = [] mlvl_centerness = [] for cls_score, bbox_pred, coef_pred, centerness, points in zip( cls_scores, bbox_preds, coef_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() coef_pred = coef_pred.permute(1, 2, 0).reshape(-1, self.num_bases) bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] coef_pred = coef_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) coefs = coef_pred mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_coefs.append(coefs) mlvl_bboxes = torch.cat(mlvl_bboxes) mlvl_coefs = torch.cat(mlvl_coefs) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) mlvl_centerness = torch.cat(mlvl_centerness) det_bboxes, det_labels, det_coefs = multiclass_nms_with_mask( mlvl_bboxes, mlvl_scores, mlvl_coefs, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness, num_bases=self.num_bases) return det_bboxes, det_labels, det_coefs
def get_bboxes_single( self, cls_scores, bbox_preds, centernesses, mlvl_points, # fpn特征上面每一个点对应于原图中的位置 img_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) # fpn层数 mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, centerness, points in zip( cls_scores, bbox_preds, centernesses, mlvl_points): # 分层处理 assert cls_score.size()[-2:] == bbox_pred.size()[-2:] #空间大小要一致 scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: # 挑选样本进行NMS max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_bboxes = torch.cat(mlvl_bboxes) # [num_points, 4] if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) # [num_points, 80] padding = mlvl_scores.new_zeros( mlvl_scores.shape[0], 1) # [num_points, 1] 因为一般都是有一个背景类别 但是fcos实际上没有背景类别的分数 mlvl_scores = torch.cat( [padding, mlvl_scores], dim=1 ) # [num_points, 5] shape (n, #class), where the 0th column contains scores of the background class, but this will be ignored in the NMS. mlvl_centerness = torch.cat(mlvl_centerness) #[num_points] det_bboxes, det_labels = multiclass_nms( # NMS mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness) return det_bboxes, det_labels
def get_bboxes_single(self, cls_scores, bbox_preds, centernesses, mlvl_points, img_shape, scale_factor, cfg, rescale=False): # TODO: change output to proposals without labels assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) # mlvl_bboxes = [] # mlvl_scores = [] # mlvl_centerness = [] mlvl_proposals = [] for cls_score, bbox_pred, centerness, points in zip( cls_scores, bbox_preds, centernesses, mlvl_points): # iteration by levels assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] scores = scores.squeeze() # scores *= centerness proposals = distance2bbox(points, bbox_pred, max_shape=img_shape) proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1) proposals, _ = nms(proposals, cfg.nms_thr) proposals = proposals[:cfg.nms_post, :] mlvl_proposals.append(proposals) # mlvl_bboxes.append(bboxes) # mlvl_scores.append(scores) # mlvl_centerness.append(centerness) proposals = torch.cat(mlvl_proposals, 0) if cfg.nms_across_levels: proposals, _ = nms(proposals, cfg.nms_thr) proposals = proposals[:cfg.max_num, :] else: scores = proposals[:, 4] num = min(cfg.max_num, proposals.shape[0]) _, topk_inds = scores.topk(num) proposals = proposals[topk_inds, :] return proposals
def _get_bboxes_single(self, cls_scores, bbox_preds, centernesses, mlvl_points, img_shape, scale_factor, cfg, rescale=False): cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, centerness, points in zip( cls_scores, bbox_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) mlvl_centerness = torch.cat(mlvl_centerness) det_bboxes, det_labels = multiclass_nms( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness) return det_bboxes, det_labels
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors, img_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) mlvl_bboxes = [] mlvl_scores = [] for stride, cls_score, bbox_pred, anchors in zip( self.anchor_strides, cls_scores, bbox_preds, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0) bbox_pred = self.distribution_project(bbox_pred) * stride nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = scores.max(dim=1) _, topk_inds = max_scores.topk(nms_pre) anchors = anchors[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] bboxes = distance2bbox(self.anchor_center(anchors), bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) det_bboxes, det_labels = multiclass_nms( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels
def get_bbox_prob_and_overlap(self, points, bbox_preds, gt_bboxes): bbox_targets = bbox2distance(points, gt_bboxes[:, None, :].repeat( 1, points.shape[1], 1), norm=self.distance_norm) bbox_prob = self.loss_bbox(bbox_preds, bbox_targets, reduction_override='none').neg().exp() pred_boxes = distance2bbox(points, bbox_preds, norm=self.distance_norm) bbox_overlap = bbox_overlaps(gt_bboxes[:, None, :].expand_as(pred_boxes), pred_boxes, is_aligned=True) return bbox_prob, bbox_overlap
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_points, img_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] for cls_score, bbox_pred, points in zip(cls_scores, bbox_preds, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels) scores = cls_score.sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if 0 < nms_pre < scores.shape[0]: max_scores, _ = scores.max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] bboxes = distance2bbox(points, bbox_pred, norm=self.distance_norm, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels
def loss(self, cls_scores, bbox_preds, centernesses, reid_feats, gt_bboxes, gt_labels, gt_ids, img_metas, gt_bboxes_ignore=None): """Compute loss of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. centernesses (list[Tensor]): Centerss for each scale level, each is a 4D-tensor, the channel number is num_points * 1. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len(reid_feats) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, ids, bbox_targets = self.get_targets(all_level_points, gt_bboxes, gt_labels, gt_ids) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_reid = [ reid_feat.permute(0, 2, 3, 1).reshape(-1, self.feat_channels) for reid_feat in reid_feats ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_reid = torch.cat(flatten_reid) #print("flatten reid", flatten_reid.shape) flatten_labels = torch.cat(labels) flatten_ids = torch.cat(ids) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) #pos_inds = nonzero((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls( flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] # background index ''' bg_inds = ((flatten_labels < 0) | (flatten_labels == bg_class_ind)).nonzero().reshape(-1) num_bg = len(bg_inds) bg_cls_scores = flatten_cls_scores[bg_inds] if num_bg > num_pos: cls_ids = torch.argsort(bg_cls_scores.squeeze(), descending=True) bg_inds = bg_inds[cls_ids[:num_pos]] ''' pos_reid = flatten_reid[pos_inds] #bg_reid = flatten_reid[bg_inds] #pos_reid = torch.cat((pos_reid, bg_reid)) pos_reid = F.normalize(pos_reid) if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) pos_reid_ids = flatten_ids[pos_inds] #bg_reid_ids = flatten_ids[bg_inds] #pos_reid_ids = torch.cat((pos_reid_ids, bg_reid_ids)) #loss_oim = self.loss_reid(pos_reid, pos_reid_ids) #print(pos_reid.shape, pos_reid_ids.shape) #print(pos_reid_ids) # reid oim loss labeled_matching_scores = self.labeled_matching_layer(pos_reid, pos_reid_ids) labeled_matching_scores *= 10 unlabeled_matching_scores = self.unlabeled_matching_layer(pos_reid, pos_reid_ids) unlabeled_matching_scores *= 10 matching_scores = torch.cat((labeled_matching_scores, unlabeled_matching_scores), dim=1) pid_labels = pos_reid_ids.clone() pid_labels[pid_labels == -2] = -1 loss_oim = F.cross_entropy(matching_scores, pid_labels, ignore_index=-1) ''' # softmax matching_scores = self.classifier_reid(pos_reid).contiguous() loss_oim = F.cross_entropy(matching_scores, pos_reid_ids, ignore_index=-1) ''' else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() loss_oim = pos_reid.sum() print('no gt box') return dict( loss_cls=loss_cls, loss_bbox=loss_bbox, #loss_centerness=loss_centerness) loss_centerness=loss_centerness, loss_oim=loss_oim)
def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None, batch_idx=0, analysis_scale=1.0): """Compute loss of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. centernesses (list[Tensor]): Centerss for each scale level, each is a 4D-tensor, the channel number is num_points * 1. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes, gt_labels) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] pos_anchor_flags = torch.zeros_like(flatten_centerness) pos_anchor_flags[pos_inds] = 1.0 pos_anchor_flags_list = [] pre_idx = 0 for i, featmap_size in enumerate(featmap_sizes): cur_featmap_size = featmap_size[0] * featmap_size[1] cur_pos_anchor_flags = pos_anchor_flags[pre_idx:pre_idx + cur_featmap_size] cur_pos_anchor_flags = cur_pos_anchor_flags.view( 1, 1, featmap_size[0], featmap_size[1]) save_image( cur_pos_anchor_flags, f"analysis_results_fcos/image_{batch_idx}_feature_{i}_flatten_anchor_flags_scale_{analysis_scale}.png" ) pre_idx = pre_idx + cur_featmap_size if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_bbox = self.loss_bbox(pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness)
def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes, gt_labels) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) pos_inds = flatten_labels.nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = sigmoid_focal_loss( flatten_cls_scores, flatten_labels, cfg.gamma, cfg.alpha, 'none').sum()[None] / (num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness = flatten_centerness[pos_inds] # pos_centerness_targets = self.centerness_target(pos_bbox_targets) if num_pos > 0: pos_centerness_targets = self.centerness_target(pos_bbox_targets) pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_reg = ( (iou_loss(pos_decoded_bbox_preds, pos_decoded_target_preds, reduction='none') * pos_centerness_targets).sum() / pos_centerness_targets.sum())[None] loss_centerness = F.binary_cross_entropy_with_logits( pos_centerness, pos_centerness_targets, reduction='mean')[None] else: loss_reg = pos_bbox_preds.sum()[None] loss_centerness = pos_centerness.sum()[None] return dict(loss_cls=loss_cls, loss_reg=loss_reg, loss_centerness=loss_centerness)
def _get_bboxes(self, cls_scores, bbox_preds, mlvl_anchors, img_shapes, scale_factors, cfg, rescale=False, with_nms=True): """Transform outputs for a single batch item into labeled boxes. Args: cls_scores (list[Tensor]): Box scores for a single scale level has shape (N, num_classes, H, W). bbox_preds (list[Tensor]): Box distribution logits for a single scale level with shape (N, 4*(n+1), H, W), n is max value of integral set. mlvl_anchors (list[Tensor]): Box reference for a single scale level with shape (num_total_anchors, 4). img_shapes (list[tuple[int]]): Shape of the input image, list[(height, width, 3)]. scale_factors (list[ndarray]): Scale factor of the image arange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. with_nms (bool): If True, do nms before return boxes. Default: True. Returns: list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. The first item is an (n, 5) tensor, where 5 represent (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. The shape of the second tensor in the tuple is (n,), and each element represents the class label of the corresponding box. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) batch_size = cls_scores[0].shape[0] mlvl_bboxes = [] mlvl_scores = [] for cls_score, bbox_pred, stride, anchors in zip( cls_scores, bbox_preds, self.anchor_generator.strides, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert stride[0] == stride[1] scores = cls_score.permute(0, 2, 3, 1).reshape( batch_size, -1, self.cls_out_channels).sigmoid() bbox_pred = bbox_pred.permute(0, 2, 3, 1) bbox_pred = self.integral(bbox_pred) * stride[0] bbox_pred = bbox_pred.reshape(batch_size, -1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[1] > nms_pre: max_scores, _ = scores.max(-1) _, topk_inds = max_scores.topk(nms_pre) batch_inds = torch.arange(batch_size).view( -1, 1).expand_as(topk_inds).long() anchors = anchors[topk_inds, :] bbox_pred = bbox_pred[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] else: anchors = anchors.expand_as(bbox_pred) bboxes = distance2bbox(self.anchor_center(anchors), bbox_pred, max_shape=img_shapes) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) if rescale: batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( scale_factors).unsqueeze(1) batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) # Add a dummy background class to the backend when using sigmoid # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class padding = batch_mlvl_scores.new_zeros(batch_size, batch_mlvl_scores.shape[1], 1) batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) if with_nms: det_results = [] for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes, batch_mlvl_scores): det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) det_results.append(tuple([det_bbox, det_label])) else: det_results = [ tuple(mlvl_bs) for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores) ] return det_results
def get_bboxes_single(self, cls_score, bbox_pred, bbox_iou, points, img_shape, cfg, rescale=False, scale_factor=None): # assert len(cls_score) == len(bbox_pred) # mlvl_bboxes = [] # mlvl_scores = [] # mlvl_bboxiou = [] assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels).sigmoid() bbox_iou = bbox_iou.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) points = points[0] nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * bbox_iou[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] bbox_iou = bbox_iou[topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) scores_cpu = np.array(scores.cpu().detach()) # bbox_iou = np.array(bbox_iou.cpu().detach()) # bboxes = np.array(bboxes.cpu().detach()) max_cls = np.unravel_index(scores_cpu.argmax(), scores_cpu.shape)[1] scores_iou = scores[:, max_cls] * bbox_iou bbox_ind = scores_iou.argmax() det_bboxes = bboxes[bbox_ind] det_labels = scores[bbox_ind] * bbox_iou[bbox_ind] # mlvl_bboxes.append(bboxes) # mlvl_scores.append(scores) # mlvl_bboxiou.append(bbox_iou) # mlvl_bboxes = torch.cat(mlvl_bboxes) # if rescale: # mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) # mlvl_scores = torch.cat(mlvl_scores) # 49, 81 # padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) # 49, 1 # mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) # 49, 82 # mlvl_bboxiou = torch.cat(mlvl_bboxiou) # # # # det_bboxes, det_labels = multiclass_nms( # mlvl_bboxes, # mlvl_scores, # cfg.score_thr, # cfg.nms, # cfg.max_per_img, # score_factors=mlvl_bboxiou) return det_bboxes, det_labels
def forward(self, feats): """Forward features from the upstream network. Args: feats (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: Usually a tuple of classification scores and bbox prediction cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, the channels number is num_anchors * num_classes. bbox_preds (list[Tensor]): Decoded box for all scale levels, each is a 4D-tensor, the channels number is num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format. """ cls_scores = [] bbox_preds = [] for idx, (x, scale, stride) in enumerate( zip(feats, self.scales, self.prior_generator.strides)): b, c, h, w = x.shape anchor = self.prior_generator.single_level_grid_priors( (h, w), idx, device=x.device) anchor = torch.cat([anchor for _ in range(b)]) # extract task interactive features inter_feats = [] for inter_conv in self.inter_convs: x = inter_conv(x) inter_feats.append(x) feat = torch.cat(inter_feats, 1) # task decomposition avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) cls_feat = self.cls_decomp(feat, avg_feat) reg_feat = self.reg_decomp(feat, avg_feat) # cls prediction and alignment cls_logits = self.tood_cls(cls_feat) cls_prob = self.cls_prob_module(feat) cls_score = sigmoid_geometric_mean(cls_logits, cls_prob) # reg prediction and alignment if self.anchor_type == 'anchor_free': reg_dist = scale(self.tood_reg(reg_feat).exp()).float() reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) reg_bbox = distance2bbox( self.anchor_center(anchor) / stride[0], reg_dist).reshape(b, h, w, 4).permute(0, 3, 1, 2) # (b, c, h, w) elif self.anchor_type == 'anchor_based': reg_dist = scale(self.tood_reg(reg_feat)).float() reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape( b, h, w, 4).permute(0, 3, 1, 2) / stride[0] else: raise NotImplementedError( f'Unknown anchor type: {self.anchor_type}.' f'Please use `anchor_free` or `anchor_based`.') reg_offset = self.reg_offset_module(feat) bbox_pred = self.deform_sampling(reg_bbox.contiguous(), reg_offset.contiguous()) cls_scores.append(cls_score) bbox_preds.append(bbox_pred) return tuple(cls_scores), tuple(bbox_preds)
def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes, gt_labels) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] if cfg.stat_2d: # from mmdet.apis import get_root_logger # logger = get_root_logger() stat = cv2.imread('kitti_tools/stat/stat_2d.png').astype( np.float32) stat = cv2.resize(stat, (106, 32)).astype(np.float32) # std = np.std(stat, axis=(0, 1)) # stat /= std w_stat = torch.from_numpy(stat[:, :, 1]).float().cuda().unsqueeze(0) h_stat = torch.from_numpy(stat[:, :, 2]).float().cuda().unsqueeze(0) stat_all = torch.cat([w_stat, h_stat, w_stat, h_stat], dim=0) # logger.info('old', bbox_preds[0][0, :, 20, 20]) bbox_preds = [bbox_pred * stat_all for bbox_pred in bbox_preds] # torch.exp(bbox_pred) # logger.info('new', bbox_preds[0][0, :, 20, 20]) flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) # check NaN and Inf assert torch.isfinite(flatten_cls_scores).all().item(), \ 'classification scores become infinite or NaN!' assert torch.isfinite(flatten_bbox_preds).all().item(), \ 'bbox predications become infinite or NaN!' assert torch.isfinite(flatten_centerness).all().item(), \ 'bbox centerness become infinite or NaN!' flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) pos_inds = flatten_labels.nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_bbox = self.loss_bbox(pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness)
def loss(self, cls_scores, bbox_preds, centernesses, cof_preds, feat_masks, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None, gt_masks_list=None): assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points, all_level_strides = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets, label_list, bbox_targets_list, gt_inds = self.fcos_target(all_level_points, gt_bboxes, gt_labels) #decode detection and groundtruth det_bboxes = [] det_targets = [] num_levels = len(bbox_preds) for img_id in range(len(img_metas)): bbox_pred_list = [ bbox_preds[i][img_id].permute(1, 2, 0).reshape(-1, 4).detach() for i in range(num_levels) ] bbox_target_list = bbox_targets_list[img_id] bboxes = [] targets = [] for i in range(len(bbox_pred_list)): bbox_pred = bbox_pred_list[i] bbox_target = bbox_target_list[i] points = all_level_points[i] bboxes.append(distance2bbox(points, bbox_pred)) targets.append(distance2bbox(points, bbox_target)) bboxes = torch.cat(bboxes, dim=0) targets = torch.cat(targets, dim=0) det_bboxes.append(bboxes) det_targets.append(targets) gt_masks = [] for i in range(len(gt_labels)): gt_label = gt_labels[i] gt_masks.append(torch.from_numpy(np.array(gt_masks_list[i][:gt_label.shape[0]], dtype=np.float32)).to(gt_label.device)) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) flatten_strides = torch.cat( [strides.view(-1,1).repeat(num_imgs, 1) for strides in all_level_strides]) pos_inds = flatten_labels.nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls( flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) pos_points = flatten_points[pos_inds] pos_strides = flatten_strides[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds/pos_strides) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets/pos_strides) # centerness weighted iou loss loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() ##########mask loss################# flatten_cls_scores1 = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs,-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_cls_scores1 = torch.cat(flatten_cls_scores1,dim=1) flatten_cof_preds = [ cof_pred.permute(0, 2, 3, 1).reshape(cof_pred.shape[0],-1, 32*4) for cof_pred in cof_preds ] loss_mask = 0 loss_iou = 0 num_iou = 0.1 flatten_cof_preds = torch.cat(flatten_cof_preds,dim=1) for i in range(num_imgs): labels = torch.cat([labels_level.flatten() for labels_level in label_list[i]]) bbox_dt = det_bboxes[i]/2 bbox_dt = bbox_dt.detach() pos_inds = (labels > 0).nonzero().view(-1) cof_pred = flatten_cof_preds[i][pos_inds] img_mask = feat_masks[i] mask_h = img_mask.shape[1] mask_w = img_mask.shape[2] idx_gt = gt_inds[i] bbox_dt = bbox_dt[pos_inds, :4] area = (bbox_dt[:, 2] - bbox_dt[:, 0]) * (bbox_dt[:, 3] - bbox_dt[:, 1]) bbox_dt = bbox_dt[area > 1.0, :] idx_gt = idx_gt[area > 1.0] cof_pred = cof_pred[area > 1.0] if bbox_dt.shape[0] == 0: loss_mask += area.sum()*0 continue bbox_gt = gt_bboxes[i] cls_score = flatten_cls_scores1[i, pos_inds, labels[pos_inds] - 1].sigmoid().detach() cls_score = cls_score[area>1.0] pos_inds = pos_inds[area > 1.0] ious = bbox_overlaps(bbox_gt[idx_gt]/2, bbox_dt, is_aligned=True) with torch.no_grad(): weighting = cls_score * ious weighting = weighting/(torch.sum(weighting)+0.0001)*len(weighting) gt_mask = F.interpolate(gt_masks[i].unsqueeze(0), scale_factor=0.5, mode='bilinear', align_corners=False).squeeze(0) shape = np.minimum(feat_masks[i].shape, gt_mask.shape) gt_mask_new = gt_mask.new_zeros(gt_mask.shape[0], mask_h, mask_w) gt_mask_new[:gt_mask.shape[0], :shape[1], :shape[2]] = gt_mask[:gt_mask.shape[0], :shape[1], :shape[2]] gt_mask_new = gt_mask_new.gt(0.5).float() gt_mask_new = torch.index_select(gt_mask_new,0,idx_gt).permute(1, 2, 0).contiguous() #######spp########################### img_mask1 = img_mask.permute(1,2,0) pos_masks00 = torch.sigmoid(img_mask1 @ cof_pred[:, 0:32].t()) pos_masks01 = torch.sigmoid(img_mask1 @ cof_pred[:, 32:64].t()) pos_masks10 = torch.sigmoid(img_mask1 @ cof_pred[:, 64:96].t()) pos_masks11 = torch.sigmoid(img_mask1 @ cof_pred[:, 96:128].t()) pred_masks = torch.stack([pos_masks00, pos_masks01, pos_masks10, pos_masks11], dim=0) pred_masks = self.crop_cuda(pred_masks, bbox_dt) gt_mask_crop = self.crop_gt_cuda(gt_mask_new, bbox_dt) # pred_masks, gt_mask_crop = crop_split(pos_masks00, pos_masks01, pos_masks10, pos_masks11, bbox_dt, # gt_mask_new) pre_loss = F.binary_cross_entropy(pred_masks, gt_mask_crop, reduction='none') pos_get_csize = center_size(bbox_dt) gt_box_width = pos_get_csize[:, 2] gt_box_height = pos_get_csize[:, 3] pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height / pos_get_csize.shape[0] loss_mask += torch.sum(pre_loss*weighting.detach()) if self.rescoring_flag: pos_labels = labels[pos_inds] - 1 input_iou = pred_masks.detach().unsqueeze(0).permute(3, 0, 1, 2) pred_iou = self.convs_scoring(input_iou) pred_iou = self.relu(self.mask_scoring(pred_iou)) pred_iou = F.max_pool2d(pred_iou, kernel_size=pred_iou.size()[2:]).squeeze(-1).squeeze(-1) pred_iou = pred_iou[range(pred_iou.size(0)), pos_labels] with torch.no_grad(): mask_pred = (pred_masks > 0.4).float() mask_pred_areas = mask_pred.sum((0, 1)) overlap_areas = (mask_pred * gt_mask_new).sum((0, 1)) gt_full_areas = gt_mask_new.sum((0, 1)) iou_targets = overlap_areas / (mask_pred_areas + gt_full_areas - overlap_areas + 0.1) iou_weights = ((iou_targets > 0.1) & (iou_targets <= 1.0) & (gt_full_areas >= 10 * 10)).float() loss_iou += self.loss_iou(pred_iou.view(-1, 1), iou_targets.view(-1, 1), iou_weights.view(-1, 1)) num_iou += torch.sum(iou_weights.detach()) loss_mask = loss_mask/num_imgs if self.rescoring_flag: loss_iou = loss_iou * 10 / num_iou.detach() return dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness, loss_mask=loss_mask, loss_iou=loss_iou) else: return dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness, loss_mask=loss_mask)
def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, bbox_targets, stride, soft_targets, num_total_samples): """Compute loss of a single scale level. Args: anchors (Tensor): Box reference for each scale level with shape (N, num_total_anchors, 4). cls_score (Tensor): Cls and quality joint scores for each scale level has shape (N, num_classes, H, W). bbox_pred (Tensor): Box distribution logits for each scale level with shape (N, 4*(n+1), H, W), n is max value of integral set. labels (Tensor): Labels of each anchors with shape (N, num_total_anchors). label_weights (Tensor): Label weights of each anchor with shape (N, num_total_anchors) bbox_targets (Tensor): BBox regression targets of each anchor wight shape (N, num_total_anchors, 4). stride (tuple): Stride in this scale level. num_total_samples (int): Number of positive samples that is reduced over all GPUs. Returns: dict[tuple, Tensor]: Loss components and weight targets. """ assert stride[0] == stride[1], 'h stride is not equal to w stride!' anchors = anchors.reshape(-1, 4) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1)) soft_targets = soft_targets.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1)) bbox_targets = bbox_targets.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) score = label_weights.new_zeros(labels.shape) if len(pos_inds) > 0: pos_bbox_targets = bbox_targets[pos_inds] pos_bbox_pred = bbox_pred[pos_inds] pos_anchors = anchors[pos_inds] pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0] weight_targets = cls_score.detach().sigmoid() weight_targets = weight_targets.max(dim=1)[0][pos_inds] pos_bbox_pred_corners = self.integral(pos_bbox_pred) pos_decode_bbox_pred = distance2bbox(pos_anchor_centers, pos_bbox_pred_corners) pos_decode_bbox_targets = pos_bbox_targets / stride[0] score[pos_inds] = bbox_overlaps(pos_decode_bbox_pred.detach(), pos_decode_bbox_targets, is_aligned=True) pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) pos_soft_targets = soft_targets[pos_inds] soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1) target_corners = bbox2distance(pos_anchor_centers, pos_decode_bbox_targets, self.reg_max).reshape(-1) # regression loss loss_bbox = self.loss_bbox(pos_decode_bbox_pred, pos_decode_bbox_targets, weight=weight_targets, avg_factor=1.0) # dfl loss loss_dfl = self.loss_dfl(pred_corners, target_corners, weight=weight_targets[:, None].expand( -1, 4).reshape(-1), avg_factor=4.0) # ld loss loss_ld = self.loss_ld(pred_corners, soft_corners, weight=weight_targets[:, None].expand( -1, 4).reshape(-1), avg_factor=4.0) else: loss_ld = bbox_pred.sum() * 0 loss_bbox = bbox_pred.sum() * 0 loss_dfl = bbox_pred.sum() * 0 weight_targets = bbox_pred.new_tensor(0) # cls (qfl) loss loss_cls = self.loss_cls(cls_score, (labels, score), weight=label_weights, avg_factor=num_total_samples) return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
def _get_bboxes(self, cls_scores, bbox_preds, centernesses, mlvl_points, img_shapes, scale_factors, cfg, rescale=False, with_nms=True): """Transform outputs for a single batch item into bbox predictions. Args: cls_scores (list[Tensor]): Box scores for a single scale level with shape (N, num_points * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for a single scale level with shape (N, num_points * 4, H, W). centernesses (list[Tensor]): Centerness for a single scale level with shape (N, num_points, H, W). mlvl_points (list[Tensor]): Box reference for a single scale level with shape (num_total_points, 4). img_shapes (list[tuple[int]]): Shape of the input image, list[(height, width, 3)]. scale_factors (list[ndarray]): Scale factor of the image arrange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. with_nms (bool): If True, do nms before return boxes. Default: True. Returns: tuple(Tensor): det_bboxes (Tensor): BBox predictions in shape (n, 5), where the first 4 columns are bounding box positions (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between 0 and 1. det_labels (Tensor): A (n,) tensor where each item is the predicted class label of the corresponding box. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) device = cls_scores[0].device batch_size = cls_scores[0].shape[0] # convert to tensor to keep tracing nms_pre_tensor = torch.tensor(cfg.get('nms_pre', -1), device=device, dtype=torch.long) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, centerness, points in zip( cls_scores, bbox_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(0, 2, 3, 1).reshape( batch_size, -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(0, 2, 3, 1).reshape(batch_size, -1).sigmoid() bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) # Always keep topk op for dynamic input in onnx if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export() or scores.shape[-2] > nms_pre_tensor): from torch import _shape_as_tensor # keep shape as tensor and get k num_anchor = _shape_as_tensor(scores)[-2].to(device) nms_pre = torch.where(nms_pre_tensor < num_anchor, nms_pre_tensor, num_anchor) max_scores, _ = (scores * centerness[..., None]).max(-1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] batch_inds = torch.arange(batch_size).view( -1, 1).expand_as(topk_inds).long() bbox_pred = bbox_pred[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] centerness = centerness[batch_inds, topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shapes) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) if rescale: batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( scale_factors).unsqueeze(1) batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1) # Set max number of box to be feed into nms in deployment deploy_nms_pre = cfg.get('deploy_nms_pre', -1) if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export(): batch_mlvl_scores, _ = ( batch_mlvl_scores * batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores) ).max(-1) _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre) batch_inds = torch.arange(batch_mlvl_scores.shape[0]).view( -1, 1).expand_as(topk_inds) batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :] batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :] batch_mlvl_centerness = batch_mlvl_centerness[batch_inds, topk_inds] # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class padding = batch_mlvl_scores.new_zeros(batch_size, batch_mlvl_scores.shape[1], 1) batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) if with_nms: det_results = [] for (mlvl_bboxes, mlvl_scores, mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness): det_bbox, det_label = multiclass_nms( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness) det_results.append(tuple([det_bbox, det_label])) else: det_results = [ tuple(mlvl_bs) for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness) ] return det_results
def _get_bboxes_single(self, cls_scores, bbox_preds, centernesses, mlvl_points, img_shape, scale_factor, cfg, rescale=False, with_nms=True): """Transform outputs for a single batch item into bbox predictions. Args: cls_scores (list[Tensor]): Box scores for a single scale level with shape (num_points * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for a single scale level with shape (num_points * 4, H, W). centernesses (list[Tensor]): Centerness for a single scale level with shape (num_points * 4, H, W). mlvl_points (list[Tensor]): Box reference for a single scale level with shape (num_total_points, 4). img_shape (tuple[int]): Shape of the input image, (height, width, 3). scale_factor (ndarray): Scale factor of the image arrange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. with_nms (bool): If True, do nms before return boxes. Default: True. Returns: tuple(Tensor): det_bboxes (Tensor): BBox predictions in shape (n, 5), where the first 4 columns are bounding box positions (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between 0 and 1. det_labels (Tensor): A (n,) tensor where each item is the predicted class label of the corresponding box. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, centerness, points in zip( cls_scores, bbox_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) mlvl_centerness = torch.cat(mlvl_centerness) if with_nms: det_bboxes, det_labels = multiclass_nms( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness) return det_bboxes, det_labels else: return mlvl_bboxes, mlvl_scores, mlvl_centerness
def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute loss of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. centernesses (list[Tensor]): centerness for each scale level, each is a 4D-tensor, the channel number is num_points * 1. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes, gt_labels) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) num_pos = torch.tensor(len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) num_pos = max(reduce_mean(num_pos), 1.0) loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos) pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) # centerness weighted iou loss centerness_denorm = max( reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) if len(pos_inds) > 0: pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) loss_bbox = self.loss_bbox(pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=centerness_denorm) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets, avg_factor=num_pos) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness)
def loss(self, cls_scores, bbox_preds, centernesses, mask_preds, gt_bboxes, gt_labels, img_metas, cfg, gt_masks=None, gt_bboxes_ignore=None, gt_centers=None, gt_max_centerness=None): assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len( mask_preds) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) self.num_points_per_level = [i.size()[0] for i in all_level_points] labels, bbox_targets, mask_targets, centerness_targets = self.polar_target( all_level_points, gt_labels, gt_bboxes, gt_masks, gt_centers, gt_max_centerness) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] if self.use_fourier: if self.loss_on_coe: flatten_mask_preds = [ mask_pred.permute(0, 2, 3, 1).reshape(-1, self.num_coe, 2) for mask_pred in mask_preds ] else: flatten_mask_preds = [] flatten_bbox_preds = [] for mask_pred, points in zip(mask_preds, all_level_points): mask_pred = mask_pred.permute(0, 2, 3, 1).reshape( -1, self.num_coe, 2) if self.bbox_from_mask: xy, m = self.distance2mask(points.repeat(num_imgs, 1), mask_pred, train=True) b = torch.stack([ xy[:, 0].min(1)[0], xy[:, 1].min(1)[0], xy[:, 0].max(1)[0], xy[:, 1].max(1)[0] ], -1) flatten_bbox_preds.append(b) flatten_mask_preds.append(m) else: m = torch.irfft( torch.cat([ mask_pred, torch.zeros(mask_pred.shape[0], self.contour_points - self.num_coe, 2).to("cuda") ], 1), 1, True, False).float().exp() flatten_mask_preds.append(m) else: flatten_mask_preds = [ mask_pred.permute(0, 2, 3, 1).reshape(-1, self.contour_points) for mask_pred in mask_preds ] if not self.bbox_from_mask: flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_cls_scores = torch.cat(flatten_cls_scores) # [num_pixel, 80] flatten_bbox_preds = torch.cat(flatten_bbox_preds) # [num_pixel, 4] flatten_mask_preds = torch.cat(flatten_mask_preds) # [num_pixel, n] flatten_centerness = torch.cat(flatten_centerness) # [num_pixel] flatten_labels = torch.cat(labels).long() # [num_pixel] flatten_centerness_targets = torch.cat(centerness_targets) flatten_bbox_targets = torch.cat(bbox_targets) # [num_pixel, 4] flatten_mask_targets = torch.cat(mask_targets) # [num_pixel, n] flatten_points = torch.cat([ points.repeat(num_imgs, 1) for points in all_level_points ]) # [num_pixel,2] pos_inds = flatten_labels.nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] pos_mask_preds = flatten_mask_preds[pos_inds] if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_mask_targets = flatten_mask_targets[pos_inds] pos_centerness_targets = flatten_centerness_targets[pos_inds] pos_points = flatten_points[pos_inds] if self.bbox_from_mask: pos_decoded_bbox_preds = pos_bbox_preds else: pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_bbox = self.loss_bbox(pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) if self.loss_on_coe: pos_mask_targets = torch.rfft(pos_mask_targets, 1, True, False) pos_mask_targets = pos_mask_targets[..., :self.num_coe, :] loss_mask = self.loss_mask(pos_mask_preds, pos_mask_targets) else: loss_mask = self.loss_mask( pos_mask_preds, pos_mask_targets, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) else: loss_bbox = pos_bbox_preds.sum() loss_mask = pos_mask_preds.sum() loss_centerness = pos_centerness.sum() return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_mask=loss_mask, loss_centerness=loss_centerness)
def _get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors, img_shape, scale_factor, cfg, rescale=False, nms=True): """Transform outputs for a single batch item into labeled boxes. Args: cls_scores (list[Tensor]): Box scores for a single scale level has shape (num_classes, H, W). bbox_preds (list[Tensor]): Box distribution logits for a single scale level with shape (4*(n+1), H, W), n is max value of integral set. mlvl_anchors (list[Tensor]): Box reference for a single scale level with shape (num_total_anchors, 4). img_shape (tuple[int]): Shape of the input image, (height, width, 3). scale_factor (ndarray): Scale factor of the image arange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. Returns: tuple(Tensor): det_bboxes (Tensor): Bbox predictions in shape (N, 5), where the first 4 columns are bounding box positions (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between 0 and 1. det_labels (Tensor): A (N,) tensor where each item is the predicted class label of the corresponding box. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) mlvl_bboxes = [] mlvl_scores = [] for cls_score, bbox_pred, stride, anchors in zip( cls_scores, bbox_preds, self.anchor_generator.strides, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert stride[0] == stride[1] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0) bbox_pred = self.integral(bbox_pred) * stride[0] nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = scores.max(dim=1) _, topk_inds = max_scores.topk(nms_pre) anchors = anchors[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] bboxes = distance2bbox(self.anchor_center(anchors), bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) # Add a dummy background class to the backend when using sigmoid # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) if nms: det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels else: return mlvl_bboxes, mlvl_scores
def get_bboxes_single(self, cls_scores, bbox_preds, mask_preds, centernesses, mlvl_points, img_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] mlvl_masks = [] mlvl_centerness = [] for cls_score, bbox_pred, mask_pred, centerness, points in zip( cls_scores, bbox_preds, mask_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) if self.use_fourier: mask_pred = mask_pred.permute(1, 2, 0).reshape(-1, self.num_coe * 2) else: mask_pred = mask_pred.permute(1, 2, 0).reshape( -1, self.contour_points) nms_pre = cfg.get('nms_pre', -1) if 0 < nms_pre < scores.shape[0]: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] mask_pred = mask_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] if not self.bbox_from_mask: bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) # masks, _ = self.distance2mask(points, mask_pred, bbox=bboxes) masks, _ = self.distance2mask(points, mask_pred, max_shape=img_shape) else: masks, _ = self.distance2mask(points, mask_pred, max_shape=img_shape) bboxes = torch.stack([ masks[:, 0].min(1)[0], masks[:, 1].min(1)[0], masks[:, 0].max(1)[0], masks[:, 1].max(1)[0] ], -1) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_masks.append(masks) mlvl_bboxes = torch.cat(mlvl_bboxes) mlvl_masks = torch.cat(mlvl_masks) if rescale: _mlvl_bboxes = mlvl_bboxes / mlvl_bboxes.new_tensor(scale_factor) try: # TODO:change cuda scale_factor = torch.tensor(scale_factor)[:2].cuda().unsqueeze( 1).repeat(1, self.contour_points) _mlvl_masks = mlvl_masks / scale_factor except (RuntimeError, TypeError, NameError, IndexError): _mlvl_masks = mlvl_masks / mlvl_masks.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) mlvl_centerness = torch.cat(mlvl_centerness) if self.mask_nms: '''1 mask->min_bbox->nms, performance same to origin box''' _mlvl_bboxes = torch.stack([ _mlvl_masks[:, 0].min(1)[0], _mlvl_masks[:, 1].min(1)[0], _mlvl_masks[:, 0].max(1)[0], _mlvl_masks[:, 1].max(1)[0] ], -1) det_bboxes, det_labels, det_masks = multiclass_nms_with_mask( _mlvl_bboxes, mlvl_scores, _mlvl_masks, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness + self.centerness_factor) else: '''2 origin bbox->nms, performance same to mask->min_bbox''' det_bboxes, det_labels, det_masks = multiclass_nms_with_mask( _mlvl_bboxes, mlvl_scores, _mlvl_masks, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness + self.centerness_factor) return det_bboxes, det_labels, det_masks
def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes, gt_labels) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) pos_inds = flatten_labels.nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls( flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness = flatten_centerness[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) if num_pos > 0: pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() return dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness)
def loss(self, cls_scores, bbox_preds, bbox_preds_refine, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute loss of the head. Args: cls_scores (list[Tensor]): Box iou-aware scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box offsets for each scale level, each is a 4D-tensor, the channel number is num_points * 4. bbox_preds_refine (list[Tensor]): Refined Box offsets for each scale level, each is a 4D-tensor, the channel number is num_points * 4. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Default: None. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, label_weights, bbox_targets, bbox_weights = self.get_targets( cls_scores, all_level_points, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and bbox_preds_refine flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous() for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4).contiguous() for bbox_pred in bbox_preds ] flatten_bbox_preds_refine = [ bbox_pred_refine.permute(0, 2, 3, 1).reshape(-1, 4).contiguous() for bbox_pred_refine in bbox_preds_refine ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_bbox_preds_refine = torch.cat(flatten_bbox_preds_refine) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) # FG cat_id: [0, num_classes - 1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = torch.where( ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)) > 0)[0] num_pos = len(pos_inds) pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_bbox_preds_refine = flatten_bbox_preds_refine[pos_inds] pos_labels = flatten_labels[pos_inds] # sync num_pos across all gpus if self.sync_num_pos: num_pos_avg_per_gpu = reduce_mean( pos_inds.new_tensor(num_pos).float()).item() num_pos_avg_per_gpu = max(num_pos_avg_per_gpu, 1.0) else: num_pos_avg_per_gpu = num_pos pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) iou_targets_ini = bbox_overlaps(pos_decoded_bbox_preds, pos_decoded_target_preds.detach(), is_aligned=True).clamp(min=1e-6) bbox_weights_ini = iou_targets_ini.clone().detach() iou_targets_ini_avg_per_gpu = reduce_mean( bbox_weights_ini.sum()).item() bbox_avg_factor_ini = max(iou_targets_ini_avg_per_gpu, 1.0) if num_pos > 0: loss_bbox = self.loss_bbox(pos_decoded_bbox_preds, pos_decoded_target_preds.detach(), weight=bbox_weights_ini, avg_factor=bbox_avg_factor_ini) pos_decoded_bbox_preds_refine = \ distance2bbox(pos_points, pos_bbox_preds_refine) iou_targets_rf = bbox_overlaps(pos_decoded_bbox_preds_refine, pos_decoded_target_preds.detach(), is_aligned=True).clamp(min=1e-6) bbox_weights_rf = iou_targets_rf.clone().detach() iou_targets_rf_avg_per_gpu = reduce_mean( bbox_weights_rf.sum()).item() bbox_avg_factor_rf = max(iou_targets_rf_avg_per_gpu, 1.0) loss_bbox_refine = self.loss_bbox_refine( pos_decoded_bbox_preds_refine, pos_decoded_target_preds.detach(), weight=bbox_weights_rf, avg_factor=bbox_avg_factor_rf) # build IoU-aware cls_score targets if self.use_vfl: pos_ious = iou_targets_rf.clone().detach() cls_iou_targets = torch.zeros_like(flatten_cls_scores) cls_iou_targets[pos_inds, pos_labels] = pos_ious else: loss_bbox = pos_bbox_preds.sum() * 0 loss_bbox_refine = pos_bbox_preds_refine.sum() * 0 if self.use_vfl: cls_iou_targets = torch.zeros_like(flatten_cls_scores) if self.use_vfl: loss_cls = self.loss_cls(flatten_cls_scores, cls_iou_targets, avg_factor=num_pos_avg_per_gpu) else: loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, weight=label_weights, avg_factor=num_pos_avg_per_gpu) return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_bbox_rf=loss_bbox_refine)
def get_bboxes_single(self, cls_scores, bbox_preds, centernesses, cof_preds, feat_mask, mlvl_points, img_shape, ori_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] mlvl_cofs = [] for cls_score, bbox_pred, cof_pred, centerness, points in zip( cls_scores, bbox_preds, cof_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) cof_pred = cof_pred.permute(1,2,0).reshape(-1,32*4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] cof_pred = cof_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) mlvl_cofs.append(cof_pred) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_bboxes = torch.cat(mlvl_bboxes) mlvl_cofs = torch.cat(mlvl_cofs) if rescale: mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) mlvl_centerness = torch.cat(mlvl_centerness) if self.ssd_flag is False: det_bboxes, det_labels, idxs_keep = multiclass_nms_idx( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness) else: mlvl_scores = mlvl_scores*mlvl_centerness.view(-1,1) det_bboxes, det_labels, det_cofs = self.fast_nms(mlvl_bboxes, mlvl_scores[:, 1:].transpose(1, 0).contiguous(), mlvl_cofs, iou_threshold=cfg.nms.iou_thr, score_thr=cfg.score_thr) cls_segms = [[] for _ in range(self.num_classes - 1)] mask_scores = [[] for _ in range(self.num_classes - 1)] if det_bboxes.shape[0]>0: scale = 2 if self.ssd_flag is False: det_cofs = mlvl_cofs[idxs_keep] #####spp######################## img_mask1 = feat_mask.permute(1,2,0) pos_masks00 = torch.sigmoid(img_mask1 @ det_cofs[:, 0:32].t()) pos_masks01 = torch.sigmoid(img_mask1 @ det_cofs[:, 32:64].t()) pos_masks10 = torch.sigmoid(img_mask1 @ det_cofs[:, 64:96].t()) pos_masks11 = torch.sigmoid(img_mask1 @ det_cofs[:, 96:128].t()) pos_masks = torch.stack([pos_masks00,pos_masks01,pos_masks10,pos_masks11],dim=0) pos_masks = self.crop_cuda(pos_masks, det_bboxes[:,:4] * det_bboxes.new_tensor(scale_factor) / scale) # pos_masks = crop_split(pos_masks00, pos_masks01, pos_masks10, pos_masks11, # det_bboxes * det_bboxes.new_tensor(scale_factor) / scale) pos_masks = pos_masks.permute(2, 0, 1) # masks = F.interpolate(pos_masks.unsqueeze(0), scale_factor=scale/scale_factor, mode='bilinear', align_corners=False).squeeze(0) if self.ssd_flag: masks = F.interpolate(pos_masks.unsqueeze(0), scale_factor=scale / scale_factor[3:1:-1], mode='bilinear', align_corners=False).squeeze(0) else: masks = F.interpolate(pos_masks.unsqueeze(0), scale_factor=scale / scale_factor, mode='bilinear', align_corners=False).squeeze(0) masks.gt_(0.4) if self.rescoring_flag: pred_iou = pos_masks.unsqueeze(1) pred_iou = self.convs_scoring(pred_iou) pred_iou = self.relu(self.mask_scoring(pred_iou)) pred_iou = F.max_pool2d(pred_iou, kernel_size=pred_iou.size()[2:]).squeeze(-1).squeeze(-1) pred_iou = pred_iou[range(pred_iou.size(0)), det_labels].squeeze() mask_scores = pred_iou*det_bboxes[:, -1] mask_scores = mask_scores.cpu().numpy() mask_scores = [mask_scores[det_labels.cpu().numpy() == i] for i in range(self.num_classes - 1)] for i in range(det_bboxes.shape[0]): label = det_labels[i] mask = masks[i].cpu().numpy() im_mask = np.zeros((ori_shape[0], ori_shape[1]), dtype=np.uint8) shape = np.minimum(mask.shape, ori_shape[0:2]) im_mask[:shape[0],:shape[1]] = mask[:shape[0],:shape[1]] rle = mask_util.encode( np.array(im_mask[:, :, np.newaxis], order='F'))[0] cls_segms[label].append(rle) if self.rescoring_flag: return det_bboxes, det_labels, (cls_segms, mask_scores) else: return det_bboxes, det_labels, cls_segms
def loss(self, cls_scores, bbox_preds, centernesses, mocs, gt_bboxes, gt_labels, img_metas, imgs, gt_bboxes_ignore=None): """Compute loss of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. centernesses (list[Tensor]): Centerss for each scale level, each is a 4D-tensor, the channel number is num_points * 1. mocs (listp[Temspr]): Coefficient for each scale level, each is a 4D-tensor, the channel numer is num_points * 1. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. imgs (list[Tensor]): images in each level. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes, gt_labels) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_moc = [moc.permute(0, 2, 3, 1).reshape(-1) for moc in mocs] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) flatten_mocs = torch.cat(flatten_moc) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] num_points = [center.size(0) for center in all_level_points] #coefficient moc_result_list = self.convertlevel2img(flatten_bbox_targets, flatten_labels, flatten_bbox_preds, flatten_points, flatten_mocs, num_points, num_imgs) flatten_bbox_targets_reshape = moc_result_list[0] flatten_labels_targets_reshape = moc_result_list[1] flatten_bbox_preds_reshape = moc_result_list[2] flatten_points_reshape = moc_result_list[3] flatten_conv_reshape = moc_result_list[4] #print(num_points,labels,bg_class_ind) pos_moc = flatten_mocs[pos_inds] if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) assert len(flatten_bbox_targets_reshape) == len( flatten_labels_targets_reshape) == len( flatten_bbox_preds_reshape) bbox_preds_moc, bbox_targets_moc, conv_moc = self.compute_bbox_per_image( flatten_bbox_targets_reshape, flatten_labels_targets_reshape, flatten_bbox_preds_reshape, flatten_points_reshape, flatten_conv_reshape, bg_class_ind) moc_result, conv_mocs, loss_conv_moc_for_clcs = self.moc_overlap( bbox_preds_moc, bbox_targets_moc, conv_moc, imgs) #print(bbox_preds_moc) loss_moc = self.loss_moc(conv_mocs.to(pos_centerness.device), moc_result.to(pos_centerness.device)) # centerness weighted iou loss #print(moc_result.sum().to(pos_centerness.device),pos_centerness_targets-moc_result.to(pos_centerness_targets.device)) #for nonzero_index in range(len(moc_result)): # if moc_result[nonzero_index]==0: # moc_result[nonzero_index] =moc_result[nonzero_index]+0.000001 if moc_result is not None and not torch.any(moc_result > 0.): loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) else: loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, #weight=pos_centerness_targets, #avg_factor=pos_centerness_targets.sum()) weight=moc_result.to(pos_centerness_targets.device), avg_factor=moc_result.sum().to( pos_centerness_targets.device)) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() loss_moc = pos_moc.sum() #loss_centerness = loss_centerness -loss_centerness+0.000001 #loss_moc = loss_moc-loss_moc return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_moc=loss_moc, loss_centerness=loss_centerness)