예제 #1
0
    def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
                    label_weights, bbox_targets, num_total_samples):

        anchors = anchors.reshape(-1, 4)
        cls_score = cls_score.permute(0, 2, 3,
                                      1).reshape(-1, self.cls_out_channels)
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
        centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
        bbox_targets = bbox_targets.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)

        # classification loss
        pos_inds = (labels > 0).nonzero().squeeze(1)
        num_pos = len(pos_inds)
        loss_cls = sigmoid_focal_loss(cls_score, labels, self.train_cfg.gamma,
                                      self.train_cfg.alpha,
                                      'none').sum()[None] / (num_pos + 2)
        if len(pos_inds) > 0:
            pos_bbox_targets = bbox_targets[pos_inds]
            pos_bbox_pred = bbox_pred[pos_inds]
            pos_anchors = anchors[pos_inds]
            pos_centerness = centerness[pos_inds]

            centerness_targets = self.centerness_target(
                pos_anchors, pos_bbox_targets)
            pos_decode_bbox_pred = delta2bbox(pos_anchors, pos_bbox_pred,
                                              self.target_means,
                                              self.target_stds)
            pos_decode_bbox_targets = delta2bbox(pos_anchors, pos_bbox_targets,
                                                 self.target_means,
                                                 self.target_stds)

            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(pos_decode_bbox_pred,
                                       pos_decode_bbox_targets,
                                       weight=centerness_targets,
                                       avg_factor=1.0)

            # centerness loss
            loss_centerness = F.binary_cross_entropy_with_logits(
                pos_centerness, centerness_targets, reduction='mean')[None]

        else:
            loss_bbox = bbox_pred.sum() * 0
            loss_centerness = centerness.sum() * 0
            centerness_targets = torch.tensor(0).cuda()

        return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
예제 #2
0
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = sigmoid_focal_loss(
            flatten_cls_scores, flatten_labels, cfg.gamma, cfg.alpha,
            'none').sum()[None] / (num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        # pos_centerness_targets = self.centerness_target(pos_bbox_targets)

        if num_pos > 0:
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)

            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_reg = (
                (iou_loss(pos_decoded_bbox_preds,
                          pos_decoded_target_preds,
                          reduction='none') * pos_centerness_targets).sum() /
                pos_centerness_targets.sum())[None]
            loss_centerness = F.binary_cross_entropy_with_logits(
                pos_centerness, pos_centerness_targets, reduction='mean')[None]
        else:
            loss_reg = pos_bbox_preds.sum()[None]
            loss_centerness = pos_centerness.sum()[None]

        return dict(loss_cls=loss_cls,
                    loss_reg=loss_reg,
                    loss_centerness=loss_centerness)
