示例#1
0
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
    if "corrupt34" in img_path:
        match_message = "Image is incomplete or truncated"
    else:
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
        decode_jpeg(data)
示例#2
0
    def test_damaged_images(self):
        # Test image with bad Huffman encoding (should not raise)
        bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
        try:
            _ = decode_jpeg(bad_huff)
        except RuntimeError:
            self.assertTrue(False)

        # Truncated images should raise an exception
        truncated_images = glob.glob(
            os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
        for image_path in truncated_images:
            data = read_file(image_path)
            with self.assertRaises(RuntimeError):
                decode_jpeg(data)
示例#3
0
def test_encode_jpeg_reference(img_path):
    # This test is *wrong*.
    # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
    # starts encoding the torchvision version from an image that comes from
    # decode_jpeg, which can yield different results from pil.decode (see
    # test_decode... which uses a high tolerance).
    # Instead, we should start encoding from the exact same decoded image, for a
    # valid comparison. This is done in test_encode_jpeg, but unfortunately
    # these more correct tests fail on windows (probably because of a difference
    # in libjpeg) between torchvision and PIL.
    # FIXME: make the correct tests pass on windows and remove this.
    dirname = os.path.dirname(img_path)
    filename, _ = os.path.splitext(os.path.basename(img_path))
    write_folder = os.path.join(dirname, 'jpeg_write')
    expected_file = os.path.join(
        write_folder, '{0}_pil.jpg'.format(filename))
    img = decode_jpeg(read_file(img_path))

    with open(expected_file, 'rb') as f:
        pil_bytes = f.read()
        pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
    for src_img in [img, img.contiguous()]:
        # PIL sets jpeg quality to 75 by default
        jpeg_bytes = encode_jpeg(src_img, quality=75)
        assert_equal(jpeg_bytes, pil_bytes)
示例#4
0
    def test_decode_jpeg(self):
        for img_path in get_images(IMAGE_ROOT, "jpg"):
            img_pil = torch.from_numpy(np.array(Image.open(img_path)))
            size = os.path.getsize(img_path)
            img_ljpeg = decode_jpeg(
                torch.from_file(img_path, dtype=torch.uint8, size=size))

            norm = img_ljpeg.shape[0] * img_ljpeg.shape[1] * img_ljpeg.shape[
                2] * 255
            err = torch.abs(img_ljpeg.flatten().float() -
                            img_pil.flatten().float()).sum().float() / (norm)

            self.assertLessEqual(err, 1e-2)

        with self.assertRaisesRegex(
                ValueError, "Expected a non empty 1-dimensional tensor."):
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

        with self.assertRaisesRegex(ValueError,
                                    "Expected a torch.uint8 tensor."):
            decode_jpeg(torch.empty((100, ), dtype=torch.float16))

        with self.assertRaisesRegex(RuntimeError,
                                    "Error while reading jpeg headers"):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))
示例#5
0
    def test_decode_jpeg(self):
        conversion = [(None, 0), ("L", 1), ("RGB", 3)]
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
            for pil_mode, channels in conversion:
                with Image.open(img_path) as img:
                    is_cmyk = img.mode == "CMYK"
                    if pil_mode is not None:
                        if is_cmyk:
                            # libjpeg does not support the conversion
                            continue
                        img = img.convert(pil_mode)
                    img_pil = torch.from_numpy(np.array(img))
                    if is_cmyk:
                        # flip the colors to match libjpeg
                        img_pil = 255 - img_pil

                img_pil = normalize_dimensions(img_pil)
                data = read_file(img_path)
                img_ljpeg = decode_image(data, channels=channels)

                # Permit a small variation on pixel values to account for implementation
                # differences between Pillow and LibJPEG.
                abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
                self.assertTrue(abs_mean_diff < 2)

        with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

        with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
            decode_jpeg(torch.empty((100, ), dtype=torch.float16))

        with self.assertRaises(RuntimeError):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))
示例#6
0
def test_decode_jpeg_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(torch.empty((100,), dtype=torch.float16))

    with pytest.raises(RuntimeError, match="Not a JPEG file"):
        decode_jpeg(torch.empty((100), dtype=torch.uint8))
示例#7
0
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(data.reshape(-1, 1), device='cuda')
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
        decode_jpeg(data.to('cuda'), device='cuda')
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(data.to(torch.float), device='cuda')
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
示例#8
0
    def test_decode_jpeg(self):
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
            img_pil = torch.load(img_path.replace('jpg', 'pth'))
            size = os.path.getsize(img_path)
            img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
            self.assertTrue(img_ljpeg.equal(img_pil))

        with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."):
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

        with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."):
            decode_jpeg(torch.empty((100, ), dtype=torch.float16))

        with self.assertRaises(RuntimeError):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))
