def test_places365(self): for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): with places365_root(split=split, small=small) as places365: root, data = places365 dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True) self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
def test_places365_devkit_no_download(self): for split in ("train-standard", "train-challenge", "val"): with self.subTest(split=split): with places365_root(split=split) as places365: root, data = places365 with self.assertRaises(RuntimeError): torchvision.datasets.Places365(root, split=split, download=False)
def test_places365_images_download(self): for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): with self.subTest(split=split, small=small): with places365_root(split=split, small=small) as places365: root, data = places365 dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True) assert all(os.path.exists(item[0]) for item in dataset.imgs)
def places365(): with log_download_attempts(patch=False) as urls_and_md5s: for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): with places365_root(split=split, small=small) as places365: root, data = places365 datasets.Places365(root, split=split, small=small, download=True) return make_download_configs(urls_and_md5s, "Places365")
def test_places365_images_download_preexisting(self): split = "train-standard" small = False images_dir = "data_large_standard" with places365_root(split=split, small=small) as places365: root, data = places365 os.mkdir(os.path.join(root, images_dir)) with self.assertRaises(RuntimeError): torchvision.datasets.Places365(root, split=split, small=small, download=True)
def collect_urls_and_md5s(self): with self.log_download_attempts(patch=False) as urls_and_md5s: for split, small in itertools.product( ("train-standard", "train-challenge", "val"), (False, True)): with places365_root(split=split, small=small) as places365: root, data = places365 datasets.Places365(root, split=split, small=small, download=True) return urls_and_md5s
def places365(): with log_download_attempts(patch=False) as urls_and_md5s: for split, small in itertools.product( ("train-standard", "train-challenge", "val"), (False, True)): with places365_root(split=split, small=small) as places365: root, data = places365 datasets.Places365(root, split=split, small=small, download=True) return [ DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s ]
def test_places365_downloadable(self, download_url): for split, small in itertools.product( ("train-standard", "train-challenge", "val"), (False, True)): with places365_root(split=split, small=small) as places365: root, data = places365 torchvision.datasets.Places365(root, split=split, small=small, download=True) urls = {call_args[0][0] for call_args in download_url.call_args_list} for url in urls: with self.subTest(url=url): response = urlopen(Request(url, method="HEAD")) assert response.code == 200, f"Server returned status code {response.code} for {url}."
def test_places365_devkit_download(self): for split in ("train-standard", "train-challenge", "val"): with self.subTest(split=split): with places365_root(split=split) as places365: root, data = places365 dataset = torchvision.datasets.Places365(root, split=split, download=True) with self.subTest("classes"): self.assertSequenceEqual(dataset.classes, data["classes"]) with self.subTest("class_to_idx"): self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"]) with self.subTest("imgs"): self.assertSequenceEqual(dataset.imgs, data["imgs"])
def test_places365_transforms(self): expected_image = "image" expected_target = "target" def transform(image): return expected_image def target_transform(target): return expected_target with places365_root() as places365: root, data = places365 dataset = torchvision.datasets.Places365( root, transform=transform, target_transform=target_transform, download=True ) actual_image, actual_target = dataset[0] self.assertEqual(actual_image, expected_image) self.assertEqual(actual_target, expected_target)
def test_places365_repr_smoke(self): with places365_root(extract_images=False) as places365: root, data = places365 dataset = torchvision.datasets.Places365(root, download=True) self.assertIsInstance(repr(dataset), str)