def test_batched_input(self, same_on_batch, expected):
     input = torch.tensor([
         [[[-2.0, 0.0], [0.0, 1.0]]],
         [[[-1.0, 0.5], [0.0, 1.0]]],
     ])
     result = to_8bit(input, per_channel=True, same_on_batch=same_on_batch)
     assert torch.allclose(result, expected)
    def test_random_input(self, input, per_channel, shape, dtype):
        result = to_8bit(input, per_channel)
        assert result.shape == shape
        assert result.min() >= 0
        assert result.max() <= 255
        assert result.unique().numel() >= 10
        assert result.dtype == torch.uint8

        if dtype == "byte":
            assert torch.allclose(input, result)
Beispiel #3
0
    def _process_image(
        self,
        img: Tensor,
        colormap: Optional[Union[Colormap, ColormapSequence]],
        split_channels: Optional[ChannelSplits],
        name: Union[str, List[str]],
        max_resolution: Optional[Tuple[int, int]],
        as_uint8: bool = True,
    ) -> Dict[str, Tensor]:
        img = prepare_image(img)

        # split channels
        images: Dict[str, Tensor]
        if split_channels is not None:
            _images = torch.split(img, split_channels, dim=1)
            if isinstance(name, str):
                _name = [f"{name}_{i}" for i in range(len(_images))]
            else:
                _name = name
            assert len(_name) == len(_images)
            images = {n: x for n, x in zip(_name, _images)}
        else:
            assert isinstance(name, str)
            images = {name: img}

        # apply resolution limit
        if max_resolution is not None:
            images = {
                k: self._apply_log_resolution(v, max_resolution,
                                              self.resize_mode)
                for k, v in images.items()
            }

        # apply colormap
        colormap_seq: Optional[ColormapSequence] = None
        if isinstance(colormap, str):
            colormap_seq = [colormap] * len(images)  # type: ignore
        elif colormap is not None:
            colormap_seq = colormap
        if colormap_seq is not None:
            assert len(colormap_seq) == len(images)
            images = {
                k: self.apply_colormap(v, cmap) if cmap is not None else v
                for cmap, (k, v) in zip(colormap_seq, images.items())
            }

        # convert to byte
        if as_uint8:
            images = {
                k: to_8bit(v, same_on_batch=not self.per_img_norm)
                for k, v in images.items()
            }

        return images
    def test_known_input_per_channel_result(self, per_channel):
        input = torch.tensor([
            [[0, 1000], [1001, 2000]],
            [[0, 500], [500, 1000]],
        ]).long()

        if per_channel:
            expected = torch.tensor([[[0, 128], [128, 255]],
                                     [[0, 128], [128, 255]]]).byte()
        else:
            expected = torch.tensor([[[0, 128], [128, 255]],
                                     [[0, 64], [64, 128]]]).byte()

        result = to_8bit(input, per_channel=per_channel)
        assert result.shape == input.shape
        assert torch.allclose(result, expected)
