def test_pad(device, dt, pad, config): script_fn = torch.jit.script(F.pad) tensor, pil_img = _create_data(7, 8, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) batch_tensors = batch_tensors.to(dt) pad_tensor = F_t.pad(tensor, pad, **config) pad_pil_img = F_pil.pad(pil_img, pad, **config) pad_tensor_8b = pad_tensor # we need to cast to uint8 to compare with PIL image if pad_tensor_8b.dtype != torch.uint8: pad_tensor_8b = pad_tensor_8b.to(torch.uint8) _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config)) if isinstance(pad, int): script_pad = [pad, ] else: script_pad = pad pad_tensor_script = script_fn(tensor, script_pad, **config) assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, config)) _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
def test_ten_crop(device): script_ten_crop = torch.jit.script(F.ten_crop) img_tensor, pil_img = _create_data(32, 34, device=device) cropped_pil_images = F.ten_crop(pil_img, [10, 11]) cropped_tensors = F.ten_crop(img_tensor, [10, 11]) for i in range(10): _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i]) cropped_tensors = script_ten_crop(img_tensor, [10, 11]) for i in range(10): _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i]) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11]) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11]) assert len(tuple_transformed_imgs) == len(tuple_transformed_batches) for j in range(len(tuple_transformed_imgs)): true_transformed_img = tuple_transformed_imgs[j] transformed_img = tuple_transformed_batches[j][i, ...] assert_equal(true_transformed_img, transformed_img) # scriptable function test s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11]) for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): assert_equal(transformed_batch, s_transformed_batch)
def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs): meth_kwargs = meth_kwargs or {} # test for class interface f = transform_cls(**meth_kwargs) scripted_fn = torch.jit.script(f) tensor, pil_img = _create_data(26, 34, channels, device=device) # set seed to reproduce the same transformation for tensor and PIL image torch.manual_seed(12) transformed_tensor = f(tensor) torch.manual_seed(12) transformed_pil_img = f(pil_img) if test_exact_match: _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs) else: _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs) torch.manual_seed(12) transformed_tensor_script = scripted_fn(tensor) assert_equal(transformed_tensor, transformed_tensor_script) batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device) _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
def test_x_crop(fn, method, out_length, size, device): meth_kwargs = fn_kwargs = {"size": size} scripted_fn = torch.jit.script(fn) tensor, pil_img = _create_data(height=20, width=20, device=device) transformed_t_list = fn(tensor, **fn_kwargs) transformed_p_list = fn(pil_img, **fn_kwargs) assert len(transformed_t_list) == len(transformed_p_list) assert len(transformed_t_list) == out_length for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) assert len(transformed_t_list) == len(transformed_t_list_script) assert len(transformed_t_list_script) == out_length for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script): assert_equal(transformed_tensor, transformed_tensor_script) # test for class interface fn = method(**meth_kwargs) scripted_fn = torch.jit.script(fn) output = scripted_fn(tensor) assert len(output) == len(transformed_t_list_script) # test on batch of tensors batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device) torch.manual_seed(12) transformed_batch_list = fn(batch_tensors) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] torch.manual_seed(12) transformed_img_list = fn(img_tensor) for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list): assert_equal(transformed_img, transformed_batch[i, ...])
def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs): fn_kwargs = fn_kwargs or {} tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device) transformed_tensor = f(tensor, **fn_kwargs) transformed_pil_img = f(pil_img, **fn_kwargs) if test_exact_match: _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs) else: _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def test_hflip(device): script_hflip = torch.jit.script(F.hflip) img_tensor, pil_img = _create_data(16, 18, device=device) hflipped_img = F.hflip(img_tensor) hflipped_pil_img = F.hflip(pil_img) _assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img) # scriptable function test hflipped_img_script = script_hflip(img_tensor) assert_equal(hflipped_img, hflipped_img_script) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.hflip)
def test_crop(device, top, left, height, width): script_crop = torch.jit.script(F.crop) img_tensor, pil_img = _create_data(16, 18, device=device) pil_img_cropped = F.crop(pil_img, top, left, height, width) img_tensor_cropped = F.crop(img_tensor, top, left, height, width) _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped) img_tensor_cropped = script_crop(img_tensor, top, left, height, width) _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
def test_center_crop(device): script_center_crop = torch.jit.script(F.center_crop) img_tensor, pil_img = _create_data(32, 34, device=device) cropped_pil_image = F.center_crop(pil_img, [10, 11]) cropped_tensor = F.center_crop(img_tensor, [10, 11]) _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image) cropped_tensor = script_center_crop(img_tensor, [10, 11]) _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_translations(self, device, height, width, dt, t, fn): # 3) Test translation tensor, pil_img = _create_data(height, width, device=device) if dt == torch.float16 and device == "cpu": # skip float16 on CPU case return if dt is not None: tensor = tensor.to(dtype=dt) out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) _assert_equal_tensor_to_pil(out_tensor, out_pil_img)
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): if fn_kwargs is None: fn_kwargs = {} if meth_kwargs is None: meth_kwargs = {} fn = getattr(F, func) scripted_fn = torch.jit.script(fn) tensor, pil_img = _create_data(height=20, width=20, device=self.device) transformed_t_list = fn(tensor, **fn_kwargs) transformed_p_list = fn(pil_img, **fn_kwargs) self.assertEqual(len(transformed_t_list), len(transformed_p_list)) self.assertEqual(len(transformed_t_list), out_length) for transformed_tensor, transformed_pil_img in zip( transformed_t_list, transformed_p_list): _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) self.assertEqual(len(transformed_t_list_script), out_length) for transformed_tensor, transformed_tensor_script in zip( transformed_t_list, transformed_t_list_script): assert_equal( transformed_tensor, transformed_tensor_script, msg="{} vs {}".format(transformed_tensor, transformed_tensor_script), ) # test for class interface fn = getattr(T, method)(**meth_kwargs) scripted_fn = torch.jit.script(fn) output = scripted_fn(tensor) self.assertEqual(len(output), len(transformed_t_list_script)) # test on batch of tensors batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) torch.manual_seed(12) transformed_batch_list = fn(batch_tensors) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] torch.manual_seed(12) transformed_img_list = fn(img_tensor) for transformed_img, transformed_batch in zip( transformed_img_list, transformed_batch_list): assert_equal( transformed_img, transformed_batch[i, ...], msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]), ) with get_tmp_dir() as tmp_dir: scripted_fn.save( os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))