示例#1
0
    def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
             cfg):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_labels_list=gt_labels,
                                        cls_out_channels=self.cls_out_channels,
                                        sampling=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        losses_cls, losses_reg = multi_apply(self.loss_single,
                                             cls_scores,
                                             bbox_preds,
                                             labels_list,
                                             label_weights_list,
                                             bbox_targets_list,
                                             bbox_weights_list,
                                             num_pos_samples=num_total_pos,
                                             cfg=cfg)
        return dict(loss_cls=losses_cls, loss_reg=losses_reg)
    def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, gt_bboxes_8_coo, img_shapes, coo_num, cfg):
        featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_shapes)

        cls_reg_targets = anchor_target(
            anchor_list, valid_flag_list, gt_bboxes, gt_bboxes_8_coo, img_shapes,
            self.target_means, self.target_stds, coo_num, cfg)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            rpn_cls_scores,
            rpn_bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_pos + num_total_neg,
            coo_num=coo_num,
            cfg=cfg)
        return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
示例#3
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=1,
                                        sampling=False,
                                        unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
            s.permute(0, 2, 3, 1).reshape(
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        all_label_weights = torch.cat(label_weights_list,
                                      -1).view(num_images, -1)
        all_bbox_preds = torch.cat([
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
            for b in bbox_preds
        ], -2)
        all_bbox_targets = torch.cat(bbox_targets_list,
                                     -2).view(num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list,
                                     -2).view(num_images, -1, 4)

        losses_cls, losses_bbox = multi_apply(self.loss_single,
                                              all_cls_scores,
                                              all_bbox_preds,
                                              all_labels,
                                              all_label_weights,
                                              all_bbox_targets,
                                              all_bbox_weights,
                                              num_total_samples=num_total_pos,
                                              cfg=cfg)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#4
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)

        # anchor number of multi levels
        num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
        # concat all level anchors and flags to a single tensor
        concat_anchor_list = []
        for i in range(len(anchor_list)):
            concat_anchor_list.append(torch.cat(anchor_list[i]))
        all_anchor_list = images_to_levels(concat_anchor_list,
                                           num_level_anchors)

        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)

        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            all_anchor_list,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#5
0
文件: anchor_head.py 项目: zyg11/TSD
    def loss(
        self,
        cls_scores,
        bbox_preds,
        gt_bboxes,
        gt_labels,
        img_metas,
        cfg,
        gt_bboxes_ignore=None,
    ):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling,
        )
        if cls_reg_targets is None:
            return None
        (
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_pos,
            num_total_neg,
        ) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg,
        )
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#6
0
    def loss(self,
             features,
             cls_scores,
             bbox_preds,
             teacher_features,
             teacher_cls_scores,
             teacher_bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        sampling = False if self.use_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos if self.use_focal_loss else
                             num_total_pos + num_total_neg)
        losses_cls, losses_kd_cls, losses_reg = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            teacher_cls_scores,
            teacher_bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        losses_at, = multi_apply(
            attention_loss,
            features,
            teacher_features,
            beta=cfg.teacher.beta)
        return dict(loss_cls=losses_cls, loss_kd_cls=losses_kd_cls, loss_reg=losses_reg,
                    loss_at=losses_at)
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None,
             iteration=None):
        # 3D images
        featmap_sizes = [featmap.size()[-3:] for featmap in cls_scores]

        assert len(featmap_sizes) == len(self.anchor_generators)
        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        sampling = False if self.use_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling,
                                        iteration=iteration)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg, anchors_list,
         inside_flags) = cls_reg_targets
        self.pos_indices = inside_flags
        num_total_samples = (num_total_pos if self.use_focal_loss else
                             num_total_pos + num_total_neg)
        level = {'level': 0}
        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            anchors_list,
            num_total_samples=num_total_samples,
            cfg=cfg,
            level=level,
            gt_bboxes=gt_bboxes,
            iteration=iteration)
        return dict(loss_cls=losses_cls, loss_reg=losses_reg)
示例#8
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device
        # import time
        # import datetime
        # t1 = time.time()
        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        # print('get_anchors in ', str(datetime.timedelta(seconds=time.time() - t1)))
        # t2 = time.time()
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        # print('anchor_target in ', str(datetime.timedelta(seconds=time.time() - t2)))
        # t3 = time.time()
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        # print('loss_single in ', str(datetime.timedelta(seconds=time.time() - t3)))
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#9
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        del featmap_sizes
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        del anchor_list, valid_flag_list, gt_bboxes, gt_bboxes_ignore, gt_labels, \
            label_channels
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        if self.topk_func is not None and hasattr(
                self, self.topk_func):  # add topk_func
            topk_func = getattr(self, self.topk_func)
            self.ignore_topk = topk_func(self.init_ignore_topk, self.iter,
                                         self.max_iters)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            img_metas,
            num_total_samples=num_total_samples,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#10
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        sampling = False if self.use_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        # first get cls_reg_targets by anchor_target, there're the gt deltas.
        # get target_deltas through gt_bboxes and anchor_list
        # first sample the positive anchors by calculating the IoUs,
        # then assign anchors to gt, then calculate the target deltas.
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos if self.use_focal_loss else
                             num_total_pos + num_total_neg)
        # the predict values are the deltas,
        # so just use the target to calculate the losses.
        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_reg=losses_reg)
