コード例 #1
0
    def __init__(self, root, download=False, **kwargs):
        self.root = root
        self.download = download
        self.CALTECH256_DRIVE_ID = '1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK'
        self.CALTECH256_ZIP_FILE = '256_ObjectCategories.tar'
        self.CALTECH256_ZIP_PATH = os.path.join(self.root,
                                                self.CALTECH256_ZIP_FILE)
        self.CALTECH256_FOLDER = os.path.join(self.root,
                                              '256_ObjectCategories')

        if not os.path.exists(self.CALTECH256_FOLDER):
            # Extract Caltech256 zip file
            if not os.path.exists(self.CALTECH256_ZIP_PATH):
                if not self.download:
                    raise (FileNotFoundError("Dataset files not found"))
                print("Downloading {} file...".format(
                    self.CALTECH256_ZIP_FILE))
                download_large_file_from_drive(self.CALTECH256_DRIVE_ID,
                                               self.CALTECH256_ZIP_PATH)
                print("Done!")
            print("Extracting {} file...".format(self.CALTECH256_ZIP_FILE))
            extract_archive(self.CALTECH256_ZIP_PATH, self.root)
            print("Done!")

        super(Caltech256, self).__init__(self.CALTECH256_FOLDER, **kwargs)

        self.samples_per_class = {
            k: self.targets.count(k)
            for k in self.class_to_idx.values()
        }
コード例 #2
0
    def test_extract_tar(self):
        def create_archive(root, ext, mode, content="this is the content"):
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
            archive = os.path.join(root, f"archive{ext}")

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

        for ext, mode in zip(['.tar', '.tar.gz', '.tgz'],
                             ['w', 'w:gz', 'w:gz']):
            with get_tmp_dir() as temp_dir:
                archive, file, content = create_archive(temp_dir, ext, mode)

                utils.extract_archive(archive, temp_dir)

                self.assertTrue(os.path.exists(file))

                with open(file, "r") as fh:
                    self.assertEqual(fh.read(), content)
コード例 #3
0
ファイル: cars.py プロジェクト: zhqiu/pytorch-fgvc-dataset
 def _download(self):
     print('Downloading...')
     for url, filename in self.file_list.values():
         download_url(url, root=self.root, filename=filename)
     print('Extracting...')
     archive = os.path.join(self.root, self.file_list['imgs'][1])
     extract_archive(archive)
コード例 #4
0
    def _download_arrange_data(
        self,
        download_url_address,
        put_data_dir,
        extract_to_dir=None,
        obtained_file_name=None,
    ):
        """ Download the raw data and arrange the data """
        # Extract to the same dir as the download dir
        if extract_to_dir is None:
            extract_to_dir = put_data_dir

        download_file_name = os.path.basename(download_url_address)
        download_file_path = os.path.join(put_data_dir, download_file_name)

        download_extracted_file_name = download_file_name.split(".")[0]
        download_extracted_dir_path = os.path.join(
            extract_to_dir, download_extracted_file_name)
        # Download the raw data if necessary
        if not self._exist_judgement(download_file_path):
            logging.info("Downloading the %s data.....", download_file_name)
            download_url(url=download_url_address,
                         root=put_data_dir,
                         filename=obtained_file_name)

        # Extract the data to the specific dir
        if ".zip" in download_file_name or ".tar.gz" in download_file_name:
            if not self._exist_judgement(download_extracted_dir_path):
                logging.info("Extracting data to %s dir.....", extract_to_dir)
                extract_archive(from_path=download_file_path,
                                to_path=extract_to_dir,
                                remove_finished=False)

        return download_extracted_file_name
コード例 #5
0
ファイル: deepfashion.py プロジェクト: yqGANs/GroupDNet
    def __init__(self, root, split='train', transform=None, target_transform=None, transforms=None):
        super(Deepfashion, self).__init__(root, transforms, transform, target_transform)
        self.images_dir = os.path.join(self.root, 'img')
        self.targets_dir = os.path.join(self.root, 'lbl')
        self.split = split
        self.images = []
        self.targets = []

        valid_modes = ("train", "test", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):

            image_dir_zip = os.path.join(self.root, '{}'.format('img.zip'))
            target_dir_zip = os.path.join(self.root, '{}'.format('lbl.zip'))

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
            else:
                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                   ' specified "split" and "mode" are inside the "root" directory')

        data_list = pd.read_csv(os.path.join(self.root, 'list_eval_partition.txt'), sep='\t', skiprows=1)
        data_list = data_list[data_list['evaluation_status'] == self.split]
        for image_path in data_list['image_name']:
            target_path = 'lbl/' + '/'.join(image_path.split('/')[1:])
            self.images.append(image_path)
            self.targets.append(target_path)
