Пример #1
0
    def download_google_drive(self, data_dir: str, filename: str) -> None:
        """Download dataset

        Parameters
        ----------
        data_dir : str
            Path to base dataset directory
        filename : str
            Filename of google drive file being downloaded

        Returns
        -------
        None

        """
        if not os.path.exists(data_dir):
            os.mkdir(data_dir)

        if not os.path.exists(self.root_dir):

            download_file_from_google_drive(file_id=self.url,
                                            root=data_dir,
                                            filename=filename)
            extract_archive(from_path=os.path.join(data_dir, filename),
                            to_path=data_dir,
                            remove_finished=True)
Пример #2
0
def load_smoothing_imagenet_model(noise_level, **load_args):
    import tarfile

    # load their checkpoint
    folder = os.path.join(torch.hub._get_torch_home(), "checkpoints")
    os.makedirs(folder, exist_ok=True)

    tar_fn = "locuslab-smoothing.tar"
    if not os.path.exists(os.path.join(folder, tar_fn)):
        from torchvision.datasets import utils

        utils.download_file_from_google_drive(
            "1h_TpbXm5haY5f-l4--IKylmdz6tvPoR4", folder, filename=tar_fn)

    with tarfile.open(os.path.join(folder, tar_fn), "r") as tar:
        fn = f"models/imagenet/resnet50/noise_{noise_level:.2f}/checkpoint.pth.tar"
        checkpoint = torch.load(tar.extractfile(fn), **load_args)

    # they checkpointed the model inside Sequential(DataParallel(model))
    def rewrite(k):
        assert k.startswith("1.module.")
        return k[9:]

    assert checkpoint["arch"] == "resnet50"
    sd = {rewrite(k): v for k, v in checkpoint["state_dict"].items()}

    model = models.resnet50(pretrained=False)
    model.load_state_dict(sd)
    return model
Пример #3
0
 def download(self):
     if self._check_integrity():
         print('Files already downloaded and verified')
         return
     download_file_from_google_drive(file_id=self.gdrive_id,
                                     root=str(self.download_root),
                                     filename=self.filename,
                                     md5=self.file_md5)
Пример #4
0
 def download_urban(cfg: DictConfig, cwd: Path) -> None:
     urban_path = Path(cfg.urban_path)
     Path.mkdir(cwd / urban_path.name, parents=True, exist_ok=True)
     download_file_from_google_drive(file_id=cfg.urban_url,
                                     root=cwd / urban_path.parent,
                                     filename="UrbanSound8K.tar.gz")
     extract_archive(from_path=str(cwd / urban_path.parent /
                                   "UrbanSound8K.tar.gz"),
                     remove_finished=True)
Пример #5
0
def download_model(model_path='./model'):
    file_id = '1LxEp7_Hm9RK4oc9UxJ8nu7KUKGFsa4i0'
    tgz_md5 = '90db23115ebec6b2c94499ac1c56ee59'
    filename = 'net.tgz'
    archive = os.path.join(model_path, filename)
    if not os.path.isfile(archive):
        print("Downloading pretrained model...")
        download_file_from_google_drive(file_id, model_path, filename, tgz_md5)
    extract_archive(archive, model_path)
Пример #6
0
 def donwload_(file_id, root=None):
     if root is None:
         root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             "../../", "data/data_files")
     root = os.path.expanduser(root)
     os.makedirs(root, exist_ok=True)
     dtutil.download_file_from_google_drive(
         file_id=file_id, root=root, filename=f"adv_data_{file_id}.pkl")
     return os.path.join(os.path.abspath(root), f"adv_data_{file_id}.pkl")
Пример #7
0
    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)
Пример #8
0
    def download(self):
        import zipfile

        for (file_id, md5, filename) in self.file_list:
            download_file_from_google_drive(
                file_id, os.path.join(self.root, self.base_folder), filename)

        with zipfile.ZipFile(
                os.path.join(self.root, self.base_folder,
                             "img_align_celeba.zip"), "r") as f:
            f.extractall(os.path.join(self.root, self.base_folder))