示例#11
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return dict(loss_cls=0, loss_reg=0)
        # each tensor in the list is corresponding to a level.
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single_level,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        # each tensor in losses_cls and losses_reg is corresponding to a level.
        return {'losses/loss_cls': losses_cls, 'losses/loss_bbox': losses_bbox}
示例#12
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):#该函数对多个预测结果使用
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]#多少个结果,每个结果的大小
        assert len(featmap_sizes) == len(self.anchor_generators)#保证特征层个数和anchor层数一致,保证anchor层和结果层一一对应

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas, device=device)#返回所有的anchor和有效的anchor索引
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling)#使用GT来编码anchor
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)#所有的样本个数
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)#这里就是使用loss_single函数来计算两个损失
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#13
0
    def embed_loss(self,
                   embed_feats_list,
                   gt_bboxes,
                   gt_trackids,
                   img_metas,
                   cfg,
                   gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in embed_feats_list]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = embed_feats_list[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cache_assigner = cfg.assigner
        cfg.assigner = cfg.track_assigner
        embed_targets = anchor_target(anchor_list,
                                      valid_flag_list,
                                      gt_bboxes,
                                      img_metas,
                                      self.target_means,
                                      self.target_stds,
                                      cfg,
                                      gt_bboxes_ignore_list=gt_bboxes_ignore,
                                      gt_labels_list=gt_trackids,
                                      label_channels=label_channels,
                                      sampling=self.sampling)
        cfg.assigner = cache_assigner
        if embed_targets is None:
            return None
        (track_targets, track_targets_weights_list, bbox_targets_list,
         bbox_weights_list, num_total_pos, num_total_neg) = embed_targets
        losses_triplet = list(
            map(self.embed_loss_single, *(embed_feats_list, track_targets)))
        return dict(loss_triplet=losses_triplet)
    def loss(self,
             cls_scores,
             bbox_preds,
             shape_preds,
             loc_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.approx_generators)

        # get loc targets
        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.octave_base_scale,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)

        # get sampled approxes
        approxs_list, inside_flag_list = self.get_sampled_approxs(
            featmap_sizes, img_metas, cfg)
        # get squares and guided anchors
        squares_list, guided_anchors_list, _ = self.get_anchors(
            featmap_sizes, shape_preds, loc_preds, img_metas)

        # get shape targets
        sampling = False if not hasattr(cfg, 'ga_sampler') else True
        shape_targets = ga_shape_target(approxs_list,
                                        inside_flag_list,
                                        squares_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.approxs_per_octave,
                                        cfg,
                                        sampling=sampling)
        if shape_targets is None:
            return None
        (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
         anchor_bg_num) = shape_targets
        anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num +
                            anchor_bg_num)

        # get anchor targets
        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(guided_anchors_list,
                                        inside_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling)
        if cls_reg_targets is None:
            return None
        # (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
        #  num_total_pos, num_total_neg) = cls_reg_targets

        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg, level_anchor_list) = cls_reg_targets

        num_total_samples = (num_total_pos if self.cls_focal_loss else
                             num_total_pos + num_total_neg)
        # added by WSK
        num_total_samples_cls = (num_total_pos_cls if self.cls_focal_loss else
                                 num_total_pos_cls + num_total_neg_cls)

        # get classification and bbox regression losses
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            level_anchor_list,  # added by Shengkai Wu
            bbox_targets_cls_list,  # added by Shengkai Wu
            bbox_weights_cls_list,  # added by Shengkai Wu
            num_total_samples_cls=num_total_samples_cls,  # added by Shengkai Wu
            num_total_samples=num_total_samples,
            gt_bboxes=gt_bboxes,  # added by Shengkai Wu
            cfg=cfg)

        # get anchor location loss
        losses_loc = []
        for i in range(len(loc_preds)):
            loss_loc = self.loss_loc_single(loc_preds[i],
                                            loc_targets[i],
                                            loc_weights[i],
                                            loc_avg_factor=loc_avg_factor,
                                            cfg=cfg)
            losses_loc.append(loss_loc)

        # get anchor shape loss
        losses_shape = []
        for i in range(len(shape_preds)):
            loss_shape = self.loss_shape_single(
                shape_preds[i],
                bbox_anchors_list[i],
                bbox_gts_list[i],
                anchor_weights_list[i],
                anchor_total_num=anchor_total_num)
            losses_shape.append(loss_shape)

        return dict(loss_cls=losses_cls,
                    loss_bbox=losses_bbox,
                    loss_shape=losses_shape,
                    loss_loc=losses_loc)
