class RetinaAnchorFreeLoss(object):
    def __init__(self,
                 gamma=2.0,
                 alpha=0.25,
                 top_k=64,
                 box_iou_thresh=0.6,
                 box_reg_weight=0.75,
                 beta=1. / 9):
        super(RetinaAnchorFreeLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.top_k = top_k
        self.box_iou_thresh = box_iou_thresh
        self.box_reg_weight = box_reg_weight
        self.beta = beta
        self.box_coder = BoxCoder()

    def __call__(self, cls_predicts, box_predicts, anchors, targets):
        """
        :param cls_predicts:
        :param box_predicts:
        :param anchors:
        :param targets:
        :return:
        """
        device = cls_predicts[0].device
        bs = cls_predicts[0].shape[0]
        cls_num = cls_predicts[0].shape[-1]
        expand_anchor = torch.cat(anchors, dim=0)

        negative_loss_list = list()
        positive_loss_list = list()

        for bi in range(bs):
            batch_cls_predicts = torch.cat([cls_item[bi] for cls_item in cls_predicts], dim=0) \
                .sigmoid() \
                .clamp(min=1e-6, max=1 - 1e-6)
            batch_targets = targets[targets[:, 0] == bi, 1:]

            if len(batch_targets) == 0:
                negative_loss = -(1 - self.alpha) * (
                    batch_cls_predicts**
                    self.gamma) * (1 - batch_cls_predicts).log()
                negative_loss_list.append(negative_loss.sum())
                continue

            batch_box_predicts = torch.cat(
                [box_item[bi] for box_item in box_predicts], dim=0)
            # calc_positive_loss
            targets_anchor_iou = box_iou(batch_targets[:, 2:], expand_anchor)
            _, top_k_anchor_idx = targets_anchor_iou.topk(k=self.top_k,
                                                          dim=1,
                                                          sorted=False)

            matched_cls_prob = batch_cls_predicts[top_k_anchor_idx].gather(
                dim=-1,
                index=(batch_targets[:, [1]][:, None, :]).long().repeat(
                    1, self.top_k, 1)).squeeze(-1)
            match_box_target = self.box_coder.encoder(
                expand_anchor[top_k_anchor_idx], batch_targets[:, None, 2:])
            matched_box_prob = (
                -self.box_reg_weight *
                smooth_l1_loss(batch_box_predicts[top_k_anchor_idx],
                               match_box_target, self.beta).sum(-1)).exp()
            positive_loss = self.alpha * mean_max(
                matched_cls_prob * matched_box_prob).sum()
            positive_loss_list.append(positive_loss)

            with torch.no_grad():
                box_localization = self.box_coder.decoder(
                    batch_box_predicts, expand_anchor)
            target_box_iou = box_iou(batch_targets[:, 2:], box_localization)
            t1 = self.box_iou_thresh
            t2 = target_box_iou.max(dim=1,
                                    keepdim=True)[0].clamp(min=t1 + 1e-6)
            target_box_prob = ((target_box_iou - t1) / (t2 - t1)).clamp(min=0.,
                                                                        max=1.)
            indices = torch.stack([
                torch.arange(len(batch_targets), device=device),
                batch_targets[:, 1]
            ],
                                  dim=0).long()
            object_cls_box_prob = torch.sparse_coo_tensor(indices,
                                                          target_box_prob,
                                                          device=device)

            cls_idx, anchor_idx = torch.sparse.sum(
                object_cls_box_prob,
                dim=0).to_dense().nonzero(as_tuple=False).t()
            if len(cls_idx) == 0:
                negative_loss = -(1 - self.alpha) * (
                    batch_cls_predicts**
                    self.gamma) * (1 - batch_cls_predicts).log()
                negative_loss_list.append(negative_loss.sum())
                continue
            anchor_positive_max_prob = torch.where(
                batch_targets[:, [1]].long() == cls_idx,
                target_box_prob[:, anchor_idx],
                torch.tensor(data=0., device=device)).max(dim=0)[0]

            anchor_cls_assign_prob = torch.zeros(size=(len(expand_anchor),
                                                       cls_num),
                                                 device=device)
            anchor_cls_assign_prob[anchor_idx,
                                   cls_idx] = anchor_positive_max_prob
            negative_prob = batch_cls_predicts * (1 - anchor_cls_assign_prob)
            negative_loss = -(1 - self.alpha) * (negative_prob**self.gamma) * (
                1 - negative_prob).log()
            negative_loss_list.append(negative_loss.sum())

        negative_losses = torch.stack(negative_loss_list).sum() / max(
            1, len(targets))
        if len(positive_loss_list) == 0:
            total_loss = negative_losses
            return total_loss, torch.stack(
                [negative_losses,
                 torch.tensor(data=0., device=device)]), len(targets)

        positive_losses = torch.stack(positive_loss_list).sum() / max(
            1, len(targets))
        total_loss = negative_losses + positive_losses
        return total_loss, torch.stack([negative_losses,
                                        positive_losses]), len(targets)
Beispiel #2
0
class RetinaAnchorFreeLoss(object):
    def __init__(self,
                 gamma=2.0,
                 alpha=0.25,
                 top_k=50,
                 box_iou_thresh=0.6,
                 box_reg_weight=0.75,
                 beta=1. / 9):
        super(RetinaAnchorFreeLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.top_k = top_k
        self.box_iou_thresh = box_iou_thresh
        self.box_reg_weight = box_reg_weight
        self.beta = beta

        self.positive_bag_loss_func = positive_bag_loss
        self.negative_bag_loss_func = focal_loss
        self.box_coder = BoxCoder()

    def __call__(self, cls_predicts, box_predicts, anchors, targets):
        '''

        :param cls_predicts:
        :param box_predicts:
        :param anchors:
        :param targets:
        :return:
        '''

        device = cls_predicts[0].device
        bs = cls_predicts[0].shape[0]
        cls_num = cls_predicts[0].shape[-1]
        expand_anchor = torch.cat(anchors, dim=0)  #shape=[num_anchors,4]

        positive_numels = 0  # gt_box的数量
        box_prob = list()  # store P_A+,  P_A-=1-P_A+
        positive_loss_list = list()
        negative_loss_list = list()
        cls_probs = list()

        for bi in range(bs):
            cls_prob = torch.cat(
                [cls_item[bi] for cls_item in cls_predicts],
                dim=0).sigmoid().clamp(
                    min=1e-6,
                    max=1 - 1e-6)  # cls_predict, shape=[num_anchors,80]
            target = targets[
                targets[:, 0] == bi,
                1:]  # gt_box, shape=[num_gts,6]  6==>conf_score,label_id,x1,y1,x2,y2

            # if no gt_box exist, just calc focal loss in negative condition
            if len(target) == 0:
                #                 negative_loss = -(cls_prob ** self.gamma) * (1 - cls_prob).log()
                negative_loss = -(cls_prob**self.gamma) * (
                    (1 - cls_prob).clamp(
                        min=1e-10, max=1.0 - 1e-10).log().clamp(min=-1000.,
                                                                max=1000.))
                negative_loss_list.append(negative_loss.sum())
                continue

            cls_probs.append(cls_prob)
            box_regression = torch.cat(
                [box_item[bi] for box_item in box_predicts],
                dim=0)  # box_predict , shape=[num_anchors,4]

            with torch.set_grad_enabled(False):

                # box_localization: a_{j}^{loc}, shape: [j, 4]
                box_localization = self.box_coder.decoder(
                    box_regression,
                    expand_anchor)  # shape=[num_anchors,4]  4==>x1,y1,x2,y2

                # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
                object_box_iou = box_iou(
                    target[:, 2:],
                    box_localization)  # shape=(num_gts,num_anchors)

                t1 = self.box_iou_thresh
                t2 = object_box_iou.max(dim=1, keepdim=True)[0].clamp(
                    min=t1 + 1e-12)  # shape=[num_gts,1]

                # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
                object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
                    min=0, max=1.)
                '''
                indices.shape=[2,num_gts]
                第0行元素代表所对应的gt_box的索引, 第1行元素代表所对应的gt_box所属的类别
                '''
                indices = torch.stack(
                    [torch.arange(len(target), device=device), target[:, 1]],
                    dim=0).long()

                # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
                '''
                object_cls_box_prob.shape=[num_gts, max_cls_id+1, num_anchors] 按照类别的取值填充
                note: 如果索引为gt_id的gt_box所属的类别为label_id, 则object_cls_box_prob[gt_id,label_id]=target_box_prob[gt_id], 其他位置均为0
                '''
                object_cls_box_prob = torch.sparse_coo_tensor(indices,
                                                              object_box_prob,
                                                              device=device)
                """
                image_box_prob: P{a_{j} \in A_{+}}, shape: [j, c] or [num_anchors,num_cls]
                image_box_prob是用来判断一个anchor是否可以匹配到某个目标(无论类别和匹配到gt box是什么)的置信度

                from "start" to "end" implement:
                image_box_prob = torch.sparse.max(object_cls_box_prob, dim=0).t()
                """
                # start

                # indices = torch.nonzero(torch.sparse.sum(object_cls_box_prob, dim=0).to_dense()).t_()  # shape=[2,N]
                indices = torch.sparse.sum(
                    object_cls_box_prob, dim=0).to_dense().nonzero(
                        as_tuple=False).t()  # shape=[2,N]

                if indices.numel() == 0:
                    image_box_prob = torch.zeros(
                        expand_anchor.shape[0],
                        cls_num).type_as(object_box_prob)
                else:
                    nonzero_box_prob = torch.where(
                        target[:, 1].unsqueeze(dim=-1) ==
                        indices[0],  # (num_gts,1)== (N) ===>(num_gts,N)
                        object_box_prob[:, indices[1]],
                        torch.tensor([
                            0
                        ]).type_as(object_box_prob)).max(dim=0)[0]  # ===> (N)

                    image_box_prob = torch.sparse_coo_tensor(
                        indices.flip([0]),
                        nonzero_box_prob,
                        size=(expand_anchor.shape[0],
                              cls_num),  # shape=[num_anchors,num_cls]
                        device=device).to_dense()
                # end
                box_prob.append(image_box_prob)

            # construct bags for objects
            match_quality_matrix = box_iou(target[:, 2:], expand_anchor)
            _, matched = torch.topk(
                match_quality_matrix, self.top_k, dim=1, sorted=False
            )  # shape=(num_gts,top_k)   元素的取值范围[0,num_gts) 表示匹配到某个gt的anchor集合的索引
            del match_quality_matrix

            # matched_cls_prob: P_{ij}^{cls}
            # shape=(num_gts,top_k) 元素的取值范围[0,num_cls) 表示匹配到某个gt的anchor所属的类别
            matched_cls_prob = cls_prob[matched].gather(
                dim=-1,
                index=(target[:,
                              [1]][:,
                                   None, :]).long().repeat(1, self.top_k,
                                                           1)).squeeze(-1)

            # matched_box_prob: P_{ij}^{loc}
            matched_object_targets = self.box_coder.encoder(
                expand_anchor[matched],
                target[:, 2:].unsqueeze(dim=1))  # shape=[num_gts,topk,4]
            # P_loc
            retinanet_regression_loss = smooth_l1_loss(box_regression[matched],
                                                       matched_object_targets,
                                                       self.box_reg_weight,
                                                       self.beta)
            matched_box_prob = torch.exp(-retinanet_regression_loss)

            # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) }
            positive_numels += len(target)
            positive_loss_list.append(
                self.positive_bag_loss_func(matched_cls_prob *
                                            matched_box_prob,
                                            dim=1))

        # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B||
        # positive_loss = torch.cat(positive_loss_list).sum() / max(1, positive_numels)
        item1 = torch.cat(positive_loss_list).sum()
        item2 = max(1, positive_numels)
        positive_loss = reduce_sum(item1) / reduce_sum(
            torch.tensor(data=item2, device=device).float()).item()

        # box_prob: P{a_{j} \in A_{+}}
        box_prob = torch.stack(box_prob, dim=0)
        cls_probs = torch.stack(cls_probs, dim=0)

        # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B||
        '''
        (1-P_bg)<==>P_cls   shape=[num_anchors,num_cls]
        P{A-}<==>(1-P{box_cls})
        '''
        if len(negative_loss_list) != 0:
            neg_loss_empty = torch.stack(negative_loss_list, dim=0).sum()
        else:
            neg_loss_empty = 0

        # negative_loss = (neg_loss_empty + self.negative_bag_loss_func(cls_probs * (1 - box_prob), self.gamma)) / max(1, positive_numels * self.top_k)
        item3 = neg_loss_empty + self.negative_bag_loss_func(
            cls_probs * (1 - box_prob), self.gamma)
        item4 = max(1, positive_numels * self.top_k)
        negative_loss = reduce_sum(item3) / reduce_sum(
            torch.tensor(data=item4, device=device).float()).item()

        total_loss = positive_loss * self.alpha + negative_loss * (1 -
                                                                   self.alpha)
        # total_loss=reduce_sum(total_loss)/get_world_size()
        return total_loss, torch.stack([negative_loss,
                                        positive_loss]), positive_numels