Beispiel #1
0
    def test_cmap_locations(self):
        inputs = torch.rand(1, 1, 10, 10)
        inputs2 = torch.rand(1, 1, 10, 10)
        out1 = apply_colormap(inputs, "gray")
        out2 = apply_colormap(inputs2, "gray")

        greater_input = inputs <= inputs2
        greater_output = out1[:, 0, ...] <= out2[:, 0, ...]
        assert ~(torch.logical_xor(greater_input, greater_output)).all()
Beispiel #2
0
    def apply_colormap(img: Tensor, cmap: str = "gnuplot") -> Tensor:
        # ensure batch dim present
        if img.ndim == 3:
            img = img[None]

        img = apply_colormap(img, cmap=cmap)[:, :3, :, :]
        img = img.squeeze_()
        return img
Beispiel #3
0
    def test_input_output_shape(self, inputs, dtype):
        if dtype == "long":
            inputs = inputs.mul(1024).long()
        elif dtype == "byte":
            inputs = inputs.mul(255).byte()

        out = apply_colormap(inputs)
        assert isinstance(out, Tensor)
        assert out.shape[2:] == inputs.shape[2:]
        assert out.shape[0] == inputs.shape[0]
        assert out.shape[1] == 4
    def expected_calls(self, data, model, mode, callback):
        if not hasattr(model, callback.attr_name):
            return []
        data = prepare_image(data)

        B, _, H, W = data.shape
        img = [data]
        name = [
            f"{mode}/{callback.name}",
        ]

        # channel splitting
        if callback.split_channels:
            img, name = [], []
            splits = torch.split(data, callback.split_channels, dim=-3)
            for i, s in enumerate(splits):
                if isinstance(callback.name, str):
                    n = f"{mode}/{callback.name}_{i}"
                else:
                    n = f"{mode}/{callback.name[i]}"
                name.append(n)
                img.append(s)

        if callback.max_resolution:
            resize_mode = callback.resize_mode
            target = callback.max_resolution
            H_max, W_max = target
            scale_factor = []
            for i in img:
                H, W = i.shape[-2:]
                height_ratio, width_ratio = H / H_max, W / W_max
                s = 1 / max(height_ratio, width_ratio)
                scale_factor.append(s)

            img = [
                F.interpolate(i, scale_factor=s, mode=resize_mode)
                if s < 1 else i for i, s in zip(img, scale_factor)
            ]

        if (colormap := callback.colormap):
            if isinstance(colormap, str):
                colormap = [colormap] * len(img)
            img = [
                apply_colormap(i, cmap)[..., :3, :, :]
                if cmap is not None else i for cmap, i in zip(colormap, img)
            ]
    def expected_calls(self, data, model, mode, callback):
        if not hasattr(model, callback.attr_name):
            return []

        data, target = data
        data = prepare_image(data)
        B, _, _, _ = data.shape
        img = [data]
        name = [
            f"{mode}/{callback.name}",
        ]

        # channel splitting
        if callback.split_channels:
            img, name = [], []
            splits = torch.split(data, callback.split_channels, dim=-3)
            for i, s in enumerate(splits):
                n = f"{mode}/{callback.name[i]}"
                name.append(n)
                img.append(s)

        if callback.max_resolution:
            resize_mode = callback.resize_mode
            target = callback.max_resolution
            H_max, W_max = target
            needs_resize = [
                i.shape[-2] > H_max or i.shape[-1] > W_max for i in img
            ]
            img = [
                F.interpolate(i, target, mode=resize_mode) if resize else i
                for i, resize in zip(img, needs_resize)
            ]

        if (colormap := callback.colormap):
            if isinstance(colormap, str):
                colormap = [colormap] * len(img)
            img = [
                apply_colormap(i, cmap)[..., :3, :, :]
                if cmap is not None else i for cmap, i in zip(colormap, img)
            ]
Beispiel #6
0
 def test_cuda_tensor(self):
     inputs = torch.rand(1, 1, 10, 10).cuda()
     output = apply_colormap(inputs, "gray")
     assert isinstance(output, Tensor)
     assert output.device != "cpu"
Beispiel #7
0
 def test_input_output_shape(self, inputs):
     out = apply_colormap(inputs)
     assert isinstance(out, Tensor)
     assert out.shape[2:] == inputs.shape[2:]
     assert out.shape[0] == inputs.shape[0]
     assert out.shape[1] == 4
