def test_get_imagenet_images(self): self.create_imagenet_images() num_images = 1000 sizes = [32, 128] for size in sizes: images = image_loader.get_imagenet_images(num_images, size=size, root=self.test_dir) assert images.shape == (num_images, size, size, 3) assert np.mean(images[0]) != np.mean(images[1])
def test_imagenet_small_sample_error(self): with pytest.raises(ValueError): image_loader.get_imagenet_images(1)