示例#15
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        anchor_list0 = anchor_list.copy()
        valid_flag_list0 = valid_flag_list.copy()
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1

        cfg['sampler']['pos_fraction'] = 0.5
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        if 1:
            lvls = len(anchor_list0)
            bboxes_lst = [[] for k in range(lvls)]
            for z in range(lvls):
                for k in range(len(anchor_list0[z])):
                    bbox_pred = bbox_preds[k][z].permute(1, 2,
                                                         0).reshape(-1, 4)
                    a = anchor_list0[z][k]
                    # print(a.shape,bbox_pred.shape)
                    bboxes = delta2bbox(anchor_list0[z][k], bbox_pred,
                                        self.target_means, self.target_stds)
                    bboxes_lst[z].append(bboxes.clone())
                    zz = 0
            # label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
            cfg['sampler']['pos_fraction'] = 0.55  #0.55
            cls_reg_targets = anchor_target(
                bboxes_lst,
                valid_flag_list0,
                gt_bboxes,
                img_metas,
                self.target_means,
                self.target_stds,
                cfg,
                gt_bboxes_ignore_list=gt_bboxes_ignore,
                gt_labels_list=gt_labels,
                label_channels=label_channels,
                sampling=self.sampling)
            if cls_reg_targets is None:
                return None
            (labels_list, label_weights_list, bbox_targets_list,
             bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
            num_total_samples = (num_total_pos + num_total_neg
                                 if self.sampling else num_total_pos)
            losses_cls2, losses_bbox2 = multi_apply(
                self.loss_single,
                cls_scores,
                bbox_preds,
                labels_list,
                label_weights_list,
                bbox_targets_list,
                bbox_weights_list,
                num_total_samples=num_total_samples,
                cfg=cfg)
            for k in range(lvls):
                losses_cls[k] = losses_cls[k] / 2 + losses_cls2[k] / 2
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#16
0
    def loss(self,
             cls_scores,
             bbox_preds,
             shape_preds,
             loc_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.approx_generators)

        device = cls_scores[0].device

        # get loc targets
        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.octave_base_scale,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)

        # get sampled approxes
        approxs_list, inside_flag_list = self.get_sampled_approxs(
            featmap_sizes, img_metas, cfg, device=device)
        # get squares and guided anchors
        squares_list, guided_anchors_list, _ = self.get_anchors(
            featmap_sizes, shape_preds, loc_preds, img_metas, device=device)

        # get shape targets
        sampling = False if not hasattr(cfg, 'ga_sampler') else True
        shape_targets = ga_shape_target(
            approxs_list,
            inside_flag_list,
            squares_list,
            gt_bboxes,
            img_metas,
            self.approxs_per_octave,
            cfg,
            sampling=sampling)
        if shape_targets is None:
            return None
        (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
         anchor_bg_num) = shape_targets
        anchor_total_num = (
            anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)

        # get anchor targets
        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        # get classification and bbox regression losses
        cfg['ga_sampler']['pos_fraction'] = 0.5  # 0.55
        cfg['sampler']['pos_fraction'] = 0.5  # 0.55
        cls_reg_targets = anchor_target(
            guided_anchors_list,
            inside_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos if self.cls_focal_loss else num_total_pos +
            num_total_neg)


        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)

        # get anchor location loss
        losses_loc = []
        for i in range(len(loc_preds)):
            loss_loc = self.loss_loc_single(
                loc_preds[i],
                loc_targets[i],
                loc_weights[i],
                loc_avg_factor=loc_avg_factor,
                cfg=cfg)
            losses_loc.append(loss_loc)

        # get anchor shape loss
        losses_shape = []
        for i in range(len(shape_preds)):
            loss_shape = self.loss_shape_single(
                shape_preds[i],
                bbox_anchors_list[i],
                bbox_gts_list[i],
                anchor_weights_list[i],
                anchor_total_num=anchor_total_num)
            losses_shape.append(loss_shape)
        if 1:
            # get sampled approxes
            approxs_list, inside_flag_list = self.get_sampled_approxs(
                featmap_sizes, img_metas, cfg, device=device)
            # get squares and guided anchors
            squares_list, guided_anchors_list, _ = self.get_anchors(
                featmap_sizes, shape_preds, loc_preds, img_metas, device=device)

            # get shape targets
            sampling = False if not hasattr(cfg, 'ga_sampler') else True
            shape_targets = ga_shape_target(
                approxs_list,
                inside_flag_list,
                squares_list,
                gt_bboxes,
                img_metas,
                self.approxs_per_octave,
                cfg,
                sampling=sampling)
            if shape_targets is None:
                return None
            (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
             anchor_bg_num) = shape_targets
            anchor_total_num = (
                anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)

            # get anchor targets
            sampling = False if self.cls_focal_loss else True
            label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
            lvls = len(guided_anchors_list)
            bboxes_lst = [[] for k in range(lvls)]
            for z in range(lvls):
                for k in range(len(guided_anchors_list[z])):
                    bbox_pred = bbox_preds[k][z].permute(1, 2, 0).reshape(-1, 4)
                    a = guided_anchors_list[z][k]
                    # print(a.shape,bbox_pred.shape)
                    bboxes = delta2bbox(guided_anchors_list[z][k], bbox_pred, self.target_means,
                                        self.target_stds)
                    bboxes_lst[z].append(bboxes.clone())
                    zz = 0
            # label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
            # label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
            cfg['ga_sampler']['pos_fraction'] = 0.55  # 0.55
            cfg['sampler']['pos_fraction'] = 0.6  # 0.55
            cls_reg_targets = anchor_target(
                guided_anchors_list,
                inside_flag_list,
                gt_bboxes,
                img_metas,
                self.target_means,
                self.target_stds,
                cfg,
                gt_bboxes_ignore_list=gt_bboxes_ignore,
                gt_labels_list=gt_labels,
                label_channels=label_channels,
                sampling=sampling)
            if cls_reg_targets is None:
                return None
            (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
             num_total_pos, num_total_neg) = cls_reg_targets
            num_total_samples = (
                num_total_pos if self.cls_focal_loss else num_total_pos +
                                                          num_total_neg)
            # get classification and bbox regression losses
            losses_cls2, losses_bbox2 = multi_apply(
                self.loss_single,
                cls_scores,
                bbox_preds,
                labels_list,
                label_weights_list,
                bbox_targets_list,
                bbox_weights_list,
                num_total_samples=num_total_samples,
                cfg=cfg)
            zz=0
            for k in range(lvls):
                losses_cls[k] = losses_cls[k] / 2 + losses_cls2[k] / 2
        return dict(
            loss_cls=losses_cls,
            loss_bbox=losses_bbox,
            loss_shape=losses_shape,
            loss_loc=losses_loc)