Пример #9
0
    def download(self):
        import zipfile
        import shutil
        import glob
        from tqdm import tqdm

        if self._check_integrity():
            return

        zip_filename = os.path.join(self.root, self.zip_filename)
        if not os.path.isfile(zip_filename):
            download_file_from_google_drive(self.gdrive_id,
                                            self.root,
                                            self.zip_filename,
                                            md5=self.zip_md5)

        zip_foldername = os.path.join(self.root, self.image_folder)
        if not os.path.isdir(zip_foldername):
            with zipfile.ZipFile(zip_filename, 'r') as f:
                for member in tqdm(f.infolist(), desc='Extracting '):
                    try:
                        f.extract(member, self.root)
                    except zipfile.BadZipFile:
                        print('Error: Zip file is corrupted')

        for split in ['train', 'val', 'test']:
            filename = os.path.join(self.root, self.filename.format(split))
            if os.path.isfile(filename):
                continue

            labels = get_asset(self.folder, '{0}.json'.format(split))
            labels_filename = os.path.join(self.root,
                                           self.filename_labels.format(split))
            with open(labels_filename, 'w') as f:
                json.dump(labels, f)

            image_folder = os.path.join(zip_foldername, split)

            with h5py.File(filename, 'w') as f:
                group = f.create_group('datasets')
                dtype = h5py.special_dtype(vlen=np.uint8)
                for i, label in enumerate(tqdm(labels, desc=filename)):
                    images = glob.glob(
                        os.path.join(image_folder, label, '*.png'))
                    images.sort()
                    dataset = group.create_dataset(label, (len(images), ),
                                                   dtype=dtype)
                    for i, image in enumerate(images):
                        with open(image, 'rb') as f:
                            array = bytearray(f.read())
                            dataset[i] = np.asarray(array, dtype=np.uint8)

        if os.path.isdir(zip_foldername):
            shutil.rmtree(zip_foldername)
Пример #10
0
def get_coqgym_container(path):
    from torchvision.datasets.utils import download_file_from_google_drive, extract_archive

    data_root_path = path.expanduser()

    # https://drive.google.com/file/d/1dzNR8uj5fpo9bN40xfJo_EkFR0vr1FAt/view?usp=sharing
    file_id = '1dzNR8uj5fpo9bN40xfJo_EkFR0vr1FAt'
    path_to_container = data_root_path / Path('coq_gym.simg')
    # if zip not there re-download it
    if not path_to_container.exists():
        download_file_from_google_drive(file_id, path, path_to_container)
Пример #11
0
    def download(self):
        import zipfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        for (file_id, md5, filename) in self.file_list:
            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)

        with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
            f.extractall(os.path.join(self.root, self.base_folder))
Пример #12
0
    def download(self):
        import tarfile
        import shutil
        import glob
        from tqdm import tqdm

        if self._check_integrity():
            return

        tgz_filename = os.path.join(self.root, self.tgz_filename)
        if not os.path.isfile(tgz_filename):
            download_file_from_google_drive(self.gdrive_id,
                                            self.root,
                                            self.tgz_filename,
                                            md5=self.tgz_md5)

        tgz_filename = os.path.join(self.root, self.tgz_filename)
        with tarfile.open(tgz_filename, 'r') as f:
            f.extractall(self.root)
        image_folder = os.path.join(self.root, self.image_folder)

        for split in ['train', 'val', 'test']:
            filename = os.path.join(self.root, self.filename.format(split))
            if os.path.isfile(filename):
                continue

            labels = get_asset(self.folder, '{0}.json'.format(split))
            labels_filename = os.path.join(self.root,
                                           self.filename_labels.format(split))
            with open(labels_filename, 'w') as f:
                json.dump(labels, f)

            with h5py.File(filename, 'w') as f:
                group = f.create_group('datasets')
                dtype = h5py.special_dtype(vlen=np.uint8)
                for i, label in enumerate(tqdm(labels, desc=filename)):
                    images = glob.glob(
                        os.path.join(image_folder, label, '*.jpg'))
                    images.sort()
                    dataset = group.create_dataset(label, (len(images), ),
                                                   dtype=dtype)
                    for i, image in enumerate(images):
                        with open(image, 'rb') as f:
                            array = bytearray(f.read())
                            dataset[i] = np.asarray(array, dtype=np.uint8)

        tar_folder, _ = os.path.splitext(tgz_filename)
        if os.path.isdir(tar_folder):
            shutil.rmtree(tar_folder)

        attributes_filename = os.path.join(self.root, 'attributes.txt')
        if os.path.isfile(attributes_filename):
            os.remove(attributes_filename)