コード例 #6
0
    def test_extract_tar(self, extension, mode):
        def create_archive(root,
                           extension,
                           mode,
                           content="this is the content"):
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
            archive = os.path.join(root, f"archive{extension}")

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

        with get_tmp_dir() as temp_dir:
            archive, file, content = create_archive(temp_dir, extension, mode)

            utils.extract_archive(archive, temp_dir)

            assert os.path.exists(file)

            with open(file, "r") as fh:
                assert fh.read() == content
コード例 #7
0
 def initialize_folder(self, **kwargs):
     print('initialize folder')
     mode_list = [
         'train', 'valid'
     ] if self.valid_set and 'valid' in self.url.keys() else ['train']
     for mode in mode_list:
         zip_path = os.path.join(self.folder_path,
                                 f'{self.name}_{mode}_store.zip')
         if os.path.isfile(zip_path):
             print('{yellow}Uncompress file{reset}: '.format(**ansi),
                   zip_path)
             extract_archive(from_path=zip_path, to_path=self.folder_path)
             print('{green}Uncompress finished{reset}: '.format(**ansi),
                   f'{zip_path}')
             print()
             continue
         tar_path = os.path.join(self.folder_path,
                                 f'{self.name}_{mode}_store.zip')
         self.download_and_extract_archive(mode=mode)
         os.rename(
             os.path.join(self.folder_path, self.org_folder_name[mode]),
             os.path.join(self.folder_path, mode))
         try:
             dirname = os.path.dirname(self.org_folder_name[mode])
             if dirname:
                 shutil.rmtree(os.path.join(self.folder_path, dirname))
         except FileNotFoundError:
             pass
コード例 #8
0
	def __init__(self, root, download=False, **kwargs):
		self.root = root
		self.download = download
		self.ETH80_URL = 'https://github.com/Kai-Xuan/ETH-80/archive/master.zip'
		self.ETH80_ZIP_FILE = 'ETH-80-master.zip'
		self.ETH80_ZIP_PATH = os.path.join(self.root, self.ETH80_ZIP_FILE)
		self.ETH80_FOLDER = os.path.join(self.root, 'ETH-80-master')
		self.ETH80_FOLDER_ORGANIZED = os.path.join(self.root, 'ETH-80-organized')
		
		if not os.path.exists(self.ETH80_FOLDER):
			# Extract ETH80 zip file
			if not os.path.exists(self.ETH80_ZIP_PATH):
				if not self.download: raise(FileNotFoundError("Dataset files not found"))
				download_url(self.ETH80_URL, self.root, self.ETH80_ZIP_FILE)
			print("Extracting {} file...".format(self.ETH80_ZIP_FILE))
			extract_archive(self.ETH80_ZIP_PATH, self.root)
			print("Done!")
		 
		# Organize images in a directory structure consistent with that required by ImageFolder. Also, ignore "map" images.
		if not os.path.exists(self.ETH80_FOLDER_ORGANIZED):
			print("Organizing dataset images...")
			os.makedirs(self.ETH80_FOLDER_ORGANIZED)
			for i in range(1, 9):
				dest = os.path.join(self.ETH80_FOLDER_ORGANIZED, str(i))
				os.makedirs(dest)
				for j in range(1, 11):
					src = os.path.join(self.ETH80_FOLDER, str(i), str(j))
					for file in os.listdir(src):
						if file != 'maps': shutil.copy(os.path.join(src, file), os.path.join(dest, file))
			print("Done!")
			
		super(ETH80, self).__init__(self.ETH80_FOLDER_ORGANIZED, **kwargs)
コード例 #9
0
    def extract_downloaded_files(download_root: str, extract_root: str):
        '''
        :param download_root: Root directory path which saves downloaded dataset files
        :type download_root: str
        :param extract_root: Root directory path which saves extracted files from downloaded files
        :type extract_root: str
        :return: None

        This function defines how to extract download files.
        '''
        temp_ext_dir = os.path.join(download_root, 'temp_ext')
        os.mkdir(temp_ext_dir)
        print(f'Mkdir [{temp_ext_dir}].')
        extract_archive(os.path.join(download_root, 'navgesture-sit.zip'),
                        temp_ext_dir)
        with ThreadPoolExecutor(
                max_workers=min(multiprocessing.cpu_count(), 4)) as tpe:
            for zip_file in os.listdir(temp_ext_dir):
                if os.path.splitext(zip_file)[1] == '.zip':
                    zip_file = os.path.join(temp_ext_dir, zip_file)
                    print(f'Extract [{zip_file}] to [{extract_root}].')
                    tpe.submit(extract_archive, zip_file, extract_root)

        shutil.rmtree(temp_ext_dir)
        print(f'Rmtree [{temp_ext_dir}].')