示例#17
0
    def loss(self,
             fam_cls_scores,
             fam_bbox_preds,
             refine_anchors,
             odm_cls_scores,
             odm_bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in odm_cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        # check for size zero boxes
        for img_nr in range(len(gt_bboxes)):
            zero_inds = gt_bboxes[img_nr][:, 2:4] == 0
            gt_bboxes[img_nr][:, 2:4][zero_inds] = 1

        device = odm_cls_scores[0].device

        anchor_list, valid_flag_list = self.get_init_anchors(featmap_sizes,
                                                             img_metas,
                                                             device=device)

        # anchor number of multi levels
        num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
        # concat all level anchors and flags to a single tensor
        concat_anchor_list = []
        for i in range(len(anchor_list)):
            concat_anchor_list.append(torch.cat(anchor_list[i]))
        all_anchor_list = images_to_levels(concat_anchor_list,
                                           num_level_anchors)

        # Feature Alignment Module
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg.fam_cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)

        losses_fam_cls, losses_fam_bbox = multi_apply(
            self.loss_fam_single,
            fam_cls_scores,
            fam_bbox_preds,
            all_anchor_list,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg.fam_cfg)

        # Oriented Detection Module targets
        refine_anchors_list, valid_flag_list = self.get_refine_anchors(
            featmap_sizes, refine_anchors, img_metas, device=device)

        # anchor number of multi levels
        num_level_anchors = [
            anchors.size(0) for anchors in refine_anchors_list[0]
        ]
        # concat all level anchors and flags to a single tensor
        concat_anchor_list = []
        for i in range(len(refine_anchors_list)):
            concat_anchor_list.append(torch.cat(refine_anchors_list[i]))
        all_anchor_list = images_to_levels(concat_anchor_list,
                                           num_level_anchors)

        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(refine_anchors_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg.odm_cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)

        losses_odm_cls, losses_odm_bbox = multi_apply(
            self.loss_odm_single,
            odm_cls_scores,
            odm_bbox_preds,
            all_anchor_list,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg.odm_cfg)

        self.last_vals = dict(
            gt_bboxes=gt_bboxes,
            gt_labels=gt_labels,
            img_metas=img_metas,
            fam_cls_scores=fam_cls_scores,
            fam_bbox_preds=fam_bbox_preds,
            refine_anchors=refine_anchors,
            odm_cls_scores=odm_cls_scores,
            odm_bbox_preds=odm_bbox_preds,
        )
        if sum(losses_fam_cls) > 1E10 or \
           sum(losses_fam_bbox) > 1E10 or \
           sum(losses_odm_cls) > 1E10 or \
           sum(losses_odm_bbox) > 1E10:
            print("bad loss")
        return dict(loss_fam_cls=losses_fam_cls,
                    loss_fam_bbox=losses_fam_bbox,
                    loss_odm_cls=losses_odm_cls,
                    loss_odm_bbox=losses_odm_bbox)
