def loss_first(self,
                   loc_preds,
                   bbox_preds,
                   gt_bboxes,
                   img_metas,
                   cfg,
                   gt_labels=None,
                   gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
        assert len(featmap_sizes) == len(self.anchor_generators)

        device = bbox_preds[0].device

        # get loc targets
        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.anchor_scales,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)

        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes,
                                                        img_metas,
                                                        device=device)
        # self.num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
        # get anchor targets
        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(anchor_list,
                                        valid_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos if self.cls_focal_loss else
                             num_total_pos + num_total_neg)
        losses_bbox_first = []
        # get bbox regression losses on the  first stage
        for i in range(len(bbox_preds)):
            loss_bbox = self.loss_first_single(
                bbox_preds[i],
                bbox_targets_list[i],
                bbox_weights_list[i],
                num_total_samples=num_total_samples,
                cfg=cfg)
            losses_bbox_first.append(loss_bbox)

        # get anchor location loss
        losses_loc = []
        for i in range(len(loc_preds)):
            loss_loc = self.loss_loc_single(loc_preds[i],
                                            loc_targets[i],
                                            loc_weights[i],
                                            loc_avg_factor=loc_avg_factor,
                                            cfg=cfg)
            losses_loc.append(loss_loc)
        return dict(loss_rpn_bbox_first=losses_bbox_first, loss_loc=losses_loc)
Esempio n. 2
0
    def loss(self,
             cls_scores,
             bbox_preds,
             shape_preds,
             loc_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.approx_generators)

        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.octave_base_scale,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)
        (approxs_list, valid_flag_list, base_approxs_list,
         guided_anchors_list) = self.get_anchors(featmap_sizes, shape_preds,
                                                 img_metas)

        sampling = False if not hasattr(cfg, 'ga_sampler') else True
        shape_targets = ga_shape_target(approxs_list,
                                        valid_flag_list,
                                        base_approxs_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.approxs_per_octave,
                                        cfg,
                                        sampling=sampling)
        if shape_targets is None:
            return None
        (bbox_anchors_list, bbox_gts_list, anchor_weights_list,
         all_inside_flags, anchor_fg_num, anchor_ng_num) = shape_targets
        anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num +
                            anchor_ng_num)

        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.cls_sigmoid_loss else 1
        cls_reg_targets = anchor_target(guided_anchors_list,
                                        all_inside_flags,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (num_total_pos if self.cls_focal_loss else
                             num_total_pos + num_total_neg)
        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        losses_loc, = multi_apply(self.loss_loc_single,
                                  loc_preds,
                                  loc_targets,
                                  loc_weights,
                                  loc_avg_factor=loc_avg_factor,
                                  cfg=cfg)
        losses_shape, = multi_apply(self.loss_shape_single,
                                    shape_preds,
                                    bbox_anchors_list,
                                    bbox_gts_list,
                                    anchor_weights_list,
                                    anchor_total_num=anchor_total_num)
        return dict(loss_cls=losses_cls,
                    loss_reg=losses_reg,
                    loss_shape=losses_shape,
                    loss_loc=losses_loc)
    def loss(self,
             cls_scores,
             bbox_preds,
             shape_preds,
             loc_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.approx_generators)

        # get loc targets
        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.octave_base_scale,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)

        # get sampled approxes
        approxs_list, inside_flag_list = self.get_sampled_approxs(
            featmap_sizes, img_metas, cfg)
        # get squares and guided anchors
        squares_list, guided_anchors_list, _ = self.get_anchors(
            featmap_sizes, shape_preds, loc_preds, img_metas)

        # get shape targets
        sampling = False if not hasattr(cfg, 'ga_sampler') else True
        shape_targets = ga_shape_target(approxs_list,
                                        inside_flag_list,
                                        squares_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.approxs_per_octave,
                                        cfg,
                                        sampling=sampling)
        if shape_targets is None:
            return None
        (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
         anchor_bg_num) = shape_targets
        anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num +
                            anchor_bg_num)

        # get anchor targets
        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = anchor_target(guided_anchors_list,
                                        inside_flag_list,
                                        gt_bboxes,
                                        img_metas,
                                        self.target_means,
                                        self.target_stds,
                                        cfg,
                                        gt_bboxes_ignore_list=gt_bboxes_ignore,
                                        gt_labels_list=gt_labels,
                                        label_channels=label_channels,
                                        sampling=sampling)
        if cls_reg_targets is None:
            return None
        # (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
        #  num_total_pos, num_total_neg) = cls_reg_targets

        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg, level_anchor_list) = cls_reg_targets

        num_total_samples = (num_total_pos if self.cls_focal_loss else
                             num_total_pos + num_total_neg)
        # added by WSK
        num_total_samples_cls = (num_total_pos_cls if self.cls_focal_loss else
                                 num_total_pos_cls + num_total_neg_cls)

        # get classification and bbox regression losses
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            level_anchor_list,  # added by Shengkai Wu
            bbox_targets_cls_list,  # added by Shengkai Wu
            bbox_weights_cls_list,  # added by Shengkai Wu
            num_total_samples_cls=num_total_samples_cls,  # added by Shengkai Wu
            num_total_samples=num_total_samples,
            gt_bboxes=gt_bboxes,  # added by Shengkai Wu
            cfg=cfg)

        # get anchor location loss
        losses_loc = []
        for i in range(len(loc_preds)):
            loss_loc = self.loss_loc_single(loc_preds[i],
                                            loc_targets[i],
                                            loc_weights[i],
                                            loc_avg_factor=loc_avg_factor,
                                            cfg=cfg)
            losses_loc.append(loss_loc)

        # get anchor shape loss
        losses_shape = []
        for i in range(len(shape_preds)):
            loss_shape = self.loss_shape_single(
                shape_preds[i],
                bbox_anchors_list[i],
                bbox_gts_list[i],
                anchor_weights_list[i],
                anchor_total_num=anchor_total_num)
            losses_shape.append(loss_shape)

        return dict(loss_cls=losses_cls,
                    loss_bbox=losses_bbox,
                    loss_shape=losses_shape,
                    loss_loc=losses_loc)