Пример #13
0
def download_cub(extract_root='./data'):
    file_id = '10A3CXDCYuGSdhAv9aFk-OdNpPbH-2aVK'
    tgz_md5 = '97eceeb196236b17998738112f37df78'
    filename = 'CUB_200_2011.tgz'
    archive = os.path.join(extract_root, filename)
    if not os.path.isfile(archive):
        print("Downloading CUB 200 dataset...")
        download_file_from_google_drive(file_id, extract_root, filename,
                                        tgz_md5)
    checkFile = os.path.join(extract_root, 'attributes.txt')
    if not os.path.isfile(checkFile):
        print("Extracting {} to {}".format(archive, extract_root))
        extract_archive(archive, extract_root)
Пример #14
0
 def _download(self):
     download_file_from_google_drive(
         "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45",
         self.root,
         filename="CUB_200_2011.tgz",
     )
     download_file_from_google_drive(
         "1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP",
         self.root,
         filename="segmentations.tgz",
     )
     extract_archive(os.path.join(self.root, "CUB_200_2011.tgz"))
     extract_archive(os.path.join(self.root, "segmentations.tgz"))
Пример #15
0
def download_G(root='checkpoints'):
    """Downloads a 128x128 BigGAN checkpoint to use for direction discovery."""
    # This is the corresponding file ID for the PyTorch BigGAN 138k checkpoint available at the following URL:
    # https://drive.google.com/file/d/1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW/view
    ID = '1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW'
    path = f'{root}/138k'
    if not os.path.isdir(path):
        zip_path = f'{path}.zip'
        print(f'Downloading BigGAN checkpoint directory to {path}')
        download_file_from_google_drive('1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW', root)
        shutil.move(f'{root}/{ID}', zip_path)
        extract_archive(zip_path, remove_finished=True)
    else:
        print(f'Resuming from checkpoint at {path}.')
Пример #16
0
    def download(self) -> None:
        """Download file from Google drive."""
        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        filename = self.dataset_name + '.tar.gz'
        download_file_from_google_drive(file_id=self.google_drive_id,
                                        root=self.root,
                                        filename=filename,
                                        md5=self.zip_md5)

        archive = os.path.join(self.root, filename)
        print("Extracting {} to {}".format(archive, self.root))
        extract_archive(archive, self.root)
def _download(base: Path) -> None:
    """Attempt to download data if files cannot be found in the base folder."""
    import zipfile

    from torchvision.datasets.utils import download_file_from_google_drive

    if _check_integrity(base):
        print("Files already downloaded and verified")
        return

    download_file_from_google_drive(_FILE_ID, str(base), _ZIP_FILE)

    fpath = base / _ZIP_FILE
    with zipfile.ZipFile(fpath, "r") as fhandle:
        fhandle.extractall(str(base))
Пример #18
0
def _download(base: Path) -> None:
    """Attempt to download data if files cannot be found in the base folder."""
    if not common.TORCHVISION_AVAILABLE:
        raise RuntimeError("Need torchvision to download data.")
    import zipfile

    from torchvision.datasets.utils import download_file_from_google_drive

    if _check_integrity(base):
        print("Files already downloaded and verified")
        return

    for (file_id, md5, filename) in CELEBA_FILE_LIST:
        download_file_from_google_drive(file_id, str(base), filename, md5)

    with zipfile.ZipFile(base / "img_align_celeba.zip", "r") as fhandle:
        fhandle.extractall(str(base))
Пример #19
0
    def __init__(self,
                 root,
                 download=False,
                 loader=default_loader,
                 extensions=None,
                 transform=None,
                 target_transform=None,
                 transform_aug=None,
                 num_trans_aug=1,
                 show=False):
        if extensions is None:
            extensions = IMG_EXTENSIONS

        if download:
            id = "1w6HkrxUQGmZw42ReciF_Gt_D4fz266qv"
            filename = root.split("/")[-1]
            root_root = root[0:-len(filename)]
            download_file_from_google_drive(id, root_root,
                                            "{}.zip".format(filename))
            extract_archive("{}/{}.zip".format(root_root, filename))

        classes, class_to_idx = self._find_classes(root)
        samples = make_dataset(root, class_to_idx, extensions)
        if len(samples) == 0:
            raise (RuntimeError("Found 0 files in subfolders of: " + root +
                                "\n"
                                "Supported extensions are: " +
                                ",".join(extensions)))

        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

        self.transform = transform
        self.target_transform = target_transform
        self.transform_aug = transform_aug
        self.num_trans_aug = num_trans_aug

        self.imgs = self.samples
        self.show = show