Beispiel #5
0
    def _process_image(
        self,
        dest: Tensor,
        src: Tensor,
        colormap: Tuple[Optional[Colormap], Optional[Colormap]],
        split_channels: Tuple[Optional[ChannelSplits], Optional[ChannelSplits]],
        name: Union[str, List[str]],
        max_resolution: Optional[Tuple[int, int]],
        as_uint8: bool = True,
    ) -> Tensor:

        B, C, H, W = dest.shape
        if any(split_channels):
            name1, name2 = (name if split is not None else "__name__" for split in split_channels)
        else:
            name1, name2 = name, name

        images_dest = super()._process_image(dest, colormap[0], split_channels[0], name1, max_resolution, False)
        images_src = super()._process_image(src, colormap[1], split_channels[1], name2, max_resolution, False)

        if len(images_dest) != len(images_src):
            if len(images_dest) == 1:
                val = next(iter(images_dest.values()))
                images_dest = {k: val for k in images_src.keys()}
            elif len(images_src) == 1:
                val = next(iter(images_src.values()))
                images_src = {k: val for k in images_dest.keys()}
            else:
                raise ValueError(
                    f"Unable to broadcast processed images:\n"
                    f"Destination dict:\n {k: v.shape for k, v in images_dest.items()}\n\n"
                    f"Source dict:\n {k: v.shape for k, v in images_src.items()}"
                )

        dest_alpha, src_alpha = self.alpha
        images = {
            k1
            if "__name__" not in k1
            else k2: VisualizeCallback.alpha_blend(d, s, self.resize_mode, dest_alpha=dest_alpha, src_alpha=src_alpha)
            for (k1, d), (k2, s) in zip(images_dest.items(), images_src.items())
        }
        assert "__name__" not in images.keys()

        if as_uint8:
            images = {k: to_8bit(v, same_on_batch=not self.per_img_norm) for k, v in images.items()}

        return images
Beispiel #6
0
    def _process_image(
        self,
        img: Tensor,
        keypoint_dict: Dict[str, Tensor],
        colormap: Optional[Union[Colormap, ColormapSequence]],
        split_channels: Optional[ChannelSplits],
        name: Union[str, List[str]],
        max_resolution: Optional[Tuple[int, int]],
        as_uint8: bool = True,
    ) -> Dict[str, Tensor]:
        img = prepare_image(img)

        B, C, H, W = img.shape
        images = super()._process_image(img, colormap, split_channels, name,
                                        max_resolution, False)

        # apply resolution limit to keypoints
        if max_resolution is not None:
            H_scaled, W_scaled = next(iter(images.values())).shape[-2:]
            if H_scaled / H != 1.0:
                coords = keypoint_dict["coords"].clone().float()
                padding = coords < 0
                coords.mul_(H_scaled / H)
                coords[padding] = -1
                keypoint_dict["coords"] = coords

        # overlay keypoints, accounting for channel splitting
        if self.split_channels:
            keypoints = [keypoint_dict] * C
            assert len(keypoints) == C
            images = {
                k: self.overlay_keypoints(v, d)
                for d, (k, v) in zip(keypoints, images.items())
            }
        else:
            images = {
                k: self.overlay_keypoints(v, keypoint_dict)
                for k, v in images.items()
            }

        if as_uint8:
            images = {
                k: to_8bit(v, same_on_batch=not self.per_img_norm)
                for k, v in images.items()
            }

        return images
Beispiel #7
0
    def _process_image(
        self,
        img: Tensor,
        colormap: Optional[Union[str, List[str]]],
        split_channels: Optional[ChannelSplits],
        name: Union[str, List[str]],
        max_resolution: Optional[Tuple[int, int]],
        as_uint8: bool = True,
    ) -> Tensor:
        # split channels
        if split_channels is not None:
            images = torch.split(img, split_channels, dim=1)
            assert isinstance(name, list)
            assert len(name) == len(images)
            images = {n: x for n, x in zip(name, images)}
        else:
            images = {name: img}

        # apply resolution limit
        if max_resolution is not None:
            images = {k: self._apply_log_resolution(v, max_resolution, self.resize_mode) for k, v in images.items()}

        # apply colormap
        if colormap is not None:
            if isinstance(colormap, str):
                colormap = [colormap] * len(images)
            assert len(colormap) == len(images)
            images = {
                k: self.apply_colormap(v, cmap) if cmap is not None else v
                for cmap, (k, v) in zip(colormap, images.items())
            }

        # convert to byte
        if as_uint8:
            images = {k: to_8bit(v, same_on_batch=not self.per_img_norm) for k, v in images.items()}

        return images
