示例#1
0
    def test_places365(self):
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
                self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
示例#2
0
    def test_places365_devkit_no_download(self):
        for split in ("train-standard", "train-challenge", "val"):
            with self.subTest(split=split):
                with places365_root(split=split) as places365:
                    root, data = places365

                    with self.assertRaises(RuntimeError):
                        torchvision.datasets.Places365(root, split=split, download=False)
示例#3
0
    def test_places365_images_download(self):
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with self.subTest(split=split, small=small):
                with places365_root(split=split, small=small) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)

                    assert all(os.path.exists(item[0]) for item in dataset.imgs)
示例#4
0
def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root, split=split, small=small, download=True)

    return make_download_configs(urls_and_md5s, "Places365")
示例#5
0
    def test_places365_images_download_preexisting(self):
        split = "train-standard"
        small = False
        images_dir = "data_large_standard"

        with places365_root(split=split, small=small) as places365:
            root, data = places365
            os.mkdir(os.path.join(root, images_dir))

            with self.assertRaises(RuntimeError):
                torchvision.datasets.Places365(root, split=split, small=small, download=True)
    def collect_urls_and_md5s(self):
        with self.log_download_attempts(patch=False) as urls_and_md5s:
            for split, small in itertools.product(
                ("train-standard", "train-challenge", "val"), (False, True)):
                with places365_root(split=split, small=small) as places365:
                    root, data = places365

                    datasets.Places365(root,
                                       split=split,
                                       small=small,
                                       download=True)

        return urls_and_md5s
示例#7
0
def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(
            ("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root,
                                   split=split,
                                   small=small,
                                   download=True)

    return [
        DownloadConfig(url, md5=md5, id=f"Places365, {url}")
        for url, md5 in urls_and_md5s
    ]
示例#8
0
    def test_places365_downloadable(self, download_url):
        for split, small in itertools.product(
            ("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                torchvision.datasets.Places365(root,
                                               split=split,
                                               small=small,
                                               download=True)

        urls = {call_args[0][0] for call_args in download_url.call_args_list}
        for url in urls:
            with self.subTest(url=url):
                response = urlopen(Request(url, method="HEAD"))
                assert response.code == 200, f"Server returned status code {response.code} for {url}."
示例#9
0
    def test_places365_devkit_download(self):
        for split in ("train-standard", "train-challenge", "val"):
            with self.subTest(split=split):
                with places365_root(split=split) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, download=True)

                    with self.subTest("classes"):
                        self.assertSequenceEqual(dataset.classes, data["classes"])

                    with self.subTest("class_to_idx"):
                        self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])

                    with self.subTest("imgs"):
                        self.assertSequenceEqual(dataset.imgs, data["imgs"])
示例#10
0
    def test_places365_transforms(self):
        expected_image = "image"
        expected_target = "target"

        def transform(image):
            return expected_image

        def target_transform(target):
            return expected_target

        with places365_root() as places365:
            root, data = places365

            dataset = torchvision.datasets.Places365(
                root, transform=transform, target_transform=target_transform, download=True
            )
            actual_image, actual_target = dataset[0]

            self.assertEqual(actual_image, expected_image)
            self.assertEqual(actual_target, expected_target)
示例#11
0
    def test_places365_repr_smoke(self):
        with places365_root(extract_images=False) as places365:
            root, data = places365

            dataset = torchvision.datasets.Places365(root, download=True)
            self.assertIsInstance(repr(dataset), str)