def __init__(self, root, download=False, **kwargs): self.root = root self.download = download self.CALTECH256_DRIVE_ID = '1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK' self.CALTECH256_ZIP_FILE = '256_ObjectCategories.tar' self.CALTECH256_ZIP_PATH = os.path.join(self.root, self.CALTECH256_ZIP_FILE) self.CALTECH256_FOLDER = os.path.join(self.root, '256_ObjectCategories') if not os.path.exists(self.CALTECH256_FOLDER): # Extract Caltech256 zip file if not os.path.exists(self.CALTECH256_ZIP_PATH): if not self.download: raise (FileNotFoundError("Dataset files not found")) print("Downloading {} file...".format( self.CALTECH256_ZIP_FILE)) download_large_file_from_drive(self.CALTECH256_DRIVE_ID, self.CALTECH256_ZIP_PATH) print("Done!") print("Extracting {} file...".format(self.CALTECH256_ZIP_FILE)) extract_archive(self.CALTECH256_ZIP_PATH, self.root) print("Done!") super(Caltech256, self).__init__(self.CALTECH256_FOLDER, **kwargs) self.samples_per_class = { k: self.targets.count(k) for k in self.class_to_idx.values() }
def test_extract_tar(self): def create_archive(root, ext, mode, content="this is the content"): src = os.path.join(root, "src.txt") dst = os.path.join(root, "dst.txt") archive = os.path.join(root, f"archive{ext}") with open(src, "w") as fh: fh.write(content) with tarfile.open(archive, mode=mode) as fh: fh.add(src, arcname=os.path.basename(dst)) return archive, dst, content for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): with get_tmp_dir() as temp_dir: archive, file, content = create_archive(temp_dir, ext, mode) utils.extract_archive(archive, temp_dir) self.assertTrue(os.path.exists(file)) with open(file, "r") as fh: self.assertEqual(fh.read(), content)
def _download(self): print('Downloading...') for url, filename in self.file_list.values(): download_url(url, root=self.root, filename=filename) print('Extracting...') archive = os.path.join(self.root, self.file_list['imgs'][1]) extract_archive(archive)
def _download_arrange_data( self, download_url_address, put_data_dir, extract_to_dir=None, obtained_file_name=None, ): """ Download the raw data and arrange the data """ # Extract to the same dir as the download dir if extract_to_dir is None: extract_to_dir = put_data_dir download_file_name = os.path.basename(download_url_address) download_file_path = os.path.join(put_data_dir, download_file_name) download_extracted_file_name = download_file_name.split(".")[0] download_extracted_dir_path = os.path.join( extract_to_dir, download_extracted_file_name) # Download the raw data if necessary if not self._exist_judgement(download_file_path): logging.info("Downloading the %s data.....", download_file_name) download_url(url=download_url_address, root=put_data_dir, filename=obtained_file_name) # Extract the data to the specific dir if ".zip" in download_file_name or ".tar.gz" in download_file_name: if not self._exist_judgement(download_extracted_dir_path): logging.info("Extracting data to %s dir.....", extract_to_dir) extract_archive(from_path=download_file_path, to_path=extract_to_dir, remove_finished=False) return download_extracted_file_name
def __init__(self, root, split='train', transform=None, target_transform=None, transforms=None): super(Deepfashion, self).__init__(root, transforms, transform, target_transform) self.images_dir = os.path.join(self.root, 'img') self.targets_dir = os.path.join(self.root, 'lbl') self.split = split self.images = [] self.targets = [] valid_modes = ("train", "test", "val") msg = ("Unknown value '{}' for argument split if mode is '{}'. " "Valid values are {{{}}}.") msg = msg.format(split, mode, iterable_to_str(valid_modes)) verify_str_arg(split, "split", valid_modes, msg) if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): image_dir_zip = os.path.join(self.root, '{}'.format('img.zip')) target_dir_zip = os.path.join(self.root, '{}'.format('lbl.zip')) if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): extract_archive(from_path=image_dir_zip, to_path=self.root) extract_archive(from_path=target_dir_zip, to_path=self.root) else: raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' ' specified "split" and "mode" are inside the "root" directory') data_list = pd.read_csv(os.path.join(self.root, 'list_eval_partition.txt'), sep='\t', skiprows=1) data_list = data_list[data_list['evaluation_status'] == self.split] for image_path in data_list['image_name']: target_path = 'lbl/' + '/'.join(image_path.split('/')[1:]) self.images.append(image_path) self.targets.append(target_path)
def test_extract_tar(self, extension, mode): def create_archive(root, extension, mode, content="this is the content"): src = os.path.join(root, "src.txt") dst = os.path.join(root, "dst.txt") archive = os.path.join(root, f"archive{extension}") with open(src, "w") as fh: fh.write(content) with tarfile.open(archive, mode=mode) as fh: fh.add(src, arcname=os.path.basename(dst)) return archive, dst, content with get_tmp_dir() as temp_dir: archive, file, content = create_archive(temp_dir, extension, mode) utils.extract_archive(archive, temp_dir) assert os.path.exists(file) with open(file, "r") as fh: assert fh.read() == content
def initialize_folder(self, **kwargs): print('initialize folder') mode_list = [ 'train', 'valid' ] if self.valid_set and 'valid' in self.url.keys() else ['train'] for mode in mode_list: zip_path = os.path.join(self.folder_path, f'{self.name}_{mode}_store.zip') if os.path.isfile(zip_path): print('{yellow}Uncompress file{reset}: '.format(**ansi), zip_path) extract_archive(from_path=zip_path, to_path=self.folder_path) print('{green}Uncompress finished{reset}: '.format(**ansi), f'{zip_path}') print() continue tar_path = os.path.join(self.folder_path, f'{self.name}_{mode}_store.zip') self.download_and_extract_archive(mode=mode) os.rename( os.path.join(self.folder_path, self.org_folder_name[mode]), os.path.join(self.folder_path, mode)) try: dirname = os.path.dirname(self.org_folder_name[mode]) if dirname: shutil.rmtree(os.path.join(self.folder_path, dirname)) except FileNotFoundError: pass
def __init__(self, root, download=False, **kwargs): self.root = root self.download = download self.ETH80_URL = 'https://github.com/Kai-Xuan/ETH-80/archive/master.zip' self.ETH80_ZIP_FILE = 'ETH-80-master.zip' self.ETH80_ZIP_PATH = os.path.join(self.root, self.ETH80_ZIP_FILE) self.ETH80_FOLDER = os.path.join(self.root, 'ETH-80-master') self.ETH80_FOLDER_ORGANIZED = os.path.join(self.root, 'ETH-80-organized') if not os.path.exists(self.ETH80_FOLDER): # Extract ETH80 zip file if not os.path.exists(self.ETH80_ZIP_PATH): if not self.download: raise(FileNotFoundError("Dataset files not found")) download_url(self.ETH80_URL, self.root, self.ETH80_ZIP_FILE) print("Extracting {} file...".format(self.ETH80_ZIP_FILE)) extract_archive(self.ETH80_ZIP_PATH, self.root) print("Done!") # Organize images in a directory structure consistent with that required by ImageFolder. Also, ignore "map" images. if not os.path.exists(self.ETH80_FOLDER_ORGANIZED): print("Organizing dataset images...") os.makedirs(self.ETH80_FOLDER_ORGANIZED) for i in range(1, 9): dest = os.path.join(self.ETH80_FOLDER_ORGANIZED, str(i)) os.makedirs(dest) for j in range(1, 11): src = os.path.join(self.ETH80_FOLDER, str(i), str(j)) for file in os.listdir(src): if file != 'maps': shutil.copy(os.path.join(src, file), os.path.join(dest, file)) print("Done!") super(ETH80, self).__init__(self.ETH80_FOLDER_ORGANIZED, **kwargs)
def extract_downloaded_files(download_root: str, extract_root: str): ''' :param download_root: Root directory path which saves downloaded dataset files :type download_root: str :param extract_root: Root directory path which saves extracted files from downloaded files :type extract_root: str :return: None This function defines how to extract download files. ''' temp_ext_dir = os.path.join(download_root, 'temp_ext') os.mkdir(temp_ext_dir) print(f'Mkdir [{temp_ext_dir}].') extract_archive(os.path.join(download_root, 'navgesture-sit.zip'), temp_ext_dir) with ThreadPoolExecutor( max_workers=min(multiprocessing.cpu_count(), 4)) as tpe: for zip_file in os.listdir(temp_ext_dir): if os.path.splitext(zip_file)[1] == '.zip': zip_file = os.path.join(temp_ext_dir, zip_file) print(f'Extract [{zip_file}] to [{extract_root}].') tpe.submit(extract_archive, zip_file, extract_root) shutil.rmtree(temp_ext_dir) print(f'Rmtree [{temp_ext_dir}].')
def __init__(self, root, mode, transform=None, target_transform=None): super(NABirds, self).__init__(root, transform=transform, target_transform=target_transform) dataset_path = root if not os.path.isdir(dataset_path): if not check_integrity(os.path.join(root, self.filename), self.md5): raise RuntimeError('Dataset not found or corrupted.') extract_archive(os.path.join(root, self.filename)) self.root = os.path.expanduser(root) self.loader = default_loader # self.transform = transform # Load in the class data self.class_names = load_class_names(root) self.class_hierarchy = load_hierarchy(root) if mode == 'train': self.data = NABirds.TRAIN_DATA self.targets = NABirds.TRAIN_TARGETS elif mode == 'query': self.data = NABirds.QUERY_DATA self.targets = NABirds.QUERY_TARGETS elif mode == 'retrieval': self.data = NABirds.RETRIEVAL_DATA self.targets = NABirds.RETRIEVAL_TARGETS else: raise ValueError(r'Invalid arguments: mode, can\'t load dataset!')
def test_paths_benchmark(self): download_url( 'https://storage.googleapis.com/mledu-datasets/' 'cats_and_dogs_filtered.zip', './data', 'cats_and_dogs_filtered.zip') archive_name = os.path.join('./data', 'cats_and_dogs_filtered.zip') extract_archive(archive_name, to_path='./data/') dirpath = "./data/cats_and_dogs_filtered/train" train_experiences = [] for rel_dir, label in zip(["cats", "dogs"], [0, 1]): filenames_list = os.listdir(os.path.join(dirpath, rel_dir)) experience_paths = [] for name in filenames_list: instance_tuple = (os.path.join(dirpath, rel_dir, name), label) experience_paths.append(instance_tuple) train_experiences.append(experience_paths) generic_scenario = paths_benchmark( train_experiences, [train_experiences[0]], # Single test set task_labels=[0, 0], complete_test_set_only=True, train_transform=ToTensor(), eval_transform=ToTensor()) self.assertEqual(2, len(generic_scenario.train_stream)) self.assertEqual(1, len(generic_scenario.test_stream))
def download(self): path = os.path.join(self.root) link = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' os.makedirs(path, exist_ok=True) download_url(link, path) extract_archive(os.path.join(path, 'tiny-imagenet-200.zip'))
def test_filelist_benchmark(self): download_url( 'https://storage.googleapis.com/mledu-datasets/' 'cats_and_dogs_filtered.zip', './data', 'cats_and_dogs_filtered.zip') archive_name = os.path.join('./data', 'cats_and_dogs_filtered.zip') extract_archive(archive_name, to_path='./data/') dirpath = "./data/cats_and_dogs_filtered/train" for filelist, dir, label in zip( ["train_filelist_00.txt", "train_filelist_01.txt"], ["cats", "dogs"], [0, 1]): # First, obtain the list of files filenames_list = os.listdir(os.path.join(dirpath, dir)) with open(filelist, "w") as wf: for name in filenames_list: wf.write("{} {}\n".format(os.path.join(dir, name), label)) generic_scenario = filelist_benchmark( dirpath, ["train_filelist_00.txt", "train_filelist_01.txt"], ["train_filelist_00.txt"], task_labels=[0, 0], complete_test_set_only=True, train_transform=ToTensor(), eval_transform=ToTensor()) self.assertEqual(2, len(generic_scenario.train_stream)) self.assertEqual(1, len(generic_scenario.test_stream))
def verify_dataset(self, split): if not os.path.isdir(os.path.join(self.root, self.images_dir)): if self.image_mode == 'gtFine': image_dir_zip = os.path.join( self.root, '{}_trainvaltest.zip'.format(self.image_mode)) elif self.image_mode == 'gtCoarse': image_dir_zip = os.path.join(self.root, '{}.zip'.format(self.image_mode)) else: if split == 'train_extra': if self.image_mode == 'leftImg8bit': image_dir_zip = os.path.join( self.root, '{}_trainextra.zip'.format(self.image_mode)) else: image_dir_zip = os.path.join( self.root, '{}_trainextra_{}.zip'.format( *self.image_mode.split('_'))) else: if self.image_mode == 'leftImg8bit': image_dir_zip = os.path.join( self.root, '{}_trainvaltest.zip'.format(self.image_mode)) else: image_dir_zip = os.path.join( self.root, '{}_trainvaltest_{}.zip'.format( *self.image_mode.split('_'))) if os.path.isfile(image_dir_zip): extract_archive(from_path=image_dir_zip, to_path=self.root) else: raise RuntimeError( 'Dataset not found or incomplete. Please make sure all required folders for the' ' specified "split" and "image_mode" are inside the "root" directory' )
def download(self): path = os.path.join(self.root, 'notMNIST') link = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_small.tar.gz' os.makedirs(path, exist_ok=True) download_url(link, path) extract_archive(os.path.join(path, 'notMNIST_small.tar.gz'))
def parse_train_archive(root, file=None, folder="train"): """Parse the train images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. Args: root (str): Root directory containing the train images archive file (str, optional): Name of train images archive. Defaults to 'ILSVRC2012_img_train.tar' folder (str, optional): Optional name for train images folder. Defaults to 'train' """ archive_meta = ImageNet.archive_meta["train"] if file is None: file = archive_meta[0] md5 = archive_meta[1] _verify_archive(root, file, md5) train_root = os.path.join(root, folder) extract_archive(os.path.join(root, file), train_root) archives = [ os.path.join(train_root, archive) for archive in os.listdir(train_root) ] for archive in archives: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=False)
def download_google_drive(self, data_dir: str, filename: str) -> None: """Download dataset Parameters ---------- data_dir : str Path to base dataset directory filename : str Filename of google drive file being downloaded Returns ------- None """ if not os.path.exists(data_dir): os.mkdir(data_dir) if not os.path.exists(self.root_dir): download_file_from_google_drive(file_id=self.url, root=data_dir, filename=filename) extract_archive(from_path=os.path.join(data_dir, filename), to_path=data_dir, remove_finished=True)
def _extract_archive( self, path: Union[str, Path], sub_directory: str = None, remove_archive: bool = False, ) -> Path: """ Utility method that can be used to extract an archive. :param path: The complete path to the archive (for instance obtained by calling `_download_file`). :param sub_directory: The name of the sub directory where to extract the archive. Can be None, which means that the archive will be extracted in the root. Beware that some archives already have a root directory inside of them, in which case it's probably better to use None here. Defaults to None. :param remove_archive: If True, the archive will be deleted after a successful extraction. Defaults to False. :return: """ if sub_directory is None: extract_root = self.root else: extract_root = self.root / sub_directory extract_archive(str(path), to_path=str(extract_root), remove_finished=remove_archive) return extract_root
def prepare_train_folder(folder): for archive in [ os.path.join(folder, archive) for archive in os.listdir(folder) ]: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
def parse_val_archive(root, file=None, wnids=None, folder="val"): """Parse the validation images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. Args: root (str): Root directory containing the validation images archive file (str, optional): Name of validation images archive. Defaults to 'ILSVRC2012_img_val.tar' wnids (list, optional): List of WordNet IDs of the validation images. If None is given, the IDs are loaded from the meta file in the root directory folder (str, optional): Optional name for validation images folder. Defaults to 'val' """ archive_meta = ARCHIVE_META["val"] if file is None: file = archive_meta[0] md5 = archive_meta[1] if wnids is None: wnids = load_meta_file(root)[1] _verify_archive(root, file, md5) val_root = os.path.join(root, folder) extract_archive(os.path.join(root, file), val_root) images = sorted( [os.path.join(val_root, image) for image in os.listdir(val_root)]) for wnid in set(wnids): os.mkdir(os.path.join(val_root, wnid)) for wnid, img_file in zip(wnids, images): shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))
def __init__(self, root, train=True, transform=None, loader=default_loader, download=False): self.root = os.path.expanduser(root) self.transform = transform self.loader = loader self.train = train self.log = logging.getLogger("avalanche") if download: self.log.error( "Download is not supported for this Dataset." "You need to download 'images.tgz' and 'lists.tgz' manually " "at: http://www.vision.caltech.edu/visipedia/CUB-200.html") if not os.path.exists(os.path.join(self.root, self.filename[:-4])): extract_archive(os.path.join(self.root, self.filename)) if not os.path.exists(os.path.join(self.root, self.metadata[:-4])): extract_archive(os.path.join(self.root, self.metadata)) if not self._check_integrity(): raise RuntimeError('Dataset corrupted')
def __init__(self, root, train=True, transform=None, target_transform=None, download=None): super(NABirds, self).__init__(root, transform=transform, target_transform=target_transform) if download is True: msg = ( "The dataset is no longer publicly accessible. You need to " "download the archives externally and place them in the root " "directory.") raise RuntimeError(msg) msg = ("The use of the download flag is deprecated, since the dataset " "is no longer publicly accessible.") warnings.warn(msg, RuntimeWarning) dataset_path = os.path.join(root, "nabirds") print(dataset_path) if not os.path.isdir(dataset_path): if not check_integrity(os.path.join(root, self.filename), self.md5): raise RuntimeError('Dataset not found or corrupted.') extract_archive(os.path.join(root, self.filename)) self.loader = default_loader self.train = train image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'), sep=' ', names=['img_id', 'filepath']) image_class_labels = pd.read_csv(os.path.join( dataset_path, 'image_class_labels.txt'), sep=' ', names=['img_id', 'target']) image_bounding_boxes = pd.read_csv( os.path.join(dataset_path, 'bounding_boxes.txt'), sep=' ', names=['img_id', 'x', 'y', 'w', 'h']) # Since the raw labels are non-continuous, map them to new ones self.label_map = get_continuous_class_map(image_class_labels['target']) train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img']) data = image_paths.merge(image_class_labels, on='img_id') data = data.merge(image_bounding_boxes, on='img_id') self.data = data.merge(train_test_split, on='img_id') # Load in the train / test split if self.train: self.data = self.data[self.data.is_training_img == 1] else: self.data = self.data[self.data.is_training_img == 0] # Load in the class data self.class_names = load_class_names(dataset_path) self.class_hierarchy = load_hierarchy(dataset_path) self.targets = None
def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): filename = "foo" file = f"{filename}{extension}" mocked = mocker.patch("torchvision.datasets.utils._decompress") utils.extract_archive(file, remove_finished=remove_finished) mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
def mocked_dataset(self, pre_extract=False, download=True, **kwargs): with self.mocked_root() as (root, data): if pre_extract: utils.extract_archive(os.path.join(root, data["archive"])) dataset = torchvision.datasets.STL10(root, download=download, **kwargs) yield dataset, data
def _download(self): if not os.path.isfile(os.path.join(self.root, self.filename)): print('Downloading...') download_url(self.url, root=self.root, filename=self.filename) print('Extracting...') extract_archive(os.path.join(self.root, self.filename)) else: print('Data zip already downloaded')
def download_urban(cfg: DictConfig, cwd: Path) -> None: urban_path = Path(cfg.urban_path) Path.mkdir(cwd / urban_path.name, parents=True, exist_ok=True) download_file_from_google_drive(file_id=cfg.urban_url, root=cwd / urban_path.parent, filename="UrbanSound8K.tar.gz") extract_archive(from_path=str(cwd / urban_path.parent / "UrbanSound8K.tar.gz"), remove_finished=True)
def __init__(self, root, split='train', mode='fine', target_type='instance', transform=None, target_transform=None, transforms=None): super(Cityscapes, self).__init__(root, transforms, transform, target_transform) self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' self.images_dir = os.path.join(self.root, 'leftImg8bit', split) self.targets_dir = os.path.join(self.root, self.mode, split) self.target_type = target_type self.split = split self.images = [] self.targets = [] verify_str_arg(mode, "mode", ("fine", "coarse")) if mode == "fine": valid_modes = ("train", "test", "val") else: valid_modes = ("train", "train_extra", "val") msg = ("Unknown value '{}' for argument split if mode is '{}'. " "Valid values are {{{}}}.") msg = msg.format(split, mode, iterable_to_str(valid_modes)) verify_str_arg(split, "split", valid_modes, msg) if not isinstance(target_type, list): self.target_type = [target_type] [verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) for value in self.target_type] if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): if split == 'train_extra': image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip')) else: image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip')) if self.mode == 'gtFine': target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip')) elif self.mode == 'gtCoarse': target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip')) if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): extract_archive(from_path=image_dir_zip, to_path=self.root) extract_archive(from_path=target_dir_zip, to_path=self.root) else: raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' ' specified "split" and "mode" are inside the "root" directory') for city in os.listdir(self.images_dir): img_dir = os.path.join(self.images_dir, city) target_dir = os.path.join(self.targets_dir, city) for file_name in os.listdir(img_dir): target_types = [] for t in self.target_type: target_name = file_name target_types.append(os.path.join(target_dir, target_name)) self.images.append(os.path.join(img_dir, file_name)) self.targets.append(target_types)
def download(self): url = "https://github.com/chiayewken/sutd-materials/releases/download/v0.1.0/bert-base-nli-mean-tokens.zip" path_zip = self.cache_dir / Path(url).name if not path_zip.exists(): download_url(url, self.cache_dir, filename=path_zip.name) dir_model = self.cache_dir / "bert" if not dir_model.exists(): extract_archive(str(path_zip), str(dir_model)) return dir_model
def download_model(model_path='./model'): file_id = '1LxEp7_Hm9RK4oc9UxJ8nu7KUKGFsa4i0' tgz_md5 = '90db23115ebec6b2c94499ac1c56ee59' filename = 'net.tgz' archive = os.path.join(model_path, filename) if not os.path.isfile(archive): print("Downloading pretrained model...") download_file_from_google_drive(file_id, model_path, filename, tgz_md5) extract_archive(archive, model_path)
def parse_devkit_archive(root, file=None): """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary file. Args: root (str): Root directory containing the devkit archive file (str, optional): Name of devkit archive. Defaults to 'ILSVRC2012_devkit_t12.tar.gz' """ import scipy.io as sio def parse_meta_mat(devkit_root): metafile = os.path.join(devkit_root, "data", "meta.mat") meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] nums_children = list(zip(*meta))[4] meta = [ meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0 ] idcs, wnids, classes = list(zip(*meta))[:3] classes = [tuple(clss.split(', ')) for clss in classes] idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} return idx_to_wnid, wnid_to_classes def parse_val_groundtruth_txt(devkit_root): file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") with open(file, 'r') as txtfh: val_idcs = txtfh.readlines() return [int(val_idx) for val_idx in val_idcs] @contextmanager def get_tmp_dir(): tmp_dir = tempfile.mkdtemp() try: yield tmp_dir finally: shutil.rmtree(tmp_dir) archive_meta = ImageNet.archive_meta["devkit"] if file is None: file = archive_meta[0] md5 = archive_meta[1] _verify_archive(root, file, md5) with get_tmp_dir() as tmp_dir: extract_archive(os.path.join(root, file), tmp_dir) devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root) val_idcs = parse_val_groundtruth_txt(devkit_root) val_wnids = [idx_to_wnid[idx] for idx in val_idcs] torch.save((wnid_to_classes, val_wnids), os.path.join(root, ImageNet.meta_file))