Example #1
0
    def test_non_float_input(self):
        dest = torch.zeros(1, 1, 10, 10).byte()
        src = torch.zeros_like(dest)

        dest[0, 0, 0, 0] = 255
        src[0, 0, 1, 1] = 255

        out, out_alpha = alpha_blend(src, dest)
        assert out.dtype == torch.uint8
        assert out[0, 0, 0, 0] == 127
        assert out[0, 0, 1, 1] == 127
Example #2
0
    def test_output_channels(self, cuda):
        dest = torch.zeros(1, 1, 10, 10).float()
        src = torch.zeros_like(dest).float()

        if cuda:
            dest = dest.cuda()
            src = src.cuda()

        dest[0, 0, 0, 0] = 1.0
        src[0, 0, 1, 1] = 1.0

        out, out_alpha = alpha_blend(src, dest)
        assert out.device == dest.device
        assert out[0, 0, 0, 0] == 0.5
        assert out[0, 0, 1, 1] == 0.5
Example #3
0
    def alpha_blend(dest: Tensor, src: Tensor, resize_mode="bilinear", *args, **kwargs) -> Tensor:
        # ensure batch dim present
        if dest.ndim == 3:
            dest = dest[None]
        if src.ndim == 3:
            dest = dest[None]

        B1, C1, H1, W1 = dest.shape
        B2, C2, H2, W2 = src.shape

        if C1 != C2:
            if C1 == 1:
                dest = dest.repeat(1, C2, 1, 1)
            elif C2 == 1:
                src = src.repeat(1, C1, 1, 1)
            else:
                raise ValueError(f"could not match shapes {dest.shape}, {src.shape}")

        if (H1, W1) != (H2, W2):
            src = F.interpolate(src, (H1, W1), mode=resize_mode)

        blended, _ = alpha_blend(src, dest, *args, **kwargs)
        return blended.view_as(dest)
Example #4
0
    def alpha_blend(dest: Tensor,
                    src: Tensor,
                    resize_mode="bilinear",
                    *args,
                    **kwargs) -> Tensor:
        # ensure batch dim present
        if dest.ndim == 3:
            dest = dest[None]
        elif dest.ndim > 4:
            dest = dest.view(-1, *dest.shape[-3:])
        if src.ndim == 3:
            src = src[None]
        elif dest.ndim > 4:
            src = src.view(-1, *src.shape[-3:])

        src = clamp_normalize(src, inplace=True)
        dest = clamp_normalize(dest, inplace=True)

        B1, C1, H1, W1 = dest.shape
        B2, C2, H2, W2 = src.shape

        if C1 != C2:
            if C1 == 1:
                dest = dest.repeat(1, C2, 1, 1)
            elif C2 == 1:
                src = src.repeat(1, C1, 1, 1)
            else:
                raise ValueError(
                    f"could not match shapes {dest.shape}, {src.shape}")

        if (H1, W1) != (H2, W2):
            src = F.interpolate(src, (H1, W1),
                                mode=resize_mode,
                                align_corners=True)

        blended, _ = alpha_blend(src, dest, *args, **kwargs)
        return blended.view_as(dest)
Example #5
0
 def test_output_alpha(self):
     dest = torch.rand(1, 1, 10, 10)
     src = torch.rand(1, 1, 10, 10)
     out, out_alpha = alpha_blend(src, dest)
     assert (out_alpha == 1).all()
Example #6
0
 def test_input_shapes(self, shape):
     dest = torch.rand(*shape)
     src = torch.rand(*shape)
     out, out_alpha = alpha_blend(src, dest)
     assert out.shape == dest.shape
Example #7
0
    def visualize_heatmap(
        heatmap: Tensor,
        background: Optional[Tensor] = None,
        cmap: str = "gnuplot",
        same_on_batch: bool = True,
        heatmap_alpha: float = 0.5,
        background_alpha: float = 0.5,
    ) -> List[ByteTensor]:
        r"""Generates visualizations of a CenterNet heatmap. Can optionally overlay the
        heatmap on top of a background image.

        Args:
            heatmap (:class:`torch.Tensor`):
                The heatmap to visualize

            background (:class:`torch.Tensor`):
                An optional background image for the heatmap visualization

            cmap (str):
                Matplotlib colormap

            same_on_batch (bool):
                See :func:`combustion.vision.to_8bit`

            heatmap_alpha (float):
                See :func:`combustion.util.alpha_blend`

            background_alpha (float):
                See :func:`combustion.util.alpha_blend`

        Returns:
            List of tensors, where each tensor is a heatmap visualization for one class in the heatmap

        Shape:
            * ``heatmap`` - :math:`(N, C, H, W)` where :math:`C` is the number of classes in the heatmap.
            * Output - :math:`(N, 3, H, W)`
        """
        check_is_tensor(heatmap, "heatmap")
        if background is not None:
            check_is_tensor(background, "heatmap")
            # need background to be float [0, 1] for alpha blend w/ heatmap
            background = to_8bit(background, same_on_batch=same_on_batch).float().div_(255).cpu()

            if background.shape[-3] == 1:
                repetitions = [
                    1,
                ] * background.ndim
                repetitions[-3] = 3
                background = background.repeat(*repetitions)

        num_channels = heatmap.shape[-3]

        result = []
        for channel_idx in range(num_channels):
            _ = heatmap[..., channel_idx : channel_idx + 1, :, :]
            _ = to_8bit(_, same_on_batch=same_on_batch)

            # output is float from [0, 1]
            heatmap_channel = apply_colormap(_.cpu(), cmap=cmap)

            # drop alpha channel
            heatmap_channel = heatmap_channel[..., :3, :, :]

            # alpha blend w/ background
            if background is not None:
                heatmap_channel = F.interpolate(
                    heatmap_channel, size=background.shape[-2:], mode="bilinear", align_corners=True
                )
                heatmap_channel = alpha_blend(heatmap_channel, background, heatmap_alpha, background_alpha)[0]

            heatmap_channel = heatmap_channel.mul_(255).byte()
            result.append(heatmap_channel)

        return result
