Exemple #1
0
 def compute_bbox_per_image(self, flatten_bbox_targets_reshape,
                            flatten_labels_targets_reshape,
                            flatten_bbox_preds_reshape,
                            flatten_points_reshape, flatten_conv_reshape,
                            bg_class_ind):
     #select bbox  per image level based on labels, and decode distance bbox
     bbox_targets_moc = []
     labels_targets_moc = []
     bbox_preds_moc = []
     conv_moc = []
     for bbox_targets, labels, bbox_preds, points, conv in zip(
             flatten_bbox_targets_reshape, flatten_labels_targets_reshape,
             flatten_bbox_preds_reshape, flatten_points_reshape,
             flatten_conv_reshape):
         pos_inds = ((labels >= 0) &
                     (labels < bg_class_ind)).nonzero().reshape(-1)
         #print(pos_inds)
         pos_bbox_preds = bbox_preds[pos_inds]
         pos_bbox_targets = bbox_targets[pos_inds]
         pos_points = points[pos_inds]
         pos_conv = conv[pos_inds]
         pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
         pos_decoded_target_preds = distance2bbox(pos_points,
                                                  pos_bbox_targets)
         bbox_targets_moc.append(pos_decoded_target_preds)
         bbox_preds_moc.append(pos_decoded_bbox_preds)
         conv_moc.append(pos_conv)
         #print(pos_decoded_target_preds[:5,3]-pos_decoded_target_preds[:5,1],pos_decoded_target_preds[:5])
     return bbox_preds_moc, bbox_targets_moc, conv_moc
Exemple #2
0
    def loss(self,
             gt_bboxes,
             cls_scores,
             bbox_preds,
             bbox_iou,
             labels,
             label_weight,
             bbox_targets,
             bbox_weights,
             points,
             reduction_override=None):
        assert len(cls_scores) == len(bbox_preds)

        num_imgs = cls_scores.size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = cls_scores.permute(0, 2, 3, 1).reshape(
            -1, self.cls_out_channels)
        flatten_bbox_preds = bbox_preds.permute(0, 2, 3, 1).reshape(-1, 4)
        flatten_bbox_iou = bbox_iou.permute(0, 2, 3, 1).reshape(-1, 1)
        flatten_labels = labels.reshape(-1)
        flatten_bbox_targets = bbox_targets.reshape(-1, 4)
        # repeat points to align with bbox_preds
        flatten_points = points.reshape(-1, 2)

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(flatten_cls_scores,
                                 flatten_labels,
                                 avg_factor=num_pos +
                                 num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_iou = flatten_bbox_iou[pos_inds]

        if num_pos > 0:
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            # pos_gt_bbox = gt_bboxes[0][pos_inds]
            # a = pos_gt_bbox.cpu().detach().numpy()
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(pos_decoded_bbox_preds,
                                       pos_decoded_target_preds,
                                       avg_factor=num_pos)
            bbox_iou_targets = bbox_goverlaps(
                pos_decoded_bbox_preds,
                pos_decoded_target_preds,
                is_aligned=True).clamp(min=1e-6)[:, None]
            loss_bbox_iou = self.loss_iou(pos_bbox_iou, bbox_iou_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_bbox_iou = pos_bbox_iou.sum()

        return dict(loss_cls=loss_cls,
                    loss_bbox=loss_bbox,
                    loss_iou=loss_bbox_iou)
Exemple #3
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          centernesses,
                          mlvl_points,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        if cfg.stat_2d:
            stat = cv2.imread('kitti_tools/stat/stat_2d.png').astype(
                np.float32)
            stat = cv2.resize(stat, (106, 32)).astype(np.float32)
            # std = np.std(stat, axis=(0, 1))
            # stat /= std
            w_stat = torch.from_numpy(stat[:, :,
                                           1]).float().cuda().unsqueeze(0)
            h_stat = torch.from_numpy(stat[:, :,
                                           2]).float().cuda().unsqueeze(0)
            stat_all = torch.cat([w_stat, h_stat, w_stat, h_stat], dim=0)
            bbox_preds = [bbox_pred * stat_all
                          for bbox_pred in bbox_preds]  # torch.exp(bbox_pred)
        for cls_score, bbox_pred, centerness, points in zip(
                cls_scores, bbox_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)
        det_bboxes, det_labels = multiclass_nms(mlvl_bboxes,
                                                mlvl_scores,
                                                cfg.score_thr,
                                                cfg.nms,
                                                cfg.max_per_img,
                                                score_factors=mlvl_centerness)
        return det_bboxes, det_labels
Exemple #4
0
    def loss_single(self, anchors, cls_score, bbox_pred, labels,
                    label_weights, bbox_targets, stride, num_total_samples, cfg):

        anchors = anchors.reshape(-1, 4)
        cls_score = cls_score.permute(0, 2, 3,
                                      1).reshape(-1, self.cls_out_channels)
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1))
        bbox_targets = bbox_targets.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)

        pos_inds = torch.nonzero(labels).squeeze(1)
        score = label_weights.new_zeros(labels.shape)

        if len(pos_inds) > 0:
            pos_bbox_targets = bbox_targets[pos_inds]
            pos_bbox_pred = bbox_pred[pos_inds] # (n, 4 * (reg_max + 1))
            pos_anchors = anchors[pos_inds]

            norm_anchor_center = self.anchor_center(pos_anchors) / stride

            pos_bbox_pred_distance = self.distribution_project(pos_bbox_pred)

            pos_decode_bbox_pred = distance2bbox(norm_anchor_center,
                                                 pos_bbox_pred_distance)
            pos_decode_bbox_targets = pos_bbox_targets / stride

            target_ltrb = bbox2distance(norm_anchor_center,
                                        pos_decode_bbox_targets, 
                                        self.reg_max).reshape(-1)
            score[pos_inds] = self.iou_target(pos_decode_bbox_pred.detach(),
                                              pos_decode_bbox_targets)
            weight_targets = \
                    cls_score.detach().sigmoid().max(dim=1)[0][pos_inds]

            # regression loss
            loss_bbox = self.loss_bbox(
                pos_decode_bbox_pred,
                pos_decode_bbox_targets,
                weight=weight_targets,
                avg_factor=1.0)

            pred_ltrb = pos_bbox_pred.reshape(-1, self.reg_max + 1)
            # dfl loss TODO
            loss_dfl = self.loss_dfl(
                pred_ltrb, 
                target_ltrb, 
                weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
                avg_factor=4.0)
        else:
            loss_bbox = bbox_pred.sum() * 0
            loss_dfl = bbox_pred.sum() * 0
            weight_targets = torch.tensor(0).cuda()
        
        # qfl loss TODO
        loss_qfl = self.loss_qfl(cls_score, labels, score,
                                 avg_factor=num_total_samples)

        return loss_qfl, loss_bbox, loss_dfl, weight_targets.sum()
Exemple #5
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          coef_preds,
                          centernesses,
                          mlvl_points,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_coefs = []
        mlvl_centerness = []
        for cls_score, bbox_pred, coef_pred, centerness, points in zip(
                cls_scores, bbox_preds, coef_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()

            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
            coef_pred = coef_pred.permute(1, 2, 0).reshape(-1, self.num_bases)
            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                coef_pred = coef_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
            coefs = coef_pred
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
            mlvl_coefs.append(coefs)

        mlvl_bboxes = torch.cat(mlvl_bboxes)
        mlvl_coefs = torch.cat(mlvl_coefs)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)
        det_bboxes, det_labels, det_coefs = multiclass_nms_with_mask(
            mlvl_bboxes,
            mlvl_scores,
            mlvl_coefs,
            cfg.score_thr,
            cfg.nms,
            cfg.max_per_img,
            score_factors=mlvl_centerness,
            num_bases=self.num_bases)
        return det_bboxes, det_labels, det_coefs
Exemple #6
0
    def get_bboxes_single(
            self,
            cls_scores,
            bbox_preds,
            centernesses,
            mlvl_points,  # fpn特征上面每一个点对应于原图中的位置
            img_shape,
            scale_factor,
            cfg,
            rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)  # fpn层数
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        for cls_score, bbox_pred, centerness, points in zip(
                cls_scores, bbox_preds, centernesses, mlvl_points):  # 分层处理
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]  #空间大小要一致
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:  # 挑选样本进行NMS
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
        mlvl_bboxes = torch.cat(mlvl_bboxes)  # [num_points, 4]
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)  # [num_points, 80]
        padding = mlvl_scores.new_zeros(
            mlvl_scores.shape[0],
            1)  # [num_points, 1] 因为一般都是有一个背景类别 但是fcos实际上没有背景类别的分数
        mlvl_scores = torch.cat(
            [padding, mlvl_scores], dim=1
        )  #  [num_points, 5]  shape (n, #class), where the 0th column contains scores of the background class, but this will be ignored in the NMS.
        mlvl_centerness = torch.cat(mlvl_centerness)  #[num_points]
        det_bboxes, det_labels = multiclass_nms(  # NMS
            mlvl_bboxes,
            mlvl_scores,
            cfg.score_thr,
            cfg.nms,
            cfg.max_per_img,
            score_factors=mlvl_centerness)
        return det_bboxes, det_labels
