Exemplo n.º 1
0
    def test_imagenet(self, mock_verify):
        with imagenet_root() as root:
            dataset = torchvision.datasets.ImageNet(root, split='train')
            self.generic_classification_dataset_test(dataset)

            dataset = torchvision.datasets.ImageNet(root, split='val')
            self.generic_classification_dataset_test(dataset)
Exemplo n.º 2
0
    def test_imagenet(self, mock_download):
        with imagenet_root() as root:
            dataset = torchvision.datasets.ImageNet(root,
                                                    split='train',
                                                    download=True)
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))
            self.assertTrue(isinstance(target, int))
            self.assertEqual(dataset.class_to_idx['fakedata'], target)

            dataset = torchvision.datasets.ImageNet(root,
                                                    split='val',
                                                    download=True)
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))
            self.assertTrue(isinstance(target, int))
            self.assertEqual(dataset.class_to_idx['fakedata'], target)