示例#18
0
    def loss(self,
             cls_scores,
             bbox_preds,
             shape_preds,
             loc_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.approx_generators)

        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.octave_base_scale,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)
        (approxs_list, valid_flag_list, base_approxs_list,
         guided_anchors_list) = self.get_anchors(featmap_sizes, shape_preds,
                                                 img_metas)

        sampling = False if not hasattr(cfg, 'ga_sampler') else True
        shape_targets = ga_shape_target(approxs_list,
                                        valid_flag_list,
                                        base_approxs_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.approxs_per_octave,
                                        cfg,
                                        sampling=sampling)
        if shape_targets is None:
            return None
        (bbox_anchors_list, bbox_gts_list, anchor_weights_list,
         all_inside_flags, anchor_fg_num, anchor_ng_num) = shape_targets
        anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num +
                            anchor_ng_num)

        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.cls_sigmoid_loss else 1
        cls_reg_targets = anchor_target(guided_anchors_list,
                                        all_inside_flags,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos if self.cls_focal_loss else
                             num_total_pos + num_total_neg)
        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        losses_loc, = multi_apply(self.loss_loc_single,
                                  loc_preds,
                                  loc_targets,
                                  loc_weights,
                                  loc_avg_factor=loc_avg_factor,
                                  cfg=cfg)
        losses_shape, = multi_apply(self.loss_shape_single,
                                    shape_preds,
                                    bbox_anchors_list,
                                    bbox_gts_list,
                                    anchor_weights_list,
                                    anchor_total_num=anchor_total_num)
        return dict(loss_cls=losses_cls,
                    loss_reg=losses_reg,
                    loss_shape=losses_shape,
                    loss_loc=losses_loc)
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)
        # img_meta = dict(
        #     ori_shape=ori_shape,
        #     img_shape=img_shape,
        #     pad_shape=pad_shape,
        #     scale_factor=scale_factor,
        #     flip=flip)
        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=1,
            sampling=False,
            unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        # cls_scores:[B(num_images), C, H, W] => [B, H, W, C] => [B, self.num_anchors x H x W (8732), self.cls_out_channels]
        all_cls_scores = torch.cat([
            s.permute(0, 2, 3, 1).reshape(
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        # all_labels : [num_images, 8732]
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        # all_label_weights : [num_images, 8732]
        all_label_weights = torch.cat(label_weights_list, -1).view(
            num_images, -1)
        # bbox_preds: [B(num_images), C, H, W] => [B, H, W, C] => [B, self.num_anchors x H x W (8732), 4]
        all_bbox_preds = torch.cat([
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
            for b in bbox_preds
        ], -2)
        # all_bbox_targets:[num_images, 8732, 4]
        all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
            num_images, -1, 4)
        # all_bbox_weights:[num_images, 8732, 4]
        all_bbox_weights = torch.cat(bbox_weights_list, -2).view(
            num_images, -1, 4)

        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
            num_total_samples=num_total_pos,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_reg=losses_reg)
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        """

        :param cls_scores: list of tensors. len(cls_scores) is the number of levels for the feature pyramid.
          cls_scores[i].size() = [num_images, A*cls_out_channels, width_i, height_i]
        :param bbox_preds: list of tensors. len(bbox_preds) is the number of levels for the feature pyramid.
              bbox_preds[i].size() = [num_images, A*4, width_i, height_i] or [num_images, A*5, width_i, height_i]
        :param gt_bboxes:
        :param gt_labels:
        :param img_metas:
        :param cfg:
        :param gt_bboxes_ignore:
        :return:
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=1,
                                        sampling=False,
                                        unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        # modifyied by Shengkai Wu
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg, level_anchor_list) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
            s.permute(0, 2, 3, 1).reshape(
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        all_label_weights = torch.cat(label_weights_list,
                                      -1).view(num_images, -1)

        all_bbox_preds = torch.cat([
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
            for b in bbox_preds
        ], -2)

        all_bbox_targets = torch.cat(bbox_targets_list,
                                     -2).view(num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list,
                                     -2).view(num_images, -1, 4)
        # added by Shengkai Wu
        all_anchors = torch.cat(level_anchor_list, -2).view(num_images, -1, 4)
        # num_total_examples = num_total_pos + num_total_neg
        # all_cls_scores.size() = [num_images,num_total_examples , cls_out_channels]
        # all_bbox_preds.size() = [num_image, num_total_examples, 4]
        # all_labels.size() = [num_images, num_total_examples]
        # all_label_weights.size() = [num_images, num_total_examples]
        # all_bbox_preds.size() = [num_images, num_total_examples, 4]
        # all_bbox_targets.size() = [num_images, num_total_examples, 4]
        # all_bbox_weights.size() = [num_images, num_total_examples, 4]

        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
            all_anchors,  #added by Shengkai Wu
            num_total_samples=num_total_pos,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#21
0
    def loss(self,
             fam_cls_scores,
             fam_bbox_preds,
             refine_anchors,
             odm_cls_scores,
             odm_bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in odm_cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)
        device = odm_cls_scores[0].device

        anchors_list, valid_flag_list = self.get_init_anchors(featmap_sizes,
                                                              img_metas,
                                                              device=device)

        # Feature Alignment Module
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchors_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg.fam_cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)

        losses_fam_cls, losses_fam_bbox = multi_apply(
            self.loss_fam_single,
            fam_cls_scores,
            fam_bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg.fam_cfg)

        # Oriented Detection Module targets
        refine_anchors_list, valid_flag_list = self.get_refine_anchors(
            featmap_sizes, refine_anchors, img_metas, device=device)

        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(refine_anchors_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg.odm_cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)

        losses_odm_cls, losses_odm_bbox = multi_apply(
            self.loss_odm_single,
            odm_cls_scores,
            odm_bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg.odm_cfg)

        return dict(loss_fam_cls=losses_fam_cls,
                    loss_fam_bbox=losses_fam_bbox,
                    loss_odm_cls=losses_odm_cls,
                    loss_odm_bbox=losses_odm_bbox)
示例#22
0
文件: anchor_head.py 项目: qinr/MRDet
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        # 生成anchor,将anchor与gt_bbox匹配,生成正负样本,计算gt_delta(anchor与其对应的gt_bbox的delta)
        # 得到每个anchor的label/label_weights/delta/delta_weights,以及pos_inds/neg_inds
        # 利用这些使用交叉熵和SmoothL1Loss,得到loss
        # 其中,get_anchor完成了anchor生成
        # anchor_target完成了anchor与gt_bbox匹配(assigner) + 正负样本生成(sampler)+ gt_delta计算
        # + 得到每个anchor的label/label_weights/delta/delta_weights,以及pos_inds/neg_inds(返回值)
        # loss_single根据anchor_target的返回值完成了loss计算
        featmap_sizes = [featmap.size()[-2:]
                         for featmap in cls_scores]  # 特征图的w/h
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        # anchor_list:每张图的anchor
        # valid_flag_list:anchor是否有效的flag,是否有效是根据bbox是否超出图片来计算的
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1

        # anchor_target: 将anchor和gt_bbox匹配,得到正样本和负样本,并用sampler将这些结果进行封装,方便之后使用
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling)
        # cls_reg_target包含:
        # labels:每个anchor对应的label
        # label_weights:每个anchor cls_loss的权重,负样本权重为1,正样本权重可为1也可为其他值
        # bbox_targets:每个anchor与其对应的gt_bbox之前的delta,用于回归
        # bbox_weights: 每个anchor bbox_reg的权重,正样本为1,负样本为0
        # pos_inds:anchor中正样本的索引
        # neg_inds: anchor中负样本的索引
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        del cls_reg_targets, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#23
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=1,
                                        sampling=False,
                                        unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
            s.permute(0, 2, 3, 1).reshape(
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        all_label_weights = torch.cat(label_weights_list,
                                      -1).view(num_images, -1)
        all_bbox_preds = torch.cat([
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
            for b in bbox_preds
        ], -2)
        all_bbox_targets = torch.cat(bbox_targets_list,
                                     -2).view(num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list,
                                     -2).view(num_images, -1, 4)

        # check NaN and Inf
        assert torch.isfinite(all_cls_scores).all().item(), \
            'classification scores become infinite or NaN!'
        assert torch.isfinite(all_bbox_preds).all().item(), \
            'bbox predications become infinite or NaN!'

        losses_cls, losses_bbox = multi_apply(self.loss_single_level,
                                              all_cls_scores,
                                              all_bbox_preds,
                                              all_labels,
                                              all_label_weights,
                                              all_bbox_targets,
                                              all_bbox_weights,
                                              num_total_samples=num_total_pos,
                                              cfg=cfg)
        return {
            'rpn_losses/cls_loss': losses_cls,
            'rpn_losses/bbox_loss': losses_bbox
        }
示例#24
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]

        assert len(featmap_sizes) == len(self.anchor_generators)
        #print(len(img_metas))
        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        """
        if cfg.online_select:
            cls_reg_targets = self.anchor_target(
                cls_scores.copy(),
                bbox_preds.copy(),
                anchor_list,
                valid_flag_list,
                gt_bboxes,
                img_metas,
                self.target_means,
                self.target_stds,
                cfg,
                gt_bboxes_ignore_list=gt_bboxes_ignore,
                gt_labels_list=gt_labels,
                label_channels=label_channels,
                sampling=self.sampling)
        else:
         """
        cls_reg_targets = anchor_target(
            #cls_scores.copy(),
            #bbox_preds.copy(),
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        #print(num_total_samples)
        losses_cls, losses_bbox = multi_apply(  # levels
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#25
0
 def generate_anchor(self):
     anchor_target()
示例#26
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        if self.use_forest and isinstance(cls_scores[0], list):
            featmap_sizes = [featmap[-1].size()[-2:] for featmap in cls_scores]
            device = cls_scores[0][-1].device

            num_im = cls_scores[0][-1].shape[0]
            anchor_cls_scores = [[] for i in range(num_im)]
            for i in range(num_im):
                for j in range(len(cls_scores)):
                    anchor_cls_scores[i].append(
                        cls_scores[j][-1][i, ...].permute(1, 2, 0).reshape(
                            -1, self.cls_out_channels))
        else:
            featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
            device = cls_scores[0].device
            anchor_cls_scores = None
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=self.sampling,
                                        cls_scores=anchor_cls_scores)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos +
                             num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox, parent_loss, forest = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        # print(type(parent_loss), len(parent_loss), len(parent_loss[0]))
        if forest:
            all_loss = dict()
            all_loss['loss_bbox'] = losses_bbox
            all_loss['loss_fine_grained_cls'] = losses_cls
            for tree_idx in range(len(self.all_classes_num) - 1):
                parent_loss_cls = []
                for i in range(len(parent_loss)):
                    parent_loss_cls.append(parent_loss[i][tree_idx])
                all_loss['loss_tree{}_parent_cls'.format(tree_idx +
                                                         1)] = parent_loss_cls
            return all_loss
        else:
            return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
    def loss_first(self,
                   loc_preds,
                   bbox_preds,
                   gt_bboxes,
                   img_metas,
                   cfg,
                   gt_labels=None,
                   gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = bbox_preds[0].device

        # get loc targets
        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.anchor_scales,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        # self.num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
        # get anchor targets
        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos if self.cls_focal_loss else
                             num_total_pos + num_total_neg)
        losses_bbox_first = []
        # get bbox regression losses on the  first stage
        for i in range(len(bbox_preds)):
            loss_bbox = self.loss_first_single(
                bbox_preds[i],
                bbox_targets_list[i],
                bbox_weights_list[i],
                num_total_samples=num_total_samples,
                cfg=cfg)
            losses_bbox_first.append(loss_bbox)

        # get anchor location loss
        losses_loc = []
        for i in range(len(loc_preds)):
            loss_loc = self.loss_loc_single(loc_preds[i],
                                            loc_targets[i],
                                            loc_weights[i],
                                            loc_avg_factor=loc_avg_factor,
                                            cfg=cfg)
            losses_loc.append(loss_loc)
        return dict(loss_rpn_bbox_first=losses_bbox_first, loss_loc=losses_loc)
示例#28
0
    def loss(self,
             fam_cls_scores,
             fam_bbox_preds,
             refine_anchors,
             odm_cls_scores,
             odm_bbox_preds,
             bboxes,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in odm_cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)
        device = odm_cls_scores[0].device

        anchors_list, valid_flag_list = self.get_init_anchors(
            featmap_sizes, img_metas, device=device)
        refine_anchors_list, valid_flag_list = self.get_refine_anchors(
            featmap_sizes, refine_anchors, img_metas, device=device)
        # anchor number of multi levels
        num_level_anchors = [anchors.size(0) for anchors in anchors_list[0]]
        # Feature Alignment Module
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(
            anchors_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg.fam_cfg,
            refine_anchors_list,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling,
            reg_decoded_bbox=self.reg_decoded_bbox,
            use_vfl=False)
        if cls_reg_targets is None:
            return None
        #labels_list 是一个list,长度为feature map levels数,一般为5,里面的每一个元素是一个tensor,shape为[bs,anchor_nums],其余的类似
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)
        all_anchor_list = images_to_levels(anchors_list, num_level_anchors)
        losses_fam_cls, losses_fam_bbox = multi_apply(
            self.loss_fam_single,
            fam_cls_scores,
            fam_bbox_preds,
            all_anchor_list,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg.fam_cfg)

        # Oriented Detection Module targets
        refine_anchors_list, valid_flag_list = self.get_refine_anchors(
            featmap_sizes, refine_anchors, img_metas, device=device)
        num_level_anchors = [anchors.size(0) for anchors in refine_anchors_list[0]]
        output_bboxes, _ = self.get_refine_anchors(
            featmap_sizes, bboxes, img_metas, device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(
            refine_anchors_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg.odm_cfg,
            output_bboxes,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling,
            reg_decoded_bbox=self.reg_decoded_bbox,
            use_vfl=self.use_vfl)
        # cls_reg_targets = anchor_target_atss(
        #     refine_anchors_list,
        #     valid_flag_list,
        #     gt_bboxes,
        #     img_metas,
        #     self.target_means,
        #     self.target_stds,
        #     cfg.odm_cfg,
        #     gt_bboxes_ignore_list=gt_bboxes_ignore,
        #     gt_labels_list=gt_labels,
        #     label_channels=label_channels,
        #     sampling=self.sampling,
        #     reg_decoded_bbox=self.reg_decoded_bbox)
        # cls_reg_targets = anchor_target_rotated(
        #     refine_anchors_list,
        #     valid_flag_list,
        #     gt_bboxes,
        #     img_metas,
        #     self.target_means,
        #     self.target_stds,
        #     cfg.odm_cfg,
        #     output_bboxes,
        #     gt_bboxes_ignore_list=gt_bboxes_ignore,
        #     gt_labels_list=gt_labels,
        #     label_channels=label_channels,
        #     sampling=self.sampling,
        #     reg_decoded_bbox=self.reg_decoded_bbox,
        #     use_vfl=self.use_vfl)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)
        all_anchor_list = images_to_levels(refine_anchors_list, num_level_anchors)
        losses_odm_cls, losses_odm_bbox = multi_apply(
            self.loss_odm_single,
            odm_cls_scores,
            odm_bbox_preds,
            all_anchor_list,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg.odm_cfg)

        return dict(loss_fam_cls=losses_fam_cls,
                    loss_fam_bbox=losses_fam_bbox,
                    loss_odm_cls=losses_odm_cls,
                    loss_odm_bbox=losses_odm_bbox)
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        """
        :param cls_scores: list[Tensor]. len(cls_scores) equals to the number of feature map levels.
              and cls_scores[i].size() is (batch, A*C, width_i, height_i). width_i and height_i is the size
              of the i-th level feature map.
        :param bbox_preds: list[Tensor]. len(bbox_preds) equals to the number of feature map levels.
              and bbox_preds[i].size() is (batch, A*4, width_i, height_i). width_i and height_i is the size
              of the i-th level feature map.
        :param gt_bboxes: list[Tensor],Ground truth bboxes of each image. store the top-left and bottom-right corners
              in the image coordinte;
        :param gt_labels:
        :param img_metas: list[dict], Meta info of each image.
        :param cfg:
        :param gt_bboxes_ignore:
        :return:
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg,
         level_anchor_list) = cls_reg_targets
        # added by WSK
        # If sampling is adopted, num_total_samples = num_total_pos + num_total_neg;
        # otherwise, num_total_samples = num_total_pos. For 'FocalLoss', 'GHMC', 'IOUbalancedSigmoidFocalLoss',
        # sampling is not adopted.
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)

        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            level_anchor_list, # added by Shengkai Wu
            num_total_samples=num_total_samples,
            gt_bboxes = gt_bboxes, # added by Shengkai Wu
            cfg=cfg)

        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
