Example #1
0
    def __call__(self, mask_logits):
        """
        Arguments:
            mask_logits (Tensor)
        Return:
            mask_loss (Tensor): scalar tensor containing the loss
            If we use maskiou head, we will return extra feature for maskiou head.
        """
        labels = [
            proposals_per_img.get_field("labels")
            for proposals_per_img in self.positive_proposals
        ]
        mask_targets = [
            proposals_per_img.get_field("mask_targets")
            for proposals_per_img in self.positive_proposals
        ]
        labels = cat(labels, dim=0)
        mask_targets = cat(mask_targets, dim=0)

        positive_inds = torch.nonzero(labels > 0).squeeze(1)
        labels_pos = labels[positive_inds]

        # torch.mean (in binary_cross_entropy_with_logits) doesn't
        # accept empty tensors, so handle it separately
        if mask_targets.numel() == 0:
            return mask_logits.sum() * 0

        mask_loss = F.binary_cross_entropy_with_logits(
            mask_logits[positive_inds, labels_pos], mask_targets)
        mask_loss *= cfg.MRCNN.LOSS_WEIGHT

        return mask_loss
Example #2
0
    def __call__(self, keypoint_logits):
        heatmaps = []
        valid = []
        for proposals_per_image in self.positive_proposals:
            kp = proposals_per_image.get_field("keypoints_target")
            heatmaps_per_image, valid_per_image = project_keypoints_to_heatmap(
                kp, proposals_per_image, self.resolution)
            heatmaps.append(heatmaps_per_image.view(-1))
            valid.append(valid_per_image.view(-1))

        keypoint_targets = cat(heatmaps, dim=0)
        valid = cat(valid, dim=0).to(dtype=torch.uint8)
        valid = torch.nonzero(valid).squeeze(1)

        # torch.mean (in binary_cross_entropy_with_logits) does'nt
        # accept empty tensors, so handle it sepaartely
        if keypoint_targets.numel() == 0 or len(valid) == 0:
            return keypoint_logits.sum() * 0

        N, K, H, W = keypoint_logits.shape
        keypoint_logits = keypoint_logits.view(N * K, H * W)

        keypoint_loss = F.cross_entropy(keypoint_logits[valid],
                                        keypoint_targets[valid])
        keypoint_loss *= cfg.KRCNN.LOSS_WEIGHT
        return keypoint_loss
Example #3
0
 def convert_to_roi_format(self, boxes):
     concat_boxes = cat([b.bbox for b in boxes], dim=0)
     device, dtype = concat_boxes.device, concat_boxes.dtype
     ids = cat(
         [
             torch.full((len(b), 1), i, dtype=dtype, device=device)
             for i, b in enumerate(boxes)
         ],
         dim=0,
     )
     rois = torch.cat([ids, concat_boxes], dim=1)
     return rois
Example #4
0
    def __call__(self, class_logits, box_regression):
        """
        Computes the loss for Faster R-CNN.
        This requires that the subsample method has been called beforehand.

        Arguments:
            class_logits (list[Tensor])
            box_regression (list[Tensor])

        Returns:
            classification_loss (Tensor)
            box_loss (Tensor)
        """
        loss_dict = {}

        if not hasattr(self, "_proposals"):
            raise RuntimeError("subsample needs to be called before")

        proposals = self._proposals
        labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)

        assert class_logits[0] is not None or box_regression[0] is not None, 'Fast R-CNN should keep 1 branch at least'

        if class_logits[0] is not None:
            class_logits = cat(class_logits, dim=0)
            classification_loss = F.cross_entropy(class_logits, labels)
            loss_dict["loss_classifier"] = classification_loss

        if box_regression[0] is not None:
            box_regression = cat(box_regression, dim=0)
            device = box_regression.device
            regression_targets = cat([proposal.get_field("regression_targets") for proposal in proposals], dim=0)

            # get indices that correspond to the regression targets for
            # the corresponding ground truth labels, to be used with
            # advanced indexing
            sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
            labels_pos = labels[sampled_pos_inds_subset]
            if self.cls_agnostic_bbox_reg:
                map_inds = torch.tensor([4, 5, 6, 7], device=device)
            else:
                map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device)

            box_loss = smooth_l1_loss(
                box_regression[sampled_pos_inds_subset[:, None], map_inds],
                regression_targets[sampled_pos_inds_subset],
                size_average=False,
                beta=cfg.FAST_RCNN.SMOOTH_L1_BETA,
            )
            box_loss = box_loss / labels.numel()
            loss_dict["loss_box_reg"] = box_loss
        return loss_dict
Example #5
0
    def __call__(self, parsing_logits):
        parsing_targets = [proposals_per_img.get_field("parsing_targets") for proposals_per_img in self.positive_proposals]
        parsing_targets = cat(parsing_targets, dim=0)

        if parsing_targets.numel() == 0:
            if not self.parsingiou_on:
                return parsing_logits.sum() * 0
            else:
                return parsing_logits.sum() * 0, None

        if self.parsingiou_on:
            # TODO: use tensor for speeding up
            pred_parsings_np = parsing_logits.detach().argmax(dim=1).cpu().numpy()
            parsing_targets_np = parsing_targets.cpu().numpy()

            N = parsing_targets_np.shape[0]
            parsingiou_targets = np.zeros(N, dtype=np.float)

            for _ in range(N):
                parsing_iou = cal_one_mean_iou(parsing_targets_np[_], pred_parsings_np[_], cfg.PRCNN.NUM_PARSING)
                parsingiou_targets[_] = np.nanmean(parsing_iou)
            parsingiou_targets = torch.from_numpy(parsingiou_targets).to(parsing_targets.device, dtype=torch.float)

        parsing_loss = F.cross_entropy(
            parsing_logits, parsing_targets, reduction="mean"
        )
        parsing_loss *= cfg.PRCNN.LOSS_WEIGHT

        if not self.parsingiou_on:
            return parsing_loss
        else:
            return parsing_loss, parsingiou_targets