Пример #20
0
def download_func(
    file_id,
    final_name,
    cache_dir="./",
    extract=False,
):
    # if cache_dir is None :
    # cache_dir ="./"# os.path.dirname(final_name)

    dtutil.download_file_from_google_drive(file_id=file_id,
                                           root=cache_dir,
                                           filename=final_name)
    # net.load_state_dict(torch.load(f"{cache_dir}/torch_best_{final_name}_{file_id}.pth"))
    if extract:
        from_path = os.path.join(cache_dir, final_name)
        to_path = os.path.join(cache_dir)
        print("from -->to ", from_path, to_path)
        dtutil.extract_archive(from_path, to_path, remove_finished=False)
Пример #21
0
 def download_and_extract_archive(self, mode: str):
     file_name = f'{self.name}_{mode}{self.ext[mode]}'
     file_path = os.path.normpath(os.path.join(self.folder_path, file_name))
     md5 = self.md5.get(mode)
     if not check_integrity(file_path, md5=md5):
         prints('{yellow}Downloading Dataset{reset} '.format(**ansi),
                f'{self.name} {mode:5s}: {file_path}',
                indent=10)
         download_file_from_google_drive(file_id=self.url[mode],
                                         root=self.folder_path,
                                         filename=file_name,
                                         md5=md5)
         print('{upline}{clear_line}'.format(**ansi))
     else:
         prints('{yellow}File Already Exists{reset}: '.format(**ansi),
                file_path,
                indent=10)
     extract_archive(from_path=file_path, to_path=self.folder_path)
Пример #22
0
def get_net(name="",
            dt_name=None,
            device=torch.device("cuda"),
            use_cache=True,
            cache_dir=default_trained_dir,
            prefix_lambda=lambda x:f"torch_best_{x[0]}_{x[1]}.pth",
            is_masked=False,
            ):
    
    name = map_dict.get(name,name)
        
    assert  name in name_list ,        print (f"expect the name contained in {name_list}, but get {name}" )
    
    dt_name =None if "_" not in  name else name.split("_")[-1]  
    name =None if "_" not in  name else name.split("_")[0]  

    if is_masked:
        net = _get_masked_net(name=name,dt_name=dt_name)
    else:
        net = _get_net(name=name,dt_name=dt_name)
        
    assert net is not None , f"name =={name},{dt_name} net none"


    if use_cache:
        import torchvision.datasets.utils as dtutil
        final_name = f"{name}_{dt_name}"
        (file_id,filename) = file_id_list[final_name]
        #filename = prefix_lambda((final_name,file_id))
        dtutil.download_file_from_google_drive(file_id=file_id, root=cache_dir, filename=filename)
        net.load_state_dict(torch.load(f"{cache_dir}/{filename}"))

        setattr(net,"fid",file_id)
        # except Exception as ex :
            # traceback.print_exc()
            # print (f"{cache_dir}/torch_best_{final_name}_{file_id}.pth")
            # print (net)
    net= net.to(device)
    net.eval()
    
    # net.feature_list = types.MethodType(feature_list,net)

    
    return net 
Пример #23
0
def _construct_spam_dataset(dest: Path) -> Tuple[Path, Path]:
    r""" Use ELMo to construct the dataset """
    # Define the output files
    processed_dir = dest / "processed"
    lst_path = processed_dir / "training.pt", processed_dir / "test.pt"
    spam_datasets = SpamDataset.TREC2005.value, SpamDataset.TREC2007.value

    processed_dir.mkdir(parents=True, exist_ok=True)
    # Downloads the processed tensors
    for tensor_pth, ds_info in zip(lst_path, spam_datasets):
        if tensor_pth.exists():
            logging.debug(
                f"Spam tensor \"{str(tensor_pth)}\" already exists. Skipping download..."
            )
            continue
        download_file_from_google_drive(file_id=ds_info.pth_file_id,
                                        root=str(tensor_pth.parent),
                                        filename=tensor_pth.name)
    return lst_path
Пример #24
0
def download_dummy_dataset():
    # Full File URL: https://drive.google.com/file/d/1U4D23R8u8MJX9KVKb92bZZX-tbpKWtga/view?usp=sharing
    gdrive_file_id = "1U4D23R8u8MJX9KVKb92bZZX-tbpKWtga"
    output_directory = os.path.join(os.getcwd())
    output_file_name = "demo_datasets.zip"

    print("Downloading Dummy Datasets")
    download_file_from_google_drive(gdrive_file_id, output_directory, output_file_name)

    if os.path.exists(os.path.join(os.getcwd(), "demo_dataset")):
        print("Skipping Download and Extraction")
        return

    print("Extracting Dummy Dataset")
    zip_file = os.path.join(os.getcwd(), output_file_name)
    extract_archive(zip_file)

    print("Datasets downloaded and extracted. Run this program again to run the demo.")
    exit()
