Exemplo n.º 1
0
    def aux_losses(self, gt_classes, pred_class_logits):
        pred_class_logits = cat([
            permute_to_N_HWA_K(x, self.num_classes) for x in pred_class_logits
        ],
                                dim=1).view(-1, self.num_classes)

        gt_classes = gt_classes.flatten()

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())

        # logits loss
        loss_cls_aux = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        return {"loss_cls_aux": loss_cls_aux}
Exemplo n.º 2
0
Arquivo: fcos.py Projeto: zyg11/DeFCN
    def losses(self, gt_classes, gt_shifts_deltas, pred_class_logits,
               pred_shift_deltas, pred_filtering):
        """
        Args:
            For `gt_classes` and `gt_shifts_deltas` parameters, see
                :meth:`FCOS.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_fitering`, see
                :meth:`FCOSHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_shift_deltas, pred_filtering = \
            permute_all_cls_and_box_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_filtering,
                self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size())

        pred_class_logits = pred_class_logits.sigmoid() * pred_filtering.sigmoid()

        # logits loss
        loss_cls = focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1.0, num_foreground) * self.reg_weight

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
        }
Exemplo n.º 3
0
    def losses(
        self,
        gt_classes,
        gt_shifts_deltas,
        gt_centerness,
        gt_classes_border,
        gt_deltas_border,
        pred_class_logits,
        pred_shift_deltas,
        pred_centerness,
        border_box_cls,
        border_bbox_reg,
    ):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see
                :meth:`BorderDet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see
                :meth:`BorderHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        (
            pred_class_logits,
            pred_shift_deltas,
            pred_centerness,
            border_class_logits,
            border_shift_deltas,
        ) = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_class_logits, pred_shift_deltas, pred_centerness,
            border_box_cls, border_bbox_reg, self.num_classes
        )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        # fcos
        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_centerness = gt_centerness.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness) / float(
            comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            gt_centerness[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1.0, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1.0, num_foreground)

        # borderdet
        gt_classes_border = gt_classes_border.flatten()
        gt_deltas_border = gt_deltas_border.view(-1, 4)

        valid_idxs_border = gt_classes_border >= 0
        foreground_idxs_border = (gt_classes_border >=
                                  0) & (gt_classes_border != self.num_classes)
        num_foreground_border = foreground_idxs_border.sum()

        gt_classes_border_target = torch.zeros_like(border_class_logits)
        gt_classes_border_target[foreground_idxs_border,
                                 gt_classes_border[foreground_idxs_border]] = 1

        num_foreground_border = (comm.all_reduce(num_foreground_border) /
                                 float(comm.get_world_size()))

        num_foreground_border = max(num_foreground_border, 1.0)
        loss_border_cls = sigmoid_focal_loss_jit(
            border_class_logits[valid_idxs_border],
            gt_classes_border_target[valid_idxs_border],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_foreground_border

        if foreground_idxs_border.numel() > 0:
            loss_border_reg = (
                smooth_l1_loss(border_shift_deltas[foreground_idxs_border],
                               gt_deltas_border[foreground_idxs_border],
                               beta=0,
                               reduction="sum") / num_foreground_border)
        else:
            loss_border_reg = border_shift_deltas.sum()

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
            "loss_border_cls": loss_border_cls,
            "loss_border_reg": loss_border_reg,
        }
Exemplo n.º 4
0
    def losses(self, gt_classes, gt_shifts_deltas, gt_centerness,
               pred_class_logits, pred_shift_deltas, pred_centerness):
        """
        Args:
            For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see
                :meth:`FCOS.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of shifts across levels, i.e. sum(Hi x Wi)
            For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see
                :meth:`FCOSHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        pred_class_logits, pred_shift_deltas, pred_centerness = \
            permute_all_cls_and_box_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_centerness,
                self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.view(-1, 4)
        gt_centerness = gt_centerness.view(-1, 1)

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness)  / float(comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(
            pred_shift_deltas[foreground_idxs],
            gt_shifts_deltas[foreground_idxs],
            gt_centerness[foreground_idxs],
            box_mode="ltrb",
            loss_type=self.iou_loss_type,
            reduction="sum",
        ) / max(1.0, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        loss = {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
        }

        # budget loss
        if self.is_dynamic_head and self.budget_loss_lambda != 0:
            soft_cost, used_cost, full_cost = get_module_running_cost(self)
            loss_budget = (soft_cost / full_cost).mean() * self.budget_loss_lambda
            storage = get_event_storage()
            storage.put_scalar("complxity_ratio", (used_cost / full_cost).mean())
            loss.update({"loss_budget": loss_budget})

        return loss
Exemplo n.º 5
0
    def proposals_losses(self, gt_classes, gt_shifts_deltas, gt_centerness,
                         gt_inds, im_inds, pred_class_logits,
                         pred_shift_deltas, pred_centerness, pred_inst_params,
                         fpn_levels, shifts):

        pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params = \
            permute_all_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_centerness,
                pred_inst_params, self.num_gen_params, self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.reshape(-1, 4)
        gt_centerness = gt_centerness.reshape(-1, 1)
        fpn_levels = fpn_levels.flatten()
        im_inds = im_inds.flatten()
        gt_inds = gt_inds.flatten()

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness) / float(
            comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(pred_shift_deltas[foreground_idxs],
                                gt_shifts_deltas[foreground_idxs],
                                gt_centerness[foreground_idxs],
                                box_mode="ltrb",
                                loss_type=self.iou_loss_type,
                                reduction="sum",
                                smooth=self.iou_smooth) / max(
                                    1e-6, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        proposals_losses = {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness
        }

        all_shifts = torch.cat([torch.cat(shift) for shift in shifts])
        proposals = Instances((0, 0))
        proposals.inst_parmas = pred_inst_params[foreground_idxs]
        proposals.fpn_levels = fpn_levels[foreground_idxs]
        proposals.shifts = all_shifts[foreground_idxs]
        proposals.gt_inds = gt_inds[foreground_idxs]
        proposals.im_inds = im_inds[foreground_idxs]

        # select_instances for saving memory
        if len(proposals):
            if self.topk_proposals_per_im != -1:
                proposals.gt_cls = gt_classes[foreground_idxs]
                proposals.pred_logits = pred_class_logits[foreground_idxs]
                proposals.pred_centerness = pred_centerness[foreground_idxs]
            proposals = self.select_instances(proposals)

        return proposals_losses, proposals