Example #1
def test_decode_png(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)

    if img_path.endswith("16.png"):
        # 16 bits image decoding is supported, but only as a private API
        # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
        with pytest.raises(RuntimeError,
                           match="At most 8-bit PNG images are supported"):
            data = read_file(img_path)
            img_lpng = decode_image(data, mode=mode)

        img_lpng = _read_png_16(img_path, mode=mode)
        assert img_lpng.dtype == torch.int32
        # PIL converts 16 bits pngs in uint8
        img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
        data = read_file(img_path)
        img_lpng = decode_image(data, mode=mode)

    tol = 0 if pil_mode is None else 1

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
Example #2
    def test_decode_image(self):
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
            img_pil = torch.load(img_path.replace('jpg', 'pth'))
            img_pil = img_pil.permute(2, 0, 1)
            img_ljpeg = decode_image(read_file(img_path))

        for img_path in get_images(IMAGE_DIR, ".png"):
            img_pil = torch.from_numpy(np.array(Image.open(img_path)))
            img_pil = img_pil.permute(2, 0, 1)
            img_lpng = decode_image(read_file(img_path))
Example #3
def test_read_file(tmpdir):
    fname, content = "test1.bin", b"TorchVision\211\n"
    fpath = os.path.join(tmpdir, fname)
    with open(fpath, "wb") as f:

    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    assert_equal(data, expected)

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
def test_read_file():
    with get_tmp_dir() as d:
        fname, content = 'test1.bin', b'TorchVision\211\n'
        fpath = os.path.join(d, fname)
        with open(fpath, 'wb') as f:

        data = read_file(fpath)
        expected = torch.tensor(list(content), dtype=torch.uint8)
        assert_equal(data, expected)

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
Example #5
    def test_read_file(self):
        with get_tmp_dir() as d:
            fname, content = 'test1.bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            with open(fpath, 'wb') as f:

            data = read_file(fpath)
            expected = torch.tensor(list(content), dtype=torch.uint8)

        with self.assertRaisesRegex(
                RuntimeError, "No such file or directory: 'tst'"):
Example #6
    def test_damaged_images(self):
        # Test image with bad Huffman encoding (should not raise)
        bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
            _ = decode_jpeg(bad_huff)
        except RuntimeError:

        # Truncated images should raise an exception
        truncated_images = glob.glob(
            os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
        for image_path in truncated_images:
            data = read_file(image_path)
            with self.assertRaises(RuntimeError):
Example #7
def test_encode_jpeg_reference(img_path):
    # This test is *wrong*.
    # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
    # starts encoding the torchvision version from an image that comes from
    # decode_jpeg, which can yield different results from pil.decode (see
    # test_decode... which uses a high tolerance).
    # Instead, we should start encoding from the exact same decoded image, for a
    # valid comparison. This is done in test_encode_jpeg, but unfortunately
    # these more correct tests fail on windows (probably because of a difference
    # in libjpeg) between torchvision and PIL.
    # FIXME: make the correct tests pass on windows and remove this.
    dirname = os.path.dirname(img_path)
    filename, _ = os.path.splitext(os.path.basename(img_path))
    write_folder = os.path.join(dirname, 'jpeg_write')
    expected_file = os.path.join(
        write_folder, '{0}_pil.jpg'.format(filename))
    img = decode_jpeg(read_file(img_path))

    with open(expected_file, 'rb') as f:
        pil_bytes = f.read()
        pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
    for src_img in [img, img.contiguous()]:
        # PIL sets jpeg quality to 75 by default
        jpeg_bytes = encode_jpeg(src_img, quality=75)
        assert_equal(jpeg_bytes, pil_bytes)
Example #8
    def test_decode_jpeg(self):
        conversion = [(None, 0), ("L", 1), ("RGB", 3)]
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
            for pil_mode, channels in conversion:
                with Image.open(img_path) as img:
                    is_cmyk = img.mode == "CMYK"
                    if pil_mode is not None:
                        if is_cmyk:
                            # libjpeg does not support the conversion
                        img = img.convert(pil_mode)
                    img_pil = torch.from_numpy(np.array(img))
                    if is_cmyk:
                        # flip the colors to match libjpeg
                        img_pil = 255 - img_pil

                img_pil = normalize_dimensions(img_pil)
                data = read_file(img_path)
                img_ljpeg = decode_image(data, channels=channels)

                # Permit a small variation on pixel values to account for implementation
                # differences between Pillow and LibJPEG.
                abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
                self.assertTrue(abs_mean_diff < 2)

        with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

        with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
            decode_jpeg(torch.empty((100, ), dtype=torch.float16))

        with self.assertRaises(RuntimeError):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))
