예제 #1
0
    def smooth_l1_loss(self):
        """
        Compute the smooth L1 loss for box regression.

        Returns:
            scalar Tensor
        """
        if self._no_instances:
            return 0.0 * self.pred_proposal_deltas.sum()

        gt_proposal_deltas = self.box2box_transform.get_deltas(
            self.proposals.tensor, self.gt_boxes.tensor)
        box_dim = gt_proposal_deltas.size(1)  # 4 or 5
        cls_agnostic_bbox_reg = self.pred_proposal_deltas.size(1) == box_dim
        device = self.pred_proposal_deltas.device

        bg_class_ind = self.pred_class_logits.shape[1] - 1

        # Box delta loss is only computed between the prediction for the gt class k
        # (if 0 <= k < bg_class_ind) and the target; there is no loss defined on predictions
        # for non-gt classes and background.
        # Empty fg_inds produces a valid loss of zero as long as the size_average
        # arg to smooth_l1_loss is False (otherwise it uses torch.mean internally
        # and would produce a nan loss).
        fg_inds = torch.nonzero(
            (self.gt_classes >= 0) & (self.gt_classes < bg_class_ind),
            as_tuple=False).squeeze(1)
        if cls_agnostic_bbox_reg:
            # pred_proposal_deltas only corresponds to foreground class for agnostic
            gt_class_cols = torch.arange(box_dim, device=device)
        else:
            fg_gt_classes = self.gt_classes[fg_inds]
            # pred_proposal_deltas for class k are located in columns [b * k : b * k + b],
            # where b is the dimension of box representation (4 or 5)
            # Note that compared to Detectron1,
            # we do not perform bounding box regression for background classes.
            gt_class_cols = box_dim * fg_gt_classes[:, None] + torch.arange(
                box_dim, device=device)

        loss_box_reg = smooth_l1_loss(
            self.pred_proposal_deltas[fg_inds[:, None], gt_class_cols],
            gt_proposal_deltas[fg_inds],
            self.smooth_l1_beta,
            reduction="sum",
        )
        # The loss is normalized using the total number of regions (R), not the number
        # of foreground regions even though the box regression loss is only defined on
        # foreground regions. Why? Because doing so gives equal training influence to
        # each foreground example. To see how, consider two different minibatches:
        #  (1) Contains a single foreground region
        #  (2) Contains 100 foreground regions
        # If we normalize by the number of foreground regions, the single example in
        # minibatch (1) will be given 100 times as much influence as each foreground
        # example in minibatch (2). Normalizing by the total number of regions, R,
        # means that the single example in minibatch (1) and each of the 100 examples
        # in minibatch (2) are given equal influence.
        loss_box_reg = loss_box_reg / self.gt_classes.numel()
        return loss_box_reg
예제 #2
0
    def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits,
               pred_anchor_deltas):
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`RetinaNet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits` and `pred_anchor_deltas`, see
                :meth:`RetinaNetHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_anchor_deltas, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_anchors_deltas = gt_anchors_deltas.view(-1, 4)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
            1 - self.loss_normalizer_momentum) * max(num_foreground.item(), 1)

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / self.loss_normalizer

        # regression loss
        loss_box_reg = smooth_l1_loss(
            pred_anchor_deltas[foreground_idxs],
            gt_anchors_deltas[foreground_idxs],
            beta=self.smooth_l1_loss_beta,
            reduction="sum",
        ) / self.loss_normalizer

        return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
예제 #3
0
    def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits,
               pred_anchor_deltas):
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`EfficientDet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits` and `pred_anchor_deltas`, see
                :meth:`EfficientDetHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_anchor_deltas, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_anchors_deltas = gt_anchors_deltas.view(-1, 4)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        # Classification loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1, num_foreground)

        # Regression loss, refer to the official released code.
        # See: https://github.com/google/automl/blob/master/efficientdet/det_model_fn.py
        loss_box_reg = self.box_loss_weight * self.smooth_l1_loss_beta * smooth_l1_loss(
            pred_anchor_deltas[foreground_idxs],
            gt_anchors_deltas[foreground_idxs],
            beta=self.smooth_l1_loss_beta,
            reduction="sum",
        ) / max(1, num_foreground * self.regress_norm)

        return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