class TestBlendVisualizeCallback(TestVisualizeCallback):

    callback_cls = BlendVisualizeCallback

    @pytest.fixture(params=[
        pytest.param(True, id="float"),
        pytest.param(False, id="long")
    ])
    def data(self, data_shape, request):
        B, C, H, W = data_shape
        img = create_image(B, C, H, W)
        return img.clone(), img.clone()

    @pytest.fixture
    def expected_calls(self, data, model, mode, callback, mocker):
        if not hasattr(model, callback.attr_name):
            return []

        if callback.split_channels == (2, 1):
            pytest.skip("incompatible test")

        data1, data2 = data
        B, C, H, W = data1.shape
        img1 = [data1]
        img2 = [data2]
        name1 = [
            f"{mode}/{callback.name}",
        ]
        name2 = [
            f"{mode}/{callback.name}",
        ]
        img = [img1, img2]
        name = [name1, name2]

        # channel splitting

        for pos in range(2):
            if callback.split_channels[pos]:
                img[pos] = []
                name[pos] = []
                img_new, name_new = [], []
                splits = torch.split(data[pos],
                                     callback.split_channels[pos],
                                     dim=-3)
                for i, s in enumerate(splits):
                    n = f"{mode}/{callback.name[i]}"
                    name_new.append(n)
                    img_new.append(s)
                img[pos] = img_new
                name[pos] = name_new

        if len(img[0]) != len(img[1]):
            if len(img[0]) == 1:
                img[0] = img[0] * len(img[1])
            elif len(img[1]) == 1:
                img[1] = img[1] * len(img[0])
            else:
                raise RuntimeError()

        for pos in range(2):
            if callback.max_resolution:
                resize_mode = callback.resize_mode
                target = callback.max_resolution
                H_max, W_max = target
                needs_resize = [
                    i.shape[-2] > H_max or i.shape[-1] > W_max
                    for i in img[pos]
                ]
                img[pos] = [
                    F.interpolate(i, target, mode=resize_mode) if resize else i
                    for i, resize in zip(img[pos], needs_resize)
                ]

        for pos in range(2):
            if (colormap := callback.colormap[pos]):
                if isinstance(colormap, str):
                    colormap = [colormap] * len(img[pos])
                img[pos] = [
                    apply_colormap(i, cmap)[..., :3, :, :]
                    if cmap is not None else i
                    for cmap, i in zip(colormap, img[pos])
                ]

        name = name[0]
        final_img = []
        for pos, (d, s) in enumerate(zip(img[0], img[1])):
            B1, C1, H1, W1 = d.shape
            B2, C2, H2, W2 = s.shape

            if C1 != C2:
                if C1 == 1:
                    d = d.repeat(1, C2, 1, 1)
                elif C2 == 1:
                    s = s.repeat(1, C1, 1, 1)
                else:
                    raise ValueError(
                        f"could not match shapes {d.shape}, {s.shape}")

            final_img.append(
                alpha_blend(d, s, callback.alpha[1], callback.alpha[0])[0])
        img = final_img

        if callback.as_uint8:
            img = [
                to_8bit(i, same_on_batch=not callback.per_img_norm)
                for i in img
            ]

        if callback.split_batches:
            new_img, new_name = [], []
            for i, n in zip(img, name):
                split_i = torch.split(i, 1, dim=0)
                split_n = [f"{n}/{b}" for b in range(B)]
                new_img += split_i
                new_name += split_n
            name, img = new_name, new_img

        step = [
            model.current_epoch
            if callback.epoch_counter else model.global_step
        ] * len(name)
        expected = [(n, i, s) for n, i, s in zip(name, img, step)]
        return expected
Example #9
0
 def blend_and_save(self, path, src, dest):
     src = apply_colormap(src)[..., :3, :, :]
     src = F.interpolate(src, dest.shape[-2:])
     _ = alpha_blend(src, dest)[0].squeeze_(0)
     self.save(path, _)