Пример #1
0
class RetinaHead(nn.Module):
    def __init__(self,
                 in_channel,
                 inner_channel,
                 num_cls=80,
                 num_convs=4,
                 layer_num=5,
                 anchor_sizes=None,
                 anchor_scales=None,
                 anchor_ratios=None,
                 strides=None):
        super(RetinaHead, self).__init__()
        self.num_cls = num_cls
        self.layer_num = layer_num
        if anchor_sizes is None:
            anchor_sizes = default_anchor_sizes
        self.anchor_sizes = anchor_sizes
        if anchor_scales is None:
            anchor_scales = default_anchor_scales
        self.anchor_scales = anchor_scales
        if anchor_ratios is None:
            anchor_ratios = default_anchor_ratios
        self.anchor_ratios = anchor_ratios
        if strides is None:
            strides = default_strides
        self.strides = strides
        self.anchor_nums = len(self.anchor_scales) * len(self.anchor_ratios)
        self.scales = nn.ModuleList(
            [Scale(init_val=1.0) for _ in range(self.layer_num)])
        self.anchors = [torch.zeros(size=(0, 4))] * self.layer_num
        self.box_coder = BoxCoder()
        self.cls_head = RetinaClsHead(in_channel, inner_channel,
                                      self.anchor_nums, num_cls, num_convs)
        self.reg_head = RetinaRegHead(in_channel, inner_channel,
                                      self.anchor_nums, num_convs)

    def build_anchors_delta(self, size=32.):
        """
        :param size:
        :return: [anchor_num, 4]
        """
        scales = torch.tensor(self.anchor_scales).float()
        ratio = torch.tensor(self.anchor_ratios).float()
        scale_size = (scales * size)
        w = (scale_size[:, None] * ratio[None, :].sqrt()).view(-1) / 2
        h = (scale_size[:, None] / ratio[None, :].sqrt()).view(-1) / 2
        delta = torch.stack([-w, -h, w, h], dim=1)
        return delta

    def build_anchors(self, feature_maps):
        """
        :param feature_maps:
        :return: list(anchor) anchor:[all,4] (x1,y1,x2,y2)
        """
        assert self.layer_num == len(feature_maps)
        assert len(self.anchor_sizes) == len(feature_maps)
        assert len(self.anchor_sizes) == len(self.strides)
        anchors = list()
        for stride, size, feature_map in zip(self.strides, self.anchor_sizes,
                                             feature_maps):
            # 9*4
            anchor_delta = self.build_anchors_delta(size)
            _, _, ny, nx = feature_map.shape
            yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
            # h,w,4
            grid = torch.stack([xv, yv, xv, yv], 2).float()
            anchor = (grid[:, :, None, :] +
                      0.5) * stride + anchor_delta[None, None, :, :]
            anchor = anchor.view(-1, 4)
            anchors.append(anchor)
        return anchors

    def forward(self, xs):
        cls_outputs = list()
        reg_outputs = list()
        for j, x in enumerate(xs):
            cls_outputs.append(self.cls_head(x))
            reg_outputs.append(self.scales[j](self.reg_head(x)))
        if self.anchors[0] is None or self.anchors[0].shape[0] != cls_outputs[
                0].shape[1]:
            with torch.no_grad():
                anchors = self.build_anchors(xs)
                assert len(anchors) == len(self.anchors)
                for i, anchor in enumerate(anchors):
                    self.anchors[i] = anchor.to(xs[0].device)
        if self.training:
            return cls_outputs, reg_outputs, self.anchors
        else:
            predicts_list = list()
            for cls_out, reg_out, anchor in zip(cls_outputs, reg_outputs,
                                                self.anchors):
                scale_reg = self.box_coder.decoder(reg_out, anchor)
                predicts_out = torch.cat([scale_reg, cls_out], dim=-1)
                predicts_list.append(predicts_out)
            return predicts_list