class TestVisualizeCallback:

    callback_cls = VisualizeCallback

    # training, validation, or testing mode
    @pytest.fixture(params=["train", "val", "test"])
    def mode(self, request):
        return request.param

    # returns a no-args closure of callback function appropriate for `mode`
    @pytest.fixture
    def callback_func(self, mode, trainer, model):
        if mode == "train":
            pass
        elif mode == "val":
            pass
        elif mode == "test":
            pass
        else:
            raise ValueError(f"{mode}")

        def func(self):
            f = getattr(self, func)
            return f(trainer, model)

        return func

    @pytest.fixture
    def data_shape(self):
        return 2, 3, 32, 32

    @pytest.fixture
    def data(self, data_shape):
        data = create_image(*data_shape)
        return data

    @pytest.fixture
    def model(self, request, mocker, callback, data, mode):

        if hasattr(request, "param"):
            step = request.param.pop("step", 10)
            epoch = request.param.pop("epoch", 1)
        else:
            step = 10
            epoch = 1

        model = mocker.MagicMock(name="module")
        model.current_epoch = epoch
        model.global_step = step
        model.global_step = step
        if callback.attr_name is not None:
            setattr(model, callback.attr_name, data)

        if mode == "train":
            attr = "on_train_batch_end"
        elif mode == "val":
            attr = "on_validation_batch_end"
        elif mode == "test":
            attr = "on_test_batch_end"
        else:
            raise ValueError(f"{mode}")
        callback.trigger = lambda: getattr(callback, attr)(trainer, model)

        return model

    @pytest.fixture
    def callback(self, request, trainer):
        cls = self.callback_cls
        init_signature = inspect.signature(cls)
        defaults = {
            k: v.default
            for k, v in init_signature.parameters.items()
            if v.default is not inspect.Parameter.empty
        }

        if hasattr(request, "param"):
            name = request.param.get("name", "image")
            defaults.update(request.param)
        else:
            name = "image"
            defaults["name"] = name

        callback = cls(**defaults)
        return callback

    @pytest.fixture
    def logger_func(self, model):
        return model.logger.experiment.add_images

    @pytest.fixture
    def expected_calls(self, data, model, mode, callback):
        if not hasattr(model, callback.attr_name):
            return []

        B, C, H, W = data.shape
        img = [data]
        name = [
            f"{mode}/{callback.name}",
        ]

        # channel splitting
        if callback.split_channels:
            img, name = [], []
            splits = torch.split(data, callback.split_channels, dim=-3)
            for i, s in enumerate(splits):
                n = f"{mode}/{callback.name[i]}"
                name.append(n)
                img.append(s)

        if callback.max_resolution:
            resize_mode = callback.resize_mode
            target = callback.max_resolution
            H_max, W_max = target
            scale_factor = []
            for i in img:
                H, W = i.shape[-2:]
                height_ratio, width_ratio = H / H_max, W / W_max
                s = 1 / max(height_ratio, width_ratio)
                scale_factor.append(s)

            img = [
                F.interpolate(i, scale_factor=s, mode=resize_mode)
                if s < 1 else i for i, s in zip(img, scale_factor)
            ]

        if (colormap := callback.colormap):
            if isinstance(colormap, str):
                colormap = [colormap] * len(img)
            img = [
                apply_colormap(i, cmap)[..., :3, :, :]
                if cmap is not None else i for cmap, i in zip(colormap, img)
            ]

        if callback.split_batches:
            new_img, new_name = [], []
            for i, n in zip(img, name):
                split_i = torch.split(i, 1, dim=0)
                split_n = [f"{n}/{b}" for b in range(B)]
                new_img += split_i
                new_name += split_n
            name, img = new_name, new_img

        if callback.as_uint8:
            img = [
                to_8bit(i, same_on_batch=not callback.per_img_norm)
                for i in img
            ]

        step = [
            model.current_epoch
            if callback.epoch_counter else model.global_step
        ] * len(name)
        expected = [(n, i, s) for n, i, s in zip(name, img, step)]
        return expected
