def download(self): archive_path = os.path.join(self.root, 'fc100.zip') print('Downloading FC100. (160Mb)') try: # Download from Google Drive first download_file_from_google_drive(GOOGLE_DRIVE_FILE_ID, archive_path) archive_file = zipfile.ZipFile(archive_path) archive_file.extractall(self.root) os.remove(archive_path) except zipfile.BadZipFile: download_file(DROPBOX_LINK, archive_path) archive_file = zipfile.ZipFile(archive_path) archive_file.extractall(self.root) os.remove(archive_path)
def download(self): if not os.path.exists(self.root): os.mkdir(self.root) data_path = os.path.join(self.root, DATA_DIR) if not os.path.exists(data_path): os.mkdir(data_path) tar_path = os.path.join(data_path, os.path.basename(ARCHIVE_URL)) print('Downloading Describable Textures dataset (600Mb)') download_file(ARCHIVE_URL, tar_path) tar_file = tarfile.open(tar_path) tar_file.extractall(data_path) tar_file.close() os.remove(tar_path)
def __init__(self, root, mode='train', transform=None, target_transform=None, download=False): super(MiniImagenet, self).__init__() self.root = os.path.expanduser(root) if not os.path.exists(self.root): os.mkdir(self.root) self.transform = transform self.target_transform = target_transform self.mode = mode self._bookkeeping_path = os.path.join( self.root, 'mini-imagenet-bookkeeping-' + mode + '.pkl') if self.mode == 'test': google_drive_file_id = '1wpmY-hmiJUUlRBkO9ZDCXAcIpHEFdOhD' dropbox_file_link = 'https://www.dropbox.com/s/ye9jeb5tyz0x01b/mini-imagenet-cache-test.pkl?dl=1' elif self.mode == 'train': google_drive_file_id = '1I3itTXpXxGV68olxM5roceUMG8itH9Xj' dropbox_file_link = 'https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1' elif self.mode == 'validation': google_drive_file_id = '1KY5e491bkLFqJDp0-UWou3463Mo8AOco' dropbox_file_link = 'https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1' else: raise ('ValueError', 'Needs to be train, test or validation') pickle_file = os.path.join(self.root, 'mini-imagenet-cache-' + mode + '.pkl') try: if not self._check_exists() and download: print('Downloading mini-ImageNet --', mode) download_pkl(google_drive_file_id, self.root, mode) with open(pickle_file, 'rb') as f: self.data = pickle.load(f) except pickle.UnpicklingError: if not self._check_exists() and download: print('Download failed. Re-trying mini-ImageNet --', mode) download_file(dropbox_file_link, pickle_file) with open(pickle_file, 'rb') as f: self.data = pickle.load(f) self.x = torch.from_numpy(self.data["image_data"]).permute(0, 3, 1, 2).float() self.y = np.ones(len(self.x)) # TODO Remove index_classes from here self.class_idx = index_classes(self.data['class_dict'].keys()) for class_name, idxs in self.data['class_dict'].items(): for idx in idxs: self.y[idx] = self.class_idx[class_name]
def get_pretrained_backbone(model, dataset, spec='default', root='~/data', download=False): """ [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/__init__.py) **Description** Returns pretrained backbone for a benchmark dataset. The returned object is a torch.nn.Module instance. **Arguments** * **model** (str) - The name of the model (`cnn4`, `resnet12`, or `wrn28`) * **dataset** (str) - The name of the benchmark dataset (`mini-imagenet` or `tiered-imagenet`). * **spec** (str, *optional*, default='default') - Which weight specification to load (`default`). * **root** (str, *optional*, default='~/data') - Location of the pretrained weights. * **download** (bool, *optional*, default=False) - Download the pretrained weights if not available? **Example** ~~~python backbone = l2l.vision.models.get_pretrained_backbone( model='resnet12', dataset='mini-imagenet', root='~/.data', download=True, ) ~~~ """ root = os.path.expanduser(root) destination_dir = os.path.join(root, 'pretrained_models', dataset) destination = os.path.join(destination_dir, model + '.pth') source = _BACKBONE_URLS[dataset][model][spec] if not os.path.exists(destination) and download: print(f'Downloading {model} weights for {dataset}.') os.makedirs(destination_dir, exist_ok=True) download_file(source, destination) if model == 'cnn4': pretrained = CNN4Backbone(channels=3, max_pool=True) elif model == 'resnet12': pretrained = ResNet12Backbone(avg_pool=False) elif model == 'wrn28': pretrained = WRN28Backbone() weights = torch.load(destination, map_location='cpu') pretrained.load_state_dict(weights) return pretrained
def download(self): if not os.path.exists(self.root): os.mkdir(self.root) data_path = os.path.join(self.root, DATA_DIR) if not os.path.exists(data_path): os.mkdir(data_path) tar_path = os.path.join(data_path, os.path.basename(IMAGES_URL)) print('Downloading VGG Flower102 dataset (330Mb)') download_file(IMAGES_URL, tar_path) tar_file = tarfile.open(tar_path) tar_file.extractall(data_path) tar_file.close() os.remove(tar_path) label_path = os.path.join(data_path, os.path.basename(LABELS_URL)) req = requests.get(LABELS_URL) with open(label_path, 'wb') as label_file: label_file.write(req.content)