Пример #25
0
def get_weights(model_name):
	"""
	get pretrained weights or download them from google drive first 

	usage of drive links to download files:
		share the corrsponding document such that anyone with the link can see it
		copy the link and use only the file id for
			download_file_from_google_drive from torchvision.datasets.utils
			https://github.com/pytorch/vision/blob/1b7c0f54e2913e159394a19ac5e50daa69c142c7/torchvision/datasets/utils.py#L169

		example:
			link: https://drive.google.com/file/d/14M3uC29aAx2AMeCeidLQjqjkVpGqnb6k/view?usp=sharing
			with id = 14M3uC29aAx2AMeCeidLQjqjkVpGqnb6k


	Parameters
	----------
	model_name : str
		name of file_id saved in config.WEIGHT_IDS

	Returns
	-------
	state_dict
		state_dict containing pretrained weights
	"""	
	from os.path import join, exists
	import config
	import torch
	from torchvision.datasets.utils import download_file_from_google_drive

	
	pretrained_path = join(config.DATA_FOLDER, 'pretrained_models')	
	assert model_name in config.WEIGHT_IDS, f'no weights for model: {model_name}'

	file_name = f'{model_name}.pt'
	file_dest = join(pretrained_path, file_name)

	if not exists(file_dest):
		id = config.WEIGHT_IDS[model_name]
		download_file_from_google_drive(file_id=id, root=pretrained_path, filename=file_name)
		print(f'weights downloaded & saved at: {file_dest}')
		
	return torch.load(file_dest) 
Пример #26
0
    def _download(self):
        """Download the data if it doesn't exist in already."""
        if self._check_exists():
            return

        if self.verbose:
            print("Making directories...")
        os.makedirs(self.raw_folder, exist_ok=True)
        os.makedirs(self.processed_folder, exist_ok=True)

        if self.verbose:
            print("Downloading...")
        for fileid, md5 in self._resources:
            filename = "tecator.npz"
            download_file_from_google_drive(fileid,
                                            root=self.raw_folder,
                                            filename=filename,
                                            md5=md5)

        if self.verbose:
            print("Processing...")
        with np.load(os.path.join(self.raw_folder, "tecator.npz"),
                     allow_pickle=False) as f:
            x_train, y_train = f["x_train"], f["y_train"]
            x_test, y_test = f["x_test"], f["y_test"]
        training_set = [
            torch.tensor(x_train, dtype=torch.float32),
            torch.tensor(y_train),
        ]
        test_set = [
            torch.tensor(x_test, dtype=torch.float32),
            torch.tensor(y_test),
        ]

        with open(os.path.join(self.processed_folder, self.training_file),
                  "wb") as f:
            torch.save(training_set, f)
        with open(os.path.join(self.processed_folder, self.test_file),
                  "wb") as f:
            torch.save(test_set, f)

        if self.verbose:
            print("Done!")
Пример #27
0
 def _download_google_driver_arrange_data(
     self,
     download_file_id,
     extract_download_file_name,
     put_data_dir,
 ):
     download_data_file_name = extract_download_file_name + ".zip"
     download_data_path = os.path.join(put_data_dir,
                                       download_data_file_name)
     extract_data_path = os.path.join(put_data_dir,
                                      extract_download_file_name)
     if not self._exist_judgement(download_data_path):
         logging.info("Downloading the data to %s", download_data_path)
         download_file_from_google_drive(file_id=download_file_id,
                                         root=put_data_dir,
                                         filename=download_data_file_name)
     if not self._exist_judgement(extract_data_path):
         extract_archive(from_path=download_data_path,
                         to_path=put_data_dir,
                         remove_finished=True)
