Пример #1
0
    def aux_losses(self, gt_classes, pred_class_logits):
        pred_class_logits = cat([
            permute_to_N_HWA_K(x, self.num_classes) for x in pred_class_logits
        ],
                                dim=1).view(-1, self.num_classes)

        gt_classes = gt_classes.flatten()

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())

        # logits loss
        loss_cls_aux = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        return {"loss_cls_aux": loss_cls_aux}
Пример #2
0
    def losses(self, gt_classes, gt_shifts_deltas, pred_class_logits,
               pred_shift_deltas):
        """
        Args:
            For `gt_classes` and `gt_shifts_deltas` parameters, see
                :meth:`FCOS.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits` and `pred_shift_deltas`, see
                :meth:`FCOSHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_shift_deltas = \
            permute_all_cls_and_box_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1.0, num_foreground) * self.reg_weight

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
        }
Пример #3
0
    def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits,
               pred_anchor_deltas):
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`RetinaNet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits` and `pred_anchor_deltas`, see
                :meth:`RetinaNetHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_anchor_deltas, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_anchors_deltas = gt_anchors_deltas.view(-1, 4)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
            1 - self.loss_normalizer_momentum) * max(num_foreground.item(), 1)

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / self.loss_normalizer

        # regression loss
        loss_box_reg = smooth_l1_loss(
            pred_anchor_deltas[foreground_idxs],
            gt_anchors_deltas[foreground_idxs],
            beta=self.smooth_l1_loss_beta,
            reduction="sum",
        ) / self.loss_normalizer

        return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
Пример #4
0
    def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits,
               pred_anchor_deltas):
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`EfficientDet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits` and `pred_anchor_deltas`, see
                :meth:`EfficientDetHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_anchor_deltas, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_anchors_deltas = gt_anchors_deltas.view(-1, 4)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        # Classification loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1, num_foreground)

        # Regression loss, refer to the official released code.
        # See: https://github.com/google/automl/blob/master/efficientdet/det_model_fn.py
        loss_box_reg = self.box_loss_weight * self.smooth_l1_loss_beta * smooth_l1_loss(
            pred_anchor_deltas[foreground_idxs],
            gt_anchors_deltas[foreground_idxs],
            beta=self.smooth_l1_loss_beta,
            reduction="sum",
        ) / max(1, num_foreground * self.regress_norm)

        return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
Пример #5
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert "pred_logits" in outputs
        src_logits = outputs["pred_logits"]

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(
            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
        )
        target_classes[idx] = target_classes_o

        if self.use_focal:
            flat_src_logits = src_logits.flatten(0, 1)
            target_classes = target_classes.flatten(0, 1)
            pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0]
            labels = torch.zeros_like(flat_src_logits)
            labels[pos_inds, target_classes[pos_inds]] = 1
            # comp focal loss.
            class_loss = sigmoid_focal_loss_jit(
                flat_src_logits,
                labels,
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / num_boxes
            losses = {'loss_ce': class_loss}
        else:
            loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
            losses = {"loss_ce": loss_ce}

        if log:
            losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]

        return losses
Пример #6
0
    def losses(
        self,
        gt_classes,
        gt_shifts_deltas,
        gt_centerness,
        gt_classes_border,
        gt_deltas_border,
        pred_class_logits,
        pred_shift_deltas,
        pred_centerness,
        border_box_cls,
        border_bbox_reg,
    ):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see
                :meth:`BorderDet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see
                :meth:`BorderHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        (
            pred_class_logits,
            pred_shift_deltas,
            pred_centerness,
            border_class_logits,
            border_shift_deltas,
        ) = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_shift_deltas, pred_centerness,
            border_box_cls, border_bbox_reg, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        # fcos
        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_centerness = gt_centerness.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()
        acc_centerness_num = gt_centerness[foreground_idxs].sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        dist.all_reduce(num_foreground)
        num_foreground /= dist.get_world_size()
        dist.all_reduce(acc_centerness_num)
        acc_centerness_num /= dist.get_world_size()

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            gt_centerness[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1, acc_centerness_num)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        # borderdet
        gt_classes_border = gt_classes_border.flatten()
        gt_deltas_border = gt_deltas_border.view(-1, 4)

        valid_idxs_border = gt_classes_border >= 0
        foreground_idxs_border = (gt_classes_border >=
                                  0) & (gt_classes_border != self.num_classes)
        num_foreground_border = foreground_idxs_border.sum()

        gt_classes_border_target = torch.zeros_like(border_class_logits)
        gt_classes_border_target[foreground_idxs_border,
                                 gt_classes_border[foreground_idxs_border]] = 1

        dist.all_reduce(num_foreground_border)
        num_foreground_border /= dist.get_world_size()

        num_foreground_border = max(num_foreground_border, 1.0)
        loss_border_cls = sigmoid_focal_loss_jit(
            border_class_logits[valid_idxs_border],
            gt_classes_border_target[valid_idxs_border],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_foreground_border

        if foreground_idxs_border.numel() > 0:
            loss_border_reg = (
                smooth_l1_loss(border_shift_deltas[foreground_idxs_border],
                               gt_deltas_border[foreground_idxs_border],
                               beta=0,
                               reduction="sum") / num_foreground_border)
        else:
            loss_border_reg = border_shift_deltas.sum()

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
            "loss_border_cls": loss_border_cls,
            "loss_border_reg": loss_border_reg,
        }
Пример #7
0
    def losses(self, ins_preds, cate_preds, ins_label_list, cate_label_list,
               ins_ind_label_list):
        """
        Compute losses:

            L = L_cate + λ * L_mask

        Args:
            ins_preds (list[Tensor]): each element in the list is mask prediction results
                of one level, and the shape of each element is [N, G*G, H, W], where:
                * N is the number of images per mini-batch
                * G is the side length of each level of the grids
                * H and W is the height and width of the predicted mask

            cate_preds (list[Tensor]): each element in the list is category prediction results
                of one level, and the shape of each element is [#N, #C, #G, #G], where:
                * C is the number of classes

            ins_label_list (list[list[Tensor]]): each element in the list is mask ground truth
                of one image, and each element is a list which contains mask tensors per level
                with shape [H, W], where:
                * H and W is the ground truth mask size per level (same as `ins_preds`)

            cate_label_list (list[list[Tensor]]): each element in the list is category ground truth
                of one image, and each element is a list which contains tensors with shape [G, G]
                per level.

            ins_ind_label_list (list[list[Tensor]]):  used to indicate which grids contain objects,
                these grids need to calculate mask loss. Each element in the list is indicator
                of one image, and each element is a list which contains tensors with shape [G*G]
                per level。

        Returns:
            dict[str -> Tensor]: losses.
        """
        # ins, per level
        ins_preds_valid = []
        ins_labels_valid = []
        cate_labels_valid = []
        num_images = len(ins_label_list)
        num_levels = len(ins_label_list[0])
        for level_idx in range(num_levels):
            ins_preds_per_level = []
            ins_labels_per_level = []
            cate_labels_per_level = []
            for img_idx in range(num_images):
                valid_ins_inds = ins_ind_label_list[img_idx][level_idx]
                ins_preds_per_level.append(
                    ins_preds[level_idx][img_idx][valid_ins_inds, ...])
                ins_labels_per_level.append(
                    ins_label_list[img_idx][level_idx][valid_ins_inds, ...])
                cate_labels_per_level.append(
                    cate_label_list[img_idx][level_idx].flatten())
            ins_preds_valid.append(torch.cat(ins_preds_per_level))
            ins_labels_valid.append(torch.cat(ins_labels_per_level))
            cate_labels_valid.append(torch.cat(cate_labels_per_level))

        # dice loss, per_level
        loss_ins = []
        for input, target in zip(ins_preds_valid, ins_labels_valid):
            if input.size()[0] == 0:
                continue
            input = torch.sigmoid(input)
            target = target.float() / 255.
            loss_ins.append(dice_loss(input, target))
        # loss_ins (list[Tensor]): each element's shape is [#Ins, #H*#W]
        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.loss_ins_weight

        # cate
        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            for cate_pred in cate_preds
        ]
        cate_preds = torch.cat(cate_preds)

        flatten_cate_labels = torch.cat(cate_labels_valid)
        foreground_idxs = flatten_cate_labels != self.num_classes
        cate_labels = torch.zeros_like(cate_preds)
        cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1
        num_ins = foreground_idxs.sum()

        loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit(
            cate_preds,
            cate_labels,
            alpha=self.loss_cat_alpha,
            gamma=self.loss_cat_gamma,
            reduction="sum",
        ) / max(1, num_ins)
        return dict(loss_ins=loss_ins, loss_cate=loss_cate)
Пример #8
0
    def losses(self, gt_classes, gt_shifts_deltas, gt_centerness,
               pred_class_logits, pred_shift_deltas, pred_centerness):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see
                :meth:`FCOS.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see
                :meth:`FCOSHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_shift_deltas, pred_centerness = \
            permute_all_cls_and_box_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_centerness,
                self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_centerness = gt_centerness.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness)  / float(comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            gt_centerness[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1.0, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        loss = {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
        }

        # budget loss
        if self.is_dynamic_head and self.budget_loss_lambda != 0:
            soft_cost, used_cost, full_cost = get_module_running_cost(self)
            loss_budget = (soft_cost / full_cost).mean() * self.budget_loss_lambda
            storage = get_event_storage()
            storage.put_scalar("complxity_ratio", (used_cost / full_cost).mean())
            loss.update({"loss_budget": loss_budget})

        return loss
Пример #9
0
    def losses(self, indices, gt_instances, anchors, pred_class_logits,
               pred_anchor_deltas):
        pred_class_logits = cat(pred_class_logits,
                                dim=1).view(-1, self.num_classes)
        pred_anchor_deltas = cat(pred_anchor_deltas, dim=1).view(-1, 4)

        anchors = [Boxes.cat(anchors_i) for anchors_i in anchors]
        N = len(anchors)
        # list[Tensor(R, 4)], one for each image
        all_anchors = Boxes.cat(anchors).tensor
        # Boxes(Tensor(N*R, 4))
        predicted_boxes = self.box2box_transform.apply_deltas(
            pred_anchor_deltas, all_anchors)
        predicted_boxes = predicted_boxes.reshape(N, -1, 4)

        ious = []
        pos_ious = []
        for i in range(N):
            src_idx, tgt_idx = indices[i]
            iou, _ = box_iou(predicted_boxes[i, ...],
                             gt_instances[i].gt_boxes.tensor)
            if iou.numel() == 0:
                max_iou = iou.new_full((iou.size(0), ), 0)
            else:
                max_iou = iou.max(dim=1)[0]
            a_iou, _ = box_iou(anchors[i].tensor,
                               gt_instances[i].gt_boxes.tensor)
            if a_iou.numel() == 0:
                pos_iou = a_iou.new_full((0, ), 0)
            else:
                pos_iou = a_iou[src_idx, tgt_idx]
            ious.append(max_iou)
            pos_ious.append(pos_iou)
        ious = torch.cat(ious)
        ignore_idx = ious > self.neg_ignore_thresh
        pos_ious = torch.cat(pos_ious)
        pos_ignore_idx = pos_ious < self.pos_ignore_thresh

        src_idx = torch.cat([
            src + idx * anchors[0].tensor.shape[0]
            for idx, (src, _) in enumerate(indices)
        ])
        gt_classes = torch.full(pred_class_logits.shape[:1],
                                self.num_classes,
                                dtype=torch.int64,
                                device=pred_class_logits.device)
        gt_classes[ignore_idx] = -1
        target_classes_o = torch.cat(
            [t.gt_classes[J] for t, (_, J) in zip(gt_instances, indices)])
        target_classes_o[pos_ignore_idx] = -1
        gt_classes[src_idx] = target_classes_o

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        if comm.get_world_size() > 1:
            dist.all_reduce(num_foreground)
        num_foreground = num_foreground * 1.0 / comm.get_world_size()

        # cls loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        )
        # reg loss
        target_boxes = torch.cat(
            [t.gt_boxes.tensor[i] for t, (_, i) in zip(gt_instances, indices)],
            dim=0)
        target_boxes = target_boxes[~pos_ignore_idx]
        matched_predicted_boxes = predicted_boxes.reshape(
            -1, 4)[src_idx[~pos_ignore_idx]]
        loss_box_reg = (1 - torch.diag(
            generalized_box_iou(matched_predicted_boxes, target_boxes))).sum()

        return {
            "loss_cls": loss_cls / max(1, num_foreground),
            "loss_box_reg": loss_box_reg / max(1, num_foreground),
        }
Пример #10
0
    def losses(self, center_pts, cls_outs, pts_outs_init, pts_outs_refine,
               targets):
        """
        Args:
            center_pts: (list[list[Tensor]]): a list of N=#image elements. Each
                is a list of #feature level tensors. The tensors contains
                shifts of this image on the specific feature level.
            cls_outs: List[Tensor], each item in list with
                shape:[N, num_classes, H, W]
            pts_outs_init: List[Tensor], each item in list with
                shape:[N, num_points*2, H, W]
            pts_outs_refine: List[Tensor], each item in list with
            shape:[N, num_points*2, H, W]
            targets: (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.
                Specify `targets` during training only.

        Returns:
            dict[str:Tensor]:
                mapping from a named loss to scalar tensor
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_outs]
        assert len(featmap_sizes) == len(center_pts[0])

        pts_dim = 2 * self.num_points

        cls_outs = [
            cls_out.permute(0, 2, 3, 1).reshape(cls_out.size(0), -1,
                                                self.num_classes)
            for cls_out in cls_outs
        ]
        pts_outs_init = [
            pts_out_init.permute(0, 2, 3, 1).reshape(pts_out_init.size(0), -1,
                                                     pts_dim)
            for pts_out_init in pts_outs_init
        ]
        pts_outs_refine = [
            pts_out_refine.permute(0, 2, 3, 1).reshape(pts_out_refine.size(0),
                                                       -1, pts_dim)
            for pts_out_refine in pts_outs_refine
        ]

        cls_outs = torch.cat(cls_outs, dim=1)
        pts_outs_init = torch.cat(pts_outs_init, dim=1)
        pts_outs_refine = torch.cat(pts_outs_refine, dim=1)

        pts_strides = []
        for i, s in enumerate(center_pts[0]):
            pts_strides.append(
                cls_outs.new_full((s.size(0), ), self.fpn_strides[i]))
        pts_strides = torch.cat(pts_strides, dim=0)

        center_pts = [
            torch.cat(c_pts, dim=0).to(self.device) for c_pts in center_pts
        ]

        pred_cls = []
        pred_init = []
        pred_refine = []

        target_cls = []
        target_init = []
        target_refine = []

        num_pos_init = 0
        num_pos_refine = 0

        for img, (per_center_pts, cls_prob, pts_init, pts_refine,
                  per_targets) in enumerate(
                      zip(center_pts, cls_outs, pts_outs_init, pts_outs_refine,
                          targets)):
            assert per_center_pts.shape[:-1] == cls_prob.shape[:-1]

            gt_bboxes = per_targets.gt_boxes.to(cls_prob.device)
            gt_labels = per_targets.gt_classes.to(cls_prob.device)

            pts_init_bbox_targets, pts_init_labels_targets = \
                self.point_targets(per_center_pts,
                                   pts_strides,
                                   gt_bboxes.tensor,
                                   gt_labels)

            # per_center_pts, shape:[N, 18]
            per_center_pts_repeat = per_center_pts.repeat(1, self.num_points)

            normalize_term = self.point_base_scale * pts_strides
            normalize_term = normalize_term.reshape(-1, 1)

            # bbox_center = torch.cat([per_center_pts, per_center_pts], dim=1)
            per_pts_strides = pts_strides.reshape(-1, 1)
            pts_init_coordinate = pts_init * per_pts_strides + \
                per_center_pts_repeat
            init_bbox_pred = self.pts_to_bbox(pts_init_coordinate)

            foreground_idxs = (pts_init_labels_targets >= 0) & \
                (pts_init_labels_targets != self.num_classes)

            pred_init.append(init_bbox_pred[foreground_idxs] /
                             normalize_term[foreground_idxs])
            target_init.append(pts_init_bbox_targets[foreground_idxs] /
                               normalize_term[foreground_idxs])
            num_pos_init += foreground_idxs.sum()

            # A another way to convert predicted offset to bbox
            # bbox_pred_init = self.pts_to_bbox(pts_init.detach()) * \
            #     per_pts_strides
            # init_bbox_pred = bbox_center + bbox_pred_init

            pts_refine_bbox_targets, pts_refine_labels_targets = \
                self.bbox_targets(init_bbox_pred, gt_bboxes, gt_labels)

            pts_refine_coordinate = pts_refine * per_pts_strides + \
                per_center_pts_repeat
            refine_bbox_pred = self.pts_to_bbox(pts_refine_coordinate)

            # bbox_pred_refine = self.pts_to_bbox(pts_refine) * per_pts_strides
            # refine_bbox_pred = bbox_center + bbox_pred_refine

            foreground_idxs = (pts_refine_labels_targets >= 0) & \
                (pts_refine_labels_targets != self.num_classes)

            pred_refine.append(refine_bbox_pred[foreground_idxs] /
                               normalize_term[foreground_idxs])
            target_refine.append(pts_refine_bbox_targets[foreground_idxs] /
                                 normalize_term[foreground_idxs])
            num_pos_refine += foreground_idxs.sum()

            gt_classes_target = torch.zeros_like(cls_prob)
            gt_classes_target[foreground_idxs,
                              pts_refine_labels_targets[foreground_idxs]] = 1
            pred_cls.append(cls_prob)
            target_cls.append(gt_classes_target)

        pred_cls = torch.cat(pred_cls, dim=0)
        pred_init = torch.cat(pred_init, dim=0)
        pred_refine = torch.cat(pred_refine, dim=0)

        target_cls = torch.cat(target_cls, dim=0)
        target_init = torch.cat(target_init, dim=0)
        target_refine = torch.cat(target_refine, dim=0)

        loss_cls = sigmoid_focal_loss_jit(
            pred_cls,
            target_cls,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum") / max(
                1, num_pos_refine.item()) * self.loss_cls_weight

        loss_pts_init = smooth_l1_loss(
            pred_init, target_init, beta=0.11, reduction='sum') / max(
                1, num_pos_init.item()) * self.loss_bbox_init_weight

        loss_pts_refine = smooth_l1_loss(
            pred_refine, target_refine, beta=0.11, reduction='sum') / max(
                1, num_pos_refine.item()) * self.loss_bbox_refine_weight

        return {
            "loss_cls": loss_cls,
            "loss_pts_init": loss_pts_init,
            "loss_pts_refine": loss_pts_refine
        }
Пример #11
0
    def losses(self, ins_preds_x, ins_preds_y, cate_preds, ins_label_list,
               cate_label_list, ins_ind_label_list, ins_ind_label_list_xy):
        # ins, per level
        ins_labels = []  # per level
        for ins_labels_level, ins_ind_labels_level in \
                zip(zip(*ins_label_list), zip(*ins_ind_label_list)):
            ins_labels_per_level = []
            for ins_labels_level_img, ins_ind_labels_level_img in \
                    zip(ins_labels_level, ins_ind_labels_level):
                ins_labels_per_level.append(
                    ins_labels_level_img[ins_ind_labels_level_img, ...])
            ins_labels.append(torch.cat(ins_labels_per_level))

        ins_preds_x_final = []
        for ins_preds_level_x, ins_ind_labels_level in \
                zip(ins_preds_x, zip(*ins_ind_label_list_xy)):
            ins_preds_x_final_per_level = []
            for ins_preds_level_img_x, ins_ind_labels_level_img in \
                    zip(ins_preds_level_x, ins_ind_labels_level):
                ins_preds_x_final_per_level.append(
                    ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...])
            ins_preds_x_final.append(torch.cat(ins_preds_x_final_per_level))

        ins_preds_y_final = []
        for ins_preds_level_y, ins_ind_labels_level in \
                zip(ins_preds_y, zip(*ins_ind_label_list_xy)):
            ins_preds_y_final_per_level = []
            for ins_preds_level_img_y, ins_ind_labels_level_img in \
                    zip(ins_preds_level_y, ins_ind_labels_level):
                ins_preds_y_final_per_level.append(
                    ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...])
            ins_preds_y_final.append(torch.cat(ins_preds_y_final_per_level))

        num_ins = 0.
        # dice loss, per_level
        loss_ins = []
        for input_x, input_y, target in zip(ins_preds_x_final,
                                            ins_preds_y_final, ins_labels):
            mask_n = input_x.size(0)
            if mask_n == 0:
                continue
            num_ins += mask_n
            input = (input_x.sigmoid()) * (input_y.sigmoid())
            target = target.float() / 255.
            loss_ins.append(dice_loss(input, target))

        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.loss_ins_weight

        # cate
        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            for cate_pred in cate_preds
        ]
        cate_preds = torch.cat(cate_preds)

        cate_labels = []
        for cate_labels_level in zip(*cate_label_list):
            cate_labels_per_level = []
            for cate_labels_level_img in cate_labels_level:
                cate_labels_per_level.append(cate_labels_level_img.flatten())
            cate_labels.append(torch.cat(cate_labels_per_level))
        flatten_cate_labels = torch.cat(cate_labels)
        foreground_idxs = flatten_cate_labels != self.num_classes
        cate_labels = torch.zeros_like(cate_preds)
        cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1

        loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit(
            cate_preds,
            cate_labels,
            alpha=self.loss_cat_alpha,
            gamma=self.loss_cat_gamma,
            reduction="sum",
        ) / max(1, num_ins)
        return dict(loss_ins=loss_ins, loss_cate=loss_cate)