示例#30
0
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        self.global_step = self.global_step + 1
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)
        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        original_anchors = [[k for k in m] for m in anchor_list]
        original_anchors = list(map(list, zip(*original_anchors)))
        original_anchors = [torch.cat(m) for m in original_anchors]
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg,
            reduction_override="none")

        cls_scores_flatten = [cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in
                              cls_scores]
        bbox_preds_flatten = [bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds]
        labels_list_flatten = [label.reshape(-1, 1) for label in labels_list]
        label_weights_list_flatten = [label_weights.reshape(-1, 1) for label_weights in label_weights_list]
        bbox_targets_list_flatten = [bbox_targets.reshape(-1, 4) for bbox_targets in bbox_targets_list]
        bbox_weights_list_flatten = [bbox_weights.reshape(-1, 4) for bbox_weights in bbox_weights_list]
        original_anchors_flatten = [anchors.reshape(-1, 4) for anchors in original_anchors]
        losses_cls_flatten = [loss_cls.reshape(-1, self.cls_out_channels) for loss_cls in losses_cls]
        losses_bbox_flatten = [loss_bbox.reshape(-1, 4) for loss_bbox in losses_bbox]
        cls_scores_flatten = torch.cat(cls_scores_flatten, dim=0)
        bbox_preds_flatten = torch.cat(bbox_preds_flatten, dim=0)
        labels_list_flatten = torch.cat(labels_list_flatten, dim=0)
        label_weights_list_flatten = torch.cat(label_weights_list_flatten, dim=0)
        bbox_targets_list_flatten = torch.cat(bbox_targets_list_flatten, dim=0)
        bbox_weights_list_flatten = torch.cat(bbox_weights_list_flatten, dim=0)
        original_anchors_flatten = torch.cat(original_anchors_flatten, dim=0)
        losses_cls_flatten = torch.cat(losses_cls_flatten, dim=0)
        losses_bbox_flatten = torch.cat(losses_bbox_flatten, dim=0)
        split_point = [m.shape[0] * m.shape[1] for m in label_weights_list]

        label_weights, bbox_weights, losses = self.predict_weights(
            cls_score=cls_scores_flatten, bbox_pred=bbox_preds_flatten, labels=labels_list_flatten,
            label_weights=label_weights_list_flatten, bbox_targets=bbox_targets_list_flatten,
            bbox_weights=bbox_weights_list_flatten, anchors=original_anchors_flatten, loss_cls=losses_cls_flatten,
            loss_bbox=losses_bbox_flatten)
        label_weights_list_new = torch.split(label_weights, split_point)
        bbox_weights_list_new = torch.split(bbox_weights, split_point)
        label_weights_list_new = [m.reshape(2, -1) for m in label_weights_list_new]
        bbox_weights_list_new = [m.reshape(2, -1, 4) for m in bbox_weights_list_new]

        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list_new,
            bbox_targets_list,
            bbox_weights_list_new,
            num_total_samples=num_total_samples,
            cfg=cfg,
            reduction_override=None)
        losses.update(
            dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
        )
        return losses