示例#1
0
    def criterion(
        self,
        cls: Tensor,
        reg: Tensor,
        centerness: Tensor,
        target: Tensor,
        reg_scaling: float = 1.0,
        centerness_scaling: float = 1.0,
    ) -> Dict[str, Tensor]:
        # get losses
        target_bbox, target_cls = split_box_target(target)
        cls_loss, reg_loss, centerness_loss = self._criterion(cls, reg, centerness, target_bbox, target_cls)

        # get number of boxes for normalization
        num_boxes = (target_cls != -1).sum().clamp_min(1)
        cls_loss = cls_loss / num_boxes
        reg_loss = reg_loss / num_boxes
        centerness_loss = centerness_loss / num_boxes

        # compute a total loss
        total_loss = cls_loss + reg_loss * reg_scaling + centerness_loss * centerness_scaling

        return {
            "type_loss": cls_loss,
            "centerness_loss": centerness_loss,
            "reg_loss": reg_loss,
            "total_loss": total_loss,
        }
示例#2
0
    def test_split_box_target_returns_views(self):
        torch.random.manual_seed(42)
        bbox_target = torch.randint(0, 10, (3, 4, 4))
        label_target = torch.randint(0, 10, (3, 4, 2))
        target = torch.cat([bbox_target, label_target], dim=-1)

        bbox, label = split_box_target(target)

        target.mul_(10)
        assert torch.allclose(bbox, target[..., :4])
        assert torch.allclose(label, target[..., 4:])
示例#3
0
    def get_pred_target_pairs(
        self,
        pred: Tensor,
        target: Tensor,
    ) -> Tensor:
        r"""Given a predicted CenterNet heatmap and target bounding box label, use box IoU to
        create a paring of predicted and target boxes such that each predicted box has
        an associated gold standard label.

        .. warning::
            This method should work with batched input, but such inputs are not thoroughly tested

        Args:
            pred (:class:`torch.Tensor`):
                Predicted heatmap.

            target (:class:`torch.Tensor`):
                Target bounding boxes in format ``x1, y1, x2, y2, class``.

            iou_threshold (float):
                Intersection over union threshold for which a prediction can be considered a
                true positive.

            true_positive_limit (bool):
                By default, only one predicted box overlapping a target box will be counted
                as a true positive. If ``False``, allow multiple true positive boxes per
                target box.

        Returns:
            A 3-tuple of tensors as follows:
                1. Predicted binary score for each box
                2. Classification target value for each box
                3. Boolean indicating if the prediction was a correct

        Shape:
            * ``pred`` - :math:`(N_{pred}, 6)`
            * ``target`` - :math:`(N_{true}, 5)`
            * Output - :math:`(N_o)`, :math:`(N_o)`, :math:`(N_o)`
        """
        # get a paring of predicted probability to target labels
        # if we didn't detect a target box at any threshold, assume P_pred = 0.0
        xform = CategoricalLabelIoU(self.iou_threshold,
                                    self.true_positive_limit)
        pred_boxes, pred_scores, pred_cls = split_bbox_scores_class(pred)
        target_bbox, target_cls = split_box_target(target)
        pred_out, binary_target, target_out = xform(pred_boxes, pred_scores,
                                                    pred_cls, target_bbox,
                                                    target_cls)

        assert pred_out.ndim == 1
        assert target_out.ndim == 1
        assert pred_out.shape == target_out.shape
        return pred_out, target_out.long(), binary_target
示例#4
0
    def test_split_box_target_result(self, split_label):
        torch.random.manual_seed(42)
        bbox_target = torch.randint(0, 10, (3, 4, 4))
        label_target = torch.randint(0, 10, (3, 4, 2))
        target = torch.cat([bbox_target, label_target], dim=-1)

        result = split_box_target(target, split_label=split_label)
        bbox = result[0]
        label = result[1:]
        assert torch.allclose(bbox, bbox_target)

        if not split_label:
            assert torch.allclose(torch.cat(label, dim=-1), label_target)
        else:
            for i, sub_label in enumerate(label):
                assert torch.allclose(sub_label, label_target[..., i:i + 1])