Exemple #7
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          centernesses,
                          mlvl_points,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        # TODO: change output to proposals without labels
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        # mlvl_bboxes = []
        # mlvl_scores = []
        # mlvl_centerness = []
        mlvl_proposals = []
        for cls_score, bbox_pred, centerness, points in zip(
                cls_scores, bbox_preds, centernesses, mlvl_points):
            # iteration by levels
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            scores = scores.squeeze()
            # scores *= centerness
            proposals = distance2bbox(points, bbox_pred, max_shape=img_shape)
            proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
            proposals, _ = nms(proposals, cfg.nms_thr)
            proposals = proposals[:cfg.nms_post, :]
            mlvl_proposals.append(proposals)
            # mlvl_bboxes.append(bboxes)
            # mlvl_scores.append(scores)
            # mlvl_centerness.append(centerness)
        proposals = torch.cat(mlvl_proposals, 0)
        if cfg.nms_across_levels:
            proposals, _ = nms(proposals, cfg.nms_thr)
            proposals = proposals[:cfg.max_num, :]
        else:
            scores = proposals[:, 4]
            num = min(cfg.max_num, proposals.shape[0])
            _, topk_inds = scores.topk(num)
            proposals = proposals[topk_inds, :]

        return proposals
    def _get_bboxes_single(self,
                           cls_scores,
                           bbox_preds,
                           centernesses,
                           mlvl_points,
                           img_shape,
                           scale_factor,
                           cfg,
                           rescale=False):
        cfg = self.test_cfg if cfg is None else cfg
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        for cls_score, bbox_pred, centerness, points in zip(
                cls_scores, bbox_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
        # BG cat_id: num_class
        mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)
        det_bboxes, det_labels = multiclass_nms(
            mlvl_bboxes,
            mlvl_scores,
            cfg.score_thr,
            cfg.nms,
            cfg.max_per_img,
            score_factors=mlvl_centerness)
        return det_bboxes, det_labels
Exemple #9
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          mlvl_anchors,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
        mlvl_bboxes = []
        mlvl_scores = []
        for stride, cls_score, bbox_pred, anchors in zip(
                self.anchor_strides, cls_scores, bbox_preds, mlvl_anchors):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            bbox_pred = bbox_pred.permute(1, 2, 0)
            bbox_pred = self.distribution_project(bbox_pred) * stride

            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = scores.max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                anchors = anchors[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
            
            bboxes = distance2bbox(self.anchor_center(anchors), bbox_pred,
                                   max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)

        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)

        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)

        det_bboxes, det_labels = multiclass_nms(
            mlvl_bboxes,
            mlvl_scores,
            cfg.score_thr,
            cfg.nms,
            cfg.max_per_img)
        return det_bboxes, det_labels
    def get_bbox_prob_and_overlap(self, points, bbox_preds, gt_bboxes):

        bbox_targets = bbox2distance(points,
                                     gt_bboxes[:, None, :].repeat(
                                         1, points.shape[1], 1),
                                     norm=self.distance_norm)
        bbox_prob = self.loss_bbox(bbox_preds,
                                   bbox_targets,
                                   reduction_override='none').neg().exp()

        pred_boxes = distance2bbox(points, bbox_preds, norm=self.distance_norm)
        bbox_overlap = bbox_overlaps(gt_bboxes[:,
                                               None, :].expand_as(pred_boxes),
                                     pred_boxes,
                                     is_aligned=True)

        return bbox_prob, bbox_overlap
 def get_bboxes_single(self,
                       cls_scores,
                       bbox_preds,
                       mlvl_points,
                       img_shape,
                       scale_factor,
                       cfg,
                       rescale=False):
     assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
     mlvl_bboxes = []
     mlvl_scores = []
     for cls_score, bbox_pred, points in zip(cls_scores, bbox_preds,
                                             mlvl_points):
         assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
         cls_score = cls_score.permute(1, 2,
                                       0).reshape(-1, self.cls_out_channels)
         scores = cls_score.sigmoid()
         bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
         nms_pre = cfg.get('nms_pre', -1)
         if 0 < nms_pre < scores.shape[0]:
             max_scores, _ = scores.max(dim=1)
             _, topk_inds = max_scores.topk(nms_pre)
             points = points[topk_inds, :]
             bbox_pred = bbox_pred[topk_inds, :]
             scores = scores[topk_inds, :]
         bboxes = distance2bbox(points,
                                bbox_pred,
                                norm=self.distance_norm,
                                max_shape=img_shape)
         mlvl_bboxes.append(bboxes)
         mlvl_scores.append(scores)
     mlvl_bboxes = torch.cat(mlvl_bboxes)
     if rescale:
         mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
     mlvl_scores = torch.cat(mlvl_scores)
     padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
     mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
     det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
                                             cfg.score_thr, cfg.nms,
                                             cfg.max_per_img)
     return det_bboxes, det_labels
Exemple #12
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             reid_feats,
             gt_bboxes,
             gt_labels,
             gt_ids, 
             img_metas,
             gt_bboxes_ignore=None):
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_classes.
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level, each is a 4D-tensor, the channel number is
                num_points * 4.
            centernesses (list[Tensor]): Centerss for each scale level, each
                is a 4D-tensor, the channel number is num_points * 1.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len(reid_feats)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, ids, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
                                                gt_labels, gt_ids)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_reid = [
            reid_feat.permute(0, 2, 3, 1).reshape(-1, self.feat_channels)
            for reid_feat in reid_feats
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_reid = torch.cat(flatten_reid)
        #print("flatten reid", flatten_reid.shape)
        flatten_labels = torch.cat(labels)
        flatten_ids = torch.cat(ids)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = ((flatten_labels >= 0)
                    & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
        #pos_inds = nonzero((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(
            flatten_cls_scores, flatten_labels,
            avg_factor=num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]

        # background index
        '''
        bg_inds = ((flatten_labels < 0)
                    | (flatten_labels == bg_class_ind)).nonzero().reshape(-1)
        num_bg = len(bg_inds)
        bg_cls_scores = flatten_cls_scores[bg_inds]
        if num_bg > num_pos:
            cls_ids = torch.argsort(bg_cls_scores.squeeze(), descending=True)
            bg_inds = bg_inds[cls_ids[:num_pos]]
        '''

        pos_reid = flatten_reid[pos_inds]
        #bg_reid = flatten_reid[bg_inds]
        #pos_reid = torch.cat((pos_reid, bg_reid))
        pos_reid = F.normalize(pos_reid)


        if num_pos > 0:
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(
                pos_decoded_bbox_preds,
                pos_decoded_target_preds,
                weight=pos_centerness_targets,
                avg_factor=pos_centerness_targets.sum())
            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)

            
            pos_reid_ids = flatten_ids[pos_inds]
            #bg_reid_ids = flatten_ids[bg_inds]
            #pos_reid_ids = torch.cat((pos_reid_ids, bg_reid_ids))
            #loss_oim = self.loss_reid(pos_reid, pos_reid_ids)
            #print(pos_reid.shape, pos_reid_ids.shape)
            #print(pos_reid_ids)
            
            # reid oim loss
            labeled_matching_scores = self.labeled_matching_layer(pos_reid, pos_reid_ids)
            labeled_matching_scores *= 10
            unlabeled_matching_scores = self.unlabeled_matching_layer(pos_reid, pos_reid_ids)
            unlabeled_matching_scores *= 10
            matching_scores = torch.cat((labeled_matching_scores, unlabeled_matching_scores), dim=1)
            pid_labels = pos_reid_ids.clone()
            pid_labels[pid_labels == -2] = -1
            loss_oim = F.cross_entropy(matching_scores, pid_labels, ignore_index=-1)
            '''
            # softmax 
            matching_scores = self.classifier_reid(pos_reid).contiguous()
            loss_oim = F.cross_entropy(matching_scores, pos_reid_ids, ignore_index=-1)
            '''

            
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()
            loss_oim = pos_reid.sum()
            print('no gt box')

        return dict(
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            #loss_centerness=loss_centerness)
            loss_centerness=loss_centerness,
            loss_oim=loss_oim)