Пример #2
0
class RetinaLoss(object):
    def __init__(self, iou_thresh=0.5, ignore_thresh=0.4, alpha=0.25, gamma=2.0,
                 iou_type="giou",
                 coord_type="xyxy"):
        self.iou_thresh = iou_thresh
        self.ignore_thresh = ignore_thresh
        self.alpha = alpha
        self.gama = gamma
        self.builder = RetinaLossBuilder(iou_thresh, ignore_thresh)
        self.iou_loss = IOULoss(iou_type, coord_type)
        self.box_coder = BoxCoder()

    def __call__(self, cls_predicts, reg_predicts, anchors, targets):
        """
        :param cls_predicts: list(cls_predict) cls_predict[bs,all,num_cls]
        :param reg_predicts: list(reg_predict) reg_predict[bs,all,4]
        :param anchors: list(anchor) anchor[all,4]
        :param targets: [gt_num,7] (batch_id,weights,label_id,x1,y1,x2,y2)
        :return:
        """
        for i in range(len(cls_predicts)):
            if cls_predicts[i].dtype == torch.float16:
                cls_predicts[i] = cls_predicts[i].float()
        device = cls_predicts[0].device
        bs = cls_predicts[0].shape[0]
        flags, gt_targets, all_anchors = self.builder(bs, anchors, targets)
        cls_loss_list = list()
        reg_loss_list = list()

        pos_num_sum = 0
        for bi in range(bs):
            batch_cls_predict = torch.cat([cls_item[bi] for cls_item in cls_predicts], dim=0) \
                .sigmoid() \
                .clamp(1e-6, 1 - 1e-6)
            batch_reg_predict = torch.cat([reg_item[bi] for reg_item in reg_predicts], dim=0)
            flag = flags[bi]
            gt = gt_targets[bi]
            pos_idx = (flag == 1).nonzero(as_tuple=False).squeeze(1)
            pos_num = len(pos_idx)
            if pos_num == 0:
                neg_cls_loss = -(1 - self.alpha) * batch_cls_predict ** self.gama * ((1 - batch_cls_predict).log())
                cls_loss_list.append(neg_cls_loss.sum())
                continue
            pos_num_sum += pos_num
            neg_idx = (flag == 0).nonzero(as_tuple=False).squeeze(1)
            valid_idx = torch.cat([pos_idx, neg_idx])
            valid_cls_predicts = batch_cls_predict[valid_idx, :]
            cls_targets = torch.zeros(size=valid_cls_predicts.shape, device=device)
            cls_targets[range(pos_num), gt[pos_idx, 1].long()] = 1.
            pos_loss = -self.alpha * cls_targets * ((1 - valid_cls_predicts) ** self.gama) * valid_cls_predicts.log()
            neg_loss = -(1 - self.alpha) * (1. - cls_targets) * (valid_cls_predicts ** self.gama) * (
                (1 - valid_cls_predicts).log())
            cls_loss = (pos_loss + neg_loss).sum()
            cls_loss_list.append(cls_loss)

            valid_reg_predicts = batch_reg_predict[pos_idx, :]
            predict_box = self.box_coder.decoder(valid_reg_predicts, all_anchors[pos_idx])
            gt_bbox = gt[pos_idx, 2:]
            reg_loss = self.iou_loss(predict_box, gt_bbox)
            reg_loss_list.append(reg_loss.sum())

        cls_loss_sum = torch.stack(cls_loss_list).sum()
        if pos_num_sum == 0:
            total_loss = cls_loss_sum / bs
            return total_loss, torch.stack([cls_loss_sum, torch.tensor(data=0., device=device)]).detach(), pos_num_sum
        reg_loss_sum = torch.stack(reg_loss_list).sum()

        cls_loss_mean = cls_loss_sum / pos_num_sum
        reg_loss_mean = reg_loss_sum / pos_num_sum
        total_loss = cls_loss_mean + reg_loss_mean

        return total_loss, torch.stack([cls_loss_mean, reg_loss_mean]).detach(), pos_num_sum