コード例 #10
0
    def __init__(self, root, mode, transform=None, target_transform=None):
        super(NABirds, self).__init__(root,
                                      transform=transform,
                                      target_transform=target_transform)

        dataset_path = root
        if not os.path.isdir(dataset_path):
            if not check_integrity(os.path.join(root, self.filename),
                                   self.md5):
                raise RuntimeError('Dataset not found or corrupted.')
            extract_archive(os.path.join(root, self.filename))
        self.root = os.path.expanduser(root)
        self.loader = default_loader
        # self.transform = transform
        # Load in the class data
        self.class_names = load_class_names(root)
        self.class_hierarchy = load_hierarchy(root)

        if mode == 'train':
            self.data = NABirds.TRAIN_DATA
            self.targets = NABirds.TRAIN_TARGETS
        elif mode == 'query':
            self.data = NABirds.QUERY_DATA
            self.targets = NABirds.QUERY_TARGETS
        elif mode == 'retrieval':
            self.data = NABirds.RETRIEVAL_DATA
            self.targets = NABirds.RETRIEVAL_TARGETS
        else:
            raise ValueError(r'Invalid arguments: mode, can\'t load dataset!')
コード例 #11
0
    def test_paths_benchmark(self):
        download_url(
            'https://storage.googleapis.com/mledu-datasets/'
            'cats_and_dogs_filtered.zip', './data',
            'cats_and_dogs_filtered.zip')
        archive_name = os.path.join('./data', 'cats_and_dogs_filtered.zip')
        extract_archive(archive_name, to_path='./data/')

        dirpath = "./data/cats_and_dogs_filtered/train"

        train_experiences = []
        for rel_dir, label in zip(["cats", "dogs"], [0, 1]):
            filenames_list = os.listdir(os.path.join(dirpath, rel_dir))

            experience_paths = []
            for name in filenames_list:
                instance_tuple = (os.path.join(dirpath, rel_dir, name), label)
                experience_paths.append(instance_tuple)
            train_experiences.append(experience_paths)

        generic_scenario = paths_benchmark(
            train_experiences,
            [train_experiences[0]],  # Single test set
            task_labels=[0, 0],
            complete_test_set_only=True,
            train_transform=ToTensor(),
            eval_transform=ToTensor())

        self.assertEqual(2, len(generic_scenario.train_stream))
        self.assertEqual(1, len(generic_scenario.test_stream))
コード例 #12
0
    def download(self):
        path = os.path.join(self.root)
        link = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'

        os.makedirs(path, exist_ok=True)
        download_url(link, path)
        extract_archive(os.path.join(path, 'tiny-imagenet-200.zip'))
コード例 #13
0
    def test_filelist_benchmark(self):
        download_url(
            'https://storage.googleapis.com/mledu-datasets/'
            'cats_and_dogs_filtered.zip', './data',
            'cats_and_dogs_filtered.zip')
        archive_name = os.path.join('./data', 'cats_and_dogs_filtered.zip')
        extract_archive(archive_name, to_path='./data/')

        dirpath = "./data/cats_and_dogs_filtered/train"

        for filelist, dir, label in zip(
            ["train_filelist_00.txt", "train_filelist_01.txt"],
            ["cats", "dogs"], [0, 1]):
            # First, obtain the list of files
            filenames_list = os.listdir(os.path.join(dirpath, dir))
            with open(filelist, "w") as wf:
                for name in filenames_list:
                    wf.write("{} {}\n".format(os.path.join(dir, name), label))

        generic_scenario = filelist_benchmark(
            dirpath, ["train_filelist_00.txt", "train_filelist_01.txt"],
            ["train_filelist_00.txt"],
            task_labels=[0, 0],
            complete_test_set_only=True,
            train_transform=ToTensor(),
            eval_transform=ToTensor())

        self.assertEqual(2, len(generic_scenario.train_stream))
        self.assertEqual(1, len(generic_scenario.test_stream))