Exemple #13
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None,
             batch_idx=0,
             analysis_scale=1.0):
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_classes.
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level, each is a 4D-tensor, the channel number is
                num_points * 4.
            centernesses (list[Tensor]): Centerss for each scale level, each
                is a 4D-tensor, the channel number is num_points * 1.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = ((flatten_labels >= 0)
                    & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(flatten_cls_scores,
                                 flatten_labels,
                                 avg_factor=num_pos +
                                 num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]

        pos_anchor_flags = torch.zeros_like(flatten_centerness)
        pos_anchor_flags[pos_inds] = 1.0

        pos_anchor_flags_list = []
        pre_idx = 0
        for i, featmap_size in enumerate(featmap_sizes):
            cur_featmap_size = featmap_size[0] * featmap_size[1]
            cur_pos_anchor_flags = pos_anchor_flags[pre_idx:pre_idx +
                                                    cur_featmap_size]
            cur_pos_anchor_flags = cur_pos_anchor_flags.view(
                1, 1, featmap_size[0], featmap_size[1])
            save_image(
                cur_pos_anchor_flags,
                f"analysis_results_fcos/image_{batch_idx}_feature_{i}_flatten_anchor_flags_scale_{analysis_scale}.png"
            )

            pre_idx = pre_idx + cur_featmap_size

        if num_pos > 0:
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(pos_decoded_bbox_preds,
                                       pos_decoded_target_preds,
                                       weight=pos_centerness_targets,
                                       avg_factor=pos_centerness_targets.sum())
            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()

        return dict(loss_cls=loss_cls,
                    loss_bbox=loss_bbox,
                    loss_centerness=loss_centerness)
Exemple #14
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = sigmoid_focal_loss(
            flatten_cls_scores, flatten_labels, cfg.gamma, cfg.alpha,
            'none').sum()[None] / (num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        # pos_centerness_targets = self.centerness_target(pos_bbox_targets)

        if num_pos > 0:
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)

            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_reg = (
                (iou_loss(pos_decoded_bbox_preds,
                          pos_decoded_target_preds,
                          reduction='none') * pos_centerness_targets).sum() /
                pos_centerness_targets.sum())[None]
            loss_centerness = F.binary_cross_entropy_with_logits(
                pos_centerness, pos_centerness_targets, reduction='mean')[None]
        else:
            loss_reg = pos_bbox_preds.sum()[None]
            loss_centerness = pos_centerness.sum()[None]

        return dict(loss_cls=loss_cls,
                    loss_reg=loss_reg,
                    loss_centerness=loss_centerness)
    def _get_bboxes(self,
                    cls_scores,
                    bbox_preds,
                    mlvl_anchors,
                    img_shapes,
                    scale_factors,
                    cfg,
                    rescale=False,
                    with_nms=True):
        """Transform outputs for a single batch item into labeled boxes.

        Args:
            cls_scores (list[Tensor]): Box scores for a single scale level
                has shape (N, num_classes, H, W).
            bbox_preds (list[Tensor]): Box distribution logits for a single
                scale level with shape (N, 4*(n+1), H, W), n is max value of
                integral set.
            mlvl_anchors (list[Tensor]): Box reference for a single scale level
                with shape (num_total_anchors, 4).
            img_shapes (list[tuple[int]]): Shape of the input image,
                list[(height, width, 3)].
            scale_factors (list[ndarray]): Scale factor of the image arange as
                (w_scale, h_scale, w_scale, h_scale).
            cfg (mmcv.Config | None): Test / postprocessing configuration,
                if None, test_cfg would be used.
            rescale (bool): If True, return boxes in original image space.
                Default: False.
            with_nms (bool): If True, do nms before return boxes.
                Default: True.

        Returns:
            list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
                The first item is an (n, 5) tensor, where 5 represent
                (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
                The shape of the second tensor in the tuple is (n,), and
                each element represents the class label of the corresponding
                box.
        """
        cfg = self.test_cfg if cfg is None else cfg
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
        batch_size = cls_scores[0].shape[0]

        mlvl_bboxes = []
        mlvl_scores = []
        for cls_score, bbox_pred, stride, anchors in zip(
                cls_scores, bbox_preds, self.anchor_generator.strides,
                mlvl_anchors):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            assert stride[0] == stride[1]
            scores = cls_score.permute(0, 2, 3, 1).reshape(
                batch_size, -1, self.cls_out_channels).sigmoid()
            bbox_pred = bbox_pred.permute(0, 2, 3, 1)

            bbox_pred = self.integral(bbox_pred) * stride[0]
            bbox_pred = bbox_pred.reshape(batch_size, -1, 4)

            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[1] > nms_pre:
                max_scores, _ = scores.max(-1)
                _, topk_inds = max_scores.topk(nms_pre)
                batch_inds = torch.arange(batch_size).view(
                    -1, 1).expand_as(topk_inds).long()
                anchors = anchors[topk_inds, :]
                bbox_pred = bbox_pred[batch_inds, topk_inds, :]
                scores = scores[batch_inds, topk_inds, :]
            else:
                anchors = anchors.expand_as(bbox_pred)

            bboxes = distance2bbox(self.anchor_center(anchors),
                                   bbox_pred,
                                   max_shape=img_shapes)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)

        batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
        if rescale:
            batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
                scale_factors).unsqueeze(1)

        batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
        # Add a dummy background class to the backend when using sigmoid
        # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
        # BG cat_id: num_class
        padding = batch_mlvl_scores.new_zeros(batch_size,
                                              batch_mlvl_scores.shape[1], 1)
        batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)

        if with_nms:
            det_results = []
            for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes,
                                                  batch_mlvl_scores):
                det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores,
                                                     cfg.score_thr, cfg.nms,
                                                     cfg.max_per_img)
                det_results.append(tuple([det_bbox, det_label]))
        else:
            det_results = [
                tuple(mlvl_bs)
                for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores)
            ]
        return det_results
Exemple #16
0
    def get_bboxes_single(self,
                          cls_score,
                          bbox_pred,
                          bbox_iou,
                          points,
                          img_shape,
                          cfg,
                          rescale=False,
                          scale_factor=None):
        # assert len(cls_score) == len(bbox_pred)
        # mlvl_bboxes = []
        # mlvl_scores = []
        # mlvl_bboxiou = []
        assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
        scores = cls_score.permute(1, 2,
                                   0).reshape(-1,
                                              self.cls_out_channels).sigmoid()
        bbox_iou = bbox_iou.permute(1, 2, 0).reshape(-1).sigmoid()
        bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
        points = points[0]
        nms_pre = cfg.get('nms_pre', -1)
        if nms_pre > 0 and scores.shape[0] > nms_pre:
            max_scores, _ = (scores * bbox_iou[:, None]).max(dim=1)
            _, topk_inds = max_scores.topk(nms_pre)
            points = points[topk_inds, :]
            bbox_pred = bbox_pred[topk_inds, :]
            scores = scores[topk_inds, :]
            bbox_iou = bbox_iou[topk_inds]
        bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)

        scores_cpu = np.array(scores.cpu().detach())
        # bbox_iou = np.array(bbox_iou.cpu().detach())
        # bboxes = np.array(bboxes.cpu().detach())

        max_cls = np.unravel_index(scores_cpu.argmax(), scores_cpu.shape)[1]
        scores_iou = scores[:, max_cls] * bbox_iou
        bbox_ind = scores_iou.argmax()
        det_bboxes = bboxes[bbox_ind]
        det_labels = scores[bbox_ind] * bbox_iou[bbox_ind]

        # mlvl_bboxes.append(bboxes)
        # mlvl_scores.append(scores)
        # mlvl_bboxiou.append(bbox_iou)
        # mlvl_bboxes = torch.cat(mlvl_bboxes)
        # if rescale:
        #     mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        # mlvl_scores = torch.cat(mlvl_scores)  # 49, 81
        # padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)  # 49, 1
        # mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)  # 49, 82
        # mlvl_bboxiou = torch.cat(mlvl_bboxiou)
        #
        # #
        # det_bboxes, det_labels = multiclass_nms(
        #     mlvl_bboxes,
        #     mlvl_scores,
        #     cfg.score_thr,
        #     cfg.nms,
        #     cfg.max_per_img,
        #     score_factors=mlvl_bboxiou)

        return det_bboxes, det_labels
Exemple #17
0
    def forward(self, feats):
        """Forward features from the upstream network.

        Args:
            feats (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            tuple: Usually a tuple of classification scores and bbox prediction
                cls_scores (list[Tensor]): Classification scores for all scale
                    levels, each is a 4D-tensor, the channels number is
                    num_anchors * num_classes.
                bbox_preds (list[Tensor]): Decoded box for all scale levels,
                    each is a 4D-tensor, the channels number is
                    num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format.
        """
        cls_scores = []
        bbox_preds = []
        for idx, (x, scale, stride) in enumerate(
                zip(feats, self.scales, self.prior_generator.strides)):
            b, c, h, w = x.shape
            anchor = self.prior_generator.single_level_grid_priors(
                (h, w), idx, device=x.device)
            anchor = torch.cat([anchor for _ in range(b)])
            # extract task interactive features
            inter_feats = []
            for inter_conv in self.inter_convs:
                x = inter_conv(x)
                inter_feats.append(x)
            feat = torch.cat(inter_feats, 1)

            # task decomposition
            avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
            cls_feat = self.cls_decomp(feat, avg_feat)
            reg_feat = self.reg_decomp(feat, avg_feat)

            # cls prediction and alignment
            cls_logits = self.tood_cls(cls_feat)
            cls_prob = self.cls_prob_module(feat)
            cls_score = sigmoid_geometric_mean(cls_logits, cls_prob)

            # reg prediction and alignment
            if self.anchor_type == 'anchor_free':
                reg_dist = scale(self.tood_reg(reg_feat).exp()).float()
                reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
                reg_bbox = distance2bbox(
                    self.anchor_center(anchor) / stride[0],
                    reg_dist).reshape(b, h, w, 4).permute(0, 3, 1,
                                                          2)  # (b, c, h, w)
            elif self.anchor_type == 'anchor_based':
                reg_dist = scale(self.tood_reg(reg_feat)).float()
                reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
                reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape(
                    b, h, w, 4).permute(0, 3, 1, 2) / stride[0]
            else:
                raise NotImplementedError(
                    f'Unknown anchor type: {self.anchor_type}.'
                    f'Please use `anchor_free` or `anchor_based`.')
            reg_offset = self.reg_offset_module(feat)
            bbox_pred = self.deform_sampling(reg_bbox.contiguous(),
                                             reg_offset.contiguous())
            cls_scores.append(cls_score)
            bbox_preds.append(bbox_pred)
        return tuple(cls_scores), tuple(bbox_preds)