class RetinaAnchorFreeLoss(object):
    def __init__(self,
                 gamma=2.0,
                 alpha=0.25,
                 top_k=64,
                 box_iou_thresh=0.6,
                 box_reg_weight=0.75,
                 beta=1. / 9):
        super(RetinaAnchorFreeLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.top_k = top_k
        self.box_iou_thresh = box_iou_thresh
        self.box_reg_weight = box_reg_weight
        self.beta = beta
        self.box_coder = BoxCoder()

    def __call__(self, cls_predicts, box_predicts, anchors, targets):
        """
        :param cls_predicts:
        :param box_predicts:
        :param anchors:
        :param targets:
        :return:
        """
        device = cls_predicts[0].device
        bs = cls_predicts[0].shape[0]
        cls_num = cls_predicts[0].shape[-1]
        expand_anchor = torch.cat(anchors, dim=0)

        negative_loss_list = list()
        positive_loss_list = list()

        for bi in range(bs):
            batch_cls_predicts = torch.cat([cls_item[bi] for cls_item in cls_predicts], dim=0) \
                .sigmoid() \
                .clamp(min=1e-6, max=1 - 1e-6)
            batch_targets = targets[targets[:, 0] == bi, 1:]

            if len(batch_targets) == 0:
                negative_loss = -(1 - self.alpha) * (
                    batch_cls_predicts**
                    self.gamma) * (1 - batch_cls_predicts).log()
                negative_loss_list.append(negative_loss.sum())
                continue

            batch_box_predicts = torch.cat(
                [box_item[bi] for box_item in box_predicts], dim=0)
            # calc_positive_loss
            targets_anchor_iou = box_iou(batch_targets[:, 2:], expand_anchor)
            _, top_k_anchor_idx = targets_anchor_iou.topk(k=self.top_k,
                                                          dim=1,
                                                          sorted=False)

            matched_cls_prob = batch_cls_predicts[top_k_anchor_idx].gather(
                dim=-1,
                index=(batch_targets[:, [1]][:, None, :]).long().repeat(
                    1, self.top_k, 1)).squeeze(-1)
            match_box_target = self.box_coder.encoder(
                expand_anchor[top_k_anchor_idx], batch_targets[:, None, 2:])
            matched_box_prob = (
                -self.box_reg_weight *
                smooth_l1_loss(batch_box_predicts[top_k_anchor_idx],
                               match_box_target, self.beta).sum(-1)).exp()
            positive_loss = self.alpha * mean_max(
                matched_cls_prob * matched_box_prob).sum()
            positive_loss_list.append(positive_loss)

            with torch.no_grad():
                box_localization = self.box_coder.decoder(
                    batch_box_predicts, expand_anchor)
            target_box_iou = box_iou(batch_targets[:, 2:], box_localization)
            t1 = self.box_iou_thresh
            t2 = target_box_iou.max(dim=1,
                                    keepdim=True)[0].clamp(min=t1 + 1e-6)
            target_box_prob = ((target_box_iou - t1) / (t2 - t1)).clamp(min=0.,
                                                                        max=1.)
            indices = torch.stack([
                torch.arange(len(batch_targets), device=device),
                batch_targets[:, 1]
            ],
                                  dim=0).long()
            object_cls_box_prob = torch.sparse_coo_tensor(indices,
                                                          target_box_prob,
                                                          device=device)

            cls_idx, anchor_idx = torch.sparse.sum(
                object_cls_box_prob,
                dim=0).to_dense().nonzero(as_tuple=False).t()
            if len(cls_idx) == 0:
                negative_loss = -(1 - self.alpha) * (
                    batch_cls_predicts**
                    self.gamma) * (1 - batch_cls_predicts).log()
                negative_loss_list.append(negative_loss.sum())
                continue
            anchor_positive_max_prob = torch.where(
                batch_targets[:, [1]].long() == cls_idx,
                target_box_prob[:, anchor_idx],
                torch.tensor(data=0., device=device)).max(dim=0)[0]

            anchor_cls_assign_prob = torch.zeros(size=(len(expand_anchor),
                                                       cls_num),
                                                 device=device)
            anchor_cls_assign_prob[anchor_idx,
                                   cls_idx] = anchor_positive_max_prob
            negative_prob = batch_cls_predicts * (1 - anchor_cls_assign_prob)
            negative_loss = -(1 - self.alpha) * (negative_prob**self.gamma) * (
                1 - negative_prob).log()
            negative_loss_list.append(negative_loss.sum())

        negative_losses = torch.stack(negative_loss_list).sum() / max(
            1, len(targets))
        if len(positive_loss_list) == 0:
            total_loss = negative_losses
            return total_loss, torch.stack(
                [negative_losses,
                 torch.tensor(data=0., device=device)]), len(targets)

        positive_losses = torch.stack(positive_loss_list).sum() / max(
            1, len(targets))
        total_loss = negative_losses + positive_losses
        return total_loss, torch.stack([negative_losses,
                                        positive_losses]), len(targets)
Пример #4
0
class RetinaAnchorFreeLoss(object):
    def __init__(self,
                 gamma=2.0,
                 alpha=0.25,
                 top_k=50,
                 box_iou_thresh=0.6,
                 box_reg_weight=0.75,
                 beta=1. / 9):
        super(RetinaAnchorFreeLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.top_k = top_k
        self.box_iou_thresh = box_iou_thresh
        self.box_reg_weight = box_reg_weight
        self.beta = beta

        self.positive_bag_loss_func = positive_bag_loss
        self.negative_bag_loss_func = focal_loss
        self.box_coder = BoxCoder()

    def __call__(self, cls_predicts, box_predicts, anchors, targets):
        '''

        :param cls_predicts:
        :param box_predicts:
        :param anchors:
        :param targets:
        :return:
        '''

        device = cls_predicts[0].device
        bs = cls_predicts[0].shape[0]
        cls_num = cls_predicts[0].shape[-1]
        expand_anchor = torch.cat(anchors, dim=0)  #shape=[num_anchors,4]

        positive_numels = 0  # gt_box的数量
        box_prob = list()  # store P_A+,  P_A-=1-P_A+
        positive_loss_list = list()
        negative_loss_list = list()
        cls_probs = list()

        for bi in range(bs):
            cls_prob = torch.cat(
                [cls_item[bi] for cls_item in cls_predicts],
                dim=0).sigmoid().clamp(
                    min=1e-6,
                    max=1 - 1e-6)  # cls_predict, shape=[num_anchors,80]
            target = targets[
                targets[:, 0] == bi,
                1:]  # gt_box, shape=[num_gts,6]  6==>conf_score,label_id,x1,y1,x2,y2

            # if no gt_box exist, just calc focal loss in negative condition
            if len(target) == 0:
                #                 negative_loss = -(cls_prob ** self.gamma) * (1 - cls_prob).log()
                negative_loss = -(cls_prob**self.gamma) * (
                    (1 - cls_prob).clamp(
                        min=1e-10, max=1.0 - 1e-10).log().clamp(min=-1000.,
                                                                max=1000.))
                negative_loss_list.append(negative_loss.sum())
                continue

            cls_probs.append(cls_prob)
            box_regression = torch.cat(
                [box_item[bi] for box_item in box_predicts],
                dim=0)  # box_predict , shape=[num_anchors,4]

            with torch.set_grad_enabled(False):

                # box_localization: a_{j}^{loc}, shape: [j, 4]
                box_localization = self.box_coder.decoder(
                    box_regression,
                    expand_anchor)  # shape=[num_anchors,4]  4==>x1,y1,x2,y2

                # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
                object_box_iou = box_iou(
                    target[:, 2:],
                    box_localization)  # shape=(num_gts,num_anchors)

                t1 = self.box_iou_thresh
                t2 = object_box_iou.max(dim=1, keepdim=True)[0].clamp(
                    min=t1 + 1e-12)  # shape=[num_gts,1]

                # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
                object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
                    min=0, max=1.)
                '''
                indices.shape=[2,num_gts]
                第0行元素代表所对应的gt_box的索引, 第1行元素代表所对应的gt_box所属的类别
                '''
                indices = torch.stack(
                    [torch.arange(len(target), device=device), target[:, 1]],
                    dim=0).long()

                # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
                '''
                object_cls_box_prob.shape=[num_gts, max_cls_id+1, num_anchors] 按照类别的取值填充
                note: 如果索引为gt_id的gt_box所属的类别为label_id, 则object_cls_box_prob[gt_id,label_id]=target_box_prob[gt_id], 其他位置均为0
                '''
                object_cls_box_prob = torch.sparse_coo_tensor(indices,
                                                              object_box_prob,
                                                              device=device)
                """
                image_box_prob: P{a_{j} \in A_{+}}, shape: [j, c] or [num_anchors,num_cls]
                image_box_prob是用来判断一个anchor是否可以匹配到某个目标(无论类别和匹配到gt box是什么)的置信度

                from "start" to "end" implement:
                image_box_prob = torch.sparse.max(object_cls_box_prob, dim=0).t()
                """
                # start

                # indices = torch.nonzero(torch.sparse.sum(object_cls_box_prob, dim=0).to_dense()).t_()  # shape=[2,N]
                indices = torch.sparse.sum(
                    object_cls_box_prob, dim=0).to_dense().nonzero(
                        as_tuple=False).t()  # shape=[2,N]

                if indices.numel() == 0:
                    image_box_prob = torch.zeros(
                        expand_anchor.shape[0],
                        cls_num).type_as(object_box_prob)
                else:
                    nonzero_box_prob = torch.where(
                        target[:, 1].unsqueeze(dim=-1) ==
                        indices[0],  # (num_gts,1)== (N) ===>(num_gts,N)
                        object_box_prob[:, indices[1]],
                        torch.tensor([
                            0
                        ]).type_as(object_box_prob)).max(dim=0)[0]  # ===> (N)

                    image_box_prob = torch.sparse_coo_tensor(
                        indices.flip([0]),
                        nonzero_box_prob,
                        size=(expand_anchor.shape[0],
                              cls_num),  # shape=[num_anchors,num_cls]
                        device=device).to_dense()
                # end
                box_prob.append(image_box_prob)

            # construct bags for objects
            match_quality_matrix = box_iou(target[:, 2:], expand_anchor)
            _, matched = torch.topk(
                match_quality_matrix, self.top_k, dim=1, sorted=False
            )  # shape=(num_gts,top_k)   元素的取值范围[0,num_gts) 表示匹配到某个gt的anchor集合的索引
            del match_quality_matrix

            # matched_cls_prob: P_{ij}^{cls}
            # shape=(num_gts,top_k) 元素的取值范围[0,num_cls) 表示匹配到某个gt的anchor所属的类别
            matched_cls_prob = cls_prob[matched].gather(
                dim=-1,
                index=(target[:,
                              [1]][:,
                                   None, :]).long().repeat(1, self.top_k,
                                                           1)).squeeze(-1)

            # matched_box_prob: P_{ij}^{loc}
            matched_object_targets = self.box_coder.encoder(
                expand_anchor[matched],
                target[:, 2:].unsqueeze(dim=1))  # shape=[num_gts,topk,4]
            # P_loc
            retinanet_regression_loss = smooth_l1_loss(box_regression[matched],
                                                       matched_object_targets,
                                                       self.box_reg_weight,
                                                       self.beta)
            matched_box_prob = torch.exp(-retinanet_regression_loss)

            # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) }
            positive_numels += len(target)
            positive_loss_list.append(
                self.positive_bag_loss_func(matched_cls_prob *
                                            matched_box_prob,
                                            dim=1))

        # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B||
        # positive_loss = torch.cat(positive_loss_list).sum() / max(1, positive_numels)
        item1 = torch.cat(positive_loss_list).sum()
        item2 = max(1, positive_numels)
        positive_loss = reduce_sum(item1) / reduce_sum(
            torch.tensor(data=item2, device=device).float()).item()

        # box_prob: P{a_{j} \in A_{+}}
        box_prob = torch.stack(box_prob, dim=0)
        cls_probs = torch.stack(cls_probs, dim=0)

        # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B||
        '''
        (1-P_bg)<==>P_cls   shape=[num_anchors,num_cls]
        P{A-}<==>(1-P{box_cls})
        '''
        if len(negative_loss_list) != 0:
            neg_loss_empty = torch.stack(negative_loss_list, dim=0).sum()
        else:
            neg_loss_empty = 0

        # negative_loss = (neg_loss_empty + self.negative_bag_loss_func(cls_probs * (1 - box_prob), self.gamma)) / max(1, positive_numels * self.top_k)
        item3 = neg_loss_empty + self.negative_bag_loss_func(
            cls_probs * (1 - box_prob), self.gamma)
        item4 = max(1, positive_numels * self.top_k)
        negative_loss = reduce_sum(item3) / reduce_sum(
            torch.tensor(data=item4, device=device).float()).item()

        total_loss = positive_loss * self.alpha + negative_loss * (1 -
                                                                   self.alpha)
        # total_loss=reduce_sum(total_loss)/get_world_size()
        return total_loss, torch.stack([negative_loss,
                                        positive_loss]), positive_numels