Пример #12
0
    def proposals_losses(self, gt_classes, gt_shifts_deltas, gt_centerness,
                         gt_inds, im_inds, pred_class_logits,
                         pred_shift_deltas, pred_centerness, pred_inst_params,
                         fpn_levels, shifts):

        pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params = \
            permute_all_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_centerness,
                pred_inst_params, self.num_gen_params, self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.reshape(-1, 4)
        gt_centerness = gt_centerness.reshape(-1, 1)
        fpn_levels = fpn_levels.flatten()
        im_inds = im_inds.flatten()
        gt_inds = gt_inds.flatten()

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness) / float(
            comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(pred_shift_deltas[foreground_idxs],
                                gt_shifts_deltas[foreground_idxs],
                                gt_centerness[foreground_idxs],
                                box_mode="ltrb",
                                loss_type=self.iou_loss_type,
                                reduction="sum",
                                smooth=self.iou_smooth) / max(
                                    1e-6, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        proposals_losses = {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness
        }

        all_shifts = torch.cat([torch.cat(shift) for shift in shifts])
        proposals = Instances((0, 0))
        proposals.inst_parmas = pred_inst_params[foreground_idxs]
        proposals.fpn_levels = fpn_levels[foreground_idxs]
        proposals.shifts = all_shifts[foreground_idxs]
        proposals.gt_inds = gt_inds[foreground_idxs]
        proposals.im_inds = im_inds[foreground_idxs]

        # select_instances for saving memory
        if len(proposals):
            if self.topk_proposals_per_im != -1:
                proposals.gt_cls = gt_classes[foreground_idxs]
                proposals.pred_logits = pred_class_logits[foreground_idxs]
                proposals.pred_centerness = pred_centerness[foreground_idxs]
            proposals = self.select_instances(proposals)

        return proposals_losses, proposals
Пример #13
0
    def get_lla_assignments_and_losses(self, shifts, targets, box_cls,
                                       box_delta, box_iou):

        gt_classes = []

        box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]
        box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]
        box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]

        box_cls = torch.cat(box_cls, dim=1)
        box_delta = torch.cat(box_delta, dim=1)
        box_iou = torch.cat(box_iou, dim=1)

        losses_cls = []
        losses_box_reg = []
        losses_iou = []

        num_fg = 0

        for shifts_per_image, targets_per_image, box_cls_per_image, \
                box_delta_per_image, box_iou_per_image in zip(
                shifts, targets, box_cls, box_delta, box_iou):

            shifts_over_all = torch.cat(shifts_per_image, dim=0)

            gt_boxes = targets_per_image.gt_boxes
            gt_classes = targets_per_image.gt_classes

            deltas = self.shift2box_transform.get_deltas(
                shifts_over_all, gt_boxes.tensor.unsqueeze(1))
            is_in_boxes = deltas.min(dim=-1).values > 0.01

            shape = (len(targets_per_image), len(shifts_over_all), -1)
            box_cls_per_image_unexpanded = box_cls_per_image
            box_delta_per_image_unexpanded = box_delta_per_image

            box_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape)
            gt_cls_per_image = F.one_hot(
                torch.max(gt_classes, torch.zeros_like(gt_classes)),
                self.num_classes).float().unsqueeze(1).expand(shape)

            with torch.no_grad():
                loss_cls = sigmoid_focal_loss_jit(
                    box_cls_per_image,
                    gt_cls_per_image,
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma).sum(dim=-1)
                loss_cls_bg = sigmoid_focal_loss_jit(
                    box_cls_per_image_unexpanded,
                    torch.zeros_like(box_cls_per_image_unexpanded),
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma).sum(dim=-1)
                box_delta_per_image = box_delta_per_image.unsqueeze(0).expand(
                    shape)
                gt_delta_per_image = self.shift2box_transform.get_deltas(
                    shifts_over_all, gt_boxes.tensor.unsqueeze(1))
                loss_delta = iou_loss(box_delta_per_image,
                                      gt_delta_per_image,
                                      box_mode="ltrb",
                                      loss_type='iou')

                ious = get_ious(box_delta_per_image,
                                gt_delta_per_image,
                                box_mode="ltrb",
                                loss_type='iou')

                loss = loss_cls + self.reg_cost * loss_delta + 1e3 * (
                    1 - is_in_boxes.float())
                loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)

                num_gt = loss.shape[0] - 1
                num_anchor = loss.shape[1]

                # Topk
                matching_matrix = torch.zeros_like(loss)
                _, topk_idx = torch.topk(loss[:-1],
                                         k=self.topk,
                                         dim=1,
                                         largest=False)
                matching_matrix[torch.arange(num_gt).unsqueeze(1).
                                repeat(1, self.topk).view(-1),
                                topk_idx.view(-1)] = 1.

                # make sure one anchor with one gt
                anchor_matched_gt = matching_matrix.sum(0)
                if (anchor_matched_gt > 1).sum() > 0:
                    loss_min, loss_argmin = torch.min(
                        loss[:-1, anchor_matched_gt > 1], dim=0)
                    matching_matrix[:, anchor_matched_gt > 1] *= 0.
                    matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1.
                    anchor_matched_gt = matching_matrix.sum(0)
                num_fg += matching_matrix.sum()
                matching_matrix[
                    -1] = 1. - anchor_matched_gt  # assignment for Background
                assigned_gt_inds = torch.argmax(matching_matrix, dim=0)

                gt_cls_per_image_bg = gt_cls_per_image.new_zeros(
                    (gt_cls_per_image.size(1),
                     gt_cls_per_image.size(2))).unsqueeze(0)
                gt_cls_per_image_with_bg = torch.cat(
                    [gt_cls_per_image, gt_cls_per_image_bg], dim=0)
                cls_target_per_image = gt_cls_per_image_with_bg[
                    assigned_gt_inds,
                    torch.arange(num_anchor)]

                # Dealing with Crowdhuman ignore label
                gt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)])
                anchor_cls_labels = gt_classes_[assigned_gt_inds]
                valid_flag = anchor_cls_labels >= 0

                pos_mask = assigned_gt_inds != len(
                    targets_per_image)  # get foreground mask
                valid_fg = pos_mask & valid_flag
                assigned_fg_inds = assigned_gt_inds[valid_fg]
                range_fg = torch.arange(num_anchor)[valid_fg]
                ious_fg = ious[assigned_fg_inds, range_fg]

            anchor_loss_cls = sigmoid_focal_loss_jit(
                box_cls_per_image_unexpanded[valid_flag],
                cls_target_per_image[valid_flag],
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma).sum(dim=-1)

            delta_target = gt_delta_per_image[assigned_fg_inds, range_fg]
            anchor_loss_delta = 2. * iou_loss(
                box_delta_per_image_unexpanded[valid_fg],
                delta_target,
                box_mode="ltrb",
                loss_type=self.iou_loss_type)

            anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits(
                box_iou_per_image.squeeze(1)[valid_fg],
                ious_fg,
                reduction='none')

            losses_cls.append(anchor_loss_cls.sum())
            losses_box_reg.append(anchor_loss_delta.sum())
            losses_iou.append(anchor_loss_iou.sum())

            if self.norm_sync:
                dist.all_reduce(num_fg)
                num_fg = num_fg.float() / dist.get_world_size()

        return {
            'loss_cls': torch.stack(losses_cls).sum() / num_fg,
            'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg,
            'loss_iou': torch.stack(losses_iou).sum() / num_fg
        }