Exemple #18
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        if cfg.stat_2d:
            # from mmdet.apis import get_root_logger
            # logger = get_root_logger()
            stat = cv2.imread('kitti_tools/stat/stat_2d.png').astype(
                np.float32)
            stat = cv2.resize(stat, (106, 32)).astype(np.float32)
            # std = np.std(stat, axis=(0, 1))
            # stat /= std
            w_stat = torch.from_numpy(stat[:, :,
                                           1]).float().cuda().unsqueeze(0)
            h_stat = torch.from_numpy(stat[:, :,
                                           2]).float().cuda().unsqueeze(0)
            stat_all = torch.cat([w_stat, h_stat, w_stat, h_stat], dim=0)
            # logger.info('old', bbox_preds[0][0, :, 20, 20])
            bbox_preds = [bbox_pred * stat_all
                          for bbox_pred in bbox_preds]  # torch.exp(bbox_pred)
            # logger.info('new', bbox_preds[0][0, :, 20, 20])
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)

        # check NaN and Inf
        assert torch.isfinite(flatten_cls_scores).all().item(), \
            'classification scores become infinite or NaN!'
        assert torch.isfinite(flatten_bbox_preds).all().item(), \
            'bbox predications become infinite or NaN!'
        assert torch.isfinite(flatten_centerness).all().item(), \
            'bbox centerness become infinite or NaN!'

        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(flatten_cls_scores,
                                 flatten_labels,
                                 avg_factor=num_pos +
                                 num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]

        if num_pos > 0:
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(pos_decoded_bbox_preds,
                                       pos_decoded_target_preds,
                                       weight=pos_centerness_targets,
                                       avg_factor=pos_centerness_targets.sum())
            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()

        return dict(loss_cls=loss_cls,
                    loss_bbox=loss_bbox,
                    loss_centerness=loss_centerness)
Exemple #19
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             cof_preds,
             feat_masks,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None,
             gt_masks_list=None):
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points, all_level_strides = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device)
        labels, bbox_targets, label_list, bbox_targets_list, gt_inds = self.fcos_target(all_level_points,
                                                                                        gt_bboxes, gt_labels)
        #decode detection and groundtruth
        det_bboxes = []
        det_targets = []
        num_levels = len(bbox_preds)

        for img_id in range(len(img_metas)):
            bbox_pred_list = [
                bbox_preds[i][img_id].permute(1, 2, 0).reshape(-1, 4).detach() for i in range(num_levels)
            ]
            bbox_target_list =  bbox_targets_list[img_id]

            bboxes = []
            targets = []
            for i in range(len(bbox_pred_list)):
                bbox_pred = bbox_pred_list[i]
                bbox_target = bbox_target_list[i]
                points = all_level_points[i]
                bboxes.append(distance2bbox(points, bbox_pred))
                targets.append(distance2bbox(points, bbox_target))

            bboxes = torch.cat(bboxes, dim=0)
            targets = torch.cat(targets, dim=0)

            det_bboxes.append(bboxes)
            det_targets.append(targets)
        gt_masks = []
        for i in range(len(gt_labels)):
            gt_label = gt_labels[i]
            gt_masks.append(torch.from_numpy(np.array(gt_masks_list[i][:gt_label.shape[0]], dtype=np.float32)).to(gt_label.device))

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])
        flatten_strides = torch.cat(
            [strides.view(-1,1).repeat(num_imgs, 1) for strides in all_level_strides])

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(
            flatten_cls_scores, flatten_labels,
            avg_factor=num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]

        if num_pos > 0:
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            pos_points = flatten_points[pos_inds]
            pos_strides = flatten_strides[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds/pos_strides)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets/pos_strides)
            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(
                pos_decoded_bbox_preds,
                pos_decoded_target_preds,
                weight=pos_centerness_targets,
                avg_factor=pos_centerness_targets.sum())
            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()

        ##########mask loss#################
        flatten_cls_scores1 = [
            cls_score.permute(0, 2, 3, 1).reshape(num_imgs,-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_cls_scores1 = torch.cat(flatten_cls_scores1,dim=1)

        flatten_cof_preds = [
            cof_pred.permute(0, 2, 3, 1).reshape(cof_pred.shape[0],-1, 32*4)
            for cof_pred in cof_preds
        ]

        loss_mask = 0
        loss_iou = 0
        num_iou = 0.1
        flatten_cof_preds = torch.cat(flatten_cof_preds,dim=1)
        for i in range(num_imgs):
            labels = torch.cat([labels_level.flatten() for labels_level in label_list[i]])
            bbox_dt = det_bboxes[i]/2
            bbox_dt = bbox_dt.detach()
            pos_inds = (labels > 0).nonzero().view(-1)
            cof_pred = flatten_cof_preds[i][pos_inds]
            img_mask = feat_masks[i]
            mask_h = img_mask.shape[1]
            mask_w = img_mask.shape[2]
            idx_gt = gt_inds[i]
            bbox_dt = bbox_dt[pos_inds, :4]


            area = (bbox_dt[:, 2] - bbox_dt[:, 0]) * (bbox_dt[:, 3] - bbox_dt[:, 1])
            bbox_dt = bbox_dt[area > 1.0, :]
            idx_gt = idx_gt[area > 1.0]
            cof_pred = cof_pred[area > 1.0]
            if bbox_dt.shape[0] == 0:
                loss_mask += area.sum()*0
                continue

            bbox_gt = gt_bboxes[i]
            cls_score = flatten_cls_scores1[i, pos_inds, labels[pos_inds] - 1].sigmoid().detach()
            cls_score = cls_score[area>1.0]
            pos_inds = pos_inds[area > 1.0]
            ious = bbox_overlaps(bbox_gt[idx_gt]/2, bbox_dt, is_aligned=True)
            with torch.no_grad():
                weighting = cls_score * ious
                weighting = weighting/(torch.sum(weighting)+0.0001)*len(weighting)

            gt_mask = F.interpolate(gt_masks[i].unsqueeze(0), scale_factor=0.5, mode='bilinear', align_corners=False).squeeze(0)

            shape = np.minimum(feat_masks[i].shape, gt_mask.shape)
            gt_mask_new = gt_mask.new_zeros(gt_mask.shape[0], mask_h, mask_w)
            gt_mask_new[:gt_mask.shape[0], :shape[1], :shape[2]] = gt_mask[:gt_mask.shape[0], :shape[1], :shape[2]]
            gt_mask_new = gt_mask_new.gt(0.5).float()

            gt_mask_new = torch.index_select(gt_mask_new,0,idx_gt).permute(1, 2, 0).contiguous()

            #######spp###########################
            img_mask1 = img_mask.permute(1,2,0)
            pos_masks00 = torch.sigmoid(img_mask1 @ cof_pred[:, 0:32].t())
            pos_masks01 = torch.sigmoid(img_mask1 @ cof_pred[:, 32:64].t())
            pos_masks10 = torch.sigmoid(img_mask1 @ cof_pred[:, 64:96].t())
            pos_masks11 = torch.sigmoid(img_mask1 @ cof_pred[:, 96:128].t())
            pred_masks = torch.stack([pos_masks00, pos_masks01, pos_masks10, pos_masks11], dim=0)
            pred_masks = self.crop_cuda(pred_masks, bbox_dt)
            gt_mask_crop = self.crop_gt_cuda(gt_mask_new, bbox_dt)
            # pred_masks, gt_mask_crop = crop_split(pos_masks00, pos_masks01, pos_masks10, pos_masks11, bbox_dt,
            #                                       gt_mask_new)

            pre_loss = F.binary_cross_entropy(pred_masks, gt_mask_crop, reduction='none')
            pos_get_csize = center_size(bbox_dt)
            gt_box_width = pos_get_csize[:, 2]
            gt_box_height = pos_get_csize[:, 3]
            pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height / pos_get_csize.shape[0]
            loss_mask += torch.sum(pre_loss*weighting.detach())


            if self.rescoring_flag:
                pos_labels = labels[pos_inds] - 1
                input_iou = pred_masks.detach().unsqueeze(0).permute(3, 0, 1, 2)
                pred_iou = self.convs_scoring(input_iou)
                pred_iou = self.relu(self.mask_scoring(pred_iou))
                pred_iou = F.max_pool2d(pred_iou, kernel_size=pred_iou.size()[2:]).squeeze(-1).squeeze(-1)
                pred_iou = pred_iou[range(pred_iou.size(0)), pos_labels]
                with torch.no_grad():
                    mask_pred = (pred_masks > 0.4).float()
                    mask_pred_areas = mask_pred.sum((0, 1))
                    overlap_areas = (mask_pred * gt_mask_new).sum((0, 1))
                    gt_full_areas = gt_mask_new.sum((0, 1))
                    iou_targets = overlap_areas / (mask_pred_areas + gt_full_areas - overlap_areas + 0.1)

                    iou_weights = ((iou_targets > 0.1) & (iou_targets <= 1.0) & (gt_full_areas >= 10 * 10)).float()

                loss_iou += self.loss_iou(pred_iou.view(-1, 1), iou_targets.view(-1, 1), iou_weights.view(-1, 1))
                num_iou += torch.sum(iou_weights.detach())
        loss_mask = loss_mask/num_imgs
        if self.rescoring_flag:
            loss_iou = loss_iou * 10 / num_iou.detach()
            return dict(
                loss_cls=loss_cls,
                loss_bbox=loss_bbox,
                loss_centerness=loss_centerness,
                loss_mask=loss_mask,
                loss_iou=loss_iou)
        else:
            return dict(
                loss_cls=loss_cls,
                loss_bbox=loss_bbox,
                loss_centerness=loss_centerness,
                loss_mask=loss_mask)
Exemple #20
0
    def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
                    bbox_targets, stride, soft_targets, num_total_samples):
        """Compute loss of a single scale level.

        Args:
            anchors (Tensor): Box reference for each scale level with shape
                (N, num_total_anchors, 4).
            cls_score (Tensor): Cls and quality joint scores for each scale
                level has shape (N, num_classes, H, W).
            bbox_pred (Tensor): Box distribution logits for each scale
                level with shape (N, 4*(n+1), H, W), n is max value of integral
                set.
            labels (Tensor): Labels of each anchors with shape
                (N, num_total_anchors).
            label_weights (Tensor): Label weights of each anchor with shape
                (N, num_total_anchors)
            bbox_targets (Tensor): BBox regression targets of each anchor wight
                shape (N, num_total_anchors, 4).
            stride (tuple): Stride in this scale level.
            num_total_samples (int): Number of positive samples that is
                reduced over all GPUs.

        Returns:
            dict[tuple, Tensor]: Loss components and weight targets.
        """
        assert stride[0] == stride[1], 'h stride is not equal to w stride!'
        anchors = anchors.reshape(-1, 4)
        cls_score = cls_score.permute(0, 2, 3,
                                      1).reshape(-1, self.cls_out_channels)
        bbox_pred = bbox_pred.permute(0, 2, 3,
                                      1).reshape(-1, 4 * (self.reg_max + 1))
        soft_targets = soft_targets.permute(0, 2, 3,
                                            1).reshape(-1,
                                                       4 * (self.reg_max + 1))

        bbox_targets = bbox_targets.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = ((labels >= 0)
                    & (labels < bg_class_ind)).nonzero().squeeze(1)
        score = label_weights.new_zeros(labels.shape)

        if len(pos_inds) > 0:
            pos_bbox_targets = bbox_targets[pos_inds]
            pos_bbox_pred = bbox_pred[pos_inds]
            pos_anchors = anchors[pos_inds]
            pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]

            weight_targets = cls_score.detach().sigmoid()
            weight_targets = weight_targets.max(dim=1)[0][pos_inds]
            pos_bbox_pred_corners = self.integral(pos_bbox_pred)
            pos_decode_bbox_pred = distance2bbox(pos_anchor_centers,
                                                 pos_bbox_pred_corners)
            pos_decode_bbox_targets = pos_bbox_targets / stride[0]
            score[pos_inds] = bbox_overlaps(pos_decode_bbox_pred.detach(),
                                            pos_decode_bbox_targets,
                                            is_aligned=True)
            pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
            pos_soft_targets = soft_targets[pos_inds]
            soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)

            target_corners = bbox2distance(pos_anchor_centers,
                                           pos_decode_bbox_targets,
                                           self.reg_max).reshape(-1)

            # regression loss
            loss_bbox = self.loss_bbox(pos_decode_bbox_pred,
                                       pos_decode_bbox_targets,
                                       weight=weight_targets,
                                       avg_factor=1.0)

            # dfl loss
            loss_dfl = self.loss_dfl(pred_corners,
                                     target_corners,
                                     weight=weight_targets[:, None].expand(
                                         -1, 4).reshape(-1),
                                     avg_factor=4.0)

            # ld loss
            loss_ld = self.loss_ld(pred_corners,
                                   soft_corners,
                                   weight=weight_targets[:, None].expand(
                                       -1, 4).reshape(-1),
                                   avg_factor=4.0)

        else:
            loss_ld = bbox_pred.sum() * 0
            loss_bbox = bbox_pred.sum() * 0
            loss_dfl = bbox_pred.sum() * 0
            weight_targets = bbox_pred.new_tensor(0)

        # cls (qfl) loss
        loss_cls = self.loss_cls(cls_score, (labels, score),
                                 weight=label_weights,
                                 avg_factor=num_total_samples)

        return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
