Пример #1
0
    def test_create_classification_target(self, stride, center_radius, size_target):
        bbox = torch.tensor(
            [
                [0, 0, 9, 9],
                [3, 4, 8, 6],
                [4, 4, 6, 6],
            ]
        )
        cls = torch.tensor([0, 0, 1]).unsqueeze_(-1)
        mask = FCOSLoss.bbox_to_mask(bbox, stride, size_target, center_radius)
        num_classes = 2

        result = FCOSLoss.create_classification_target(bbox, cls, mask, num_classes, size_target)

        assert isinstance(result, Tensor)
        assert result.shape == torch.Size([num_classes, *size_target])
Пример #2
0
    def test_bbox_to_mask(self, stride, center_radius, size_target):
        bbox = torch.tensor(
            [
                [0, 0, 9, 9],
                [2, 2, 5, 5],
                [1, 1, 2, 2],
            ]
        )
        result = FCOSLoss.bbox_to_mask(bbox, stride, size_target, center_radius)

        assert isinstance(result, Tensor)
        assert result.shape == torch.Size([bbox.shape[-2], *size_target])

        for box, res in zip(bbox, result):
            center_x = (box[0] + box[2]).true_divide(2)
            center_y = (box[1] + box[3]).true_divide(2)
            radius_x = (box[2] - box[0]).true_divide(2)
            radius_y = (box[3] - box[1]).true_divide(2)

            if center_radius is not None:
                x1 = center_x - center_radius * stride
                x2 = center_x + center_radius * stride
                y1 = center_y - center_radius * stride
                y2 = center_y + center_radius * stride
            else:
                x1 = center_x - radius_x
                x2 = center_x + radius_x
                y1 = center_y - radius_y
                y2 = center_y + radius_y

            x1.clamp_min_(center_x - radius_x)
            x2.clamp_max_(center_x + radius_x)
            y1.clamp_min_(center_y - radius_y)
            y2.clamp_max_(center_y + radius_y)

            h = torch.arange(res.shape[-2], dtype=torch.float, device=box.device)
            w = torch.arange(res.shape[-1], dtype=torch.float, device=box.device)

            mesh = torch.stack(torch.meshgrid(h, w), 0).mul_(stride).add_(stride / 2)
            lower_bound = torch.stack([x1, y1]).view(2, 1, 1)
            upper_bound = torch.stack([x2, y2]).view(2, 1, 1)
            mask = (mesh >= lower_bound).logical_and_(mesh <= upper_bound).all(dim=-3)
            pos_region = res[mask]

            assert res.any()
            assert pos_region.all()
            assert res.sum() - pos_region.sum() == 0