예제 #3
0
    def loss(
        self,
        cls_scores,
        bbox_preds,
        centernesses,
        attr_scores,
        gt_bboxes,
        gt_labels,
        gt_attributes,
        img_metas,
        cfg,
        gt_bboxes_ignore=None,
    ):
        assert (
            len(cls_scores) == len(bbox_preds) == len(centernesses) == len(attr_scores)
        )
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(
            featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device
        )
        labels, bbox_targets, attrs = self.fcos_target(
            all_level_points, gt_bboxes, gt_labels, gt_attributes
        )

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses
        ]
        flatten_attr_scores = [
            attr_score.permute(0, 2, 3, 1).reshape(-1, 400)
            for attr_score in attr_scores
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_attr_scores = torch.cat(flatten_attr_scores)
        flatten_labels = torch.cat(labels)
        flatten_attrs = torch.cat(attrs)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points]
        )

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = sigmoid_focal_loss(
            flatten_cls_scores, flatten_labels, cfg.gamma, cfg.alpha, "none"
        ).sum()[None] / (
            num_pos + num_imgs
        )  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        pos_attr_pred = flatten_attr_scores[pos_inds]
        pos_attr_targets = flatten_attrs[pos_inds]

        if num_pos > 0:
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets)
            # centerness weighted iou loss
            loss_reg = (
                (
                    iou_loss(
                        pos_decoded_bbox_preds,
                        pos_decoded_target_preds,
                        reduction="none",
                    )
                    * pos_centerness_targets
                ).sum()
                / pos_centerness_targets.sum()
            )[None]
            loss_centerness = F.binary_cross_entropy_with_logits(
                pos_centerness, pos_centerness_targets, reduction="mean"
            )[None]
            # train those have at least one attribute
            valid_attr_idx = pos_attr_targets.sum(dim=1) != 0
            loss_attr = F.binary_cross_entropy_with_logits(
                pos_attr_pred[valid_attr_idx],
                pos_attr_targets[valid_attr_idx],
                reduction="mean",
            )[None]
            if torch.isnan(loss_attr):
                loss_attr = (
                    F.binary_cross_entropy_with_logits(
                        pos_attr_pred, pos_attr_targets, reduction="mean"
                    )[None]
                    * 0
                )
        else:
            loss_reg = flatten_bbox_preds.sum()[None] * 0
            loss_centerness = flatten_centerness.sum()[None] * 0
            loss_attr = flatten_attr_scores.sum()[None] * 0

        return dict(
            loss_cls=loss_cls,
            loss_reg=loss_reg,
            loss_centerness=loss_centerness,
            loss_attr=loss_attr,
        )
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             ious,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        """

        :param cls_scores: list of tensor. cls_scores[i].size() = (batch, num_class, width_i, height_i) for
              the i-th level of FPN.
        :param bbox_preds: list of tensor. bbox_preds[i].size() = (batch, 4, width_i, height_i) for
              the i-th level of FPN.
        :param centernesses: list of tensor. centernesses[i].size() = (batch, 1, width_i, height_i) for
              the i-th level of FPN.
        :param iou: list of tensor. centernesses[i].size() = (batch, 1, width_i, height_i) for
              the i-th level of FPN.
        :param gt_bboxes:
        :param gt_labels:
        :param img_metas:
        :param cfg:
        :param gt_bboxes_ignore:
        :return:
        """
        assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len(
            ious)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]

        # added by WSK
        flatten_iou = [iou.permute(0, 2, 3, 1).reshape(-1) for iou in ious]
        flatten_iou = torch.cat(flatten_iou)  # (num_all)

        flatten_cls_scores = torch.cat(
            flatten_cls_scores)  # (num_all, num_class)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)  # (num_all, 4)
        flatten_centerness = torch.cat(flatten_centerness)  # (num_all)
        flatten_labels = torch.cat(labels)  # (num_all)
        flatten_bbox_targets = torch.cat(bbox_targets)  # (num_all, 4)
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points])

        pos_inds = flatten_labels.nonzero().reshape(-1)
        num_pos = len(pos_inds)
        loss_cls = sigmoid_focal_loss(
            flatten_cls_scores, flatten_labels, cfg.gamma, cfg.alpha,
            'none').sum()[None] / (num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]  # (num_pos, 4)
        pos_bbox_targets = flatten_bbox_targets[pos_inds]  # (num_pos, 4)
        pos_centerness = flatten_centerness[pos_inds]  # (num_pos)
        pos_centerness_targets = self.centerness_target(
            pos_bbox_targets)  # (num_pos)

        # added by WSK
        pos_iou = flatten_iou[pos_inds]  # (num_pos)

        if num_pos > 0:
            pos_points = flatten_points[pos_inds]
            pos_decoded_bbox_preds = distance2bbox(
                pos_points, pos_bbox_preds)  # (num_pos, 4)
            pos_decoded_target_preds = distance2bbox(
                pos_points, pos_bbox_targets)  # (num_pos, 4)
            # centerness weighted iou loss
            loss_reg = (
                (iou_loss(pos_decoded_bbox_preds,
                          pos_decoded_target_preds,
                          reduction='none') * pos_centerness_targets).sum() /
                pos_centerness_targets.sum())[None]
            loss_centerness = F.binary_cross_entropy_with_logits(
                pos_centerness, pos_centerness_targets, reduction='mean')[None]

            # added by WSK
            pos_iou_target = bbox_overlaps(pos_decoded_target_preds,
                                           pos_decoded_bbox_preds,
                                           is_aligned=True)
            # pos_iou_target = pos_iou_target.detach()
            loss_iou = F.binary_cross_entropy_with_logits(
                pos_iou, pos_iou_target, reduction='mean')[None]

        else:
            loss_reg = pos_bbox_preds.sum()[None]
            loss_centerness = pos_centerness.sum()[None]

            # added by WSK
            loss_iou = pos_iou.sum()[None]

        return dict(loss_cls=loss_cls,
                    loss_reg=loss_reg,
                    loss_centerness=loss_centerness,
                    loss_iou=loss_iou)