def test_batched_input(self, same_on_batch, expected): input = torch.tensor([ [[[-2.0, 0.0], [0.0, 1.0]]], [[[-1.0, 0.5], [0.0, 1.0]]], ]) result = to_8bit(input, per_channel=True, same_on_batch=same_on_batch) assert torch.allclose(result, expected)
def test_random_input(self, input, per_channel, shape, dtype): result = to_8bit(input, per_channel) assert result.shape == shape assert result.min() >= 0 assert result.max() <= 255 assert result.unique().numel() >= 10 assert result.dtype == torch.uint8 if dtype == "byte": assert torch.allclose(input, result)
def _process_image( self, img: Tensor, colormap: Optional[Union[Colormap, ColormapSequence]], split_channels: Optional[ChannelSplits], name: Union[str, List[str]], max_resolution: Optional[Tuple[int, int]], as_uint8: bool = True, ) -> Dict[str, Tensor]: img = prepare_image(img) # split channels images: Dict[str, Tensor] if split_channels is not None: _images = torch.split(img, split_channels, dim=1) if isinstance(name, str): _name = [f"{name}_{i}" for i in range(len(_images))] else: _name = name assert len(_name) == len(_images) images = {n: x for n, x in zip(_name, _images)} else: assert isinstance(name, str) images = {name: img} # apply resolution limit if max_resolution is not None: images = { k: self._apply_log_resolution(v, max_resolution, self.resize_mode) for k, v in images.items() } # apply colormap colormap_seq: Optional[ColormapSequence] = None if isinstance(colormap, str): colormap_seq = [colormap] * len(images) # type: ignore elif colormap is not None: colormap_seq = colormap if colormap_seq is not None: assert len(colormap_seq) == len(images) images = { k: self.apply_colormap(v, cmap) if cmap is not None else v for cmap, (k, v) in zip(colormap_seq, images.items()) } # convert to byte if as_uint8: images = { k: to_8bit(v, same_on_batch=not self.per_img_norm) for k, v in images.items() } return images
def test_known_input_per_channel_result(self, per_channel): input = torch.tensor([ [[0, 1000], [1001, 2000]], [[0, 500], [500, 1000]], ]).long() if per_channel: expected = torch.tensor([[[0, 128], [128, 255]], [[0, 128], [128, 255]]]).byte() else: expected = torch.tensor([[[0, 128], [128, 255]], [[0, 64], [64, 128]]]).byte() result = to_8bit(input, per_channel=per_channel) assert result.shape == input.shape assert torch.allclose(result, expected)
def _process_image( self, dest: Tensor, src: Tensor, colormap: Tuple[Optional[Colormap], Optional[Colormap]], split_channels: Tuple[Optional[ChannelSplits], Optional[ChannelSplits]], name: Union[str, List[str]], max_resolution: Optional[Tuple[int, int]], as_uint8: bool = True, ) -> Tensor: B, C, H, W = dest.shape if any(split_channels): name1, name2 = (name if split is not None else "__name__" for split in split_channels) else: name1, name2 = name, name images_dest = super()._process_image(dest, colormap[0], split_channels[0], name1, max_resolution, False) images_src = super()._process_image(src, colormap[1], split_channels[1], name2, max_resolution, False) if len(images_dest) != len(images_src): if len(images_dest) == 1: val = next(iter(images_dest.values())) images_dest = {k: val for k in images_src.keys()} elif len(images_src) == 1: val = next(iter(images_src.values())) images_src = {k: val for k in images_dest.keys()} else: raise ValueError( f"Unable to broadcast processed images:\n" f"Destination dict:\n {k: v.shape for k, v in images_dest.items()}\n\n" f"Source dict:\n {k: v.shape for k, v in images_src.items()}" ) dest_alpha, src_alpha = self.alpha images = { k1 if "__name__" not in k1 else k2: VisualizeCallback.alpha_blend(d, s, self.resize_mode, dest_alpha=dest_alpha, src_alpha=src_alpha) for (k1, d), (k2, s) in zip(images_dest.items(), images_src.items()) } assert "__name__" not in images.keys() if as_uint8: images = {k: to_8bit(v, same_on_batch=not self.per_img_norm) for k, v in images.items()} return images
def _process_image( self, img: Tensor, keypoint_dict: Dict[str, Tensor], colormap: Optional[Union[Colormap, ColormapSequence]], split_channels: Optional[ChannelSplits], name: Union[str, List[str]], max_resolution: Optional[Tuple[int, int]], as_uint8: bool = True, ) -> Dict[str, Tensor]: img = prepare_image(img) B, C, H, W = img.shape images = super()._process_image(img, colormap, split_channels, name, max_resolution, False) # apply resolution limit to keypoints if max_resolution is not None: H_scaled, W_scaled = next(iter(images.values())).shape[-2:] if H_scaled / H != 1.0: coords = keypoint_dict["coords"].clone().float() padding = coords < 0 coords.mul_(H_scaled / H) coords[padding] = -1 keypoint_dict["coords"] = coords # overlay keypoints, accounting for channel splitting if self.split_channels: keypoints = [keypoint_dict] * C assert len(keypoints) == C images = { k: self.overlay_keypoints(v, d) for d, (k, v) in zip(keypoints, images.items()) } else: images = { k: self.overlay_keypoints(v, keypoint_dict) for k, v in images.items() } if as_uint8: images = { k: to_8bit(v, same_on_batch=not self.per_img_norm) for k, v in images.items() } return images
def _process_image( self, img: Tensor, colormap: Optional[Union[str, List[str]]], split_channels: Optional[ChannelSplits], name: Union[str, List[str]], max_resolution: Optional[Tuple[int, int]], as_uint8: bool = True, ) -> Tensor: # split channels if split_channels is not None: images = torch.split(img, split_channels, dim=1) assert isinstance(name, list) assert len(name) == len(images) images = {n: x for n, x in zip(name, images)} else: images = {name: img} # apply resolution limit if max_resolution is not None: images = {k: self._apply_log_resolution(v, max_resolution, self.resize_mode) for k, v in images.items()} # apply colormap if colormap is not None: if isinstance(colormap, str): colormap = [colormap] * len(images) assert len(colormap) == len(images) images = { k: self.apply_colormap(v, cmap) if cmap is not None else v for cmap, (k, v) in zip(colormap, images.items()) } # convert to byte if as_uint8: images = {k: to_8bit(v, same_on_batch=not self.per_img_norm) for k, v in images.items()} return images
class TestVisualizeCallback: callback_cls = VisualizeCallback # training, validation, or testing mode @pytest.fixture(params=["train", "val", "test"]) def mode(self, request): return request.param # returns a no-args closure of callback function appropriate for `mode` @pytest.fixture def callback_func(self, mode, trainer, model): if mode == "train": pass elif mode == "val": pass elif mode == "test": pass else: raise ValueError(f"{mode}") def func(self): f = getattr(self, func) return f(trainer, model) return func @pytest.fixture def data_shape(self): return 2, 3, 32, 32 @pytest.fixture def data(self, data_shape): data = create_image(*data_shape) return data @pytest.fixture def model(self, request, mocker, callback, data, mode): if hasattr(request, "param"): step = request.param.pop("step", 10) epoch = request.param.pop("epoch", 1) else: step = 10 epoch = 1 model = mocker.MagicMock(name="module") model.current_epoch = epoch model.global_step = step model.global_step = step if callback.attr_name is not None: setattr(model, callback.attr_name, data) if mode == "train": attr = "on_train_batch_end" elif mode == "val": attr = "on_validation_batch_end" elif mode == "test": attr = "on_test_batch_end" else: raise ValueError(f"{mode}") callback.trigger = lambda: getattr(callback, attr)(trainer, model) return model @pytest.fixture def callback(self, request, trainer): cls = self.callback_cls init_signature = inspect.signature(cls) defaults = { k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty } if hasattr(request, "param"): name = request.param.get("name", "image") defaults.update(request.param) else: name = "image" defaults["name"] = name callback = cls(**defaults) return callback @pytest.fixture def logger_func(self, model): return model.logger.experiment.add_images @pytest.fixture def expected_calls(self, data, model, mode, callback): if not hasattr(model, callback.attr_name): return [] B, C, H, W = data.shape img = [data] name = [ f"{mode}/{callback.name}", ] # channel splitting if callback.split_channels: img, name = [], [] splits = torch.split(data, callback.split_channels, dim=-3) for i, s in enumerate(splits): n = f"{mode}/{callback.name[i]}" name.append(n) img.append(s) if callback.max_resolution: resize_mode = callback.resize_mode target = callback.max_resolution H_max, W_max = target scale_factor = [] for i in img: H, W = i.shape[-2:] height_ratio, width_ratio = H / H_max, W / W_max s = 1 / max(height_ratio, width_ratio) scale_factor.append(s) img = [ F.interpolate(i, scale_factor=s, mode=resize_mode) if s < 1 else i for i, s in zip(img, scale_factor) ] if (colormap := callback.colormap): if isinstance(colormap, str): colormap = [colormap] * len(img) img = [ apply_colormap(i, cmap)[..., :3, :, :] if cmap is not None else i for cmap, i in zip(colormap, img) ] if callback.split_batches: new_img, new_name = [], [] for i, n in zip(img, name): split_i = torch.split(i, 1, dim=0) split_n = [f"{n}/{b}" for b in range(B)] new_img += split_i new_name += split_n name, img = new_name, new_img if callback.as_uint8: img = [ to_8bit(i, same_on_batch=not callback.per_img_norm) for i in img ] step = [ model.current_epoch if callback.epoch_counter else model.global_step ] * len(name) expected = [(n, i, s) for n, i, s in zip(name, img, step)] return expected
class TestBlendVisualizeCallback(TestVisualizeCallback): callback_cls = BlendVisualizeCallback @pytest.fixture(params=[ pytest.param(True, id="float"), pytest.param(False, id="long") ]) def data(self, data_shape, request): B, C, H, W = data_shape img = create_image(B, C, H, W) return img.clone(), img.clone() @pytest.fixture def expected_calls(self, data, model, mode, callback, mocker): if not hasattr(model, callback.attr_name): return [] if callback.split_channels == (2, 1): pytest.skip("incompatible test") data1, data2 = data B, C, H, W = data1.shape img1 = [data1] img2 = [data2] name1 = [ f"{mode}/{callback.name}", ] name2 = [ f"{mode}/{callback.name}", ] img = [img1, img2] name = [name1, name2] # channel splitting for pos in range(2): if callback.split_channels[pos]: img[pos] = [] name[pos] = [] img_new, name_new = [], [] splits = torch.split(data[pos], callback.split_channels[pos], dim=-3) for i, s in enumerate(splits): n = f"{mode}/{callback.name[i]}" name_new.append(n) img_new.append(s) img[pos] = img_new name[pos] = name_new if len(img[0]) != len(img[1]): if len(img[0]) == 1: img[0] = img[0] * len(img[1]) elif len(img[1]) == 1: img[1] = img[1] * len(img[0]) else: raise RuntimeError() for pos in range(2): if callback.max_resolution: resize_mode = callback.resize_mode target = callback.max_resolution H_max, W_max = target needs_resize = [ i.shape[-2] > H_max or i.shape[-1] > W_max for i in img[pos] ] img[pos] = [ F.interpolate(i, target, mode=resize_mode) if resize else i for i, resize in zip(img[pos], needs_resize) ] for pos in range(2): if (colormap := callback.colormap[pos]): if isinstance(colormap, str): colormap = [colormap] * len(img[pos]) img[pos] = [ apply_colormap(i, cmap)[..., :3, :, :] if cmap is not None else i for cmap, i in zip(colormap, img[pos]) ] name = name[0] final_img = [] for pos, (d, s) in enumerate(zip(img[0], img[1])): B1, C1, H1, W1 = d.shape B2, C2, H2, W2 = s.shape if C1 != C2: if C1 == 1: d = d.repeat(1, C2, 1, 1) elif C2 == 1: s = s.repeat(1, C1, 1, 1) else: raise ValueError( f"could not match shapes {d.shape}, {s.shape}") final_img.append( alpha_blend(d, s, callback.alpha[1], callback.alpha[0])[0]) img = final_img if callback.as_uint8: img = [ to_8bit(i, same_on_batch=not callback.per_img_norm) for i in img ] if callback.split_batches: new_img, new_name = [], [] for i, n in zip(img, name): split_i = torch.split(i, 1, dim=0) split_n = [f"{n}/{b}" for b in range(B)] new_img += split_i new_name += split_n name, img = new_name, new_img step = [ model.current_epoch if callback.epoch_counter else model.global_step ] * len(name) expected = [(n, i, s) for n, i, s in zip(name, img, step)] return expected
class TestKeypointVisualizeCallback(TestVisualizeCallback): callback_cls = KeypointVisualizeCallback @pytest.fixture(params=[ pytest.param(True, id="float"), pytest.param(False, id="long") ]) def data(self, data_shape, request): B, C, H, W = data_shape N = 3 img = create_image(B, C, H, W) bbox = self.create_bbox(B, N) bbox = bbox.float() if request.param else bbox.long() cls = self.create_classes(B, N) score = self.create_classes(B, N) target = {"coords": bbox, "class": cls, "score": score} return img, target def create_bbox(self, B, N): torch.random.manual_seed(42) bbox = torch.empty(B, N, 4).fill_(-1).float() return bbox def create_classes(self, B, N): torch.random.manual_seed(42) return torch.empty(B, N, 1).fill_(-1).float() def create_scores(self, B, N): torch.random.manual_seed(42) return torch.empty(B, N, 1).fill_(-1) @pytest.fixture def expected_calls(self, data, model, mode, callback, mocker): if not hasattr(model, callback.attr_name): return [] data, target = data B, C, H, W = data.shape img = [data] name = [ f"{mode}/{callback.name}", ] # channel splitting if callback.split_channels: img, name = [], [] splits = torch.split(data, callback.split_channels, dim=-3) for i, s in enumerate(splits): n = f"{mode}/{callback.name[i]}" name.append(n) img.append(s) if callback.max_resolution: resize_mode = callback.resize_mode target = callback.max_resolution H_max, W_max = target needs_resize = [ i.shape[-2] > H_max or i.shape[-1] > W_max for i in img ] img = [ F.interpolate(i, target, mode=resize_mode) if resize else i for i, resize in zip(img, needs_resize) ] if (colormap := callback.colormap): if isinstance(colormap, str): colormap = [colormap] * len(img) img = [ apply_colormap(i, cmap)[..., :3, :, :] if cmap is not None else i for cmap, i in zip(colormap, img) ] img = [i.repeat(1, 3, 1, 1) if i.shape[-3] == 1 else i for i in img] if callback.as_uint8: img = [ to_8bit(i, same_on_batch=not callback.per_img_norm) for i in img ] if callback.split_batches: new_img, new_name = [], [] for i, n in zip(img, name): split_i = torch.split(i, 1, dim=0) split_n = [f"{n}/{b}" for b in range(B)] new_img += split_i new_name += split_n name, img = new_name, new_img step = [ model.current_epoch if callback.epoch_counter else model.global_step ] * len(name) expected = [(n, i, s) for n, i, s in zip(name, img, step)] return expected
def check_call(call, name, img, step, as_uint8=True): if as_uint8: img = to_8bit(img, same_on_batch=True) assert call.args[0] == name assert torch.allclose(call.args[1], img, atol=1) assert call.args[2] == step
def test_input_unchanged(self): input = torch.tensor([[0, 1000], [1001, 2000]]).unsqueeze(0).long() original_input = input.clone() result = to_8bit(input, per_channel=True) assert torch.allclose(input, original_input)
def test_known_long_input_per_channel(self): input = torch.tensor([[0, 1000], [1001, 2000]]).unsqueeze(0).long() expected = torch.tensor([[0, 128], [128, 255]]).unsqueeze(0).byte() result = to_8bit(input, per_channel=True) assert result.shape == input.shape assert torch.allclose(result, expected)
def test_known_float_input_per_channel(self): input = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).unsqueeze(0) expected = torch.tensor([[0, 128], [128, 255]]).unsqueeze(0).byte() result = to_8bit(input, per_channel=True) assert result.shape == input.shape assert torch.allclose(result, expected)