Exemplo n.º 1
0
 def download(self):
     archive_path = os.path.join(self.root, 'fc100.zip')
     print('Downloading FC100. (160Mb)')
     try:  # Download from Google Drive first
         download_file_from_google_drive(GOOGLE_DRIVE_FILE_ID, archive_path)
         archive_file = zipfile.ZipFile(archive_path)
         archive_file.extractall(self.root)
         os.remove(archive_path)
     except zipfile.BadZipFile:
         download_file(DROPBOX_LINK, archive_path)
         archive_file = zipfile.ZipFile(archive_path)
         archive_file.extractall(self.root)
         os.remove(archive_path)
Exemplo n.º 2
0
 def download(self):
     if not os.path.exists(self.root):
         os.mkdir(self.root)
     data_path = os.path.join(self.root, DATA_DIR)
     if not os.path.exists(data_path):
         os.mkdir(data_path)
     tar_path = os.path.join(data_path, os.path.basename(ARCHIVE_URL))
     print('Downloading Describable Textures dataset (600Mb)')
     download_file(ARCHIVE_URL, tar_path)
     tar_file = tarfile.open(tar_path)
     tar_file.extractall(data_path)
     tar_file.close()
     os.remove(tar_path)
Exemplo n.º 3
0
    def __init__(self,
                 root,
                 mode='train',
                 transform=None,
                 target_transform=None,
                 download=False):
        super(MiniImagenet, self).__init__()
        self.root = os.path.expanduser(root)
        if not os.path.exists(self.root):
            os.mkdir(self.root)
        self.transform = transform
        self.target_transform = target_transform
        self.mode = mode
        self._bookkeeping_path = os.path.join(
            self.root, 'mini-imagenet-bookkeeping-' + mode + '.pkl')
        if self.mode == 'test':
            google_drive_file_id = '1wpmY-hmiJUUlRBkO9ZDCXAcIpHEFdOhD'
            dropbox_file_link = 'https://www.dropbox.com/s/ye9jeb5tyz0x01b/mini-imagenet-cache-test.pkl?dl=1'
        elif self.mode == 'train':
            google_drive_file_id = '1I3itTXpXxGV68olxM5roceUMG8itH9Xj'
            dropbox_file_link = 'https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1'
        elif self.mode == 'validation':
            google_drive_file_id = '1KY5e491bkLFqJDp0-UWou3463Mo8AOco'
            dropbox_file_link = 'https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1'
        else:
            raise ('ValueError', 'Needs to be train, test or validation')

        pickle_file = os.path.join(self.root,
                                   'mini-imagenet-cache-' + mode + '.pkl')
        try:
            if not self._check_exists() and download:
                print('Downloading mini-ImageNet --', mode)
                download_pkl(google_drive_file_id, self.root, mode)
            with open(pickle_file, 'rb') as f:
                self.data = pickle.load(f)
        except pickle.UnpicklingError:
            if not self._check_exists() and download:
                print('Download failed. Re-trying mini-ImageNet --', mode)
                download_file(dropbox_file_link, pickle_file)
            with open(pickle_file, 'rb') as f:
                self.data = pickle.load(f)

        self.x = torch.from_numpy(self.data["image_data"]).permute(0, 3, 1,
                                                                   2).float()
        self.y = np.ones(len(self.x))

        # TODO Remove index_classes from here
        self.class_idx = index_classes(self.data['class_dict'].keys())
        for class_name, idxs in self.data['class_dict'].items():
            for idx in idxs:
                self.y[idx] = self.class_idx[class_name]
Exemplo n.º 4
0
def get_pretrained_backbone(model, dataset, spec='default', root='~/data', download=False):
    """
    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/__init__.py)

    **Description**

    Returns pretrained backbone for a benchmark dataset.

    The returned object is a torch.nn.Module instance.

    **Arguments**

    * **model** (str) - The name of the model (`cnn4`, `resnet12`, or `wrn28`)
    * **dataset** (str) - The name of the benchmark dataset (`mini-imagenet` or `tiered-imagenet`).
    * **spec** (str, *optional*, default='default') - Which weight specification to load (`default`).
    * **root** (str, *optional*, default='~/data') - Location of the pretrained weights.
    * **download** (bool, *optional*, default=False) - Download the pretrained weights if not available?

    **Example**
    ~~~python
    backbone = l2l.vision.models.get_pretrained_backbone(
        model='resnet12',
        dataset='mini-imagenet',
        root='~/.data',
        download=True,
    )
    ~~~
    """
    root = os.path.expanduser(root)
    destination_dir = os.path.join(root, 'pretrained_models', dataset)
    destination = os.path.join(destination_dir, model + '.pth')
    source = _BACKBONE_URLS[dataset][model][spec]
    if not os.path.exists(destination) and download:
        print(f'Downloading {model} weights for {dataset}.')
        os.makedirs(destination_dir, exist_ok=True)
        download_file(source, destination)

    if model == 'cnn4':
        pretrained = CNN4Backbone(channels=3, max_pool=True)
    elif model == 'resnet12':
        pretrained = ResNet12Backbone(avg_pool=False)
    elif model == 'wrn28':
        pretrained = WRN28Backbone()

    weights = torch.load(destination, map_location='cpu')
    pretrained.load_state_dict(weights)
    return pretrained
Exemplo n.º 5
0
    def download(self):
        if not os.path.exists(self.root):
            os.mkdir(self.root)
        data_path = os.path.join(self.root, DATA_DIR)
        if not os.path.exists(data_path):
            os.mkdir(data_path)
        tar_path = os.path.join(data_path, os.path.basename(IMAGES_URL))
        print('Downloading VGG Flower102 dataset (330Mb)')
        download_file(IMAGES_URL, tar_path)
        tar_file = tarfile.open(tar_path)
        tar_file.extractall(data_path)
        tar_file.close()
        os.remove(tar_path)

        label_path = os.path.join(data_path, os.path.basename(LABELS_URL))
        req = requests.get(LABELS_URL)
        with open(label_path, 'wb') as label_file:
            label_file.write(req.content)