class TestBlendVisualizeCallback(TestVisualizeCallback):

    callback_cls = BlendVisualizeCallback

    @pytest.fixture(params=[
        pytest.param(True, id="float"),
        pytest.param(False, id="long")
    ])
    def data(self, data_shape, request):
        B, C, H, W = data_shape
        img = create_image(B, C, H, W)
        return img.clone(), img.clone()

    @pytest.fixture
    def expected_calls(self, data, model, mode, callback, mocker):
        if not hasattr(model, callback.attr_name):
            return []

        if callback.split_channels == (2, 1):
            pytest.skip("incompatible test")

        data1, data2 = data
        B, C, H, W = data1.shape
        img1 = [data1]
        img2 = [data2]
        name1 = [
            f"{mode}/{callback.name}",
        ]
        name2 = [
            f"{mode}/{callback.name}",
        ]
        img = [img1, img2]
        name = [name1, name2]

        # channel splitting

        for pos in range(2):
            if callback.split_channels[pos]:
                img[pos] = []
                name[pos] = []
                img_new, name_new = [], []
                splits = torch.split(data[pos],
                                     callback.split_channels[pos],
                                     dim=-3)
                for i, s in enumerate(splits):
                    n = f"{mode}/{callback.name[i]}"
                    name_new.append(n)
                    img_new.append(s)
                img[pos] = img_new
                name[pos] = name_new

        if len(img[0]) != len(img[1]):
            if len(img[0]) == 1:
                img[0] = img[0] * len(img[1])
            elif len(img[1]) == 1:
                img[1] = img[1] * len(img[0])
            else:
                raise RuntimeError()

        for pos in range(2):
            if callback.max_resolution:
                resize_mode = callback.resize_mode
                target = callback.max_resolution
                H_max, W_max = target
                needs_resize = [
                    i.shape[-2] > H_max or i.shape[-1] > W_max
                    for i in img[pos]
                ]
                img[pos] = [
                    F.interpolate(i, target, mode=resize_mode) if resize else i
                    for i, resize in zip(img[pos], needs_resize)
                ]

        for pos in range(2):
            if (colormap := callback.colormap[pos]):
                if isinstance(colormap, str):
                    colormap = [colormap] * len(img[pos])
                img[pos] = [
                    apply_colormap(i, cmap)[..., :3, :, :]
                    if cmap is not None else i
                    for cmap, i in zip(colormap, img[pos])
                ]

        name = name[0]
        final_img = []
        for pos, (d, s) in enumerate(zip(img[0], img[1])):
            B1, C1, H1, W1 = d.shape
            B2, C2, H2, W2 = s.shape

            if C1 != C2:
                if C1 == 1:
                    d = d.repeat(1, C2, 1, 1)
                elif C2 == 1:
                    s = s.repeat(1, C1, 1, 1)
                else:
                    raise ValueError(
                        f"could not match shapes {d.shape}, {s.shape}")

            final_img.append(
                alpha_blend(d, s, callback.alpha[1], callback.alpha[0])[0])
        img = final_img

        if callback.as_uint8:
            img = [
                to_8bit(i, same_on_batch=not callback.per_img_norm)
                for i in img
            ]

        if callback.split_batches:
            new_img, new_name = [], []
            for i, n in zip(img, name):
                split_i = torch.split(i, 1, dim=0)
                split_n = [f"{n}/{b}" for b in range(B)]
                new_img += split_i
                new_name += split_n
            name, img = new_name, new_img

        step = [
            model.current_epoch
            if callback.epoch_counter else model.global_step
        ] * len(name)
        expected = [(n, i, s) for n, i, s in zip(name, img, step)]
        return expected