コード例 #14
0
    def verify_dataset(self, split):
        if not os.path.isdir(os.path.join(self.root, self.images_dir)):
            if self.image_mode == 'gtFine':
                image_dir_zip = os.path.join(
                    self.root, '{}_trainvaltest.zip'.format(self.image_mode))
            elif self.image_mode == 'gtCoarse':
                image_dir_zip = os.path.join(self.root,
                                             '{}.zip'.format(self.image_mode))
            else:
                if split == 'train_extra':
                    if self.image_mode == 'leftImg8bit':
                        image_dir_zip = os.path.join(
                            self.root,
                            '{}_trainextra.zip'.format(self.image_mode))
                    else:
                        image_dir_zip = os.path.join(
                            self.root, '{}_trainextra_{}.zip'.format(
                                *self.image_mode.split('_')))
                else:
                    if self.image_mode == 'leftImg8bit':
                        image_dir_zip = os.path.join(
                            self.root,
                            '{}_trainvaltest.zip'.format(self.image_mode))
                    else:
                        image_dir_zip = os.path.join(
                            self.root, '{}_trainvaltest_{}.zip'.format(
                                *self.image_mode.split('_')))

            if os.path.isfile(image_dir_zip):
                extract_archive(from_path=image_dir_zip, to_path=self.root)
            else:
                raise RuntimeError(
                    'Dataset not found or incomplete. Please make sure all required folders for the'
                    ' specified "split" and "image_mode" are inside the "root" directory'
                )
コード例 #15
0
    def download(self):
        path = os.path.join(self.root, 'notMNIST')
        link = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_small.tar.gz'

        os.makedirs(path, exist_ok=True)
        download_url(link, path)
        extract_archive(os.path.join(path, 'notMNIST_small.tar.gz'))
コード例 #16
0
def parse_train_archive(root, file=None, folder="train"):
    """Parse the train images archive of the ImageNet2012 classification
        dataset and prepare it for usage with the ImageNet dataset.

    Args:
        root (str): Root directory containing the train images archive
        file (str, optional): Name of train images archive. Defaults to
            'ILSVRC2012_img_train.tar'
        folder (str, optional): Optional name for train images folder.
            Defaults to 'train'
    """
    archive_meta = ImageNet.archive_meta["train"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]

    _verify_archive(root, file, md5)

    train_root = os.path.join(root, folder)
    extract_archive(os.path.join(root, file), train_root)

    archives = [
        os.path.join(train_root, archive) for archive in os.listdir(train_root)
    ]
    for archive in archives:
        extract_archive(archive,
                        os.path.splitext(archive)[0],
                        remove_finished=False)
コード例 #17
0
ファイル: base.py プロジェクト: tangzhenjie/pytorch-enhance
    def download_google_drive(self, data_dir: str, filename: str) -> None:
        """Download dataset

        Parameters
        ----------
        data_dir : str
            Path to base dataset directory
        filename : str
            Filename of google drive file being downloaded

        Returns
        -------
        None

        """
        if not os.path.exists(data_dir):
            os.mkdir(data_dir)

        if not os.path.exists(self.root_dir):

            download_file_from_google_drive(file_id=self.url,
                                            root=data_dir,
                                            filename=filename)
            extract_archive(from_path=os.path.join(data_dir, filename),
                            to_path=data_dir,
                            remove_finished=True)
コード例 #18
0
    def _extract_archive(
        self,
        path: Union[str, Path],
        sub_directory: str = None,
        remove_archive: bool = False,
    ) -> Path:
        """
        Utility method that can be used to extract an archive.

        :param path: The complete path to the archive (for instance obtained
            by calling `_download_file`).
        :param sub_directory: The name of the sub directory where to extract the
            archive. Can be None, which means that the archive will be extracted
            in the root. Beware that some archives already have a root directory
            inside of them, in which case it's probably better to use None here.
            Defaults to None.
        :param remove_archive: If True, the archive will be deleted after a
            successful extraction. Defaults to False.
        :return:
        """

        if sub_directory is None:
            extract_root = self.root
        else:
            extract_root = self.root / sub_directory

        extract_archive(str(path),
                        to_path=str(extract_root),
                        remove_finished=remove_archive)

        return extract_root
コード例 #19
0
ファイル: imagenet.py プロジェクト: kwotsin/mimicry
def prepare_train_folder(folder):
    for archive in [
            os.path.join(folder, archive) for archive in os.listdir(folder)
    ]:
        extract_archive(archive,
                        os.path.splitext(archive)[0],
                        remove_finished=True)