Пример #28
0
    def download(self):
        import tarfile

        if self._check_integrity():
            return

        download_file_from_google_drive(self.gdrive_id,
                                        self.root,
                                        self.gz_filename,
                                        md5=self.gz_md5)

        filename = os.path.join(self.root, self.gz_filename)
        with tarfile.open(filename, 'r') as f:
            f.extractall(self.root)

        for split in ['train', 'val', 'test']:
            filename = os.path.join(self.root, self.filename.format(split))
            if os.path.isfile(filename):
                continue

            pkl_filename = os.path.join(self.root,
                                        self.pkl_filename.format(split))
            if not os.path.isfile(pkl_filename):
                raise IOError()
            with open(pkl_filename, 'rb') as f:
                data = pickle.load(f)
                images, classes = data['image_data'], data['class_dict']

            with h5py.File(filename, 'w') as f:
                group = f.create_group('datasets')
                for name, indices in classes.items():
                    group.create_dataset(name, data=images[indices])

            labels_filename = os.path.join(self.root,
                                           self.filename_labels.format(split))
            with open(labels_filename, 'w') as f:
                labels = sorted(list(classes.keys()))
                json.dump(labels, f)

            if os.path.isfile(pkl_filename):
                os.remove(pkl_filename)
Пример #29
0
 def get_official_weights(self,
                          dataset: str = None,
                          **kwargs) -> OrderedDict[str, torch.Tensor]:
     if dataset is None and isinstance(self.dataset, ImageSet):
         dataset = self.dataset.name
     folder_path = os.path.join(torch.hub.get_dir(), 'lanet')
     file_path = os.path.join(folder_path, f'lanet_{dataset}.pt')
     if not os.path.exists(file_path):
         zip_file_name = 'temp.zip'
         zip_path = os.path.join(folder_path, zip_file_name)
         download_file_from_google_drive(file_id=self.model_urls[dataset],
                                         root=folder_path,
                                         filename=zip_file_name)
         with zipfile.ZipFile(zip_path, 'r') as zf:
             data = zf.read('lanas_128_99.03/top1.pt')
         with open(file_path, 'wb') as f:
             f.write(data)
         os.remove(zip_path)
     print('get official model weights from Google Drive: ',
           self.model_urls[dataset])
     _dict: OrderedDict[str, torch.Tensor] = torch.load(file_path,
                                                        map_location='cpu')
     if 'model_state_dict' in _dict.keys():
         _dict = _dict['model_state_dict']
     new_dict: OrderedDict[str, torch.Tensor] = self.state_dict()
     old_keys = list(_dict.keys())
     new_keys = list(new_dict.keys())
     new2old: dict[str, str] = {}
     i = 0
     j = 0
     while (i < len(new_keys) and j < len(old_keys)):
         if 'auxiliary_head' not in new_keys[
                 i] and 'auxiliary_head' in old_keys[j]:
             j += 1
             continue
         new2old[new_keys[i]] = old_keys[j]
         i += 1
         j += 1
     for i, key in enumerate(new_keys):
         new_dict[key] = _dict[new2old[key]]
     return new_dict
Пример #30
0
    def get_official_weights(self,
                             dataset: str = None,
                             **kwargs) -> OrderedDict[str, torch.Tensor]:
        assert str(self.genotype) == str(genotypes.DARTS)
        if dataset is None and isinstance(self.dataset, ImageSet):
            dataset = self.dataset.name
        file_name = f'darts_{dataset}.pt'
        folder_path = os.path.join(torch.hub.get_dir(), 'darts')
        download_file_from_google_drive(file_id=url[dataset],
                                        root=folder_path,
                                        filename=file_name)
        print('get official model weights from Google Drive: ', url[dataset])
        _dict: OrderedDict[str, torch.Tensor] = torch.load(os.path.join(
            self.folder_path, file_name),
                                                           map_location='cpu')
        if 'state_dict' in _dict.keys():
            _dict = _dict['state_dict']

        new_dict: OrderedDict[str, torch.Tensor] = self.state_dict()
        old_keys = list(_dict.keys())
        new_keys = list(new_dict.keys())
        new2old: dict[str, str] = {}
        i = 0
        j = 0
        while (i < len(new_keys) and j < len(old_keys)):
            if 'num_batches_tracked' in new_keys[i]:
                i += 1
                continue
            if 'auxiliary_head' not in new_keys[
                    i] and 'auxiliary_head' in old_keys[j]:
                j += 1
                continue
            new2old[new_keys[i]] = old_keys[j]
            i += 1
            j += 1
        for i, key in enumerate(new_keys):
            if 'num_batches_tracked' in key:
                new_dict[key] = torch.tensor(0)
            else:
                new_dict[key] = _dict[new2old[key]]
        return new_dict