def test_decode_png(img_path, pil_mode, mode): with Image.open(img_path) as img: if pil_mode is not None: img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) img_pil = normalize_dimensions(img_pil) if img_path.endswith("16.png"): # 16 bits image decoding is supported, but only as a private API # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"): data = read_file(img_path) img_lpng = decode_image(data, mode=mode) img_lpng = _read_png_16(img_path, mode=mode) assert img_lpng.dtype == torch.int32 # PIL converts 16 bits pngs in uint8 img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8) else: data = read_file(img_path) img_lpng = decode_image(data, mode=mode) tol = 0 if pil_mode is None else 1 if PILLOW_VERSION >= (8, 3) and pil_mode == "LA": # Avoid checking the transparency channel until # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910 # is fixed. # TODO: remove once fix is released in PIL. Should be > 8.3.1. img_lpng, img_pil = img_lpng[0], img_pil[0] torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
def test_decode_image(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): img_pil = torch.load(img_path.replace('jpg', 'pth')) img_pil = img_pil.permute(2, 0, 1) img_ljpeg = decode_image(read_file(img_path)) self.assertTrue(img_ljpeg.equal(img_pil)) for img_path in get_images(IMAGE_DIR, ".png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = img_pil.permute(2, 0, 1) img_lpng = decode_image(read_file(img_path)) self.assertTrue(img_lpng.equal(img_pil))
def test_read_file(tmpdir): fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) with open(fpath, "wb") as f: f.write(content) data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) os.unlink(fpath) assert_equal(data, expected) with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): read_file("tst")
def test_read_file(): with get_tmp_dir() as d: fname, content = 'test1.bin', b'TorchVision\211\n' fpath = os.path.join(d, fname) with open(fpath, 'wb') as f: f.write(content) data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) os.unlink(fpath) assert_equal(data, expected) with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): read_file('tst')
def test_read_file(self): with get_tmp_dir() as d: fname, content = 'test1.bin', b'TorchVision\211\n' fpath = os.path.join(d, fname) with open(fpath, 'wb') as f: f.write(content) data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) self.assertTrue(data.equal(expected)) os.unlink(fpath) with self.assertRaisesRegex( RuntimeError, "No such file or directory: 'tst'"): read_file('tst')
def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) try: _ = decode_jpeg(bad_huff) except RuntimeError: self.assertTrue(False) # Truncated images should raise an exception truncated_images = glob.glob( os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) for image_path in truncated_images: data = read_file(image_path) with self.assertRaises(RuntimeError): decode_jpeg(data)
def test_encode_jpeg_reference(img_path): # This test is *wrong*. # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it # starts encoding the torchvision version from an image that comes from # decode_jpeg, which can yield different results from pil.decode (see # test_decode... which uses a high tolerance). # Instead, we should start encoding from the exact same decoded image, for a # valid comparison. This is done in test_encode_jpeg, but unfortunately # these more correct tests fail on windows (probably because of a difference # in libjpeg) between torchvision and PIL. # FIXME: make the correct tests pass on windows and remove this. dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) write_folder = os.path.join(dirname, 'jpeg_write') expected_file = os.path.join( write_folder, '{0}_pil.jpg'.format(filename)) img = decode_jpeg(read_file(img_path)) with open(expected_file, 'rb') as f: pil_bytes = f.read() pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) for src_img in [img, img.contiguous()]: # PIL sets jpeg quality to 75 by default jpeg_bytes = encode_jpeg(src_img, quality=75) assert_equal(jpeg_bytes, pil_bytes)
def test_decode_jpeg(self): conversion = [(None, 0), ("L", 1), ("RGB", 3)] for img_path in get_images(IMAGE_ROOT, ".jpg"): for pil_mode, channels in conversion: with Image.open(img_path) as img: is_cmyk = img.mode == "CMYK" if pil_mode is not None: if is_cmyk: # libjpeg does not support the conversion continue img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) if is_cmyk: # flip the colors to match libjpeg img_pil = 255 - img_pil img_pil = normalize_dimensions(img_pil) data = read_file(img_path) img_ljpeg = decode_image(data, channels=channels) # Permit a small variation on pixel values to account for implementation # differences between Pillow and LibJPEG. abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() self.assertTrue(abs_mean_diff < 2) with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): decode_jpeg(torch.empty((100, ), dtype=torch.float16)) with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8))
def test_damaged_corrupt_images(img_path): # Truncated images should raise an exception data = read_file(img_path) if "corrupt34" in img_path: match_message = "Image is incomplete or truncated" else: match_message = "Unsupported marker type" with pytest.raises(RuntimeError, match=match_message): decode_jpeg(data)
def test_read_file_non_ascii(tmpdir): fname, content = "日本語(Japanese).bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) with open(fpath, "wb") as f: f.write(content) data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) os.unlink(fpath) assert_equal(data, expected)
def test_read_file(self): with get_tmp_dir() as d: fname, content = 'test1.bin', b'TorchVision\211\n' fpath = os.path.join(d, fname) with open(fpath, 'wb') as f: f.write(content) data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) self.assertTrue(data.equal(expected)) # Windows holds into the file until the tensor is alive # so need to del the tensor before deleting the file see # https://github.com/pytorch/vision/issues/2743#issuecomment-703817293 del data os.unlink(fpath) with self.assertRaisesRegex(RuntimeError, "No such file or directory: 'tst'"): read_file('tst')
def test_decode_jpeg_cuda_errors(): data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): decode_jpeg(data.reshape(-1, 1), device='cuda') with pytest.raises(RuntimeError, match="input tensor must be on CPU"): decode_jpeg(data.to('cuda'), device='cuda') with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): decode_jpeg(data.to(torch.float), device='cuda') with pytest.raises(RuntimeError, match="Expected a cuda device"): torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
def test_decode_jpeg_cuda(mode, img_path, scripted): if "cmyk" in img_path: pytest.xfail("Decoding a CMYK jpeg isn't supported") data = read_file(img_path) img = decode_image(data, mode=mode) f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg img_nvjpeg = f(data, mode=mode, device="cuda") # Some difference expected between jpeg implementations assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
def test_read_file_non_ascii(self): with get_tmp_dir() as d: fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' fpath = os.path.join(d, fname) with open(fpath, 'wb') as f: f.write(content) data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) self.assertTrue(data.equal(expected)) os.unlink(fpath)
def test_decode_png(self): for img_path in get_images(IMAGE_DIR, ".png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = img_pil.permute(2, 0, 1) data = read_file(img_path) img_lpng = decode_png(data) self.assertTrue(img_lpng.equal(img_pil)) with self.assertRaises(RuntimeError): decode_png(torch.empty((), dtype=torch.uint8)) with self.assertRaises(RuntimeError): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_decode_png(img_path, pil_mode, mode): with Image.open(img_path) as img: if pil_mode is not None: img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) img_pil = normalize_dimensions(img_pil) data = read_file(img_path) img_lpng = decode_image(data, mode=mode) tol = 0 if pil_mode is None else 1 assert img_lpng.allclose(img_pil, atol=tol)
def test_decode_jpeg(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): img_pil = torch.load(img_path.replace('jpg', 'pth')) img_pil = img_pil.permute(2, 0, 1) data = read_file(img_path) img_ljpeg = decode_jpeg(data) self.assertTrue(img_ljpeg.equal(img_pil)) with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): decode_jpeg(torch.empty((100, ), dtype=torch.float16)) with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8))
def test_encode_jpeg(self): for img_path in get_images(ENCODE_JPEG, ".jpg"): dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) write_folder = os.path.join(dirname, 'jpeg_write') expected_file = os.path.join(write_folder, '{0}_pil.jpg'.format(filename)) img = decode_jpeg(read_file(img_path)) with open(expected_file, 'rb') as f: pil_bytes = f.read() pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) for src_img in [img, img.contiguous()]: # PIL sets jpeg quality to 75 by default jpeg_bytes = encode_jpeg(src_img, quality=75) self.assertTrue(jpeg_bytes.equal(pil_bytes)) with self.assertRaisesRegex(RuntimeError, "Input tensor dtype should be uint8"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) with self.assertRaisesRegex( RuntimeError, "The number of channels should be 1 or 3, got: 5"): encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) with self.assertRaisesRegex( RuntimeError, "Input data should be a 3-dimensional tensor"): encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) with self.assertRaisesRegex( RuntimeError, "Input data should be a 3-dimensional tensor"): encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg_reference(img_path, tmpdir): # FIXME: Remove this eventually, see test_encode_jpeg_reference data = read_file(img_path) img = decode_jpeg(data) basedir = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg") pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg") write_jpeg(img, torch_jpeg, quality=75) with open(torch_jpeg, "rb") as f: torch_bytes = f.read() with open(pil_jpeg, "rb") as f: pil_bytes = f.read() assert_equal(torch_bytes, pil_bytes)
def test_decode_png(self): conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)] for img_path in get_images(FAKEDATA_DIR, ".png"): for pil_mode, channels in conversion: with Image.open(img_path) as img: if pil_mode is not None: img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) img_pil = normalize_dimensions(img_pil) data = read_file(img_path) img_lpng = decode_image(data, channels=channels) tol = 0 if conversion is None else 1 self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) with self.assertRaises(RuntimeError): decode_png(torch.empty((), dtype=torch.uint8)) with self.assertRaises(RuntimeError): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_write_jpeg_reference(img_path): # FIXME: Remove this eventually, see test_encode_jpeg_reference with get_tmp_dir() as d: data = read_file(img_path) img = decode_jpeg(data) basedir = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) torch_jpeg = os.path.join(d, '{0}_torch.jpg'.format(filename)) pil_jpeg = os.path.join(basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) write_jpeg(img, torch_jpeg, quality=75) with open(torch_jpeg, 'rb') as f: torch_bytes = f.read() with open(pil_jpeg, 'rb') as f: pil_bytes = f.read() assert_equal(torch_bytes, pil_bytes)
def test_write_jpeg(self): with get_tmp_dir() as d: for img_path in get_images(ENCODE_JPEG, ".jpg"): data = read_file(img_path) img = decode_jpeg(data) basedir = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) torch_jpeg = os.path.join(d, '{0}_torch.jpg'.format(filename)) pil_jpeg = os.path.join(basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) write_jpeg(img, torch_jpeg, quality=75) with open(torch_jpeg, 'rb') as f: torch_bytes = f.read() with open(pil_jpeg, 'rb') as f: pil_bytes = f.read() self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(img_path, pil_mode, mode): with Image.open(img_path) as img: if pil_mode is not None: img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) img_pil = normalize_dimensions(img_pil) data = read_file(img_path) img_lpng = decode_image(data, mode=mode) tol = 0 if pil_mode is None else 1 if PILLOW_VERSION >= (8, 3) and pil_mode == "LA": # Avoid checking the transparency channel until # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910 # is fixed. # TODO: remove once fix is released in PIL. Should be > 8.3.1. img_lpng, img_pil = img_lpng[0], img_pil[0] torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
def test_decode_jpeg(img_path, pil_mode, mode): with Image.open(img_path) as img: is_cmyk = img.mode == "CMYK" if pil_mode is not None: if is_cmyk: # libjpeg does not support the conversion pytest.xfail("Decoding a CMYK jpeg isn't supported") img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) if is_cmyk: # flip the colors to match libjpeg img_pil = 255 - img_pil img_pil = normalize_dimensions(img_pil) data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) # Permit a small variation on pixel values to account for implementation # differences between Pillow and LibJPEG. abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() assert abs_mean_diff < 2
def test_decode_jpeg_cuda_device_param(cuda_device): """Make sure we can pass a string or a torch.device as device param""" path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path) data = read_file(path) decode_jpeg(data, device=cuda_device)
def test_decode_jpeg_cuda_device_param(cuda_device): """Make sure we can pass a string or a torch.device as device param""" data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) decode_jpeg(data, device=cuda_device)
def test_decode_bad_huffman_images(): # sanity check: make sure we can decode the bad Huffman encoding bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg")) decode_jpeg(bad_huff)