Exemple #21
0
    def _get_bboxes(self,
                    cls_scores,
                    bbox_preds,
                    centernesses,
                    mlvl_points,
                    img_shapes,
                    scale_factors,
                    cfg,
                    rescale=False,
                    with_nms=True):
        """Transform outputs for a single batch item into bbox predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for a single scale level
                with shape (N, num_points * num_classes, H, W).
            bbox_preds (list[Tensor]): Box energies / deltas for a single scale
                level with shape (N, num_points * 4, H, W).
            centernesses (list[Tensor]): Centerness for a single scale level
                with shape (N, num_points, H, W).
            mlvl_points (list[Tensor]): Box reference for a single scale level
                with shape (num_total_points, 4).
            img_shapes (list[tuple[int]]): Shape of the input image,
                list[(height, width, 3)].
            scale_factors (list[ndarray]): Scale factor of the image arrange as
                (w_scale, h_scale, w_scale, h_scale).
            cfg (mmcv.Config | None): Test / postprocessing configuration,
                if None, test_cfg would be used.
            rescale (bool): If True, return boxes in original image space.
                Default: False.
            with_nms (bool): If True, do nms before return boxes.
                Default: True.

        Returns:
            tuple(Tensor):
                det_bboxes (Tensor): BBox predictions in shape (n, 5), where
                    the first 4 columns are bounding box positions
                    (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
                    between 0 and 1.
                det_labels (Tensor): A (n,) tensor where each item is the
                    predicted class label of the corresponding box.
        """
        cfg = self.test_cfg if cfg is None else cfg
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        device = cls_scores[0].device
        batch_size = cls_scores[0].shape[0]
        # convert to tensor to keep tracing
        nms_pre_tensor = torch.tensor(cfg.get('nms_pre', -1),
                                      device=device,
                                      dtype=torch.long)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        for cls_score, bbox_pred, centerness, points in zip(
                cls_scores, bbox_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(0, 2, 3, 1).reshape(
                batch_size, -1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(0, 2, 3,
                                            1).reshape(batch_size,
                                                       -1).sigmoid()

            bbox_pred = bbox_pred.permute(0, 2, 3,
                                          1).reshape(batch_size, -1, 4)
            # Always keep topk op for dynamic input in onnx
            if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
                                       or scores.shape[-2] > nms_pre_tensor):
                from torch import _shape_as_tensor
                # keep shape as tensor and get k
                num_anchor = _shape_as_tensor(scores)[-2].to(device)
                nms_pre = torch.where(nms_pre_tensor < num_anchor,
                                      nms_pre_tensor, num_anchor)

                max_scores, _ = (scores * centerness[..., None]).max(-1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                batch_inds = torch.arange(batch_size).view(
                    -1, 1).expand_as(topk_inds).long()
                bbox_pred = bbox_pred[batch_inds, topk_inds, :]
                scores = scores[batch_inds, topk_inds, :]
                centerness = centerness[batch_inds, topk_inds]

            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shapes)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)

        batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
        if rescale:
            batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
                scale_factors).unsqueeze(1)
        batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
        batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)

        # Set max number of box to be feed into nms in deployment
        deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
        if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
            batch_mlvl_scores, _ = (
                batch_mlvl_scores *
                batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores)
            ).max(-1)
            _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre)
            batch_inds = torch.arange(batch_mlvl_scores.shape[0]).view(
                -1, 1).expand_as(topk_inds)
            batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :]
            batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :]
            batch_mlvl_centerness = batch_mlvl_centerness[batch_inds,
                                                          topk_inds]

        # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
        # BG cat_id: num_class
        padding = batch_mlvl_scores.new_zeros(batch_size,
                                              batch_mlvl_scores.shape[1], 1)
        batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)

        if with_nms:
            det_results = []
            for (mlvl_bboxes, mlvl_scores,
                 mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores,
                                         batch_mlvl_centerness):
                det_bbox, det_label = multiclass_nms(
                    mlvl_bboxes,
                    mlvl_scores,
                    cfg.score_thr,
                    cfg.nms,
                    cfg.max_per_img,
                    score_factors=mlvl_centerness)
                det_results.append(tuple([det_bbox, det_label]))
        else:
            det_results = [
                tuple(mlvl_bs)
                for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores,
                                   batch_mlvl_centerness)
            ]
        return det_results