コード例 #20
0
def parse_val_archive(root, file=None, wnids=None, folder="val"):
    """Parse the validation images archive of the ImageNet2012 classification dataset
    and prepare it for usage with the ImageNet dataset.

    Args:
        root (str): Root directory containing the validation images archive
        file (str, optional): Name of validation images archive. Defaults to
            'ILSVRC2012_img_val.tar'
        wnids (list, optional): List of WordNet IDs of the validation images. If None
            is given, the IDs are loaded from the meta file in the root directory
        folder (str, optional): Optional name for validation images folder. Defaults to
            'val'
    """
    archive_meta = ARCHIVE_META["val"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]
    if wnids is None:
        wnids = load_meta_file(root)[1]

    _verify_archive(root, file, md5)

    val_root = os.path.join(root, folder)
    extract_archive(os.path.join(root, file), val_root)

    images = sorted(
        [os.path.join(val_root, image) for image in os.listdir(val_root)])

    for wnid in set(wnids):
        os.mkdir(os.path.join(val_root, wnid))

    for wnid, img_file in zip(wnids, images):
        shutil.move(img_file,
                    os.path.join(val_root, wnid, os.path.basename(img_file)))
コード例 #21
0
    def __init__(self,
                 root,
                 train=True,
                 transform=None,
                 loader=default_loader,
                 download=False):

        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = loader
        self.train = train
        self.log = logging.getLogger("avalanche")

        if download:
            self.log.error(
                "Download is not supported for this Dataset."
                "You need to download 'images.tgz' and 'lists.tgz' manually "
                "at: http://www.vision.caltech.edu/visipedia/CUB-200.html")

        if not os.path.exists(os.path.join(self.root, self.filename[:-4])):
            extract_archive(os.path.join(self.root, self.filename))
        if not os.path.exists(os.path.join(self.root, self.metadata[:-4])):
            extract_archive(os.path.join(self.root, self.metadata))

        if not self._check_integrity():
            raise RuntimeError('Dataset corrupted')
コード例 #22
0
    def __init__(self,
                 root,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=None):
        super(NABirds, self).__init__(root,
                                      transform=transform,
                                      target_transform=target_transform)
        if download is True:
            msg = (
                "The dataset is no longer publicly accessible. You need to "
                "download the archives externally and place them in the root "
                "directory.")
            raise RuntimeError(msg)
        msg = ("The use of the download flag is deprecated, since the dataset "
               "is no longer publicly accessible.")
        warnings.warn(msg, RuntimeWarning)

        dataset_path = os.path.join(root, "nabirds")
        print(dataset_path)
        if not os.path.isdir(dataset_path):
            if not check_integrity(os.path.join(root, self.filename),
                                   self.md5):
                raise RuntimeError('Dataset not found or corrupted.')
            extract_archive(os.path.join(root, self.filename))
        self.loader = default_loader
        self.train = train

        image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
                                  sep=' ',
                                  names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(
            dataset_path, 'image_class_labels.txt'),
                                         sep=' ',
                                         names=['img_id', 'target'])
        image_bounding_boxes = pd.read_csv(
            os.path.join(dataset_path, 'bounding_boxes.txt'),
            sep=' ',
            names=['img_id', 'x', 'y', 'w', 'h'])
        # Since the raw labels are non-continuous, map them to new ones
        self.label_map = get_continuous_class_map(image_class_labels['target'])
        train_test_split = pd.read_csv(os.path.join(dataset_path,
                                                    'train_test_split.txt'),
                                       sep=' ',
                                       names=['img_id', 'is_training_img'])
        data = image_paths.merge(image_class_labels, on='img_id')
        data = data.merge(image_bounding_boxes, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')
        # Load in the train / test split
        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

        # Load in the class data
        self.class_names = load_class_names(dataset_path)
        self.class_hierarchy = load_hierarchy(dataset_path)
        self.targets = None
コード例 #23
0
    def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
        filename = "foo"
        file = f"{filename}{extension}"

        mocked = mocker.patch("torchvision.datasets.utils._decompress")
        utils.extract_archive(file, remove_finished=remove_finished)

        mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
コード例 #24
0
ファイル: test_datasets.py プロジェクト: truongkyle/vision
 def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
     with self.mocked_root() as (root, data):
         if pre_extract:
             utils.extract_archive(os.path.join(root, data["archive"]))
         dataset = torchvision.datasets.STL10(root,
                                              download=download,
                                              **kwargs)
         yield dataset, data
コード例 #25
0
 def _download(self):
     if not os.path.isfile(os.path.join(self.root, self.filename)):
         print('Downloading...')
         download_url(self.url, root=self.root, filename=self.filename)
         print('Extracting...')
         extract_archive(os.path.join(self.root, self.filename))
     else:
         print('Data zip already downloaded')
コード例 #26
0
 def download_urban(cfg: DictConfig, cwd: Path) -> None:
     urban_path = Path(cfg.urban_path)
     Path.mkdir(cwd / urban_path.name, parents=True, exist_ok=True)
     download_file_from_google_drive(file_id=cfg.urban_url,
                                     root=cwd / urban_path.parent,
                                     filename="UrbanSound8K.tar.gz")
     extract_archive(from_path=str(cwd / urban_path.parent /
                                   "UrbanSound8K.tar.gz"),
                     remove_finished=True)
コード例 #27
0
    def __init__(self, root, split='train', mode='fine', target_type='instance',
                 transform=None, target_transform=None, transforms=None):
        super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(self.root, self.mode, split)
        self.target_type = target_type
        self.split = split
        self.images = []
        self.targets = []

        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)

        if not isinstance(target_type, list):
            self.target_type = [target_type]
        [verify_str_arg(value, "target_type",
                        ("instance", "semantic", "polygon", "color"))
         for value in self.target_type]

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):

            if split == 'train_extra':
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
            else:
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))

            if self.mode == 'gtFine':
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
            elif self.mode == 'gtCoarse':
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
            else:
                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                   ' specified "split" and "mode" are inside the "root" directory')

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_types = []
                for t in self.target_type:
                    target_name = file_name
                    target_types.append(os.path.join(target_dir, target_name))

                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(target_types)
