Пример #1
0
    def losses(self, gt_classes, gt_shifts_deltas, pred_class_logits,
               pred_shift_deltas, pred_filtering):
        """
        Args:
            For `gt_classes` and `gt_shifts_deltas` parameters, see
                :meth:`FCOS.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_fitering`, see
                :meth:`FCOSHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_shift_deltas, pred_filtering = \
            permute_all_cls_and_box_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_filtering,
                self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size())

        pred_class_logits = pred_class_logits.sigmoid() * pred_filtering.sigmoid()

        # logits loss
        loss_cls = focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1.0, num_foreground) * self.reg_weight

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
        }
Пример #2
0
    def losses(
        self,
        gt_classes,
        gt_shifts_deltas,
        gt_centerness,
        gt_classes_border,
        gt_deltas_border,
        pred_class_logits,
        pred_shift_deltas,
        pred_centerness,
        border_box_cls,
        border_bbox_reg,
    ):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see
                :meth:`BorderDet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see
                :meth:`BorderHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        (
            pred_class_logits,
            pred_shift_deltas,
            pred_centerness,
            border_class_logits,
            border_shift_deltas,
        ) = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_shift_deltas, pred_centerness,
            border_box_cls, border_bbox_reg, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        # fcos
        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_centerness = gt_centerness.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()
        acc_centerness_num = gt_centerness[foreground_idxs].sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        dist.all_reduce(num_foreground)
        num_foreground /= dist.get_world_size()
        dist.all_reduce(acc_centerness_num)
        acc_centerness_num /= dist.get_world_size()

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            gt_centerness[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1, acc_centerness_num)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        # borderdet
        gt_classes_border = gt_classes_border.flatten()
        gt_deltas_border = gt_deltas_border.view(-1, 4)

        valid_idxs_border = gt_classes_border >= 0
        foreground_idxs_border = (gt_classes_border >=
                                  0) & (gt_classes_border != self.num_classes)
        num_foreground_border = foreground_idxs_border.sum()

        gt_classes_border_target = torch.zeros_like(border_class_logits)
        gt_classes_border_target[foreground_idxs_border,
                                 gt_classes_border[foreground_idxs_border]] = 1

        dist.all_reduce(num_foreground_border)
        num_foreground_border /= dist.get_world_size()

        num_foreground_border = max(num_foreground_border, 1.0)
        loss_border_cls = sigmoid_focal_loss_jit(
            border_class_logits[valid_idxs_border],
            gt_classes_border_target[valid_idxs_border],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_foreground_border

        if foreground_idxs_border.numel() > 0:
            loss_border_reg = (
                smooth_l1_loss(border_shift_deltas[foreground_idxs_border],
                               gt_deltas_border[foreground_idxs_border],
                               beta=0,
                               reduction="sum") / num_foreground_border)
        else:
            loss_border_reg = border_shift_deltas.sum()

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
            "loss_border_cls": loss_border_cls,
            "loss_border_reg": loss_border_reg,
        }
Пример #3
0
    def losses(self, gt_classes, gt_shifts_deltas, gt_centerness,
               pred_class_logits, pred_shift_deltas, pred_centerness):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see
                :meth:`FCOS.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see
                :meth:`FCOSHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_shift_deltas, pred_centerness = \
            permute_all_cls_and_box_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_centerness,
                self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_centerness = gt_centerness.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness)  / float(comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            gt_centerness[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1.0, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        loss = {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
        }

        # budget loss
        if self.is_dynamic_head and self.budget_loss_lambda != 0:
            soft_cost, used_cost, full_cost = get_module_running_cost(self)
            loss_budget = (soft_cost / full_cost).mean() * self.budget_loss_lambda
            storage = get_event_storage()
            storage.put_scalar("complxity_ratio", (used_cost / full_cost).mean())
            loss.update({"loss_budget": loss_budget})

        return loss
Пример #4
0
    def proposals_losses(self, gt_classes, gt_shifts_deltas, gt_centerness,
                         gt_inds, im_inds, pred_class_logits,
                         pred_shift_deltas, pred_centerness, pred_inst_params,
                         fpn_levels, shifts):

        pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params = \
            permute_all_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_centerness,
                pred_inst_params, self.num_gen_params, self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.reshape(-1, 4)
        gt_centerness = gt_centerness.reshape(-1, 1)
        fpn_levels = fpn_levels.flatten()
        im_inds = im_inds.flatten()
        gt_inds = gt_inds.flatten()

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness) / float(
            comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(pred_shift_deltas[foreground_idxs],
                                gt_shifts_deltas[foreground_idxs],
                                gt_centerness[foreground_idxs],
                                box_mode="ltrb",
                                loss_type=self.iou_loss_type,
                                reduction="sum",
                                smooth=self.iou_smooth) / max(
                                    1e-6, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        proposals_losses = {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness
        }

        all_shifts = torch.cat([torch.cat(shift) for shift in shifts])
        proposals = Instances((0, 0))
        proposals.inst_parmas = pred_inst_params[foreground_idxs]
        proposals.fpn_levels = fpn_levels[foreground_idxs]
        proposals.shifts = all_shifts[foreground_idxs]
        proposals.gt_inds = gt_inds[foreground_idxs]
        proposals.im_inds = im_inds[foreground_idxs]

        # select_instances for saving memory
        if len(proposals):
            if self.topk_proposals_per_im != -1:
                proposals.gt_cls = gt_classes[foreground_idxs]
                proposals.pred_logits = pred_class_logits[foreground_idxs]
                proposals.pred_centerness = pred_centerness[foreground_idxs]
            proposals = self.select_instances(proposals)

        return proposals_losses, proposals
Пример #5
0
    def get_lla_assignments_and_losses(self, shifts, targets, box_cls,
                                       box_delta, box_iou):

        gt_classes = []

        box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]
        box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]
        box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]

        box_cls = torch.cat(box_cls, dim=1)
        box_delta = torch.cat(box_delta, dim=1)
        box_iou = torch.cat(box_iou, dim=1)

        losses_cls = []
        losses_box_reg = []
        losses_iou = []

        num_fg = 0

        for shifts_per_image, targets_per_image, box_cls_per_image, \
                box_delta_per_image, box_iou_per_image in zip(
                shifts, targets, box_cls, box_delta, box_iou):

            shifts_over_all = torch.cat(shifts_per_image, dim=0)

            gt_boxes = targets_per_image.gt_boxes
            gt_classes = targets_per_image.gt_classes

            deltas = self.shift2box_transform.get_deltas(
                shifts_over_all, gt_boxes.tensor.unsqueeze(1))
            is_in_boxes = deltas.min(dim=-1).values > 0.01

            shape = (len(targets_per_image), len(shifts_over_all), -1)
            box_cls_per_image_unexpanded = box_cls_per_image
            box_delta_per_image_unexpanded = box_delta_per_image

            box_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape)
            gt_cls_per_image = F.one_hot(
                torch.max(gt_classes, torch.zeros_like(gt_classes)),
                self.num_classes).float().unsqueeze(1).expand(shape)

            with torch.no_grad():
                loss_cls = sigmoid_focal_loss_jit(
                    box_cls_per_image,
                    gt_cls_per_image,
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma).sum(dim=-1)
                loss_cls_bg = sigmoid_focal_loss_jit(
                    box_cls_per_image_unexpanded,
                    torch.zeros_like(box_cls_per_image_unexpanded),
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma).sum(dim=-1)
                box_delta_per_image = box_delta_per_image.unsqueeze(0).expand(
                    shape)
                gt_delta_per_image = self.shift2box_transform.get_deltas(
                    shifts_over_all, gt_boxes.tensor.unsqueeze(1))
                loss_delta = iou_loss(box_delta_per_image,
                                      gt_delta_per_image,
                                      box_mode="ltrb",
                                      loss_type='iou')

                ious = get_ious(box_delta_per_image,
                                gt_delta_per_image,
                                box_mode="ltrb",
                                loss_type='iou')

                loss = loss_cls + self.reg_cost * loss_delta + 1e3 * (
                    1 - is_in_boxes.float())
                loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)

                num_gt = loss.shape[0] - 1
                num_anchor = loss.shape[1]

                # Topk
                matching_matrix = torch.zeros_like(loss)
                _, topk_idx = torch.topk(loss[:-1],
                                         k=self.topk,
                                         dim=1,
                                         largest=False)
                matching_matrix[torch.arange(num_gt).unsqueeze(1).
                                repeat(1, self.topk).view(-1),
                                topk_idx.view(-1)] = 1.

                # make sure one anchor with one gt
                anchor_matched_gt = matching_matrix.sum(0)
                if (anchor_matched_gt > 1).sum() > 0:
                    loss_min, loss_argmin = torch.min(
                        loss[:-1, anchor_matched_gt > 1], dim=0)
                    matching_matrix[:, anchor_matched_gt > 1] *= 0.
                    matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1.
                    anchor_matched_gt = matching_matrix.sum(0)
                num_fg += matching_matrix.sum()
                matching_matrix[
                    -1] = 1. - anchor_matched_gt  # assignment for Background
                assigned_gt_inds = torch.argmax(matching_matrix, dim=0)

                gt_cls_per_image_bg = gt_cls_per_image.new_zeros(
                    (gt_cls_per_image.size(1),
                     gt_cls_per_image.size(2))).unsqueeze(0)
                gt_cls_per_image_with_bg = torch.cat(
                    [gt_cls_per_image, gt_cls_per_image_bg], dim=0)
                cls_target_per_image = gt_cls_per_image_with_bg[
                    assigned_gt_inds,
                    torch.arange(num_anchor)]

                # Dealing with Crowdhuman ignore label
                gt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)])
                anchor_cls_labels = gt_classes_[assigned_gt_inds]
                valid_flag = anchor_cls_labels >= 0

                pos_mask = assigned_gt_inds != len(
                    targets_per_image)  # get foreground mask
                valid_fg = pos_mask & valid_flag
                assigned_fg_inds = assigned_gt_inds[valid_fg]
                range_fg = torch.arange(num_anchor)[valid_fg]
                ious_fg = ious[assigned_fg_inds, range_fg]

            anchor_loss_cls = sigmoid_focal_loss_jit(
                box_cls_per_image_unexpanded[valid_flag],
                cls_target_per_image[valid_flag],
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma).sum(dim=-1)

            delta_target = gt_delta_per_image[assigned_fg_inds, range_fg]
            anchor_loss_delta = 2. * iou_loss(
                box_delta_per_image_unexpanded[valid_fg],
                delta_target,
                box_mode="ltrb",
                loss_type=self.iou_loss_type)

            anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits(
                box_iou_per_image.squeeze(1)[valid_fg],
                ious_fg,
                reduction='none')

            losses_cls.append(anchor_loss_cls.sum())
            losses_box_reg.append(anchor_loss_delta.sum())
            losses_iou.append(anchor_loss_iou.sum())

            if self.norm_sync:
                dist.all_reduce(num_fg)
                num_fg = num_fg.float() / dist.get_world_size()

        return {
            'loss_cls': torch.stack(losses_cls).sum() / num_fg,
            'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg,
            'loss_iou': torch.stack(losses_iou).sum() / num_fg
        }
Пример #6
0
    def losses(self, shifts, gt_instances, box_cls, box_delta, box_center):
        box_cls_flattened = [
            permute_to_N_HWA_K(x, self.num_classes) for x in box_cls
        ]
        box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta]
        box_center_flattened = [permute_to_N_HWA_K(x, 1) for x in box_center]
        pred_class_logits = cat(box_cls_flattened, dim=1)
        pred_shift_deltas = cat(box_delta_flattened, dim=1)
        pred_obj_logits = cat(box_center_flattened, dim=1)

        pred_class_probs = pred_class_logits.sigmoid()
        pred_obj_probs = pred_obj_logits.sigmoid()
        pred_box_probs = []
        num_foreground = pred_class_logits.new_zeros(1)
        num_background = pred_class_logits.new_zeros(1)
        positive_losses = []
        gaussian_norm_losses = []

        for shifts_per_image, gt_instances_per_image, \
            pred_class_probs_per_image, pred_shift_deltas_per_image, \
            pred_obj_probs_per_image in zip(
                shifts, gt_instances, pred_class_probs, pred_shift_deltas,
                pred_obj_probs):
            locations = torch.cat(shifts_per_image, dim=0)
            labels = gt_instances_per_image.gt_classes
            gt_boxes = gt_instances_per_image.gt_boxes

            target_shift_deltas = self.shift2box_transform.get_deltas(
                locations, gt_boxes.tensor.unsqueeze(1))
            is_in_boxes = target_shift_deltas.min(dim=-1).values > 0

            foreground_idxs = torch.nonzero(is_in_boxes, as_tuple=True)

            with torch.no_grad():
                # predicted_boxes_per_image: a_{j}^{loc}, shape: [j, 4]
                predicted_boxes_per_image = self.shift2box_transform.apply_deltas(
                    pred_shift_deltas_per_image, locations)
                # gt_pred_iou: IoU_{ij}^{loc}, shape: [i, j]
                gt_pred_iou = pairwise_iou(
                    gt_boxes, Boxes(predicted_boxes_per_image)).max(
                        dim=0, keepdim=True).values.repeat(
                            len(gt_instances_per_image), 1)

                # pred_box_prob_per_image: P{a_{j} \in A_{+}}, shape: [j, c]
                pred_box_prob_per_image = torch.zeros_like(
                    pred_class_probs_per_image)
                box_prob = 1 / (1 - gt_pred_iou[foreground_idxs]).clamp_(1e-12)
                for i in range(len(gt_instances_per_image)):
                    idxs = foreground_idxs[0] == i
                    if idxs.sum() > 0:
                        box_prob[idxs] = normalize(box_prob[idxs])
                pred_box_prob_per_image[foreground_idxs[1],
                                        labels[foreground_idxs[0]]] = box_prob
                pred_box_probs.append(pred_box_prob_per_image)

            normal_probs = []
            for stride, shifts_i in zip(self.fpn_strides, shifts_per_image):
                gt_shift_deltas = self.shift2box_transform.get_deltas(
                    shifts_i, gt_boxes.tensor.unsqueeze(1))
                distances = (gt_shift_deltas[..., :2] -
                             gt_shift_deltas[..., 2:]) / 2
                normal_probs.append(
                    normal_distribution(distances / stride,
                                        self.mu[labels].unsqueeze(1),
                                        self.sigma[labels].unsqueeze(1)))
            normal_probs = torch.cat(normal_probs, dim=1).prod(dim=-1)

            composed_cls_prob = pred_class_probs_per_image[:,
                                                           labels] * pred_obj_probs_per_image

            # matched_gt_shift_deltas: P_{ij}^{loc}
            loss_box_reg = iou_loss(pred_shift_deltas_per_image.unsqueeze(0),
                                    target_shift_deltas,
                                    box_mode="ltrb",
                                    loss_type=self.iou_loss_type,
                                    reduction="none") * self.reg_weight
            pred_reg_probs = (-loss_box_reg).exp()

            # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) }
            positive_losses.append(
                positive_bag_loss(
                    composed_cls_prob.permute(1, 0) * pred_reg_probs,
                    is_in_boxes.float(), normal_probs))

            num_foreground += len(gt_instances_per_image)
            num_background += normal_probs[foreground_idxs].sum().item()

            gaussian_norm_losses.append(
                len(gt_instances_per_image) /
                normal_probs[foreground_idxs].sum().clamp_(1e-12))

        if dist.is_initialized():
            dist.all_reduce(num_foreground)
            num_foreground /= dist.get_world_size()
            dist.all_reduce(num_background)
            num_background /= dist.get_world_size()

        # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B||
        positive_loss = torch.cat(positive_losses).sum() / max(
            1, num_foreground)

        # pred_box_probs: P{a_{j} \in A_{+}}
        pred_box_probs = torch.stack(pred_box_probs, dim=0)
        # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B||
        negative_loss = negative_bag_loss(
            pred_class_probs * pred_obj_probs * (1 - pred_box_probs),
            self.focal_loss_gamma).sum() / max(1, num_background)

        loss_pos = positive_loss * self.focal_loss_alpha
        loss_neg = negative_loss * (1 - self.focal_loss_alpha)
        loss_norm = torch.stack(gaussian_norm_losses).mean() * (
            1 - self.focal_loss_alpha)

        return {
            "loss_pos": loss_pos,
            "loss_neg": loss_neg,
            "loss_norm": loss_norm,
        }
Пример #7
0
    def get_ground_truth(self, shifts, targets, box_cls, box_delta):
        """
        Args:
            shifts (list[list[Tensor]]): a list of N=#image elements. Each is a
                list of #feature level tensors. The tensors contains shifts of
                this image on the specific feature level.
            targets (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.  Specify `targets` during training only.

        Returns:
            gt_classes (Tensor):
                An integer tensor of shape (N, R) storing ground-truth
                labels for each shift.
                R is the total number of shifts, i.e. the sum of Hi x Wi for all levels.
                Shifts in the valid boxes are assigned their corresponding label in the
                [0, K-1] range. Shifts in the background are assigned the label "K".
                Shifts in the ignore areas are assigned a label "-1", i.e. ignore.
            gt_shifts_deltas (Tensor):
                Shape (N, R, 4).
                The last dimension represents ground-truth shift2box transform
                targets (dl, dt, dr, db) that map each shift to its matched ground-truth box.
                The values in the tensor are meaningful only when the corresponding
                shift is labeled as foreground.
        """
        gt_classes = []
        gt_shifts_deltas = []

        box_cls = torch.cat(
            [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls], dim=1)
        box_delta = torch.cat([permute_to_N_HWA_K(x, 4) for x in box_delta],
                              dim=1)

        num_fg = 0
        num_gt = 0

        for shifts_per_image, targets_per_image, box_cls_per_image, box_delta_per_image in zip(
                shifts, targets, box_cls, box_delta):
            shifts_over_all_feature_maps = torch.cat(shifts_per_image, dim=0)

            gt_boxes = targets_per_image.gt_boxes

            shape = (len(targets_per_image), len(shifts_over_all_feature_maps),
                     -1)

            gt_cls_per_image = F.one_hot(targets_per_image.gt_classes,
                                         self.num_classes).float()
            loss_cls = sigmoid_focal_loss_jit(
                box_cls_per_image.unsqueeze(0).expand(shape),
                gt_cls_per_image.unsqueeze(1).expand(shape),
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
            ).sum(dim=2)

            gt_delta_per_image = self.shift2box_transform.get_deltas(
                shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1))
            loss_delta = iou_loss(
                box_delta_per_image.unsqueeze(0).expand(shape),
                gt_delta_per_image,
                box_mode="ltrb",
                loss_type=self.iou_loss_type,
            ) * self.reg_weight

            loss = loss_cls + loss_delta

            INF = 1e8

            deltas = self.shift2box_transform.get_deltas(
                shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1))

            if self.center_sampling_radius > 0:
                centers = gt_boxes.get_centers()
                is_in_boxes = []
                for stride, shifts_i in zip(self.fpn_strides,
                                            shifts_per_image):
                    radius = stride * self.center_sampling_radius
                    center_boxes = torch.cat((
                        torch.max(centers - radius, gt_boxes.tensor[:, :2]),
                        torch.min(centers + radius, gt_boxes.tensor[:, 2:]),
                    ),
                                             dim=-1)
                    center_deltas = self.shift2box_transform.get_deltas(
                        shifts_i, center_boxes.unsqueeze(1))
                    is_in_boxes.append(center_deltas.min(dim=-1).values > 0)
                is_in_boxes = torch.cat(is_in_boxes, dim=1)
            else:
                # no center sampling, it will use all the locations within a ground-truth box
                is_in_boxes = deltas.min(dim=-1).values > 0

            loss[~is_in_boxes] = INF

            gt_idxs, shift_idxs = linear_sum_assignment(loss.cpu().numpy())

            num_fg += len(shift_idxs)
            num_gt += len(targets_per_image)

            gt_classes_i = shifts_over_all_feature_maps.new_full(
                (len(shifts_over_all_feature_maps), ),
                self.num_classes,
                dtype=torch.long)
            gt_shifts_reg_deltas_i = shifts_over_all_feature_maps.new_zeros(
                len(shifts_over_all_feature_maps), 4)
            if len(targets_per_image) > 0:
                # ground truth classes
                gt_classes_i[shift_idxs] = targets_per_image.gt_classes[
                    gt_idxs]
                # ground truth box regression
                gt_shifts_reg_deltas_i[
                    shift_idxs] = self.shift2box_transform.get_deltas(
                        shifts_over_all_feature_maps[shift_idxs],
                        gt_boxes[gt_idxs].tensor)

            gt_classes.append(gt_classes_i)
            gt_shifts_deltas.append(gt_shifts_reg_deltas_i)

        get_event_storage().put_scalar("num_fg_per_gt", num_fg / num_gt)

        return torch.stack(gt_classes), torch.stack(gt_shifts_deltas)
Пример #8
0
    def losses(self, gt_classes, gt_shifts_deltas, gt_ious, pred_class_logits,
               pred_shift_deltas, pred_ious):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_ious` parameters, see
                :meth:`FCOS.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_ious`, see
                :meth:`FCOSHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_shift_deltas, pred_ious = \
            permute_all_cls_and_box_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_ious,
                self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_ious = gt_ious.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)

        num_foreground = foreground_idxs.sum()
        num_target = gt_ious[foreground_idxs].sum()
        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        if self.norm_sync:
            dist.all_reduce(num_foreground)
            num_foreground /= dist.get_world_size()

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1, num_foreground)

        # regression loss
        loss_box_reg = 2. * iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1, num_foreground)

        # iou branch loss
        loss_iou = 0.5 * F.binary_cross_entropy_with_logits(
            pred_ious[foreground_idxs],
            gt_ious[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_iou": loss_iou
        }