Exemple #22
0
    def _get_bboxes_single(self,
                           cls_scores,
                           bbox_preds,
                           centernesses,
                           mlvl_points,
                           img_shape,
                           scale_factor,
                           cfg,
                           rescale=False,
                           with_nms=True):
        """Transform outputs for a single batch item into bbox predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for a single scale level
                with shape (num_points * num_classes, H, W).
            bbox_preds (list[Tensor]): Box energies / deltas for a single scale
                level with shape (num_points * 4, H, W).
            centernesses (list[Tensor]): Centerness for a single scale level
                with shape (num_points * 4, H, W).
            mlvl_points (list[Tensor]): Box reference for a single scale level
                with shape (num_total_points, 4).
            img_shape (tuple[int]): Shape of the input image,
                (height, width, 3).
            scale_factor (ndarray): Scale factor of the image arrange as
                (w_scale, h_scale, w_scale, h_scale).
            cfg (mmcv.Config | None): Test / postprocessing configuration,
                if None, test_cfg would be used.
            rescale (bool): If True, return boxes in original image space.
                Default: False.
            with_nms (bool): If True, do nms before return boxes.
                Default: True.

        Returns:
            tuple(Tensor):
                det_bboxes (Tensor): BBox predictions in shape (n, 5), where
                    the first 4 columns are bounding box positions
                    (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
                    between 0 and 1.
                det_labels (Tensor): A (n,) tensor where each item is the
                    predicted class label of the corresponding box.
        """
        cfg = self.test_cfg if cfg is None else cfg
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        for cls_score, bbox_pred, centerness, points in zip(
                cls_scores, bbox_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
        # BG cat_id: num_class
        mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)

        if with_nms:
            det_bboxes, det_labels = multiclass_nms(
                mlvl_bboxes,
                mlvl_scores,
                cfg.score_thr,
                cfg.nms,
                cfg.max_per_img,
                score_factors=mlvl_centerness)
            return det_bboxes, det_labels
        else:
            return mlvl_bboxes, mlvl_scores, mlvl_centerness
Exemple #23
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None):
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_classes.
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level, each is a 4D-tensor, the channel number is
                num_points * 4.
            centernesses (list[Tensor]): centerness for each scale level, each
                is a 4D-tensor, the channel number is num_points * 1.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = ((flatten_labels >= 0)
                    & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
        num_pos = torch.tensor(len(pos_inds),
                               dtype=torch.float,
                               device=bbox_preds[0].device)
        num_pos = max(reduce_mean(num_pos), 1.0)
        loss_cls = self.loss_cls(flatten_cls_scores,
                                 flatten_labels,
                                 avg_factor=num_pos)

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_centerness_targets = self.centerness_target(pos_bbox_targets)
        # centerness weighted iou loss
        centerness_denorm = max(
            reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)

        if len(pos_inds) > 0:
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            loss_bbox = self.loss_bbox(pos_decoded_bbox_preds,
                                       pos_decoded_target_preds,
                                       weight=pos_centerness_targets,
                                       avg_factor=centerness_denorm)
            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets,
                                                   avg_factor=num_pos)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()

        return dict(loss_cls=loss_cls,
                    loss_bbox=loss_bbox,
                    loss_centerness=loss_centerness)
Exemple #24
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             mask_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_masks=None,
             gt_bboxes_ignore=None,
             gt_centers=None,
             gt_max_centerness=None):
        assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len(
            mask_preds)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        self.num_points_per_level = [i.size()[0] for i in all_level_points]

        labels, bbox_targets, mask_targets, centerness_targets = self.polar_target(
            all_level_points, gt_labels, gt_bboxes, gt_masks, gt_centers,
            gt_max_centerness)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]

        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        if self.use_fourier:
            if self.loss_on_coe:
                flatten_mask_preds = [
                    mask_pred.permute(0, 2, 3, 1).reshape(-1, self.num_coe, 2)
                    for mask_pred in mask_preds
                ]
            else:
                flatten_mask_preds = []
                flatten_bbox_preds = []
                for mask_pred, points in zip(mask_preds, all_level_points):
                    mask_pred = mask_pred.permute(0, 2, 3, 1).reshape(
                        -1, self.num_coe, 2)
                    if self.bbox_from_mask:
                        xy, m = self.distance2mask(points.repeat(num_imgs, 1),
                                                   mask_pred,
                                                   train=True)
                        b = torch.stack([
                            xy[:, 0].min(1)[0], xy[:, 1].min(1)[0],
                            xy[:, 0].max(1)[0], xy[:, 1].max(1)[0]
                        ], -1)
                        flatten_bbox_preds.append(b)
                        flatten_mask_preds.append(m)
                    else:
                        m = torch.irfft(
                            torch.cat([
                                mask_pred,
                                torch.zeros(mask_pred.shape[0],
                                            self.contour_points - self.num_coe,
                                            2).to("cuda")
                            ], 1), 1, True, False).float().exp()
                        flatten_mask_preds.append(m)

        else:
            flatten_mask_preds = [
                mask_pred.permute(0, 2, 3, 1).reshape(-1, self.contour_points)
                for mask_pred in mask_preds
            ]
        if not self.bbox_from_mask:
            flatten_bbox_preds = [
                bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
                for bbox_pred in bbox_preds
            ]

        flatten_cls_scores = torch.cat(flatten_cls_scores)  # [num_pixel, 80]
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)  # [num_pixel, 4]
        flatten_mask_preds = torch.cat(flatten_mask_preds)  # [num_pixel, n]
        flatten_centerness = torch.cat(flatten_centerness)  # [num_pixel]

        flatten_labels = torch.cat(labels).long()  # [num_pixel]
        flatten_centerness_targets = torch.cat(centerness_targets)
        flatten_bbox_targets = torch.cat(bbox_targets)  # [num_pixel, 4]
        flatten_mask_targets = torch.cat(mask_targets)  # [num_pixel, n]
        flatten_points = torch.cat([
            points.repeat(num_imgs, 1) for points in all_level_points
        ])  # [num_pixel,2]
        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)

        loss_cls = self.loss_cls(flatten_cls_scores,
                                 flatten_labels,
                                 avg_factor=num_pos +
                                 num_imgs)  # avoid num_pos is 0
        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        pos_mask_preds = flatten_mask_preds[pos_inds]

        if num_pos > 0:
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            pos_mask_targets = flatten_mask_targets[pos_inds]
            pos_centerness_targets = flatten_centerness_targets[pos_inds]

            pos_points = flatten_points[pos_inds]
            if self.bbox_from_mask:
                pos_decoded_bbox_preds = pos_bbox_preds
            else:
                pos_decoded_bbox_preds = distance2bbox(pos_points,
                                                       pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)

            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(pos_decoded_bbox_preds,
                                       pos_decoded_target_preds,
                                       weight=pos_centerness_targets,
                                       avg_factor=pos_centerness_targets.sum())

            if self.loss_on_coe:
                pos_mask_targets = torch.rfft(pos_mask_targets, 1, True, False)
                pos_mask_targets = pos_mask_targets[..., :self.num_coe, :]
                loss_mask = self.loss_mask(pos_mask_preds, pos_mask_targets)
            else:
                loss_mask = self.loss_mask(
                    pos_mask_preds,
                    pos_mask_targets,
                    weight=pos_centerness_targets,
                    avg_factor=pos_centerness_targets.sum())

            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_mask = pos_mask_preds.sum()
            loss_centerness = pos_centerness.sum()

        return dict(loss_cls=loss_cls,
                    loss_bbox=loss_bbox,
                    loss_mask=loss_mask,
                    loss_centerness=loss_centerness)