コード例 #28
0
ファイル: datasets.py プロジェクト: yunyikang/DL_project
 def download(self):
     url = "https://github.com/chiayewken/sutd-materials/releases/download/v0.1.0/bert-base-nli-mean-tokens.zip"
     path_zip = self.cache_dir / Path(url).name
     if not path_zip.exists():
         download_url(url, self.cache_dir, filename=path_zip.name)
     dir_model = self.cache_dir / "bert"
     if not dir_model.exists():
         extract_archive(str(path_zip), str(dir_model))
     return dir_model
コード例 #29
0
def download_model(model_path='./model'):
    file_id = '1LxEp7_Hm9RK4oc9UxJ8nu7KUKGFsa4i0'
    tgz_md5 = '90db23115ebec6b2c94499ac1c56ee59'
    filename = 'net.tgz'
    archive = os.path.join(model_path, filename)
    if not os.path.isfile(archive):
        print("Downloading pretrained model...")
        download_file_from_google_drive(file_id, model_path, filename, tgz_md5)
    extract_archive(archive, model_path)
コード例 #30
0
def parse_devkit_archive(root, file=None):
    """Parse the devkit archive of the ImageNet2012 classification dataset and save
    the meta information in a binary file.

    Args:
        root (str): Root directory containing the devkit archive
        file (str, optional): Name of devkit archive. Defaults to
            'ILSVRC2012_devkit_t12.tar.gz'
    """
    import scipy.io as sio

    def parse_meta_mat(devkit_root):
        metafile = os.path.join(devkit_root, "data", "meta.mat")
        meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
        nums_children = list(zip(*meta))[4]
        meta = [
            meta[idx] for idx, num_children in enumerate(nums_children)
            if num_children == 0
        ]
        idcs, wnids, classes = list(zip(*meta))[:3]
        classes = [tuple(clss.split(', ')) for clss in classes]
        idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
        wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
        return idx_to_wnid, wnid_to_classes

    def parse_val_groundtruth_txt(devkit_root):
        file = os.path.join(devkit_root, "data",
                            "ILSVRC2012_validation_ground_truth.txt")
        with open(file, 'r') as txtfh:
            val_idcs = txtfh.readlines()
        return [int(val_idx) for val_idx in val_idcs]

    @contextmanager
    def get_tmp_dir():
        tmp_dir = tempfile.mkdtemp()
        try:
            yield tmp_dir
        finally:
            shutil.rmtree(tmp_dir)

    archive_meta = ImageNet.archive_meta["devkit"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]

    _verify_archive(root, file, md5)

    with get_tmp_dir() as tmp_dir:
        extract_archive(os.path.join(root, file), tmp_dir)

        devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
        idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
        val_idcs = parse_val_groundtruth_txt(devkit_root)
        val_wnids = [idx_to_wnid[idx] for idx in val_idcs]

        torch.save((wnid_to_classes, val_wnids),
                   os.path.join(root, ImageNet.meta_file))