def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs): meth_kwargs = meth_kwargs or {} # test for class interface f = transform_cls(**meth_kwargs) scripted_fn = torch.jit.script(f) tensor, pil_img = _create_data(26, 34, channels, device=device) # set seed to reproduce the same transformation for tensor and PIL image torch.manual_seed(12) transformed_tensor = f(tensor) torch.manual_seed(12) transformed_pil_img = f(pil_img) if test_exact_match: _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs) else: _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs) torch.manual_seed(12) transformed_tensor_script = scripted_fn(tensor) assert_equal(transformed_tensor, transformed_tensor_script) batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device) _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
def test_rgb2hsv(device): scripted_fn = torch.jit.script(F_t._rgb2hsv) shape = (3, 150, 100) for _ in range(10): rgb_img = torch.rand(*shape, dtype=torch.float, device=device) hsv_img = F_t._rgb2hsv(rgb_img) ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1) r, g, b, = rgb_img.unbind(dim=-3) r = r.flatten().cpu().numpy() g = g.flatten().cpu().numpy() b = b.flatten().cpu().numpy() hsv = [] for r1, g1, b1 in zip(r, g, b): hsv.append(colorsys.rgb_to_hsv(r1, g1, b1)) colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=device) ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1) colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1) max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max() max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max() max_diff = max(max_diff_h, max_diff_sv) assert max_diff < 1e-5 s_hsv_img = scripted_fn(rgb_img) torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7) batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float() _test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): script_fn = torch.jit.script(fn) torch.manual_seed(15) tensor, pil_img = _create_data(26, 34, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) if dtype is not None: tensor = F.convert_image_dtype(tensor, dtype) batch_tensors = F.convert_image_dtype(batch_tensors, dtype) out_fn_t = fn_t(tensor, **config) out_pil = fn_pil(pil_img, **config) out_scripted = script_fn(tensor, **config) assert out_fn_t.dtype == out_scripted.dtype assert out_fn_t.size()[1:] == out_pil.size[::-1] rbg_tensor = out_fn_t if out_fn_t.dtype != torch.uint8: rbg_tensor = F.convert_image_dtype(out_fn_t, torch.uint8) # Check that max difference does not exceed 2 in [0, 255] range # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results _assert_approx_equal_tensor_to_pil(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method) atol = 1e-6 if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type: atol = 1.0 assert out_fn_t.allclose(out_scripted, atol=atol) # FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that. _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
def test_pad(device, dt, pad, config): script_fn = torch.jit.script(F.pad) tensor, pil_img = _create_data(7, 8, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) batch_tensors = batch_tensors.to(dt) pad_tensor = F_t.pad(tensor, pad, **config) pad_pil_img = F_pil.pad(pil_img, pad, **config) pad_tensor_8b = pad_tensor # we need to cast to uint8 to compare with PIL image if pad_tensor_8b.dtype != torch.uint8: pad_tensor_8b = pad_tensor_8b.to(torch.uint8) _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config)) if isinstance(pad, int): script_pad = [pad, ] else: script_pad = pad pad_tensor_script = script_fn(tensor, script_pad, **config) assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, config)) _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
def test_ten_crop(device): script_ten_crop = torch.jit.script(F.ten_crop) img_tensor, pil_img = _create_data(32, 34, device=device) cropped_pil_images = F.ten_crop(pil_img, [10, 11]) cropped_tensors = F.ten_crop(img_tensor, [10, 11]) for i in range(10): _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i]) cropped_tensors = script_ten_crop(img_tensor, [10, 11]) for i in range(10): _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i]) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11]) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11]) assert len(tuple_transformed_imgs) == len(tuple_transformed_batches) for j in range(len(tuple_transformed_imgs)): true_transformed_img = tuple_transformed_imgs[j] transformed_img = tuple_transformed_batches[j][i, ...] assert_equal(true_transformed_img, transformed_img) # scriptable function test s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11]) for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): assert_equal(transformed_batch, s_transformed_batch)
def test_x_crop(fn, method, out_length, size, device): meth_kwargs = fn_kwargs = {"size": size} scripted_fn = torch.jit.script(fn) tensor, pil_img = _create_data(height=20, width=20, device=device) transformed_t_list = fn(tensor, **fn_kwargs) transformed_p_list = fn(pil_img, **fn_kwargs) assert len(transformed_t_list) == len(transformed_p_list) assert len(transformed_t_list) == out_length for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) assert len(transformed_t_list) == len(transformed_t_list_script) assert len(transformed_t_list_script) == out_length for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script): assert_equal(transformed_tensor, transformed_tensor_script) # test for class interface fn = method(**meth_kwargs) scripted_fn = torch.jit.script(fn) output = scripted_fn(tensor) assert len(output) == len(transformed_t_list_script) # test on batch of tensors batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device) torch.manual_seed(12) transformed_batch_list = fn(batch_tensors) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] torch.manual_seed(12) transformed_img_list = fn(img_tensor) for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list): assert_equal(transformed_img, transformed_batch[i, ...])
def test_batches(self, device, dt): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device) if dt is not None: batch_tensors = batch_tensors.to(dtype=dt) _test_fn_on_batch( batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] )
def test_rotate_batch(self, device, dt): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device) if dt is not None: batch_tensors = batch_tensors.to(dtype=dt) center = (20, 22) _test_fn_on_batch( batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center )
def test_hflip(device): script_hflip = torch.jit.script(F.hflip) img_tensor, pil_img = _create_data(16, 18, device=device) hflipped_img = F.hflip(img_tensor) hflipped_pil_img = F.hflip(pil_img) _assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img) # scriptable function test hflipped_img_script = script_hflip(img_tensor) assert_equal(hflipped_img, hflipped_img_script) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.hflip)
def test_resize(device, dt, size, max_size, interpolation): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return if max_size is not None and isinstance(size, Sequence) and len(size) != 1: return # unsupported torch.manual_seed(12) script_fn = torch.jit.script(F.resize) tensor, pil_img = _create_data(26, 36, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) batch_tensors = batch_tensors.to(dt) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] if interpolation not in [NEAREST, ]: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] resized_tensor_f = resized_tensor # we need to cast to uint8 to compare with PIL image if resized_tensor_f.dtype == torch.uint8: resized_tensor_f = resized_tensor_f.to(torch.float) # Pay attention to high tolerance for MAE _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0) if isinstance(size, int): script_size = [size, ] else: script_size = size resize_result = script_fn( tensor, size=script_size, interpolation=interpolation, max_size=max_size ) assert_equal(resized_tensor, resize_result) _test_fn_on_batch( batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size )
def test_crop(device, top, left, height, width): script_crop = torch.jit.script(F.crop) img_tensor, pil_img = _create_data(16, 18, device=device) pil_img_cropped = F.crop(pil_img, top, left, height, width) img_tensor_cropped = F.crop(img_tensor, top, left, height, width) _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped) img_tensor_cropped = script_crop(img_tensor, top, left, height, width) _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
def test_rgb_to_grayscale(device, num_output_channels): script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) img_tensor, pil_img = _create_data(32, 34, device=device) gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels) gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) _assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) assert_equal(s_gray_tensor, gray_tensor) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
def test_center_crop(device): script_center_crop = torch.jit.script(F.center_crop) img_tensor, pil_img = _create_data(32, 34, device=device) cropped_pil_image = F.center_crop(pil_img, [10, 11]) cropped_tensor = F.center_crop(img_tensor, [10, 11]) _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image) cropped_tensor = script_center_crop(img_tensor, [10, 11]) _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_perspective_batch(device, dims_and_points, dt): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return data_dims, (spoints, epoints) = dims_and_points batch_tensors = _create_data_batch(*data_dims, num_samples=4, device=device) if dt is not None: batch_tensors = batch_tensors.to(dtype=dt) # Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at # the border may be entirely different due to small rounding errors. scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8 _test_fn_on_batch( batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol, startpoints=spoints, endpoints=epoints, interpolation=NEAREST )
def test_resized_crop(device, mode): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity tensor, _ = _create_data(26, 36, device=device) out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) # 2) resize by half and crop a TL corner tensor, _ = _create_data(26, 36, device=device) out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST) expected_out_tensor = tensor[:, :20:2, :30:2] assert_equal( expected_out_tensor, out_tensor, msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]), ) batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device) _test_fn_on_batch( batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST )
def test_hsv2rgb(device): scripted_fn = torch.jit.script(F_t._hsv2rgb) shape = (3, 100, 150) for _ in range(10): hsv_img = torch.rand(*shape, dtype=torch.float, device=device) rgb_img = F_t._hsv2rgb(hsv_img) ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1) h, s, v, = hsv_img.unbind(0) h = h.flatten().cpu().numpy() s = s.flatten().cpu().numpy() v = v.flatten().cpu().numpy() rgb = [] for h1, s1, v1 in zip(h, s, v): rgb.append(colorsys.hsv_to_rgb(h1, s1, v1)) colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=device) torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5) s_rgb_img = scripted_fn(hsv_img) torch.testing.assert_close(rgb_img, s_rgb_img) batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float() _test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): if fn_kwargs is None: fn_kwargs = {} if meth_kwargs is None: meth_kwargs = {} fn = getattr(F, func) scripted_fn = torch.jit.script(fn) tensor, pil_img = _create_data(height=20, width=20, device=self.device) transformed_t_list = fn(tensor, **fn_kwargs) transformed_p_list = fn(pil_img, **fn_kwargs) self.assertEqual(len(transformed_t_list), len(transformed_p_list)) self.assertEqual(len(transformed_t_list), out_length) for transformed_tensor, transformed_pil_img in zip( transformed_t_list, transformed_p_list): _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) self.assertEqual(len(transformed_t_list_script), out_length) for transformed_tensor, transformed_tensor_script in zip( transformed_t_list, transformed_t_list_script): assert_equal( transformed_tensor, transformed_tensor_script, msg="{} vs {}".format(transformed_tensor, transformed_tensor_script), ) # test for class interface fn = getattr(T, method)(**meth_kwargs) scripted_fn = torch.jit.script(fn) output = scripted_fn(tensor) self.assertEqual(len(output), len(transformed_t_list_script)) # test on batch of tensors batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) torch.manual_seed(12) transformed_batch_list = fn(batch_tensors) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] torch.manual_seed(12) transformed_img_list = fn(img_tensor) for transformed_img, transformed_batch in zip( transformed_img_list, transformed_batch_list): assert_equal( transformed_img, transformed_batch[i, ...], msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]), ) with get_tmp_dir() as tmp_dir: scripted_fn.save( os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))