Exemple #25
0
    def _get_bboxes_single(self,
                           cls_scores,
                           bbox_preds,
                           mlvl_anchors,
                           img_shape,
                           scale_factor,
                           cfg,
                           rescale=False,
                           nms=True):
        """Transform outputs for a single batch item into labeled boxes.

        Args:
            cls_scores (list[Tensor]): Box scores for a single scale level
                has shape (num_classes, H, W).
            bbox_preds (list[Tensor]): Box distribution logits for a single
                scale level with shape (4*(n+1), H, W), n is max value of
                integral set.
            mlvl_anchors (list[Tensor]): Box reference for a single scale level
                with shape (num_total_anchors, 4).
            img_shape (tuple[int]): Shape of the input image,
                (height, width, 3).
            scale_factor (ndarray): Scale factor of the image arange as
                (w_scale, h_scale, w_scale, h_scale).
            cfg (mmcv.Config | None): Test / postprocessing configuration,
                if None, test_cfg would be used.
            rescale (bool): If True, return boxes in original image space.
                Default: False.

        Returns:
            tuple(Tensor):
                det_bboxes (Tensor): Bbox predictions in shape (N, 5), where
                    the first 4 columns are bounding box positions
                    (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
                    between 0 and 1.
                det_labels (Tensor): A (N,) tensor where each item is the
                    predicted class label of the corresponding box.
        """
        cfg = self.test_cfg if cfg is None else cfg
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
        mlvl_bboxes = []
        mlvl_scores = []
        for cls_score, bbox_pred, stride, anchors in zip(
                cls_scores, bbox_preds, self.anchor_generator.strides,
                mlvl_anchors):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            assert stride[0] == stride[1]

            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            bbox_pred = bbox_pred.permute(1, 2, 0)
            bbox_pred = self.integral(bbox_pred) * stride[0]

            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = scores.max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                anchors = anchors[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]

            bboxes = distance2bbox(self.anchor_center(anchors),
                                   bbox_pred,
                                   max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)

        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)

        mlvl_scores = torch.cat(mlvl_scores)
        # Add a dummy background class to the backend when using sigmoid
        # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
        # BG cat_id: num_class
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)

        if nms:
            det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
                                                    cfg.score_thr, cfg.nms,
                                                    cfg.max_per_img)
            return det_bboxes, det_labels
        else:
            return mlvl_bboxes, mlvl_scores
Exemple #26
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          mask_preds,
                          centernesses,
                          mlvl_points,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_masks = []
        mlvl_centerness = []
        for cls_score, bbox_pred, mask_pred, centerness, points in zip(
                cls_scores, bbox_preds, mask_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()

            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            if self.use_fourier:
                mask_pred = mask_pred.permute(1, 2,
                                              0).reshape(-1, self.num_coe * 2)
            else:
                mask_pred = mask_pred.permute(1, 2, 0).reshape(
                    -1, self.contour_points)
            nms_pre = cfg.get('nms_pre', -1)
            if 0 < nms_pre < scores.shape[0]:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                mask_pred = mask_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            if not self.bbox_from_mask:
                bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
                # masks, _ = self.distance2mask(points, mask_pred, bbox=bboxes)
                masks, _ = self.distance2mask(points,
                                              mask_pred,
                                              max_shape=img_shape)
            else:
                masks, _ = self.distance2mask(points,
                                              mask_pred,
                                              max_shape=img_shape)
                bboxes = torch.stack([
                    masks[:, 0].min(1)[0], masks[:, 1].min(1)[0],
                    masks[:, 0].max(1)[0], masks[:, 1].max(1)[0]
                ], -1)

            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
            mlvl_masks.append(masks)

        mlvl_bboxes = torch.cat(mlvl_bboxes)
        mlvl_masks = torch.cat(mlvl_masks)
        if rescale:
            _mlvl_bboxes = mlvl_bboxes / mlvl_bboxes.new_tensor(scale_factor)
            try:
                # TODO:change cuda
                scale_factor = torch.tensor(scale_factor)[:2].cuda().unsqueeze(
                    1).repeat(1, self.contour_points)
                _mlvl_masks = mlvl_masks / scale_factor
            except (RuntimeError, TypeError, NameError, IndexError):
                _mlvl_masks = mlvl_masks / mlvl_masks.new_tensor(scale_factor)

        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)

        if self.mask_nms:
            '''1 mask->min_bbox->nms, performance same to origin box'''
            _mlvl_bboxes = torch.stack([
                _mlvl_masks[:, 0].min(1)[0], _mlvl_masks[:, 1].min(1)[0],
                _mlvl_masks[:, 0].max(1)[0], _mlvl_masks[:, 1].max(1)[0]
            ], -1)
            det_bboxes, det_labels, det_masks = multiclass_nms_with_mask(
                _mlvl_bboxes,
                mlvl_scores,
                _mlvl_masks,
                cfg.score_thr,
                cfg.nms,
                cfg.max_per_img,
                score_factors=mlvl_centerness + self.centerness_factor)

        else:
            '''2 origin bbox->nms, performance same to mask->min_bbox'''
            det_bboxes, det_labels, det_masks = multiclass_nms_with_mask(
                _mlvl_bboxes,
                mlvl_scores,
                _mlvl_masks,
                cfg.score_thr,
                cfg.nms,
                cfg.max_per_img,
                score_factors=mlvl_centerness + self.centerness_factor)

        return det_bboxes, det_labels, det_masks
Exemple #27
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(
            flatten_cls_scores, flatten_labels,
            avg_factor=num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        pos_centerness_targets = self.centerness_target(pos_bbox_targets)

        if num_pos > 0:
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(
                pos_decoded_bbox_preds,
                pos_decoded_target_preds,
                weight=pos_centerness_targets,
                avg_factor=pos_centerness_targets.sum())
            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()

        return dict(
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_centerness=loss_centerness)
Exemple #28
0
    def loss(self,
             cls_scores,
             bbox_preds,
             bbox_preds_refine,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None):
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box iou-aware scores for each scale
                level, each is a 4D-tensor, the channel number is
                num_points * num_classes.
            bbox_preds (list[Tensor]): Box offsets for each
                scale level, each is a 4D-tensor, the channel number is
                num_points * 4.
            bbox_preds_refine (list[Tensor]): Refined Box offsets for
                each scale level, each is a 4D-tensor, the channel
                number is num_points * 4.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.
                Default: None.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, label_weights, bbox_targets, bbox_weights = self.get_targets(
            cls_scores, all_level_points, gt_bboxes, gt_labels, img_metas,
            gt_bboxes_ignore)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and bbox_preds_refine
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3,
                              1).reshape(-1,
                                         self.cls_out_channels).contiguous()
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4).contiguous()
            for bbox_pred in bbox_preds
        ]
        flatten_bbox_preds_refine = [
            bbox_pred_refine.permute(0, 2, 3, 1).reshape(-1, 4).contiguous()
            for bbox_pred_refine in bbox_preds_refine
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_bbox_preds_refine = torch.cat(flatten_bbox_preds_refine)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        # FG cat_id: [0, num_classes - 1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = torch.where(
            ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)) > 0)[0]
        num_pos = len(pos_inds)

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_preds_refine = flatten_bbox_preds_refine[pos_inds]
        pos_labels = flatten_labels[pos_inds]

        # sync num_pos across all gpus
        if self.sync_num_pos:
            num_pos_avg_per_gpu = reduce_mean(
                pos_inds.new_tensor(num_pos).float()).item()
            num_pos_avg_per_gpu = max(num_pos_avg_per_gpu, 1.0)
        else:
            num_pos_avg_per_gpu = num_pos

        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_points = flatten_points[pos_inds]

        pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
        pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets)
        iou_targets_ini = bbox_overlaps(pos_decoded_bbox_preds,
                                        pos_decoded_target_preds.detach(),
                                        is_aligned=True).clamp(min=1e-6)
        bbox_weights_ini = iou_targets_ini.clone().detach()
        iou_targets_ini_avg_per_gpu = reduce_mean(
            bbox_weights_ini.sum()).item()
        bbox_avg_factor_ini = max(iou_targets_ini_avg_per_gpu, 1.0)

        if num_pos > 0:
            loss_bbox = self.loss_bbox(pos_decoded_bbox_preds,
                                       pos_decoded_target_preds.detach(),
                                       weight=bbox_weights_ini,
                                       avg_factor=bbox_avg_factor_ini)

            pos_decoded_bbox_preds_refine = \
                distance2bbox(pos_points, pos_bbox_preds_refine)
            iou_targets_rf = bbox_overlaps(pos_decoded_bbox_preds_refine,
                                           pos_decoded_target_preds.detach(),
                                           is_aligned=True).clamp(min=1e-6)
            bbox_weights_rf = iou_targets_rf.clone().detach()
            iou_targets_rf_avg_per_gpu = reduce_mean(
                bbox_weights_rf.sum()).item()
            bbox_avg_factor_rf = max(iou_targets_rf_avg_per_gpu, 1.0)
            loss_bbox_refine = self.loss_bbox_refine(
                pos_decoded_bbox_preds_refine,
                pos_decoded_target_preds.detach(),
                weight=bbox_weights_rf,
                avg_factor=bbox_avg_factor_rf)

            # build IoU-aware cls_score targets
            if self.use_vfl:
                pos_ious = iou_targets_rf.clone().detach()
                cls_iou_targets = torch.zeros_like(flatten_cls_scores)
                cls_iou_targets[pos_inds, pos_labels] = pos_ious
        else:
            loss_bbox = pos_bbox_preds.sum() * 0
            loss_bbox_refine = pos_bbox_preds_refine.sum() * 0
            if self.use_vfl:
                cls_iou_targets = torch.zeros_like(flatten_cls_scores)

        if self.use_vfl:
            loss_cls = self.loss_cls(flatten_cls_scores,
                                     cls_iou_targets,
                                     avg_factor=num_pos_avg_per_gpu)
        else:
            loss_cls = self.loss_cls(flatten_cls_scores,
                                     flatten_labels,
                                     weight=label_weights,
                                     avg_factor=num_pos_avg_per_gpu)

        return dict(loss_cls=loss_cls,
                    loss_bbox=loss_bbox,
                    loss_bbox_rf=loss_bbox_refine)