Example #6
0
    def __call__(self, semantic_pred, targets):
        labels = self.semseg_batch_resize(targets)
        labels = cat([label for label in labels], dim=0).long()
        assert len(labels.shape) == 3

        loss_semseg = F.cross_entropy(semantic_pred,
                                      labels,
                                      ignore_index=self.ignore_label)
        loss_semseg *= self.loss_weight

        return loss_semseg
Example #7
0
    def __call__(self, boxlists):
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
        return target_lvls.to(torch.int64) - self.k_min
Example #8
0
    def __call__(self, parsing_logits):
        parsing_targets = [
            proposals_per_img.get_field("parsing_targets")
            for proposals_per_img in self.positive_proposals
        ]
        parsing_targets = cat(parsing_targets, dim=0)

        if parsing_targets.numel() == 0:
            return parsing_logits.sum() * 0

        parsing_loss = F.cross_entropy(parsing_logits,
                                       parsing_targets,
                                       reduction="mean")
        parsing_loss *= cfg.PRCNN.LOSS_WEIGHT

        return parsing_loss
Example #9
0
    def __call__(self, logits):
        targets = [
            proposals_per_img.get_field("targets")
            for proposals_per_img in self.positive_proposals
        ]
        targets = cat(targets, dim=0).float()

        if targets.numel() == 0:
            return logits['fused'].sum() * 0

        loss_fused = self.loss_weight * F.binary_cross_entropy_with_logits(
            logits['fused'], targets)
        loss_unfused = self.loss_weight * F.binary_cross_entropy_with_logits(
            logits['unfused'], targets)
        loss = loss_fused + loss_unfused
        return loss
Example #10
0
    def __call__(self, mask_logits):
        """
        Arguments:
            mask_logits (Tensor)

        Return:
            mask_loss (Tensor): scalar tensor containing the loss
            If we use maskiou head, we will return extra feature for maskiou head.
        """
        labels = [
            proposals_per_img.get_field("labels")
            for proposals_per_img in self.positive_proposals
        ]
        mask_targets = [
            proposals_per_img.get_field("mask_targets")
            for proposals_per_img in self.positive_proposals
        ]
        if self.maskiou_on:
            mask_ratios = [
                proposals_per_img.get_field("mask_ratios")
                for proposals_per_img in self.positive_proposals
            ]

        labels = cat(labels, dim=0)
        mask_targets = cat(mask_targets, dim=0)

        positive_inds = torch.nonzero(labels > 0).squeeze(1)
        labels_pos = labels[positive_inds]

        # torch.mean (in binary_cross_entropy_with_logits) doesn't
        # accept empty tensors, so handle it separately
        if mask_targets.numel() == 0:
            if not self.maskiou_on:
                return mask_logits.sum() * 0
            else:
                selected_index = torch.arange(mask_logits.shape[0],
                                              device=labels.device)
                selected_mask = mask_logits[selected_index, labels]
                mask_num, mask_h, mask_w = selected_mask.shape
                selected_mask = selected_mask.reshape(mask_num, 1, mask_h,
                                                      mask_w)
                return mask_logits.sum() * 0, selected_mask, labels, None

        if self.maskiou_on:
            mask_ratios = cat(mask_ratios, dim=0)
            value_eps = 1e-10 * torch.ones(mask_targets.shape[0],
                                           device=labels.device)
            mask_ratios = torch.max(mask_ratios, value_eps)
            pred_masks = mask_logits[positive_inds, labels_pos]
            pred_masks[:] = pred_masks > 0.5
            mask_targets_full_area = mask_targets.sum(dim=[1, 2]) / mask_ratios
            mask_ovr = pred_masks * mask_targets
            mask_ovr_area = mask_ovr.sum(dim=[1, 2])
            mask_union_area = pred_masks.sum(
                dim=[1, 2]) + mask_targets_full_area - mask_ovr_area
            value_1 = torch.ones(pred_masks.shape[0], device=labels.device)
            value_0 = torch.zeros(pred_masks.shape[0], device=labels.device)
            mask_union_area = torch.max(mask_union_area, value_1)
            mask_ovr_area = torch.max(mask_ovr_area, value_0)
            maskiou_targets = mask_ovr_area / mask_union_area

        mask_loss = F.binary_cross_entropy_with_logits(
            mask_logits[positive_inds, labels_pos], mask_targets)
        mask_loss *= cfg.MRCNN.LOSS_WEIGHT
        if not self.maskiou_on:
            return mask_loss
        else:
            selected_index = torch.arange(mask_logits.shape[0],
                                          device=labels.device)
            selected_mask = mask_logits[selected_index, labels]
            mask_num, mask_h, mask_w = selected_mask.shape
            selected_mask = selected_mask.reshape(mask_num, 1, mask_h, mask_w)
            selected_mask = selected_mask.sigmoid()
            return mask_loss, selected_mask, labels, maskiou_targets