예제 #4
0
파일: net.py 프로젝트: zzzhoudj/BorderDet
    def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits,
               pred_anchor_deltas):
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`RetinaNet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits` and `pred_anchor_deltas`, see
                :meth:`RetinaNetHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_anchor_deltas, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_anchors_deltas = gt_anchors_deltas.view(-1, 4)

        pos_inds = torch.nonzero((gt_classes >= 0)
                                 & (gt_classes != self.num_classes)).squeeze(1)

        retinanet_regression_loss = smooth_l1_loss(
            pred_anchor_deltas[pos_inds],
            gt_anchors_deltas[pos_inds],
            beta=self.smooth_l1_loss_beta,
            # size_average=False,
            reduction="sum",
        ) / max(1,
                pos_inds.numel() * self.regress_norm)

        labels = torch.ones_like(gt_classes)
        # convert labels from 0~79 to 1~80
        labels[pos_inds] += gt_classes[pos_inds]
        labels[gt_classes == -1] = gt_classes[gt_classes == -1]
        labels[gt_classes == self.num_classes] = 0
        labels = labels.int()

        retinanet_cls_loss = self.box_cls_loss_func(pred_class_logits, labels)

        return {
            "loss_cls": retinanet_cls_loss,
            "loss_box_reg": retinanet_regression_loss
        }
예제 #5
0
def rpn_losses(
    gt_objectness_logits,
    gt_anchor_deltas,
    pred_objectness_logits,
    pred_anchor_deltas,
    smooth_l1_beta,
):
    """
    Args:
        gt_objectness_logits (Tensor): shape (N,), each element in {-1, 0, 1} representing
            ground-truth objectness labels with: -1 = ignore; 0 = not object; 1 = object.
        gt_anchor_deltas (Tensor): shape (N, box_dim), row i represents ground-truth
            box2box transform targets (dx, dy, dw, dh) or (dx, dy, dw, dh, da) that map anchor i to
            its matched ground-truth box.
        pred_objectness_logits (Tensor): shape (N,), each element is a predicted objectness
            logit.
        pred_anchor_deltas (Tensor): shape (N, box_dim), each row is a predicted box2box
            transform (dx, dy, dw, dh) or (dx, dy, dw, dh, da)
        smooth_l1_beta (float): The transition point between L1 and L2 loss in
            the smooth L1 loss function. When set to 0, the loss becomes L1. When
            set to +inf, the loss becomes constant 0.

    Returns:
        objectness_loss, localization_loss, both unnormalized (summed over samples).
    """
    pos_masks = gt_objectness_logits == 1
    localization_loss = smooth_l1_loss(pred_anchor_deltas[pos_masks],
                                       gt_anchor_deltas[pos_masks],
                                       smooth_l1_beta,
                                       reduction="sum")

    valid_masks = gt_objectness_logits >= 0
    objectness_loss = F.binary_cross_entropy_with_logits(
        pred_objectness_logits[valid_masks],
        gt_objectness_logits[valid_masks].to(torch.float32),
        reduction="sum",
    )
    return objectness_loss, localization_loss
예제 #6
0
    def losses(
        self,
        gt_classes,
        gt_shifts_deltas,
        gt_centerness,
        gt_classes_border,
        gt_deltas_border,
        pred_class_logits,
        pred_shift_deltas,
        pred_centerness,
        border_box_cls,
        border_bbox_reg,
    ):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see
                :meth:`BorderDet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see
                :meth:`BorderHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        (
            pred_class_logits,
            pred_shift_deltas,
            pred_centerness,
            border_class_logits,
            border_shift_deltas,
        ) = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_shift_deltas, pred_centerness,
            border_box_cls, border_bbox_reg, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        # fcos
        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_centerness = gt_centerness.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()
        acc_centerness_num = gt_centerness[foreground_idxs].sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        dist.all_reduce(num_foreground)
        num_foreground /= dist.get_world_size()
        dist.all_reduce(acc_centerness_num)
        acc_centerness_num /= dist.get_world_size()

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            gt_centerness[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1, acc_centerness_num)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        # borderdet
        gt_classes_border = gt_classes_border.flatten()
        gt_deltas_border = gt_deltas_border.view(-1, 4)

        valid_idxs_border = gt_classes_border >= 0
        foreground_idxs_border = (gt_classes_border >=
                                  0) & (gt_classes_border != self.num_classes)
        num_foreground_border = foreground_idxs_border.sum()

        gt_classes_border_target = torch.zeros_like(border_class_logits)
        gt_classes_border_target[foreground_idxs_border,
                                 gt_classes_border[foreground_idxs_border]] = 1

        dist.all_reduce(num_foreground_border)
        num_foreground_border /= dist.get_world_size()

        num_foreground_border = max(num_foreground_border, 1.0)
        loss_border_cls = sigmoid_focal_loss_jit(
            border_class_logits[valid_idxs_border],
            gt_classes_border_target[valid_idxs_border],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_foreground_border

        if foreground_idxs_border.numel() > 0:
            loss_border_reg = (
                smooth_l1_loss(border_shift_deltas[foreground_idxs_border],
                               gt_deltas_border[foreground_idxs_border],
                               beta=0,
                               reduction="sum") / num_foreground_border)
        else:
            loss_border_reg = border_shift_deltas.sum()

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
            "loss_border_cls": loss_border_cls,
            "loss_border_reg": loss_border_reg,
        }
예제 #7
0
    def losses(
        self,
        gt_class_info,
        gt_delta_info,
        gt_mask_info,
        num_fg,
        pred_logits,
        pred_deltas,
        pred_masks,
    ):
        """
        Args:
            For `gt_class_info`, `gt_delta_info`, `gt_mask_info` and `num_fg` parameters, see
                :meth:`TensorMask.get_ground_truth`.
            For `pred_logits`, `pred_deltas` and `pred_masks`, see
                :meth:`TensorMaskHead.forward`.
        Returns:
            losses (dict[str: Tensor]): mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The potential dict keys are:
                "loss_cls", "loss_box_reg" and "loss_mask".
        """
        gt_classes_target, gt_valid_inds = gt_class_info
        gt_deltas, gt_fg_inds = gt_delta_info
        gt_masks, gt_mask_inds = gt_mask_info
        loss_normalizer = torch.tensor(max(1, num_fg),
                                       dtype=torch.float32,
                                       device=self.device)

        # classification and regression
        pred_logits, pred_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_logits, pred_deltas, self.num_classes)
        loss_cls = (sigmoid_focal_loss_star_jit(
            pred_logits[gt_valid_inds],
            gt_classes_target[gt_valid_inds],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / loss_normalizer)

        if num_fg == 0:
            loss_box_reg = pred_deltas.sum() * 0
        else:
            loss_box_reg = (smooth_l1_loss(
                pred_deltas[gt_fg_inds], gt_deltas, beta=0.0, reduction="sum")
                            / loss_normalizer)
        losses = {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}

        # mask prediction
        if self.mask_on:
            loss_mask = 0
            for lvl in range(self.num_levels):
                cur_level_factor = 2**lvl if self.bipyramid_on else 1
                for anc in range(self.num_anchors):
                    cur_gt_mask_inds = gt_mask_inds[lvl][anc]
                    if cur_gt_mask_inds is None:
                        loss_mask += pred_masks[lvl][anc][0, 0, 0, 0] * 0
                    else:
                        cur_mask_size = self.mask_sizes[anc] * cur_level_factor
                        # TODO maybe there are numerical issues when mask sizes are large
                        cur_size_divider = torch.tensor(
                            self.mask_loss_weight / (cur_mask_size**2),
                            dtype=torch.float32,
                            device=self.device,
                        )

                        cur_pred_masks = pred_masks[lvl][anc][
                            cur_gt_mask_inds[:, 0],  # N
                            :,  # V x U
                            cur_gt_mask_inds[:, 1],  # H
                            cur_gt_mask_inds[:, 2],  # W
                        ]

                        loss_mask += F.binary_cross_entropy_with_logits(
                            # V, U
                            cur_pred_masks.view(-1, cur_mask_size,
                                                cur_mask_size),
                            gt_masks[lvl][anc].to(dtype=torch.float32),
                            reduction="sum",
                            weight=cur_size_divider,
                            pos_weight=self.mask_pos_weight,
                        )
            losses["loss_mask"] = loss_mask / loss_normalizer
        return losses
예제 #8
0
    def losses(self, center_pts, cls_outs, pts_outs_init, pts_outs_refine,
               targets):
        """
        Args:
            center_pts: (list[list[Tensor]]): a list of N=#image elements. Each
                is a list of #feature level tensors. The tensors contains
                shifts of this image on the specific feature level.
            cls_outs: List[Tensor], each item in list with
                shape:[N, num_classes, H, W]
            pts_outs_init: List[Tensor], each item in list with
                shape:[N, num_points*2, H, W]
            pts_outs_refine: List[Tensor], each item in list with
            shape:[N, num_points*2, H, W]
            targets: (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.
                Specify `targets` during training only.

        Returns:
            dict[str:Tensor]:
                mapping from a named loss to scalar tensor
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_outs]
        assert len(featmap_sizes) == len(center_pts[0])

        pts_dim = 2 * self.num_points

        cls_outs = [
            cls_out.permute(0, 2, 3, 1).reshape(cls_out.size(0), -1,
                                                self.num_classes)
            for cls_out in cls_outs
        ]
        pts_outs_init = [
            pts_out_init.permute(0, 2, 3, 1).reshape(pts_out_init.size(0), -1,
                                                     pts_dim)
            for pts_out_init in pts_outs_init
        ]
        pts_outs_refine = [
            pts_out_refine.permute(0, 2, 3, 1).reshape(pts_out_refine.size(0),
                                                       -1, pts_dim)
            for pts_out_refine in pts_outs_refine
        ]

        cls_outs = torch.cat(cls_outs, dim=1)
        pts_outs_init = torch.cat(pts_outs_init, dim=1)
        pts_outs_refine = torch.cat(pts_outs_refine, dim=1)

        pts_strides = []
        for i, s in enumerate(center_pts[0]):
            pts_strides.append(
                cls_outs.new_full((s.size(0), ), self.fpn_strides[i]))
        pts_strides = torch.cat(pts_strides, dim=0)

        center_pts = [
            torch.cat(c_pts, dim=0).to(self.device) for c_pts in center_pts
        ]

        pred_cls = []
        pred_init = []
        pred_refine = []

        target_cls = []
        target_init = []
        target_refine = []

        num_pos_init = 0
        num_pos_refine = 0

        for img, (per_center_pts, cls_prob, pts_init, pts_refine,
                  per_targets) in enumerate(
                      zip(center_pts, cls_outs, pts_outs_init, pts_outs_refine,
                          targets)):
            assert per_center_pts.shape[:-1] == cls_prob.shape[:-1]

            gt_bboxes = per_targets.gt_boxes.to(cls_prob.device)
            gt_labels = per_targets.gt_classes.to(cls_prob.device)

            pts_init_bbox_targets, pts_init_labels_targets = \
                self.point_targets(per_center_pts,
                                   pts_strides,
                                   gt_bboxes.tensor,
                                   gt_labels)

            # per_center_pts, shape:[N, 18]
            per_center_pts_repeat = per_center_pts.repeat(1, self.num_points)

            normalize_term = self.point_base_scale * pts_strides
            normalize_term = normalize_term.reshape(-1, 1)

            # bbox_center = torch.cat([per_center_pts, per_center_pts], dim=1)
            per_pts_strides = pts_strides.reshape(-1, 1)
            pts_init_coordinate = pts_init * per_pts_strides + \
                per_center_pts_repeat
            init_bbox_pred = self.pts_to_bbox(pts_init_coordinate)

            foreground_idxs = (pts_init_labels_targets >= 0) & \
                (pts_init_labels_targets != self.num_classes)

            pred_init.append(init_bbox_pred[foreground_idxs] /
                             normalize_term[foreground_idxs])
            target_init.append(pts_init_bbox_targets[foreground_idxs] /
                               normalize_term[foreground_idxs])
            num_pos_init += foreground_idxs.sum()

            # A another way to convert predicted offset to bbox
            # bbox_pred_init = self.pts_to_bbox(pts_init.detach()) * \
            #     per_pts_strides
            # init_bbox_pred = bbox_center + bbox_pred_init

            pts_refine_bbox_targets, pts_refine_labels_targets = \
                self.bbox_targets(init_bbox_pred, gt_bboxes, gt_labels)

            pts_refine_coordinate = pts_refine * per_pts_strides + \
                per_center_pts_repeat
            refine_bbox_pred = self.pts_to_bbox(pts_refine_coordinate)

            # bbox_pred_refine = self.pts_to_bbox(pts_refine) * per_pts_strides
            # refine_bbox_pred = bbox_center + bbox_pred_refine

            foreground_idxs = (pts_refine_labels_targets >= 0) & \
                (pts_refine_labels_targets != self.num_classes)

            pred_refine.append(refine_bbox_pred[foreground_idxs] /
                               normalize_term[foreground_idxs])
            target_refine.append(pts_refine_bbox_targets[foreground_idxs] /
                                 normalize_term[foreground_idxs])
            num_pos_refine += foreground_idxs.sum()

            gt_classes_target = torch.zeros_like(cls_prob)
            gt_classes_target[foreground_idxs,
                              pts_refine_labels_targets[foreground_idxs]] = 1
            pred_cls.append(cls_prob)
            target_cls.append(gt_classes_target)

        pred_cls = torch.cat(pred_cls, dim=0)
        pred_init = torch.cat(pred_init, dim=0)
        pred_refine = torch.cat(pred_refine, dim=0)

        target_cls = torch.cat(target_cls, dim=0)
        target_init = torch.cat(target_init, dim=0)
        target_refine = torch.cat(target_refine, dim=0)

        loss_cls = sigmoid_focal_loss_jit(
            pred_cls,
            target_cls,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum") / max(
                1, num_pos_refine.item()) * self.loss_cls_weight

        loss_pts_init = smooth_l1_loss(
            pred_init, target_init, beta=0.11, reduction='sum') / max(
                1, num_pos_init.item()) * self.loss_bbox_init_weight

        loss_pts_refine = smooth_l1_loss(
            pred_refine, target_refine, beta=0.11, reduction='sum') / max(
                1, num_pos_refine.item()) * self.loss_bbox_refine_weight

        return {
            "loss_cls": loss_cls,
            "loss_pts_init": loss_pts_init,
            "loss_pts_refine": loss_pts_refine
        }
예제 #9
0
    def losses(self, anchors, gt_instances, box_cls, box_delta):
        anchors = [Boxes.cat(anchors_i) for anchors_i in anchors]

        box_cls_flattened = [
            permute_to_N_HWA_K(x, self.num_classes) for x in box_cls
        ]
        box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta]
        pred_class_logits = cat(box_cls_flattened, dim=1)
        pred_anchor_deltas = cat(box_delta_flattened, dim=1)

        pred_class_probs = pred_class_logits.sigmoid()
        pred_box_probs = []
        num_foreground = 0
        positive_losses = []
        for anchors_per_image, \
            gt_instances_per_image, \
            pred_class_probs_per_image, \
            pred_anchor_deltas_per_image in zip(
                anchors, gt_instances, pred_class_probs, pred_anchor_deltas):
            gt_classes_per_image = gt_instances_per_image.gt_classes

            with torch.no_grad():
                # predicted_boxes_per_image: a_{j}^{loc}, shape: [j, 4]
                predicted_boxes_per_image = self.box2box_transform.apply_deltas(
                    pred_anchor_deltas_per_image, anchors_per_image.tensor)
                # gt_pred_iou: IoU_{ij}^{loc}, shape: [i, j]
                gt_pred_iou = pairwise_iou(gt_instances_per_image.gt_boxes,
                                           Boxes(predicted_boxes_per_image))

                t1 = self.bbox_threshold
                t2 = gt_pred_iou.max(dim=1, keepdim=True).values.clamp_(
                    min=t1 + torch.finfo(torch.float32).eps)
                # gt_pred_prob: P{a_{j} -> b_{i}}, shape: [i, j]
                gt_pred_prob = ((gt_pred_iou - t1) / (t2 - t1)).clamp_(min=0,
                                                                       max=1)

                # pred_box_prob_per_image: P{a_{j} \in A_{+}}, shape: [j, c]
                nonzero_idxs = torch.nonzero(gt_pred_prob, as_tuple=True)
                pred_box_prob_per_image = torch.zeros_like(
                    pred_class_probs_per_image)
                pred_box_prob_per_image[nonzero_idxs[1], gt_classes_per_image[nonzero_idxs[0]]] \
                    = gt_pred_prob[nonzero_idxs]
                pred_box_probs.append(pred_box_prob_per_image)

            # construct bags for objects
            match_quality_matrix = pairwise_iou(
                gt_instances_per_image.gt_boxes, anchors_per_image)
            _, foreground_idxs = torch.topk(match_quality_matrix,
                                            self.pos_anchor_topk,
                                            dim=1,
                                            sorted=False)

            # matched_pred_class_probs_per_image: P_{ij}^{cls}
            matched_pred_class_probs_per_image = torch.gather(
                pred_class_probs_per_image[foreground_idxs], 2,
                gt_classes_per_image.view(-1, 1,
                                          1).repeat(1, self.pos_anchor_topk,
                                                    1)).squeeze(2)

            # matched_gt_anchor_deltas_per_image: P_{ij}^{loc}
            matched_gt_anchor_deltas_per_image = self.box2box_transform.get_deltas(
                anchors_per_image.tensor[foreground_idxs],
                gt_instances_per_image.gt_boxes.tensor.unsqueeze(1))
            loss_box_reg = smooth_l1_loss(
                pred_anchor_deltas_per_image[foreground_idxs],
                matched_gt_anchor_deltas_per_image,
                beta=self.smooth_l1_loss_beta,
                reduction="none").sum(dim=-1) * self.reg_weight
            matched_pred_reg_probs_per_image = (-loss_box_reg).exp()

            # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) }
            num_foreground += len(gt_instances_per_image)
            positive_losses.append(
                positive_bag_loss(matched_pred_class_probs_per_image *
                                  matched_pred_reg_probs_per_image,
                                  dim=1))

        # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B||
        positive_loss = torch.cat(positive_losses).sum() / max(
            1, num_foreground)

        # pred_box_probs: P{a_{j} \in A_{+}}
        pred_box_probs = torch.stack(pred_box_probs, dim=0)
        # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B||
        negative_loss = negative_bag_loss(
            pred_class_probs *
            (1 - pred_box_probs), self.focal_loss_gamma).sum() / max(
                1, num_foreground * self.pos_anchor_topk)

        loss_pos = positive_loss * self.focal_loss_alpha
        loss_neg = negative_loss * (1 - self.focal_loss_alpha)

        return {"loss_pos": loss_pos, "loss_neg": loss_neg}