Esempio n. 4
0
    def loss(self,
             cls_scores,
             bbox_preds,
             shape_preds,
             loc_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.approx_generators)

        device = cls_scores[0].device

        # get loc targets
        loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
            gt_bboxes,
            featmap_sizes,
            self.octave_base_scale,
            self.anchor_strides,
            center_ratio=cfg.center_ratio,
            ignore_ratio=cfg.ignore_ratio)

        # get sampled approxes
        approxs_list, inside_flag_list = self.get_sampled_approxs(
            featmap_sizes, img_metas, cfg, device=device)
        # get squares and guided anchors
        squares_list, guided_anchors_list, _ = self.get_anchors(
            featmap_sizes, shape_preds, loc_preds, img_metas, device=device)

        # get shape targets
        sampling = False if not hasattr(cfg, 'ga_sampler') else True
        shape_targets = ga_shape_target(
            approxs_list,
            inside_flag_list,
            squares_list,
            gt_bboxes,
            img_metas,
            self.approxs_per_octave,
            cfg,
            sampling=sampling)
        if shape_targets is None:
            return None
        (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
         anchor_bg_num) = shape_targets
        anchor_total_num = (
            anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)

        # get anchor targets
        sampling = False if self.cls_focal_loss else True
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        # get classification and bbox regression losses
        cfg['ga_sampler']['pos_fraction'] = 0.5  # 0.55
        cfg['sampler']['pos_fraction'] = 0.5  # 0.55
        cls_reg_targets = anchor_target(
            guided_anchors_list,
            inside_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos if self.cls_focal_loss else num_total_pos +
            num_total_neg)


        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)

        # get anchor location loss
        losses_loc = []
        for i in range(len(loc_preds)):
            loss_loc = self.loss_loc_single(
                loc_preds[i],
                loc_targets[i],
                loc_weights[i],
                loc_avg_factor=loc_avg_factor,
                cfg=cfg)
            losses_loc.append(loss_loc)

        # get anchor shape loss
        losses_shape = []
        for i in range(len(shape_preds)):
            loss_shape = self.loss_shape_single(
                shape_preds[i],
                bbox_anchors_list[i],
                bbox_gts_list[i],
                anchor_weights_list[i],
                anchor_total_num=anchor_total_num)
            losses_shape.append(loss_shape)
        if 1:
            # get sampled approxes
            approxs_list, inside_flag_list = self.get_sampled_approxs(
                featmap_sizes, img_metas, cfg, device=device)
            # get squares and guided anchors
            squares_list, guided_anchors_list, _ = self.get_anchors(
                featmap_sizes, shape_preds, loc_preds, img_metas, device=device)

            # get shape targets
            sampling = False if not hasattr(cfg, 'ga_sampler') else True
            shape_targets = ga_shape_target(
                approxs_list,
                inside_flag_list,
                squares_list,
                gt_bboxes,
                img_metas,
                self.approxs_per_octave,
                cfg,
                sampling=sampling)
            if shape_targets is None:
                return None
            (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
             anchor_bg_num) = shape_targets
            anchor_total_num = (
                anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)

            # get anchor targets
            sampling = False if self.cls_focal_loss else True
            label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
            lvls = len(guided_anchors_list)
            bboxes_lst = [[] for k in range(lvls)]
            for z in range(lvls):
                for k in range(len(guided_anchors_list[z])):
                    bbox_pred = bbox_preds[k][z].permute(1, 2, 0).reshape(-1, 4)
                    a = guided_anchors_list[z][k]
                    # print(a.shape,bbox_pred.shape)
                    bboxes = delta2bbox(guided_anchors_list[z][k], bbox_pred, self.target_means,
                                        self.target_stds)
                    bboxes_lst[z].append(bboxes.clone())
                    zz = 0
            # label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
            # label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
            cfg['ga_sampler']['pos_fraction'] = 0.55  # 0.55
            cfg['sampler']['pos_fraction'] = 0.6  # 0.55
            cls_reg_targets = anchor_target(
                guided_anchors_list,
                inside_flag_list,
                gt_bboxes,
                img_metas,
                self.target_means,
                self.target_stds,
                cfg,
                gt_bboxes_ignore_list=gt_bboxes_ignore,
                gt_labels_list=gt_labels,
                label_channels=label_channels,
                sampling=sampling)
            if cls_reg_targets is None:
                return None
            (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
             num_total_pos, num_total_neg) = cls_reg_targets
            num_total_samples = (
                num_total_pos if self.cls_focal_loss else num_total_pos +
                                                          num_total_neg)
            # get classification and bbox regression losses
            losses_cls2, losses_bbox2 = multi_apply(
                self.loss_single,
                cls_scores,
                bbox_preds,
                labels_list,
                label_weights_list,
                bbox_targets_list,
                bbox_weights_list,
                num_total_samples=num_total_samples,
                cfg=cfg)
            zz=0
            for k in range(lvls):
                losses_cls[k] = losses_cls[k] / 2 + losses_cls2[k] / 2
        return dict(
            loss_cls=losses_cls,
            loss_bbox=losses_bbox,
            loss_shape=losses_shape,
            loss_loc=losses_loc)