Example #9
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
    if "corrupt34" in img_path:
        match_message = "Image is incomplete or truncated"
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
Example #10
def test_read_file_non_ascii(tmpdir):
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
    fpath = os.path.join(tmpdir, fname)
    with open(fpath, "wb") as f:

    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    assert_equal(data, expected)
Example #11
    def test_read_file(self):
        with get_tmp_dir() as d:
            fname, content = 'test1.bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            with open(fpath, 'wb') as f:

            data = read_file(fpath)
            expected = torch.tensor(list(content), dtype=torch.uint8)
            # Windows holds into the file until the tensor is alive
            # so need to del the tensor before deleting the file see
            # https://github.com/pytorch/vision/issues/2743#issuecomment-703817293
            del data

        with self.assertRaisesRegex(RuntimeError,
                                    "No such file or directory: 'tst'"):
Example #12
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(data.reshape(-1, 1), device='cuda')
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
        decode_jpeg(data.to('cuda'), device='cuda')
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(data.to(torch.float), device='cuda')
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
Example #13
def test_decode_jpeg_cuda(mode, img_path, scripted):
    if "cmyk" in img_path:
        pytest.xfail("Decoding a CMYK jpeg isn't supported")

    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
    img_nvjpeg = f(data, mode=mode, device="cuda")

    # Some difference expected between jpeg implementations
    assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
Example #14
    def test_read_file_non_ascii(self):
        with get_tmp_dir() as d:
            fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            with open(fpath, 'wb') as f:

            data = read_file(fpath)
            expected = torch.tensor(list(content), dtype=torch.uint8)
Example #15
    def test_decode_png(self):
        for img_path in get_images(IMAGE_DIR, ".png"):
            img_pil = torch.from_numpy(np.array(Image.open(img_path)))
            img_pil = img_pil.permute(2, 0, 1)
            data = read_file(img_path)
            img_lpng = decode_png(data)

            with self.assertRaises(RuntimeError):
                decode_png(torch.empty((), dtype=torch.uint8))
            with self.assertRaises(RuntimeError):
                decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_decode_png(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
    img_lpng = decode_image(data, mode=mode)

    tol = 0 if pil_mode is None else 1
    assert img_lpng.allclose(img_pil, atol=tol)
Example #17
    def test_decode_jpeg(self):
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
            img_pil = torch.load(img_path.replace('jpg', 'pth'))
            img_pil = img_pil.permute(2, 0, 1)
            data = read_file(img_path)
            img_ljpeg = decode_jpeg(data)

        with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

        with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
            decode_jpeg(torch.empty((100, ), dtype=torch.float16))

        with self.assertRaises(RuntimeError):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))
Example #18
    def test_encode_jpeg(self):
        for img_path in get_images(ENCODE_JPEG, ".jpg"):
            dirname = os.path.dirname(img_path)
            filename, _ = os.path.splitext(os.path.basename(img_path))
            write_folder = os.path.join(dirname, 'jpeg_write')
            expected_file = os.path.join(write_folder,
            img = decode_jpeg(read_file(img_path))

            with open(expected_file, 'rb') as f:
                pil_bytes = f.read()
                pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
            for src_img in [img, img.contiguous()]:
                # PIL sets jpeg quality to 75 by default
                jpeg_bytes = encode_jpeg(src_img, quality=75)

        with self.assertRaisesRegex(RuntimeError,
                                    "Input tensor dtype should be uint8"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
                "between 1 and 100"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8),

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
                "between 1 and 100"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8),

        with self.assertRaisesRegex(
                "The number of channels should be 1 or 3, got: 5"):
            encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