Пример #14
0
    def get_ground_truth(self, shifts, targets, box_cls, box_delta):
        """
        Args:
            shifts (list[list[Tensor]]): a list of N=#image elements. Each is a
                list of #feature level tensors. The tensors contains shifts of
                this image on the specific feature level.
            targets (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.  Specify `targets` during training only.

        Returns:
            gt_classes (Tensor):
                An integer tensor of shape (N, R) storing ground-truth
                labels for each shift.
                R is the total number of shifts, i.e. the sum of Hi x Wi for all levels.
                Shifts in the valid boxes are assigned their corresponding label in the
                [0, K-1] range. Shifts in the background are assigned the label "K".
                Shifts in the ignore areas are assigned a label "-1", i.e. ignore.
            gt_shifts_deltas (Tensor):
                Shape (N, R, 4).
                The last dimension represents ground-truth shift2box transform
                targets (dl, dt, dr, db) that map each shift to its matched ground-truth box.
                The values in the tensor are meaningful only when the corresponding
                shift is labeled as foreground.
        """
        gt_classes = []
        gt_shifts_deltas = []

        box_cls = torch.cat(
            [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls], dim=1)
        box_delta = torch.cat([permute_to_N_HWA_K(x, 4) for x in box_delta],
                              dim=1)

        num_fg = 0
        num_gt = 0

        for shifts_per_image, targets_per_image, box_cls_per_image, box_delta_per_image in zip(
                shifts, targets, box_cls, box_delta):
            shifts_over_all_feature_maps = torch.cat(shifts_per_image, dim=0)

            gt_boxes = targets_per_image.gt_boxes

            shape = (len(targets_per_image), len(shifts_over_all_feature_maps),
                     -1)

            gt_cls_per_image = F.one_hot(targets_per_image.gt_classes,
                                         self.num_classes).float()
            loss_cls = sigmoid_focal_loss_jit(
                box_cls_per_image.unsqueeze(0).expand(shape),
                gt_cls_per_image.unsqueeze(1).expand(shape),
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
            ).sum(dim=2)

            gt_delta_per_image = self.shift2box_transform.get_deltas(
                shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1))
            loss_delta = iou_loss(
                box_delta_per_image.unsqueeze(0).expand(shape),
                gt_delta_per_image,
                box_mode="ltrb",
                loss_type=self.iou_loss_type,
            ) * self.reg_weight

            loss = loss_cls + loss_delta

            INF = 1e8

            deltas = self.shift2box_transform.get_deltas(
                shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1))

            if self.center_sampling_radius > 0:
                centers = gt_boxes.get_centers()
                is_in_boxes = []
                for stride, shifts_i in zip(self.fpn_strides,
                                            shifts_per_image):
                    radius = stride * self.center_sampling_radius
                    center_boxes = torch.cat((
                        torch.max(centers - radius, gt_boxes.tensor[:, :2]),
                        torch.min(centers + radius, gt_boxes.tensor[:, 2:]),
                    ),
                                             dim=-1)
                    center_deltas = self.shift2box_transform.get_deltas(
                        shifts_i, center_boxes.unsqueeze(1))
                    is_in_boxes.append(center_deltas.min(dim=-1).values > 0)
                is_in_boxes = torch.cat(is_in_boxes, dim=1)
            else:
                # no center sampling, it will use all the locations within a ground-truth box
                is_in_boxes = deltas.min(dim=-1).values > 0

            loss[~is_in_boxes] = INF

            gt_idxs, shift_idxs = linear_sum_assignment(loss.cpu().numpy())

            num_fg += len(shift_idxs)
            num_gt += len(targets_per_image)

            gt_classes_i = shifts_over_all_feature_maps.new_full(
                (len(shifts_over_all_feature_maps), ),
                self.num_classes,
                dtype=torch.long)
            gt_shifts_reg_deltas_i = shifts_over_all_feature_maps.new_zeros(
                len(shifts_over_all_feature_maps), 4)
            if len(targets_per_image) > 0:
                # ground truth classes
                gt_classes_i[shift_idxs] = targets_per_image.gt_classes[
                    gt_idxs]
                # ground truth box regression
                gt_shifts_reg_deltas_i[
                    shift_idxs] = self.shift2box_transform.get_deltas(
                        shifts_over_all_feature_maps[shift_idxs],
                        gt_boxes[gt_idxs].tensor)

            gt_classes.append(gt_classes_i)
            gt_shifts_deltas.append(gt_shifts_reg_deltas_i)

        get_event_storage().put_scalar("num_fg_per_gt", num_fg / num_gt)

        return torch.stack(gt_classes), torch.stack(gt_shifts_deltas)
