def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs): meth_kwargs = meth_kwargs or {} # test for class interface f = transform_cls(**meth_kwargs) scripted_fn = torch.jit.script(f) tensor, pil_img = _create_data(26, 34, channels, device=device) # set seed to reproduce the same transformation for tensor and PIL image torch.manual_seed(12) transformed_tensor = f(tensor) torch.manual_seed(12) transformed_pil_img = f(pil_img) if test_exact_match: _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs) else: _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs) torch.manual_seed(12) transformed_tensor_script = scripted_fn(tensor) assert_equal(transformed_tensor, transformed_tensor_script) batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device) _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): script_fn = torch.jit.script(fn) torch.manual_seed(15) tensor, pil_img = _create_data(26, 34, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) if dtype is not None: tensor = F.convert_image_dtype(tensor, dtype) batch_tensors = F.convert_image_dtype(batch_tensors, dtype) out_fn_t = fn_t(tensor, **config) out_pil = fn_pil(pil_img, **config) out_scripted = script_fn(tensor, **config) assert out_fn_t.dtype == out_scripted.dtype assert out_fn_t.size()[1:] == out_pil.size[::-1] rbg_tensor = out_fn_t if out_fn_t.dtype != torch.uint8: rbg_tensor = F.convert_image_dtype(out_fn_t, torch.uint8) # Check that max difference does not exceed 2 in [0, 255] range # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results _assert_approx_equal_tensor_to_pil(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method) atol = 1e-6 if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type: atol = 1.0 assert out_fn_t.allclose(out_scripted, atol=atol) # FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that. _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs): fn_kwargs = fn_kwargs or {} tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device) transformed_tensor = f(tensor, **fn_kwargs) transformed_pil_img = f(pil_img, **fn_kwargs) if test_exact_match: _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs) else: _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def test_resize(device, dt, size, max_size, interpolation): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return if max_size is not None and isinstance(size, Sequence) and len(size) != 1: return # unsupported torch.manual_seed(12) script_fn = torch.jit.script(F.resize) tensor, pil_img = _create_data(26, 36, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) batch_tensors = batch_tensors.to(dt) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] if interpolation not in [NEAREST, ]: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] resized_tensor_f = resized_tensor # we need to cast to uint8 to compare with PIL image if resized_tensor_f.dtype == torch.uint8: resized_tensor_f = resized_tensor_f.to(torch.float) # Pay attention to high tolerance for MAE _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0) if isinstance(size, int): script_size = [size, ] else: script_size = size resize_result = script_fn( tensor, size=script_size, interpolation=interpolation, max_size=max_size ) assert_equal(resized_tensor, resize_result) _test_fn_on_batch( batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size )
def test_autoaugment__op_apply_shear(interpolation, mode): # We check that torchvision's implementation of shear is equivalent # to official CIFAR10 autoaugment implementation: # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L290 image_size = 32 def shear(pil_img, level, mode, resample): if mode == "X": matrix = (1, level, 0, 0, 1, 0) elif mode == "Y": matrix = (1, 0, 0, level, 1, 0) return pil_img.transform((image_size, image_size), Image.AFFINE, matrix, resample=resample) t_img, pil_img = _create_data(image_size, image_size) resample_pil = { F.InterpolationMode.NEAREST: Image.NEAREST, F.InterpolationMode.BILINEAR: Image.BILINEAR, }[interpolation] level = 0.3 expected_out = shear(pil_img, level, mode=mode, resample=resample_pil) # Check pil output vs expected pil out = _apply_op(pil_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0) assert out == expected_out if interpolation == F.InterpolationMode.BILINEAR: # We skip bilinear mode for tensors as # affine transformation results are not exactly the same # between tensors and pil images # MAE as around 1.40 # Max Abs error can be 163 or 170 return # Check tensor output vs expected pil out = _apply_op(t_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0) _assert_approx_equal_tensor_to_pil(out, expected_out)
def test_rgb_to_grayscale(device, num_output_channels): script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) img_tensor, pil_img = _create_data(32, 34, device=device) gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels) gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) _assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) assert_equal(s_gray_tensor, gray_tensor) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
def test_resize_antialias(device, dt, size, interpolation): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return torch.manual_seed(12) script_fn = torch.jit.script(F.resize) tensor, pil_img = _create_data(320, 290, device=device) if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] resized_tensor_f = resized_tensor # we need to cast to uint8 to compare with PIL image if resized_tensor_f.dtype == torch.uint8: resized_tensor_f = resized_tensor_f.to(torch.float) _assert_approx_equal_tensor_to_pil( resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" ) accepted_tol = 1.0 + 1e-5 if interpolation == BICUBIC: # this overall mean value to make the tests pass # High value is mostly required for test cases with # downsampling and upsampling where we can not exactly # match PIL implementation. accepted_tol = 15.0 _assert_approx_equal_tensor_to_pil( resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", msg=f"{size}, {interpolation}, {dt}" ) if isinstance(size, int): script_size = [size, ] else: script_size = size resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True) assert_equal(resized_tensor, resize_result)