def test_decode_png(self): for img_path in get_images(IMAGE_DIR, "png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) size = os.path.getsize(img_path) img_lpng = decode_png( torch.from_file(img_path, dtype=torch.uint8, size=size)) self.assertEqual(img_lpng, img_pil) self.assertEqual(decode_png(torch.empty()), torch.empty()) self.assertEqual(decode_png(torch.randint(3, 5, (300, ))), torch.empty())
def test_decode_png(self): for img_path in get_images(IMAGE_DIR, ".png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) size = os.path.getsize(img_path) img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) self.assertTrue(img_lpng.equal(img_pil)) with self.assertRaises(ValueError): decode_png(torch.empty((), dtype=torch.uint8)) with self.assertRaises(RuntimeError): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_decode_png(self): for img_path in get_images(IMAGE_DIR, ".png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = img_pil.permute(2, 0, 1) data = read_file(img_path) img_lpng = decode_png(data) self.assertTrue(img_lpng.equal(img_pil)) with self.assertRaises(RuntimeError): decode_png(torch.empty((), dtype=torch.uint8)) with self.assertRaises(RuntimeError): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_decode_png(self): conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)] for img_path in get_images(FAKEDATA_DIR, ".png"): for pil_mode, channels in conversion: with Image.open(img_path) as img: if pil_mode is not None: img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) img_pil = normalize_dimensions(img_pil) data = read_file(img_path) img_lpng = decode_image(data, channels=channels) tol = 0 if conversion is None else 1 self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) with self.assertRaises(RuntimeError): decode_png(torch.empty((), dtype=torch.uint8)) with self.assertRaises(RuntimeError): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_decode_png_errors(): with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): decode_png(torch.empty((), dtype=torch.uint8)) with pytest.raises(RuntimeError, match="Content is not png"): decode_png(torch.randint(3, 5, (300, ), dtype=torch.uint8))