def test_encode_jpeg_reference(img_path): # This test is *wrong*. # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it # starts encoding the torchvision version from an image that comes from # decode_jpeg, which can yield different results from pil.decode (see # test_decode... which uses a high tolerance). # Instead, we should start encoding from the exact same decoded image, for a # valid comparison. This is done in test_encode_jpeg, but unfortunately # these more correct tests fail on windows (probably because of a difference # in libjpeg) between torchvision and PIL. # FIXME: make the correct tests pass on windows and remove this. dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) write_folder = os.path.join(dirname, 'jpeg_write') expected_file = os.path.join( write_folder, '{0}_pil.jpg'.format(filename)) img = decode_jpeg(read_file(img_path)) with open(expected_file, 'rb') as f: pil_bytes = f.read() pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) for src_img in [img, img.contiguous()]: # PIL sets jpeg quality to 75 by default jpeg_bytes = encode_jpeg(src_img, quality=75) assert_equal(jpeg_bytes, pil_bytes)
def test_encode_jpeg(img_path): img = read_image(img_path) pil_img = F.to_pil_image(img) buf = io.BytesIO() pil_img.save(buf, format="JPEG", quality=75) encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8) for src_img in [img, img.contiguous()]: encoded_jpeg_torch = encode_jpeg(src_img, quality=75) assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
def test_encode_jpeg(img_path): img = read_image(img_path) pil_img = F.to_pil_image(img) buf = io.BytesIO() pil_img.save(buf, format='JPEG', quality=75) # pytorch can't read from raw bytes so we go through numpy pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) encoded_jpeg_pil = torch.as_tensor(pil_bytes) for src_img in [img, img.contiguous()]: encoded_jpeg_torch = encode_jpeg(src_img, quality=75) assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
def test_encode_jpeg_errors(): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) with pytest.raises( ValueError, match="Image quality should be a positive number between 1 and 100" ): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) with pytest.raises( ValueError, match="Image quality should be a positive number between 1 and 100" ): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) with pytest.raises( RuntimeError, match="The number of channels should be 1 or 3, got: 5"): encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_encode_jpeg(self): for img_path in get_images(ENCODE_JPEG, ".jpg"): dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) write_folder = os.path.join(dirname, 'jpeg_write') expected_file = os.path.join( write_folder, '{0}_pil.jpg'.format(filename)) img = decode_jpeg(read_file(img_path)) with open(expected_file, 'rb') as f: pil_bytes = f.read() pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) for src_img in [img, img.contiguous()]: # PIL sets jpeg quality to 75 by default jpeg_bytes = encode_jpeg(src_img, quality=75) self.assertTrue(jpeg_bytes.equal(pil_bytes)) with self.assertRaisesRegex( RuntimeError, "Input tensor dtype should be uint8"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) with self.assertRaisesRegex( RuntimeError, "The number of channels should be 1 or 3, got: 5"): encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) with self.assertRaisesRegex( RuntimeError, "Input data should be a 3-dimensional tensor"): encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) with self.assertRaisesRegex( RuntimeError, "Input data should be a 3-dimensional tensor"): encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))