Пример #15
0
    def get_ground_truth(self, shifts, targets, box_cls, box_delta, box_iou):

        gt_classes = []
        gt_shifts_deltas = []
        gt_ious = []
        assigned_units = []

        box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]
        box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]
        box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]

        box_cls = torch.cat(box_cls, dim=1)
        box_delta = torch.cat(box_delta, dim=1)
        box_iou = torch.cat(box_iou, dim=1)

        for shifts_per_image, targets_per_image, box_cls_per_image, \
                box_delta_per_image, box_iou_per_image in zip(
                    shifts, targets, box_cls, box_delta, box_iou):

            shifts_over_all = torch.cat(shifts_per_image, dim=0)

            gt_boxes = targets_per_image.gt_boxes

            # In gt box and center.
            deltas = self.shift2box_transform.get_deltas(
                shifts_over_all, gt_boxes.tensor.unsqueeze(1))
            is_in_boxes = deltas.min(dim=-1).values > 0.01

            center_sampling_radius = 2.5
            centers = gt_boxes.get_centers()
            is_in_centers = []
            for stride, shifts_i in zip(self.fpn_strides, shifts_per_image):
                radius = stride * center_sampling_radius
                center_boxes = torch.cat((
                    torch.max(centers - radius, gt_boxes.tensor[:, :2]),
                    torch.min(centers + radius, gt_boxes.tensor[:, 2:]),
                ),
                                         dim=-1)
                center_deltas = self.shift2box_transform.get_deltas(
                    shifts_i, center_boxes.unsqueeze(1))
                is_in_centers.append(center_deltas.min(dim=-1).values > 0)
            is_in_centers = torch.cat(is_in_centers, dim=1)
            del centers, center_boxes, deltas, center_deltas
            is_in_boxes = (is_in_boxes & is_in_centers)

            num_gt = len(targets_per_image)
            num_anchor = len(shifts_over_all)
            shape = (num_gt, num_anchor, -1)

            gt_cls_per_image = F.one_hot(targets_per_image.gt_classes,
                                         self.num_classes).float()

            with torch.no_grad():
                loss_cls = sigmoid_focal_loss_jit(
                    box_cls_per_image.unsqueeze(0).expand(shape),
                    gt_cls_per_image.unsqueeze(1).expand(shape),
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma,
                ).sum(dim=-1)

                loss_cls_bg = sigmoid_focal_loss_jit(
                    box_cls_per_image,
                    torch.zeros_like(box_cls_per_image),
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma,
                ).sum(dim=-1)

                gt_delta_per_image = self.shift2box_transform.get_deltas(
                    shifts_over_all, gt_boxes.tensor.unsqueeze(1))

                ious, loss_delta = get_ious_and_iou_loss(
                    box_delta_per_image.unsqueeze(0).expand(shape),
                    gt_delta_per_image,
                    box_mode="ltrb",
                    loss_type='iou')

                loss = loss_cls + self.reg_weight * loss_delta + 1e6 * (
                    1 - is_in_boxes.float())

                # Performing Dynamic k Estimation
                topk_ious, _ = torch.topk(ious * is_in_boxes.float(),
                                          self.top_candidates,
                                          dim=1)
                mu = ious.new_ones(num_gt + 1)
                mu[:-1] = torch.clamp(topk_ious.sum(1).int(), min=1).float()
                mu[-1] = num_anchor - mu[:-1].sum()
                nu = ious.new_ones(num_anchor)
                loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)

                # Solving Optimal-Transportation-Plan pi via Sinkhorn-Iteration.
                _, pi = self.sinkhorn(mu, nu, loss)

                # Rescale pi so that the max pi for each gt equals to 1.
                rescale_factor, _ = pi.max(dim=1)
                pi = pi / rescale_factor.unsqueeze(1)

                max_assigned_units, matched_gt_inds = torch.max(pi, dim=0)
                gt_classes_i = targets_per_image.gt_classes.new_ones(
                    num_anchor) * self.num_classes
                fg_mask = matched_gt_inds != num_gt
                gt_classes_i[fg_mask] = targets_per_image.gt_classes[
                    matched_gt_inds[fg_mask]]
                gt_classes.append(gt_classes_i)
                assigned_units.append(max_assigned_units)

                box_target_per_image = gt_delta_per_image.new_zeros(
                    (num_anchor, 4))
                box_target_per_image[fg_mask] = \
                    gt_delta_per_image[matched_gt_inds[fg_mask], torch.arange(num_anchor)[fg_mask]]
                gt_shifts_deltas.append(box_target_per_image)

                gt_ious_per_image = ious.new_zeros((num_anchor, 1))
                gt_ious_per_image[fg_mask] = ious[
                    matched_gt_inds[fg_mask],
                    torch.arange(num_anchor)[fg_mask]].unsqueeze(1)
                gt_ious.append(gt_ious_per_image)

        return torch.cat(gt_classes), torch.cat(gt_shifts_deltas), torch.cat(
            gt_ious)