Beispiel #8
0
    def visualize_heatmap(
        heatmap: Tensor,
        background: Optional[Tensor] = None,
        cmap: str = "gnuplot",
        same_on_batch: bool = True,
        heatmap_alpha: float = 0.5,
        background_alpha: float = 0.5,
    ) -> List[ByteTensor]:
        r"""Generates visualizations of a CenterNet heatmap. Can optionally overlay the
        heatmap on top of a background image.

        Args:
            heatmap (:class:`torch.Tensor`):
                The heatmap to visualize

            background (:class:`torch.Tensor`):
                An optional background image for the heatmap visualization

            cmap (str):
                Matplotlib colormap

            same_on_batch (bool):
                See :func:`combustion.vision.to_8bit`

            heatmap_alpha (float):
                See :func:`combustion.util.alpha_blend`

            background_alpha (float):
                See :func:`combustion.util.alpha_blend`

        Returns:
            List of tensors, where each tensor is a heatmap visualization for one class in the heatmap

        Shape:
            * ``heatmap`` - :math:`(N, C, H, W)` where :math:`C` is the number of classes in the heatmap.
            * Output - :math:`(N, 3, H, W)`
        """
        check_is_tensor(heatmap, "heatmap")
        if background is not None:
            check_is_tensor(background, "heatmap")
            # need background to be float [0, 1] for alpha blend w/ heatmap
            background = to_8bit(background, same_on_batch=same_on_batch).float().div_(255).cpu()

            if background.shape[-3] == 1:
                repetitions = [
                    1,
                ] * background.ndim
                repetitions[-3] = 3
                background = background.repeat(*repetitions)

        num_channels = heatmap.shape[-3]

        result = []
        for channel_idx in range(num_channels):
            _ = heatmap[..., channel_idx : channel_idx + 1, :, :]
            _ = to_8bit(_, same_on_batch=same_on_batch)

            # output is float from [0, 1]
            heatmap_channel = apply_colormap(_.cpu(), cmap=cmap)

            # drop alpha channel
            heatmap_channel = heatmap_channel[..., :3, :, :]

            # alpha blend w/ background
            if background is not None:
                heatmap_channel = F.interpolate(
                    heatmap_channel, size=background.shape[-2:], mode="bilinear", align_corners=True
                )
                heatmap_channel = alpha_blend(heatmap_channel, background, heatmap_alpha, background_alpha)[0]

            heatmap_channel = heatmap_channel.mul_(255).byte()
            result.append(heatmap_channel)

        return result
    def expected_calls(self, data, model, mode, callback, mocker):
        if not hasattr(model, callback.attr_name):
            return []

        if callback.split_channels == (2, 1):
            pytest.skip("incompatible test")

        data1, data2 = data
        B, C, H, W = data1.shape
        img1 = [data1]
        img2 = [data2]
        name1 = [
            f"{mode}/{callback.name}",
        ]
        name2 = [
            f"{mode}/{callback.name}",
        ]
        img = [img1, img2]
        name = [name1, name2]

        # channel splitting

        for pos in range(2):
            if callback.split_channels[pos]:
                img[pos] = []
                name[pos] = []
                img_new, name_new = [], []
                splits = torch.split(data[pos],
                                     callback.split_channels[pos],
                                     dim=-3)
                for i, s in enumerate(splits):
                    n = f"{mode}/{callback.name[i]}"
                    name_new.append(n)
                    img_new.append(s)
                img[pos] = img_new
                name[pos] = name_new

        if len(img[0]) != len(img[1]):
            if len(img[0]) == 1:
                img[0] = img[0] * len(img[1])
            elif len(img[1]) == 1:
                img[1] = img[1] * len(img[0])
            else:
                raise RuntimeError()

        for pos in range(2):
            if callback.max_resolution:
                resize_mode = callback.resize_mode
                target = callback.max_resolution
                H_max, W_max = target
                needs_resize = [
                    i.shape[-2] > H_max or i.shape[-1] > W_max
                    for i in img[pos]
                ]
                img[pos] = [
                    F.interpolate(i, target, mode=resize_mode) if resize else i
                    for i, resize in zip(img[pos], needs_resize)
                ]

        for pos in range(2):
            if (colormap := callback.colormap[pos]):
                if isinstance(colormap, str):
                    colormap = [colormap] * len(img[pos])
                img[pos] = [
                    apply_colormap(i, cmap)[..., :3, :, :]
                    if cmap is not None else i
                    for cmap, i in zip(colormap, img[pos])
                ]
Beispiel #10
0
 def blend_and_save(self, path, src, dest):
     src = apply_colormap(src)[..., :3, :, :]
     src = F.interpolate(src, dest.shape[-2:])
     _ = alpha_blend(src, dest)[0].squeeze_(0)
     self.save(path, _)