class TestKeypointVisualizeCallback(TestVisualizeCallback):

    callback_cls = KeypointVisualizeCallback

    @pytest.fixture(params=[
        pytest.param(True, id="float"),
        pytest.param(False, id="long")
    ])
    def data(self, data_shape, request):
        B, C, H, W = data_shape
        N = 3
        img = create_image(B, C, H, W)
        bbox = self.create_bbox(B, N)
        bbox = bbox.float() if request.param else bbox.long()
        cls = self.create_classes(B, N)
        score = self.create_classes(B, N)
        target = {"coords": bbox, "class": cls, "score": score}
        return img, target

    def create_bbox(self, B, N):
        torch.random.manual_seed(42)
        bbox = torch.empty(B, N, 4).fill_(-1).float()
        return bbox

    def create_classes(self, B, N):
        torch.random.manual_seed(42)
        return torch.empty(B, N, 1).fill_(-1).float()

    def create_scores(self, B, N):
        torch.random.manual_seed(42)
        return torch.empty(B, N, 1).fill_(-1)

    @pytest.fixture
    def expected_calls(self, data, model, mode, callback, mocker):
        if not hasattr(model, callback.attr_name):
            return []

        data, target = data
        B, C, H, W = data.shape
        img = [data]
        name = [
            f"{mode}/{callback.name}",
        ]

        # channel splitting
        if callback.split_channels:
            img, name = [], []
            splits = torch.split(data, callback.split_channels, dim=-3)
            for i, s in enumerate(splits):
                n = f"{mode}/{callback.name[i]}"
                name.append(n)
                img.append(s)

        if callback.max_resolution:
            resize_mode = callback.resize_mode
            target = callback.max_resolution
            H_max, W_max = target
            needs_resize = [
                i.shape[-2] > H_max or i.shape[-1] > W_max for i in img
            ]
            img = [
                F.interpolate(i, target, mode=resize_mode) if resize else i
                for i, resize in zip(img, needs_resize)
            ]

        if (colormap := callback.colormap):
            if isinstance(colormap, str):
                colormap = [colormap] * len(img)
            img = [
                apply_colormap(i, cmap)[..., :3, :, :]
                if cmap is not None else i for cmap, i in zip(colormap, img)
            ]

        img = [i.repeat(1, 3, 1, 1) if i.shape[-3] == 1 else i for i in img]

        if callback.as_uint8:
            img = [
                to_8bit(i, same_on_batch=not callback.per_img_norm)
                for i in img
            ]

        if callback.split_batches:
            new_img, new_name = [], []
            for i, n in zip(img, name):
                split_i = torch.split(i, 1, dim=0)
                split_n = [f"{n}/{b}" for b in range(B)]
                new_img += split_i
                new_name += split_n
            name, img = new_name, new_img

        step = [
            model.current_epoch
            if callback.epoch_counter else model.global_step
        ] * len(name)
        expected = [(n, i, s) for n, i, s in zip(name, img, step)]
        return expected
def check_call(call, name, img, step, as_uint8=True):
    if as_uint8:
        img = to_8bit(img, same_on_batch=True)
    assert call.args[0] == name
    assert torch.allclose(call.args[1], img, atol=1)
    assert call.args[2] == step
Beispiel #12
0
 def test_input_unchanged(self):
     input = torch.tensor([[0, 1000], [1001, 2000]]).unsqueeze(0).long()
     original_input = input.clone()
     result = to_8bit(input, per_channel=True)
     assert torch.allclose(input, original_input)
Beispiel #13
0
 def test_known_long_input_per_channel(self):
     input = torch.tensor([[0, 1000], [1001, 2000]]).unsqueeze(0).long()
     expected = torch.tensor([[0, 128], [128, 255]]).unsqueeze(0).byte()
     result = to_8bit(input, per_channel=True)
     assert result.shape == input.shape
     assert torch.allclose(result, expected)
Beispiel #14
0
 def test_known_float_input_per_channel(self):
     input = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).unsqueeze(0)
     expected = torch.tensor([[0, 128], [128, 255]]).unsqueeze(0).byte()
     result = to_8bit(input, per_channel=True)
     assert result.shape == input.shape
     assert torch.allclose(result, expected)