Example #19
def test_write_jpeg_reference(img_path, tmpdir):
    # FIXME: Remove this eventually, see test_encode_jpeg_reference
    data = read_file(img_path)
    img = decode_jpeg(data)

    basedir = os.path.dirname(img_path)
    filename, _ = os.path.splitext(os.path.basename(img_path))
    torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg")
    pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg")

    write_jpeg(img, torch_jpeg, quality=75)

    with open(torch_jpeg, "rb") as f:
        torch_bytes = f.read()

    with open(pil_jpeg, "rb") as f:
        pil_bytes = f.read()

    assert_equal(torch_bytes, pil_bytes)
Example #20
    def test_decode_png(self):
        conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
        for img_path in get_images(FAKEDATA_DIR, ".png"):
            for pil_mode, channels in conversion:
                with Image.open(img_path) as img:
                    if pil_mode is not None:
                        img = img.convert(pil_mode)
                    img_pil = torch.from_numpy(np.array(img))

                img_pil = normalize_dimensions(img_pil)
                data = read_file(img_path)
                img_lpng = decode_image(data, channels=channels)

                tol = 0 if conversion is None else 1
                self.assertTrue(img_lpng.allclose(img_pil, atol=tol))

        with self.assertRaises(RuntimeError):
            decode_png(torch.empty((), dtype=torch.uint8))
        with self.assertRaises(RuntimeError):
            decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_write_jpeg_reference(img_path):
    # FIXME: Remove this eventually, see test_encode_jpeg_reference
    with get_tmp_dir() as d:
        data = read_file(img_path)
        img = decode_jpeg(data)

        basedir = os.path.dirname(img_path)
        filename, _ = os.path.splitext(os.path.basename(img_path))
        torch_jpeg = os.path.join(d, '{0}_torch.jpg'.format(filename))
        pil_jpeg = os.path.join(basedir, 'jpeg_write',

        write_jpeg(img, torch_jpeg, quality=75)

        with open(torch_jpeg, 'rb') as f:
            torch_bytes = f.read()

        with open(pil_jpeg, 'rb') as f:
            pil_bytes = f.read()

        assert_equal(torch_bytes, pil_bytes)
Example #22
    def test_write_jpeg(self):
        with get_tmp_dir() as d:
            for img_path in get_images(ENCODE_JPEG, ".jpg"):
                data = read_file(img_path)
                img = decode_jpeg(data)

                basedir = os.path.dirname(img_path)
                filename, _ = os.path.splitext(os.path.basename(img_path))
                torch_jpeg = os.path.join(d, '{0}_torch.jpg'.format(filename))
                pil_jpeg = os.path.join(basedir, 'jpeg_write',

                write_jpeg(img, torch_jpeg, quality=75)

                with open(torch_jpeg, 'rb') as f:
                    torch_bytes = f.read()

                with open(pil_jpeg, 'rb') as f:
                    pil_bytes = f.read()

                self.assertEqual(torch_bytes, pil_bytes)
Example #23
def test_decode_png(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
    img_lpng = decode_image(data, mode=mode)

    tol = 0 if pil_mode is None else 1

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
Example #24
def test_decode_jpeg(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        is_cmyk = img.mode == "CMYK"
        if pil_mode is not None:
            if is_cmyk:
                # libjpeg does not support the conversion
                pytest.xfail("Decoding a CMYK jpeg isn't supported")
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))
        if is_cmyk:
            # flip the colors to match libjpeg
            img_pil = 255 - img_pil

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
    img_ljpeg = decode_image(data, mode=mode)

    # Permit a small variation on pixel values to account for implementation
    # differences between Pillow and LibJPEG.
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
    assert abs_mean_diff < 2
Example #25
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg")
                if "cmyk" not in path)
    data = read_file(path)
    decode_jpeg(data, device=cuda_device)
Example #26
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    decode_jpeg(data, device=cuda_device)
Example #27
def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))