Exemple #29
0
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          centernesses,
                          cof_preds,
                          feat_mask,
                          mlvl_points,
                          img_shape,
                          ori_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        mlvl_cofs = []
        for cls_score, bbox_pred, cof_pred, centerness, points in zip(
                cls_scores, bbox_preds, cof_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            cof_pred = cof_pred.permute(1,2,0).reshape(-1,32*4)

            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                cof_pred = cof_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
            mlvl_cofs.append(cof_pred)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        mlvl_cofs = torch.cat(mlvl_cofs)

        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)

        if self.ssd_flag is False:
            det_bboxes, det_labels, idxs_keep = multiclass_nms_idx(
                mlvl_bboxes,
                mlvl_scores,
                cfg.score_thr,
                cfg.nms,
                cfg.max_per_img,
                score_factors=mlvl_centerness)
        else:
            mlvl_scores = mlvl_scores*mlvl_centerness.view(-1,1)
            det_bboxes, det_labels, det_cofs = self.fast_nms(mlvl_bboxes, mlvl_scores[:, 1:].transpose(1, 0).contiguous(),
                                                             mlvl_cofs, iou_threshold=cfg.nms.iou_thr, score_thr=cfg.score_thr)

        cls_segms = [[] for _ in range(self.num_classes - 1)]
        mask_scores = [[] for _ in range(self.num_classes - 1)]
        if det_bboxes.shape[0]>0:
            scale = 2

            if self.ssd_flag is False:
                det_cofs = mlvl_cofs[idxs_keep]
            #####spp########################
            img_mask1 = feat_mask.permute(1,2,0)
            pos_masks00 = torch.sigmoid(img_mask1 @ det_cofs[:, 0:32].t())
            pos_masks01 = torch.sigmoid(img_mask1 @ det_cofs[:, 32:64].t())
            pos_masks10 = torch.sigmoid(img_mask1 @ det_cofs[:, 64:96].t())
            pos_masks11 = torch.sigmoid(img_mask1 @ det_cofs[:, 96:128].t())
            pos_masks = torch.stack([pos_masks00,pos_masks01,pos_masks10,pos_masks11],dim=0)
            pos_masks = self.crop_cuda(pos_masks, det_bboxes[:,:4] * det_bboxes.new_tensor(scale_factor) / scale)
            # pos_masks = crop_split(pos_masks00, pos_masks01, pos_masks10, pos_masks11,
            #                        det_bboxes * det_bboxes.new_tensor(scale_factor) / scale)

            pos_masks = pos_masks.permute(2, 0, 1)
            # masks = F.interpolate(pos_masks.unsqueeze(0), scale_factor=scale/scale_factor, mode='bilinear', align_corners=False).squeeze(0)
            if self.ssd_flag:
                masks = F.interpolate(pos_masks.unsqueeze(0), scale_factor=scale / scale_factor[3:1:-1], mode='bilinear', align_corners=False).squeeze(0)
            else:
                masks = F.interpolate(pos_masks.unsqueeze(0), scale_factor=scale / scale_factor, mode='bilinear', align_corners=False).squeeze(0)
            masks.gt_(0.4)

            if self.rescoring_flag:
                pred_iou = pos_masks.unsqueeze(1)
                pred_iou = self.convs_scoring(pred_iou)
                pred_iou = self.relu(self.mask_scoring(pred_iou))
                pred_iou = F.max_pool2d(pred_iou, kernel_size=pred_iou.size()[2:]).squeeze(-1).squeeze(-1)
                pred_iou = pred_iou[range(pred_iou.size(0)), det_labels].squeeze()
                mask_scores = pred_iou*det_bboxes[:, -1]
                mask_scores = mask_scores.cpu().numpy()
                mask_scores = [mask_scores[det_labels.cpu().numpy() == i] for i in range(self.num_classes - 1)]

        for i in range(det_bboxes.shape[0]):
            label = det_labels[i]
            mask = masks[i].cpu().numpy()
            im_mask = np.zeros((ori_shape[0], ori_shape[1]), dtype=np.uint8)
            shape = np.minimum(mask.shape, ori_shape[0:2])
            im_mask[:shape[0],:shape[1]] = mask[:shape[0],:shape[1]]
            rle = mask_util.encode(
                np.array(im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms[label].append(rle)

        if self.rescoring_flag:
            return det_bboxes, det_labels, (cls_segms, mask_scores)
        else:
            return det_bboxes, det_labels, cls_segms
Exemple #30
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             mocs,
             gt_bboxes,
             gt_labels,
             img_metas,
             imgs,
             gt_bboxes_ignore=None):
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_classes.
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level, each is a 4D-tensor, the channel number is
                num_points * 4.
            centernesses (list[Tensor]): Centerss for each scale level, each
                is a 4D-tensor, the channel number is num_points * 1.
            mocs (listp[Temspr]): Coefficient for each scale level, each is 
                a 4D-tensor, the channel numer is num_points * 1.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            imgs (list[Tensor]): images in each level.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_moc = [moc.permute(0, 2, 3, 1).reshape(-1) for moc in mocs]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        flatten_mocs = torch.cat(flatten_moc)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = ((flatten_labels >= 0)
                    & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(flatten_cls_scores,
                                 flatten_labels,
                                 avg_factor=num_pos +
                                 num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        num_points = [center.size(0) for center in all_level_points]
        #coefficient
        moc_result_list = self.convertlevel2img(flatten_bbox_targets,
                                                flatten_labels,
                                                flatten_bbox_preds,
                                                flatten_points, flatten_mocs,
                                                num_points, num_imgs)
        flatten_bbox_targets_reshape = moc_result_list[0]
        flatten_labels_targets_reshape = moc_result_list[1]
        flatten_bbox_preds_reshape = moc_result_list[2]
        flatten_points_reshape = moc_result_list[3]
        flatten_conv_reshape = moc_result_list[4]

        #print(num_points,labels,bg_class_ind)
        pos_moc = flatten_mocs[pos_inds]
        if num_pos > 0:

            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)

            assert len(flatten_bbox_targets_reshape) == len(
                flatten_labels_targets_reshape) == len(
                    flatten_bbox_preds_reshape)
            bbox_preds_moc, bbox_targets_moc, conv_moc = self.compute_bbox_per_image(
                flatten_bbox_targets_reshape, flatten_labels_targets_reshape,
                flatten_bbox_preds_reshape, flatten_points_reshape,
                flatten_conv_reshape, bg_class_ind)
            moc_result, conv_mocs, loss_conv_moc_for_clcs = self.moc_overlap(
                bbox_preds_moc, bbox_targets_moc, conv_moc, imgs)
            #print(bbox_preds_moc)
            loss_moc = self.loss_moc(conv_mocs.to(pos_centerness.device),
                                     moc_result.to(pos_centerness.device))
            # centerness weighted iou loss
            #print(moc_result.sum().to(pos_centerness.device),pos_centerness_targets-moc_result.to(pos_centerness_targets.device))
            #for nonzero_index in range(len(moc_result)):
            #   if moc_result[nonzero_index]==0:
            #        moc_result[nonzero_index] =moc_result[nonzero_index]+0.000001
            if moc_result is not None and not torch.any(moc_result > 0.):
                loss_bbox = self.loss_bbox(
                    pos_decoded_bbox_preds,
                    pos_decoded_target_preds,
                    weight=pos_centerness_targets,
                    avg_factor=pos_centerness_targets.sum())

            else:
                loss_bbox = self.loss_bbox(
                    pos_decoded_bbox_preds,
                    pos_decoded_target_preds,
                    #weight=pos_centerness_targets,
                    #avg_factor=pos_centerness_targets.sum())
                    weight=moc_result.to(pos_centerness_targets.device),
                    avg_factor=moc_result.sum().to(
                        pos_centerness_targets.device))

            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()
            loss_moc = pos_moc.sum()

        #loss_centerness = loss_centerness -loss_centerness+0.000001
        #loss_moc = loss_moc-loss_moc
        return dict(loss_cls=loss_cls,
                    loss_bbox=loss_bbox,
                    loss_moc=loss_moc,
                    loss_centerness=loss_centerness)