示例#9
0
    def test_encode_jpeg(self):
        for img_path in get_images(ENCODE_JPEG, ".jpg"):
            dirname = os.path.dirname(img_path)
            filename, _ = os.path.splitext(os.path.basename(img_path))
            write_folder = os.path.join(dirname, 'jpeg_write')
            expected_file = os.path.join(write_folder,
                                         '{0}_pil.jpg'.format(filename))
            img = decode_jpeg(read_file(img_path))

            with open(expected_file, 'rb') as f:
                pil_bytes = f.read()
                pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
            for src_img in [img, img.contiguous()]:
                # PIL sets jpeg quality to 75 by default
                jpeg_bytes = encode_jpeg(src_img, quality=75)
                self.assertTrue(jpeg_bytes.equal(pil_bytes))

        with self.assertRaisesRegex(RuntimeError,
                                    "Input tensor dtype should be uint8"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
                "between 1 and 100"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8),
                        quality=-1)

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
                "between 1 and 100"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8),
                        quality=101)

        with self.assertRaisesRegex(
                RuntimeError,
                "The number of channels should be 1 or 3, got: 5"):
            encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
示例#10
0
    def test_decode_jpeg(self):
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
            img_pil = torch.load(img_path.replace('jpg', 'pth'))
            img_pil = img_pil.permute(2, 0, 1)
            data = read_file(img_path)
            img_ljpeg = decode_jpeg(data)
            self.assertTrue(img_ljpeg.equal(img_pil))

        with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

        with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
            decode_jpeg(torch.empty((100, ), dtype=torch.float16))

        with self.assertRaises(RuntimeError):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))
示例#11
0
def test_write_jpeg_reference(img_path, tmpdir):
    # FIXME: Remove this eventually, see test_encode_jpeg_reference
    data = read_file(img_path)
    img = decode_jpeg(data)

    basedir = os.path.dirname(img_path)
    filename, _ = os.path.splitext(os.path.basename(img_path))
    torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg")
    pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg")

    write_jpeg(img, torch_jpeg, quality=75)

    with open(torch_jpeg, "rb") as f:
        torch_bytes = f.read()

    with open(pil_jpeg, "rb") as f:
        pil_bytes = f.read()

    assert_equal(torch_bytes, pil_bytes)
示例#12
0
    def test_write_jpeg(self):
        with get_tmp_dir() as d:
            for img_path in get_images(ENCODE_JPEG, ".jpg"):
                data = read_file(img_path)
                img = decode_jpeg(data)

                basedir = os.path.dirname(img_path)
                filename, _ = os.path.splitext(os.path.basename(img_path))
                torch_jpeg = os.path.join(d, '{0}_torch.jpg'.format(filename))
                pil_jpeg = os.path.join(basedir, 'jpeg_write',
                                        '{0}_pil.jpg'.format(filename))

                write_jpeg(img, torch_jpeg, quality=75)

                with open(torch_jpeg, 'rb') as f:
                    torch_bytes = f.read()

                with open(pil_jpeg, 'rb') as f:
                    pil_bytes = f.read()

                self.assertEqual(torch_bytes, pil_bytes)
def test_write_jpeg_reference(img_path):
    # FIXME: Remove this eventually, see test_encode_jpeg_reference
    with get_tmp_dir() as d:
        data = read_file(img_path)
        img = decode_jpeg(data)

        basedir = os.path.dirname(img_path)
        filename, _ = os.path.splitext(os.path.basename(img_path))
        torch_jpeg = os.path.join(d, '{0}_torch.jpg'.format(filename))
        pil_jpeg = os.path.join(basedir, 'jpeg_write',
                                '{0}_pil.jpg'.format(filename))

        write_jpeg(img, torch_jpeg, quality=75)

        with open(torch_jpeg, 'rb') as f:
            torch_bytes = f.read()

        with open(pil_jpeg, 'rb') as f:
            pil_bytes = f.read()

        assert_equal(torch_bytes, pil_bytes)
示例#14
0
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg")
                if "cmyk" not in path)
    data = read_file(path)
    decode_jpeg(data, device=cuda_device)
示例#15
0
def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
    decode_jpeg(bad_huff)
示例#16
0
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    decode_jpeg(data, device=cuda_device)