Esempio n. 1
0
def get_torchmeta_rand_fnn_dataloaders(args):
    # get data
    dataset_train = RandFNN(args.data_path, 'train')
    dataset_val = RandFNN(args.data_path, 'val')
    dataset_test = RandFNN(args.data_path, 'test')
    # get meta-sets
    metaset_train = ClassSplitter(dataset_train,
                                  num_train_per_class=args.k_shots,
                                  num_test_per_class=args.k_eval,
                                  shuffle=True)
    metaset_val = ClassSplitter(dataset_val,
                                num_train_per_class=args.k_shots,
                                num_test_per_class=args.k_eval,
                                shuffle=True)
    metaset_test = ClassSplitter(dataset_test,
                                 num_train_per_class=args.k_shots,
                                 num_test_per_class=args.k_eval,
                                 shuffle=True)
    # get meta-dataloader
    meta_train_dataloader = BatchMetaDataLoader(
        metaset_train,
        batch_size=args.meta_batch_size_train,
        num_workers=args.num_workers)
    meta_val_dataloader = BatchMetaDataLoader(
        metaset_val,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    meta_test_dataloader = BatchMetaDataLoader(
        metaset_test,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    return meta_train_dataloader, meta_val_dataloader, meta_test_dataloader
Esempio n. 2
0
def helper_with_default(klass,
                        folder,
                        shots,
                        ways,
                        shuffle=True,
                        test_shots=None,
                        seed=None,
                        defaults={},
                        **kwargs):

    if 'num_classes_per_task' in kwargs:
        warnings.warn(
            'Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.',
            stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = defaults.get('transform', ToTensor())
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = defaults.get('target_transform',
                                                  Categorical(ways))
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = defaults.get('class_augmentations',
                                                     None)
    if test_shots is None:
        test_shots = shots
    dataset = klass(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset,
                            shuffle=shuffle,
                            num_train_per_class=shots,
                            num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
 def __init__(self, device, problem="default", task_num=16, n_way=5, imgsz=28, k_spt=1, k_qry=19):
     self.device = device
     self.task_num = task_num
     self.n_way, self.imgsz = n_way, imgsz
     self.k_spt, self.k_qry = k_spt, k_qry
     assert k_spt + k_qry <= 20, "Max 20 k_spt + k_20"
     class_augmentations = [Rotation([90, 180, 270])]
     meta_train_dataset = Omniglot("data",
                                   transform=Compose([Resize(self.imgsz), ToTensor()]),
                                   target_transform=Categorical(num_classes=self.n_way),
                                   num_classes_per_task=self.n_way,
                                   meta_train=True,
                                   class_augmentations=class_augmentations,
                                   download=True
                                 )
     meta_val_dataset = Omniglot("data",
                                   transform=Compose([Resize(self.imgsz), ToTensor()]),
                                   target_transform=Categorical(num_classes=self.n_way),
                                   num_classes_per_task=self.n_way,
                                   meta_val=True,
                                   class_augmentations=class_augmentations,
                                 )
     meta_test_dataset = Omniglot("data",
                                   transform=Compose([Resize(self.imgsz), ToTensor()]),
                                   target_transform=Categorical(num_classes=self.n_way),
                                   num_classes_per_task=self.n_way,
                                   meta_test=True,
                                   class_augmentations=class_augmentations,
                                 )
     self.train_dataset = ClassSplitter(meta_train_dataset, shuffle=True, num_train_per_class=k_spt, num_test_per_class=k_qry)
     self.val_dataset = ClassSplitter(meta_val_dataset, shuffle=True, num_train_per_class=k_spt, num_test_per_class=k_qry)
     self.test_dataset = ClassSplitter(meta_test_dataset, shuffle=True, num_train_per_class=k_spt, num_test_per_class=k_qry)
Esempio n. 4
0
def get_fewshotsen12msdataset(folder, shots, ways, shuffle=True, test_shots=None,
            seed=None, **kwargs):
    if test_shots is None:
        test_shots = shots

    dataset = Sen12MS(folder, num_classes_per_task=ways, min_samples_per_class=shots + test_shots,
                      min_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
                            num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
Esempio n. 5
0
def _update_args(shots, ways, kwargs, shuffle=True, test_shots=None):
    if 'num_classes_per_task' in kwargs:
        assert ways == kwargs['num_classes_per_task']
        del kwargs['num_classes_per_task']
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = [Rotation([90, 180, 270])]

    if isinstance(shots, int):
        min_shot = max_shot = shots
    else:
        min_shot, max_shot = shots
    if test_shots is None:
        test_shots = min_shot
    if 'dataset_transform' not in kwargs:
        if min_shot == max_shot:
            dataset_transform = ClassSplitter(shuffle=shuffle,
                                            num_train_per_class=min_shot,
                                            num_test_per_class=test_shots)
        else:
            dataset_transform = RandClassSplitter(shuffle=shuffle, 
                                                    min_train_per_class=min_shot,
                                                    max_train_per_class=max_shot,
                                                    num_test_per_class=test_shots)
        kwargs['dataset_transform'] = dataset_transform
    return kwargs
Esempio n. 6
0
def cub(folder, shots, ways, shuffle=True, test_shots=None,
        seed=None, **kwargs):
    """Helper function to create a meta-dataset for the Caltech-UCSD Birds dataset.

    Parameters
    ----------
    folder : string
        Root directory where the dataset folder `cub` exists.

    shots : int
        Number of (training) examples per class in each task. This corresponds 
        to `k` in `k-shot` classification.

    ways : int
        Number of classes per task. This corresponds to `N` in `N-way` 
        classification.

    shuffle : bool (default: `True`)
        Shuffle the examples when creating the tasks.

    test_shots : int, optional
        Number of test examples per class in each task. If `None`, then the 
        number of test examples is equal to the number of training examples per 
        class.

    seed : int, optional
        Random seed to be used in the meta-dataset.

    kwargs
        Additional arguments passed to the `CUB` class.

    See also
    --------
    `datasets.cub.CUB` : Meta-dataset for the Caltech-UCSD Birds dataset.
    """
    if 'num_classes_per_task' is kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        image_size = 84
        kwargs['transform'] = Compose([
            Resize(int(image_size * 1.5)),
            CenterCrop(image_size),
            ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = CUB(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
Esempio n. 7
0
def doublemnist(folder, shots, ways, shuffle=True, test_shots=None,
                seed=None, **kwargs):
    """Helper function to create a meta-dataset for the Double MNIST dataset.

    Parameters
    ----------
    folder : string
        Root directory where the dataset folder `doublemnist` exists.

    shots : int
        Number of (training) examples per class in each task. This corresponds 
        to `k` in `k-shot` classification.

    ways : int
        Number of classes per task. This corresponds to `N` in `N-way` 
        classification.

    shuffle : bool (default: `True`)
        Shuffle the examples when creating the tasks.

    test_shots : int, optional
        Number of test examples per class in each task. If `None`, then the 
        number of test examples is equal to the number of training examples per 
        class.

    seed : int, optional
        Random seed to be used in the meta-dataset.

    kwargs
        Additional arguments passed to the `DoubleMNIST` class.

    See also
    --------
    `datasets.doublemnist.DoubleMNIST` : Meta-dataset for the Double MNIST dataset.
    """
    if 'num_classes_per_task' is kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = [Rotation([90, 180, 270])]
    if test_shots is None:
        test_shots = shots

    dataset = DoubleMNIST(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
Esempio n. 8
0
def harmonic(shots, shuffle=True, test_shots=None, seed=None, **kwargs):
    """Helper function to create a meta-dataset for the Harmonic toy dataset.

    Parameters
    ----------
    shots : int
        Number of (training) examples in each task. This corresponds to `k` in
        `k-shot` classification.

    shuffle : bool (default: `True`)
        Shuffle the examples when creating the tasks.

    test_shots : int, optional
        Number of test examples in each task. If `None`, then the number of test
        examples is equal to the number of training examples in each task.

    seed : int, optional
        Random seed to be used in the meta-dataset.

    kwargs
        Additional arguments passed to the `Harmonic` class.

    See also
    --------
    `torchmeta.toy.Harmonic` : Meta-dataset for the Harmonic toy dataset.
    """
    if 'num_samples_per_task' in kwargs:
        warnings.warn(
            'Both arguments `shots` and `num_samples_per_task` were '
            'set in the helper function for the number of samples in each task. '
            'Ignoring the argument `shots`.',
            stacklevel=2)
        if test_shots is not None:
            shots = kwargs['num_samples_per_task'] - test_shots
            if shots <= 0:
                raise ValueError(
                    'The argument `test_shots` ({0}) is greater '
                    'than the number of samples per task ({1}). Either use the '
                    'argument `shots` instead of `num_samples_per_task`, or '
                    'increase the value of `num_samples_per_task`.'.format(
                        test_shots, kwargs['num_samples_per_task']))
        else:
            shots = kwargs['num_samples_per_task'] // 2
    if test_shots is None:
        test_shots = shots

    dataset = Harmonic(num_samples_per_task=shots + test_shots, **kwargs)
    dataset = ClassSplitter(dataset,
                            shuffle=shuffle,
                            num_train_per_class=shots,
                            num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
Esempio n. 9
0
def create_og_data_loader(
    root,
    meta_split,
    k_way,
    n_shot,
    input_size,
    n_query,
    batch_size,
    num_workers,
    download=False,
    use_vinyals_split=False,
    seed=None,
):
    """Create a torchmeta BatchMetaDataLoader for Omniglot

    Args:
        root: Path to Omniglot data root folder (containing an 'omniglot'` subfolder with the
            preprocess json-Files or downloaded zip-files).
        meta_split: see torchmeta.datasets.Omniglot
        k_way: Number of classes per task
        n_shot: Number of samples per class
        input_size: Images are resized to this size.
        n_query: Number of test images per class
        batch_size: Meta batch size
        num_workers: Number of workers for data preprocessing
        download: Download (and dataset specific preprocessing that needs to be done on the
            downloaded files).
        use_vinyals_split: see torchmeta.datasets.Omniglot
        seed: Seed to be used in the meta-dataset

    Returns:
        A torchmeta :class:`BatchMetaDataLoader` object.
    """
    dataset = Omniglot(
        root,
        num_classes_per_task=k_way,
        transform=Compose([Resize(input_size), ToTensor()]),
        target_transform=Categorical(num_classes=k_way),
        class_augmentations=[Rotation([90, 180, 270])],
        meta_split=meta_split,
        download=download,
        use_vinyals_split=use_vinyals_split,
    )
    dataset = ClassSplitter(dataset,
                            shuffle=True,
                            num_train_per_class=n_shot,
                            num_test_per_class=n_query)
    dataset.seed = seed
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=batch_size,
                                     num_workers=num_workers,
                                     shuffle=True)
    return dataloader
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None,
                          random_seed=123,
                          num_training_samples=100):
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'quickdraw':
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = QuickDraw(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True,
            random_seed=random_seed,
            num_training_samples=num_training_samples)
        meta_val_dataset = QuickDraw(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_val=True,
                                     dataset_transform=dataset_transform)
        meta_test_dataset = QuickDraw(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_test=True,
                                      dataset_transform=dataset_transform)

        model = ModelConvQuickDraw(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Esempio n. 11
0
def get_sine_loader(batch_size, num_steps, shots=10, test_shots=15):
    dataset_transform = ClassSplitter(
        shuffle=True, num_train_per_class=shots, num_test_per_class=test_shots
    )
    transform = ToTensor1D()
    dataset = Sinusoid(
        shots + test_shots,
        num_tasks=batch_size * num_steps,
        transform=transform,
        target_transform=transform,
        dataset_transform=dataset_transform,
    )
    loader = BatchMetaDataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
    )
    return loader
Esempio n. 12
0
    def __init__(self, num_shots_tr, num_shots_ts, video_data, label_data,
                 option='train', unsupervised=0,frame_depth=10,
                 ):

        self.transform = transforms.Compose([transforms.ToTensor()])
        self.num_samples_per_task = num_shots_tr + num_shots_ts
        self.frame_depth = frame_depth
        self.option = option
        self.num_shots_tr = num_shots_tr
        self.num_shots_ts = num_shots_ts
        self.unsupervised = unsupervised
        self.dataset_transform = ClassSplitter(shuffle=False, num_train_per_class=num_shots_tr,
                                                   num_test_per_class=num_shots_ts)

        self.video_data = video_data
        self.label = label_data
Esempio n. 13
0
    def helper_meta_imagelist(self, batch_size, input_size, num_shots,
                              num_shots_test, num_ways, num_workers):
        with tempfile.TemporaryDirectory(
        ) as folder, tempfile.NamedTemporaryFile(mode='w+t') as fp:
            create_random_imagelist(folder, fp, input_size, create_image)
            dataset = data.ImagelistMetaDataset(
                imagelistname=fp.name,
                root='',
                transform=transforms.Compose(
                    [transforms.Resize(input_size),
                     transforms.ToTensor()]))
            meta_dataset = CombinationMetaDataset(
                dataset,
                num_classes_per_task=num_ways,
                target_transform=Categorical(num_ways),
                dataset_transform=ClassSplitter(
                    shuffle=True,
                    num_train_per_class=num_shots,
                    num_test_per_class=num_shots_test))
            meta_dataloader = BatchMetaDataLoader(meta_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers,
                                                  pin_memory=True)

            for batch in meta_dataloader:
                batch_data, batch_label = batch['train']
                for img_train, label_train, img_test, label_test in zip(
                        *batch['train'], *batch['test']):
                    classmap = {}
                    for idx in range(img_train.shape[0]):
                        npimg = img_train[idx, ...].detach().numpy()
                        npimg[npimg < 0.001] = 0
                        imagesum = int(np.sum(npimg) * 255)
                        if imagesum not in classmap:
                            classmap[imagesum] = int(label_train[idx])
                        self.assertEqual(classmap[imagesum],
                                         int(label_train[idx]))

                    for idx in range(img_test.shape[0]):
                        npimg = img_test[idx, ...].detach().numpy()
                        npimg[npimg < 0.001] = 0
                        imagesum = int(np.sum(npimg) * 255)
                        self.assertEqual(classmap[imagesum],
                                         int(label_test[idx]),
                                         "Error on {}".format(idx))
Esempio n. 14
0
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        vi = []
        la = []

        data_len = len(self.label[index]) // 2  # support set 1 + query set 1
        for i in range(2):
            vi.append(self.video_data[index][data_len * i:data_len * (i + 1)])
            la.append(self.label[index][data_len * i:data_len * (i + 1)])


        self.dataset_transform = ClassSplitter(shuffle=False, num_train_per_class=self.num_shots_tr,
                                               num_test_per_class=self.num_shots_ts)

        task = PersonTask(vi, la, len(vi))
        if self.dataset_transform is not None:
            task = self.dataset_transform(task)

        return task
Esempio n. 15
0
def test_batch_meta_dataloader_splitter():
    dataset = Sinusoid(20, num_tasks=1000, noise_std=None)
    dataset = ClassSplitter(dataset, num_train_per_class=5,
        num_test_per_class=15)
    meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)

    batch = next(iter(meta_dataloader))
    assert isinstance(batch, dict)
    assert 'train' in batch
    assert 'test' in batch

    train_inputs, train_targets = batch['train']
    test_inputs, test_targets = batch['test']
    assert isinstance(train_inputs, torch.Tensor)
    assert isinstance(train_targets, torch.Tensor)
    assert train_inputs.shape == (4, 5, 1)
    assert train_targets.shape == (4, 5, 1)
    assert isinstance(test_inputs, torch.Tensor)
    assert isinstance(test_targets, torch.Tensor)
    assert test_inputs.shape == (4, 15, 1)
    assert test_targets.shape == (4, 15, 1)
Esempio n. 16
0
 def test_custom_db(self):
     pil_logger = logging.getLogger('PIL')
     pil_logger.setLevel(logging.INFO)
     input_size = 40
     num_ways = 10
     num_shots = 4
     num_shots_test = 4
     batch_size = 1
     num_workers = 0
     with tempfile.TemporaryDirectory(
     ) as folder, tempfile.NamedTemporaryFile(mode='w+t') as fp:
         tests.datatests.create_random_imagelist(folder, fp, input_size)
         dataset = data.ImagelistMetaDataset(imagelistname=fp.name,
                                             root='',
                                             transform=transforms.Compose([
                                                 transforms.Resize(84),
                                                 transforms.ToTensor()
                                             ]))
         meta_dataset = CombinationMetaDataset(
             dataset,
             num_classes_per_task=num_ways,
             target_transform=Categorical(num_ways),
             dataset_transform=ClassSplitter(
                 shuffle=True,
                 num_train_per_class=num_shots,
                 num_test_per_class=num_shots_test))
         args = ArgWrapper()
         args.output_folder = folder
         args.dataset = None
         benchmark = Benchmark(meta_train_dataset=meta_dataset,
                               meta_val_dataset=meta_dataset,
                               meta_test_dataset=meta_dataset,
                               model=ModelConvMiniImagenet(
                                   args.num_ways,
                                   hidden_size=args.hidden_size),
                               loss_function=F.cross_entropy)
         train.main(args, benchmark)
Esempio n. 17
0
                        num_workers=0)

num_shots = 100
num_shots_test = 100
batch_size = 30
num_train_per_class = 200
num_test_per_class = 200

transform = Compose([
    CenterCrop((178, 178)),
    Resize((32, 32)),
    ToTensor(),
    Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
])
dataset_transform = ClassSplitter(shuffle=True,
                                  num_train_per_class=num_shots,
                                  num_test_per_class=num_shots_test)

image_resolution = (32, 32)
dataset = CelebA(num_samples_per_task=1000,
                 transform=transform,
                 target_transform=transform)
split_dataset = ClassSplitter(dataset,
                              num_train_per_class=num_train_per_class,
                              num_test_per_class=num_test_per_class)
meta_dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size)

model = meta_modules.MAMLEncoder(in_features=dataset.img_channels + 2,
                                 out_features=dataset.img_channels,
                                 image_resolution=image_resolution,
                                 hypo_net_nl=opt.hypo_net_nl)
Esempio n. 18
0
def get_dataset(args, dataset_name, phase):
    if dataset_name == 'cub':
        from torchmeta.datasets import CUBMM as dataset_class
        image_size = 84
        padding_len = 8
    elif dataset_name == 'sun':
        from torchmeta.datasets import SUNMM as dataset_class
        image_size = 84
        padding_len = 8
    else:
        raise ValueError('Non-supported Dataset.')

    # augmentations
    # reference: https://github.com/Sha-Lab/FEAT
    if args.augment and phase == 'train':
        transforms_list = [
            transforms.RandomResizedCrop((image_size, image_size)),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    else:
        transforms_list = [
            transforms.Resize((image_size+padding_len, image_size+padding_len)),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]


    # pre-processing 
    if args.backbone == 'resnet12':
        transforms_list = transforms.Compose(
                transforms_list + [
                transforms.Normalize(np.array([x / 255.0 for x in [120.39586422,  115.59361427, 104.54012653]]),
                                     np.array([x / 255.0 for x in [70.68188272,   68.27635443,  72.54505529]]))
            ])

    else:
        transforms_list = transforms.Compose(
                transforms_list + [
                transforms.Normalize(np.array([0.485, 0.456, 0.406]),
                                     np.array([0.229, 0.224, 0.225]))
            ])


    # get datasets
    dataset = dataset_class(
        root=args.data_folder,
        num_classes_per_task=args.num_ways,
        meta_split=phase,
        transform=transforms_list,
        target_transform=Categorical(num_classes=args.num_ways),
        download=args.download
    )

    dataset = ClassSplitter(dataset, 
        shuffle=(phase == 'train'),
        num_train_per_class=args.num_shots, 
        num_test_per_class=args.test_shots
    )

    return dataset
Esempio n. 19
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=args.num_shots,
                                      num_test_per_class=args.num_shots_test)
    class_augmentations = [Rotation([90, 180, 270])]
    if args.dataset == 'sinusoid':
        transform = ToTensor()

        meta_train_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif args.dataset == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(args.folder,
                                      transform=transform,
                                      target_transform=Categorical(
                                          args.num_ways),
                                      num_classes_per_task=args.num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(args.folder,
                                    transform=transform,
                                    target_transform=Categorical(
                                        args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)

        model = ModelConvOmniglot(args.num_ways, hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    elif args.dataset == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_train=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_val=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(args.num_ways,
                                      hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            args.dataset))

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(
        model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=loss_function,
        device=device)

    best_val_accuracy = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))

        if (best_val_accuracy is None) \
                or (best_val_accuracy < results['accuracies_after']):
            best_val_accuracy = results['accuracies_after']
            if args.output_folder is not None:
                with open(args.model_path, 'wb') as f:
                    torch.save(model.state_dict(), f)

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()
Esempio n. 20
0
    def generate_batch(self, test):
        '''
        The data-loaders of torch meta are fully compatible with standard data
        components of PyTorch, such as Dataset and DataLoade+r.
        Augments the pool of class candidates with variants, such as rotated images
        '''
        if test == True:
            meta_train = False
            meta_test = True
            f = "metatest"
        elif test == False:
            meta_train = True
            meta_test = False
            f = "metatrain"

        if self.dataset == "miniImageNet":
            dataset = MiniImagenet(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(84), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "tieredImageNet":
            dataset = TieredImagenet(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(84), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "CIFARFS":
            dataset = CIFARFS(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(32), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "FC100":
            dataset = FC100(
                f,
                # Number of waysfrom torchmeta.datasets
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(32), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "Omniglot":
            dataset = Omniglot(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(28), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        dataset = ClassSplitter(dataset,
                                shuffle=True,
                                num_train_per_class=self.K,
                                num_test_per_class=self.num_test_per_class)

        dataloader = BatchMetaDataLoader(dataset,
                                         batch_size=self.batch_size,
                                         num_workers=2)
        return dataloader
def main(args):
    wandb.init(project="gbml")

    DataClass = None
    normalizer = None
    if args.dataset == 'omniglot':
        DataClass = Omniglot
        args.in_channels = 1
        normalizer = lambda x: x
    elif args.dataset == 'miniimagenet':
        DataClass = MiniImagenet
        args.in_channels = 3
        normalizer = transforms.Normalize(np.array([0.485, 0.456, 0.406]),
                                          np.array([0.229, 0.224, 0.225]))

    if args.alg == 'MAML':
        model = MAML(args)
    elif args.alg == 'Reptile':
        model = Reptile(args)
    elif args.alg == 'Neumann':
        model = Neumann(args)
    elif args.alg == 'CAVIA':
        model = CAVIA(args)
    elif args.alg == 'iMAML':
        model = iMAML(args)
    else:
        raise ValueError('Not implemented Meta-Learning Algorithm')
    if args.load:
        model.load()
    elif args.load_encoder:
        model.load_encoder()

    wandb.config.update(args)

    train_dataset = DataClass(
        args.data_path,
        num_classes_per_task=args.num_way,
        meta_split='train',
        transform=transforms.Compose([
            transforms.RandomCrop(80, padding=8),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalizer
        ]),
        target_transform=Categorical(num_classes=args.num_way),
        download=True)
    train_dataset = ClassSplitter(train_dataset,
                                  shuffle=True,
                                  num_train_per_class=args.num_shot,
                                  num_test_per_class=args.num_query)
    train_loader = BatchMetaDataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=args.num_workers)

    valid_dataset = DataClass(
        args.data_path,
        num_classes_per_task=args.num_way,
        meta_split='val',
        transform=transforms.Compose(
            [transforms.CenterCrop(80),
             transforms.ToTensor(), normalizer]),
        target_transform=Categorical(num_classes=args.num_way))
    valid_dataset = ClassSplitter(valid_dataset,
                                  shuffle=True,
                                  num_train_per_class=args.num_shot,
                                  num_test_per_class=args.num_query)
    valid_loader = BatchMetaDataLoader(valid_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=args.num_workers)

    test_dataset = DataClass(
        args.data_path,
        num_classes_per_task=args.num_way,
        meta_split='test',
        transform=transforms.Compose(
            [transforms.CenterCrop(80),
             transforms.ToTensor(), normalizer]),
        target_transform=Categorical(num_classes=args.num_way))
    test_dataset = ClassSplitter(test_dataset,
                                 shuffle=True,
                                 num_train_per_class=args.num_shot,
                                 num_test_per_class=args.num_query)
    test_loader = BatchMetaDataLoader(test_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      pin_memory=True,
                                      num_workers=args.num_workers)

    for epoch in range(args.num_epoch):

        res, is_best = run_epoch(epoch, args, model, train_loader,
                                 valid_loader, test_loader)
        dict2tsv(res, os.path.join(args.result_path, args.alg, args.log_path))

        if is_best:
            model.save()
        torch.cuda.empty_cache()

        if args.lr_sched:
            model.lr_sched()

    return None
Esempio n. 22
0
def main(args):

    if args.alg=='MAML':
        model = MAML(args)
    elif args.alg=='Reptile':
        model = Reptile(args)
    elif args.alg=='Neumann':
        model = Neumann(args)
    elif args.alg=='CAVIA':
        model = CAVIA(args)
    elif args.alg=='iMAML':
        model = iMAML(args)
    elif args.alg=='FOMAML':
        model = FOMAML(args)
    else:
        raise ValueError('Not implemented Meta-Learning Algorithm')

    if args.load:
        model.load()
    elif args.load_encoder:
        model.load_encoder()

    train_dataset = Omniglot(args.data_path, num_classes_per_task=args.num_way,
                        meta_split='train', 
                        transform=transforms.Compose([
                        transforms.RandomCrop(80, padding=8),
                        transforms.ToTensor(),
                        ]),
                        target_transform=Categorical(num_classes=args.num_way)
                        )
                        

    train_dataset = ClassSplitter(train_dataset, shuffle=True, num_train_per_class=args.num_shot, num_test_per_class=args.num_query)
    train_loader = BatchMetaDataLoader(train_dataset, batch_size=args.batch_size,
        shuffle=True, pin_memory=True, num_workers=args.num_workers)

    valid_dataset = Omniglot(args.data_path, num_classes_per_task=args.num_way,
                        meta_split='val', 
                        transform=transforms.Compose([
                        transforms.CenterCrop(80),
                        transforms.ToTensor(),
                        ]),
                        target_transform=Categorical(num_classes=args.num_way)
                        )

    valid_dataset = ClassSplitter(valid_dataset, shuffle=True, num_train_per_class=args.num_shot, num_test_per_class=args.num_query)
    valid_loader = BatchMetaDataLoader(valid_dataset, batch_size=args.batch_size,
        shuffle=True, pin_memory=True, num_workers=args.num_workers)

    test_dataset = Omniglot(args.data_path, num_classes_per_task=args.num_way,
                        meta_split='test', 
                        transform=transforms.Compose([
                        transforms.CenterCrop(80),
                        transforms.ToTensor(),
                        ]),
                        target_transform=Categorical(num_classes=args.num_way)
                        )

    test_dataset = ClassSplitter(test_dataset, shuffle=True, num_train_per_class=args.num_shot, num_test_per_class=args.num_query)
    test_loader = BatchMetaDataLoader(test_dataset, batch_size=args.batch_size,
        shuffle=True, pin_memory=True, num_workers=args.num_workers)

    for epoch in range(args.num_epoch):

        res, is_best = run_epoch(epoch, args, model, train_loader, valid_loader, test_loader)

        filename = os.path.join(args.result_path, args.alg, 'omniglot_' '{0}shot_{1}way'.format(args.num_shot, args.num_way)+args.log_path)
        dict2tsv(res, filename)

        if is_best:
            model.save('omniglot_' '{0}shot_{1}way'.format(args.num_shot, args.num_way))
        torch.cuda.empty_cache()

        if args.lr_sched:
            model.lr_sched()

    return None
Esempio n. 23
0
def train():
    transform = transforms.Compose(
        [transforms.Resize(84), transforms.ToTensor()])
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=5,
                                      num_test_per_class=5)
    dataset = MiniImagenet('',
                           transform=transform,
                           num_classes_per_task=5,
                           target_transform=Categorical(num_classes=5),
                           meta_split="train",
                           dataset_transform=dataset_transform)

    dataloader = BatchMetaDataLoader(dataset, batch_size=1, shuffle=True)

    model = ModelConvMiniImagenet(5)
    model.to(device='cuda')
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    accuracy_l = list()

    with tqdm(dataloader, total=1000) as pbar:
        for batch_idx, batch in enumerate(pbar):

            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device='cuda')
            train_targets = train_targets.to(device='cuda')

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device='cuda')
            test_targets = test_targets.to(device='cuda')

            outer_loss = torch.tensor(0., device='cuda')
            accuracy = torch.tensor(0., device='cuda')
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):

                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model, inner_loss)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)
            outer_loss.div_(1)
            accuracy.div_(1)

            outer_loss.backward()
            meta_optimizer.step()
            accuracy_l.append(accuracy.item())
            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if (batch_idx >= 1000):
                break

    plt.plot(accuracy_l)
    plt.show()
Esempio n. 24
0
def helper_with_default_tabular(klass, folder, shots, ways, shuffle=True,
                                test_shots=None, seed=None, defaults=None, **kwargs):
    """
    Parameters
    ----------
    klass : CombinationMetaDataset
        the class corresponding to the meta-dataset, e.g., Covertype

    folder : string
        Root directory where the dataset folder exists, e.g., `covertype_task_id_2118`.

    shots : int
        Number of (training) examples per class in each task. This corresponds
        to `k` in `k-shot` classification.

    ways : int
        Number of classes per task. This corresponds to `N` in `N-way`
        classification.

    shuffle : bool (default: `True`)
        Shuffle the examples when creating the tasks.

    test_shots : int, optional
        Number of test examples per class in each task. If `None`, then the
        number of test examples is equal to the number of training examples per
        class.

    seed : int, optional
        Random seed to be used in the meta-dataset.

    kwargs
        Additional arguments passed to the `TieredImagenet` class.

    Returns
    -------
    klass
        The meta-dataset with ClassSplitter applied, e.g., Covertype.
    """

    if defaults is None:
        defaults = {}

    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']

    if 'transform' not in kwargs:
        kwargs['transform'] = defaults.get('transform', NumpyToTorch())

    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = defaults.get('target_transform',
                                                  Categorical(ways))
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = defaults.get('class_augmentations', None)

    if test_shots is None:
        test_shots = shots
    dataset = klass(folder,
                    num_classes_per_task=ways,
                    **kwargs)
    dataset = ClassSplitter(dataset,
                            shuffle=shuffle,
                            num_train_per_class=shots,
                            num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
    transform = None
    target_transform = None
    nclasses = 5
    ntasks = 3
    samples_per_class = 2
    samples_per_test = 2
    classes_meta = np.arange(100, dtype='int')

    size = [2, 32//ds, 32//ds]


    if transform is None:
        transform = Compose([
            CropDims(low_crop=[0,0], high_crop=[32,32], dims=[2,3]),
            Downsample(factor=[dt,1,ds,ds]),
            ToEventSum(T = chunk_size, size = size),
            ToTensor()])

    if target_transform is None:
        target_transform = Compose([Repeat(chunk_size), toOneHot(nclasses)])

    cc = DoubleNMNIST(root = root, meta_val=True, transform = transform, target_transform = target_transform, chunk_size=chunk_size,  num_classes_per_task=5)
    cd = ClassNMNISTDataset(root,meta_train=True, transform = transform, target_transform = target_transform, chunk_size=chunk_size)

    it = BatchMetaDataLoader(ClassSplitter(cc, shuffle=True, num_train_per_class=3, num_test_per_class=3), batch_size=16, num_workers=0)

    from torchmeta.datasets.doublemnist import DoubleMNISTClassDataset
    dataset = DoubleMNISTClassDataset("data/",meta_train=True)
    dataset_h = BatchMetaDataLoader(doublemnist("data/",meta_train=True, ways=5, shots=10), batch_size=16, num_workers=0)

Esempio n. 26
0
def main(args):
    with open(args.config, 'r') as f:
        config = json.load(f)

    if args.folder is not None:
        config['folder'] = args.folder
    if args.num_steps > 0:
        config['num_steps'] = args.num_steps
    if args.num_batches > 0:
        config['num_batches'] = args.num_batches
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    dataset_transform = ClassSplitter(
        shuffle=True,
        num_train_per_class=config['num_shots'],
        num_test_per_class=config['num_shots_test'])
    if config['dataset'] == 'sinusoid':
        transform = ToTensor()
        meta_test_dataset = Sinusoid(config['num_shots'] +
                                     config['num_shots_test'],
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)
        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif config['dataset'] == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])
        meta_test_dataset = Omniglot(config['folder'],
                                     transform=transform,
                                     target_transform=Categorical(
                                         config['num_ways']),
                                     num_classes_per_task=config['num_ways'],
                                     meta_train=True,
                                     dataset_transform=dataset_transform,
                                     download=True)
        model = ModelConvOmniglot(config['num_ways'],
                                  hidden_size=config['hidden_size'])
        loss_function = F.cross_entropy

    elif config['dataset'] == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])
        meta_test_dataset = MiniImagenet(
            config['folder'],
            transform=transform,
            target_transform=Categorical(config['num_ways']),
            num_classes_per_task=config['num_ways'],
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        model = ModelConvMiniImagenet(config['num_ways'],
                                      hidden_size=config['hidden_size'])
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            config['dataset']))

    with open(config['model_path'], 'rb') as f:
        model.load_state_dict(torch.load(f, map_location=device))

    meta_test_dataloader = BatchMetaDataLoader(meta_test_dataset,
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    metalearner = ModelAgnosticMetaLearning(
        model,
        first_order=config['first_order'],
        num_adaptation_steps=config['num_steps'],
        step_size=config['step_size'],
        loss_function=loss_function,
        device=device)

    results = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=config['num_batches'],
                                   verbose=args.verbose,
                                   desc='Test')

    # Save results
    dirname = os.path.dirname(config['model_path'])
    with open(os.path.join(dirname, 'results.json'), 'w') as f:
        json.dump(results, f)
Esempio n. 27
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None,
                          meta_batch_size=1,
                          ensemble_size=0
                          ):
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40], meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size, meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = batch_cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(folder,
                                          transform=transform,
                                          target_transform=Categorical(num_ways),
                                          num_classes_per_task=num_ways,
                                          meta_train=True,
                                          dataset_transform=dataset_transform,
                                          download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(folder,
                                         transform=transform,
                                         target_transform=Categorical(num_ways),
                                         num_classes_per_task=num_ways,
                                         meta_test=True,
                                         dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size, meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = batch_cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Esempio n. 28
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None):

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_test=True,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'doublenmnist':
        from torchneuromorphic.doublenmnist_torchmeta.doublenmnist_dataloaders import DoubleNMNIST, Compose, ClassNMNISTDataset, CropDims, Downsample, ToCountFrame, ToTensor, ToEventSum, Repeat, toOneHot
        from torchneuromorphic.utils import plot_frames_imshow
        from matplotlib import pyplot as plt
        from torchmeta.utils.data import CombinationMetaDataset

        root = 'data/nmnist/n_mnist.hdf5'
        chunk_size = 300
        ds = 2
        dt = 1000
        transform = None
        target_transform = None

        size = [2, 32 // ds, 32 // ds]

        transform = Compose([
            CropDims(low_crop=[0, 0], high_crop=[32, 32], dims=[2, 3]),
            Downsample(factor=[dt, 1, ds, ds]),
            ToEventSum(T=chunk_size, size=size),
            ToTensor()
        ])

        if target_transform is None:
            target_transform = Compose(
                [Repeat(chunk_size), toOneHot(num_ways)])

        loss_function = F.cross_entropy

        meta_train_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_train=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                           num_train_per_class=num_shots,
                                           num_test_per_class=num_shots_test)
        meta_val_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_val=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                         num_train_per_class=num_shots,
                                         num_test_per_class=num_shots_test)
        meta_test_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_test=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                          num_train_per_class=num_shots,
                                          num_test_per_class=num_shots_test)

        model = ModelConvDoubleNMNIST(num_ways, hidden_size=hidden_size)

    elif name == 'doublenmnistsequence':
        from torchneuromorphic.doublenmnist_torchmeta.doublenmnist_dataloaders import DoubleNMNIST, Compose, ClassNMNISTDataset, CropDims, Downsample, ToCountFrame, ToTensor, ToEventSum, Repeat, toOneHot
        from torchneuromorphic.utils import plot_frames_imshow
        from matplotlib import pyplot as plt
        from torchmeta.utils.data import CombinationMetaDataset

        root = 'data/nmnist/n_mnist.hdf5'
        chunk_size = 300
        ds = 2
        dt = 1000
        transform = None
        target_transform = None

        size = [2, 32 // ds, 32 // ds]

        transform = Compose([
            CropDims(low_crop=[0, 0], high_crop=[32, 32], dims=[2, 3]),
            Downsample(factor=[dt, 1, ds, ds]),
            ToCountFrame(T=chunk_size, size=size),
            ToTensor()
        ])

        if target_transform is None:
            target_transform = Compose(
                [Repeat(chunk_size), toOneHot(num_ways)])

        loss_function = F.cross_entropy

        meta_train_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_train=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                           num_train_per_class=num_shots,
                                           num_test_per_class=num_shots_test)
        meta_val_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_val=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                         num_train_per_class=num_shots,
                                         num_test_per_class=num_shots_test)
        meta_test_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_test=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                          num_train_per_class=num_shots,
                                          num_test_per_class=num_shots_test)

        model = ModelDECOLLE(num_ways)

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Esempio n. 29
0
def get_dataset(options):
    # Choose the learning datdset
    if options.dataset == 'miniImageNet':
        from torchmeta.datasets import MiniImagenet
        mean_pix = [
            x / 255 for x in [120.39586422, 115.59361427, 104.54012653]
        ]
        std_pix = [x / 255 for x in [70.68188272, 68.27635443, 72.54505529]]
        if options.network == 'ResNet18':
            train_start = RandomResizedCrop(224)
        else:
            train_start = RandomCrop(84, padding=8)
        dataset_train = MiniImagenet(
            "data",
            num_classes_per_task=options.train_way,
            transform=Compose([
                train_start,
                ColorJitter(brightness=.4, contrast=.4, saturation=.4),
                RandomHorizontalFlip(),
                ToTensor(),
                Normalize(mean=mean_pix, std=std_pix),
            ]),
            target_transform=Categorical(num_classes=options.train_way),
            meta_train=True,
            download=True)
        dataset_train = ClassSplitter(dataset_train,
                                      shuffle=True,
                                      num_train_per_class=options.train_shot,
                                      num_test_per_class=options.train_query)
        dataloader_train = BatchMetaDataLoader(
            dataset_train,
            batch_size=options.episodes_per_batch,
            num_workers=options.num_workers)
        if options.network == 'ResNet18':
            dataset_val = MiniImagenet(
                "data",
                num_classes_per_task=options.val_way,
                transform=Compose([
                    Resize(224),
                    ToTensor(),
                    Normalize(mean=mean_pix, std=std_pix),
                ]),
                target_transform=Categorical(num_classes=options.val_way),
                meta_val=True,
                download=False)
        else:
            dataset_val = MiniImagenet(
                "data",
                num_classes_per_task=options.val_way,
                transform=Compose([
                    ToTensor(),
                    Normalize(mean=mean_pix, std=std_pix),
                ]),
                target_transform=Categorical(num_classes=options.val_way),
                meta_val=True,
                download=False)
        dataset_val = ClassSplitter(dataset_val,
                                    shuffle=True,
                                    num_train_per_class=options.val_shot,
                                    num_test_per_class=options.val_query)
        dataloader_val = BatchMetaDataLoader(dataset_val,
                                             batch_size=1,
                                             num_workers=options.num_workers)
    elif options.dataset == 'tieredImageNet':
        from torchmeta.datasets import TieredImagenet
        mean_pix = [
            x / 255 for x in [120.39586422, 115.59361427, 104.54012653]
        ]
        std_pix = [x / 255 for x in [70.68188272, 68.27635443, 72.54505529]]
        dataset_train = TieredImagenet(
            "data",
            num_classes_per_task=options.train_way,
            transform=Compose([
                RandomCrop(84, padding=8),
                ColorJitter(brightness=.4, contrast=.4, saturation=.4),
                RandomHorizontalFlip(),
                ToTensor(),
                Normalize(mean=mean_pix, std=std_pix),
            ]),
            target_transform=Categorical(num_classes=options.train_way),
            meta_train=True,
            download=True)
        dataset_train = ClassSplitter(dataset_train,
                                      shuffle=True,
                                      num_train_per_class=options.train_shot,
                                      num_test_per_class=options.train_query)
        dataloader_train = BatchMetaDataLoader(
            dataset_train,
            batch_size=options.episodes_per_batch,
            num_workers=options.num_workers)
        dataset_val = TieredImagenet(
            "data",
            num_classes_per_task=options.val_way,
            transform=Compose([
                ToTensor(),
                Normalize(mean=mean_pix, std=std_pix),
            ]),
            target_transform=Categorical(num_classes=options.val_way),
            meta_val=True,
            download=False)
        dataset_val = ClassSplitter(dataset_val,
                                    shuffle=True,
                                    num_train_per_class=options.val_shot,
                                    num_test_per_class=options.val_query)
        dataloader_val = BatchMetaDataLoader(dataset_val,
                                             batch_size=1,
                                             num_workers=options.num_workers)
    elif options.dataset == 'CIFAR_FS':
        from torchmeta.datasets import CIFARFS
        mean_pix = [
            x / 255 for x in [129.37731888, 124.10583864, 112.47758569]
        ]
        std_pix = [x / 255 for x in [68.20947949, 65.43124043, 70.45866994]]
        if options.coarse:
            dataset_train = CIFARFS("data",
                                    num_classes_per_task=1,
                                    meta_train=True,
                                    download=True)
            dataset_train = ClassSplitter(dataset_train,
                                          shuffle=False,
                                          num_train_per_class=1,
                                          num_test_per_class=1)
            li = {}
            for i in range(len(dataset_train)):
                li[i] = dataset_train[(i, )]['train'].__getitem__(0)[1][0][0]
            sli = list(li.values())
            dli = {x: ix for ix, x in enumerate(set(sli))}
            if options.super_coarse:
                dli['aquatic_mammals'] = 21
                dli['fish'] = 21
                dli['flowers'] = 22
                dli['fruit_and_vegetables'] = 22
                dli['food_containers'] = 23
                dli['household_electrical_devices'] = 23
                dli['household_furniture'] = 23
                dli['insects'] = 24
                dli['non-insect_invertebrates'] = 24
                dli['large_carnivores'] = 25
                dli['reptiles'] = 25
                dli['large_natural_outdoor_scenes'] = 26
                dli['trees'] = 26
                dli['large_omnivores_and_herbivores'] = 27
                dli['medium_mammals'] = 27
                dli['people'] = 27
                dli['vehicles_1'] = 28
                dli['vehicles_2'] = 28
            nli = rankdata([dli[item] for item in sli], 'dense')

            def new__iter__(self):
                num_coarse = max(nli) + 1
                for ix in range(1, num_coarse):
                    for index in combinations(
                        [n for n in range(len(li)) if nli[n] == ix],
                            self.num_classes_per_task):
                        yield self[index]

            def newsample_task(self):
                num = self.np_random.randint(1, max(nli) + 1)
                sample = [n for n in range(len(li)) if nli[n] == num]
                index = self.np_random.choice(sample,
                                              size=self.num_classes_per_task,
                                              replace=False)
                return self[tuple(index)]

            def new__len__(self):
                total_length = 0
                num_coarse = max(nli) + 1
                for jx in range(1, num_coarse):
                    num_classes, length = len(
                        [n for n in range(len(li)) if nli[n] == jx]), 1
                    for ix in range(1, self.num_classes_per_task + 1):
                        length *= (num_classes - ix + 1) / ix
                    total_length += length
                return int(total_length)

            CIFARFS.__iter__ = new__iter__
            CIFARFS.sample_task = newsample_task
            CIFARFS.__len__ = new__len__
        dataset_train = CIFARFS(
            "data",
            num_classes_per_task=options.train_way,
            transform=Compose([
                RandomCrop(32, padding=4),
                ColorJitter(brightness=.4, contrast=.4, saturation=.4),
                RandomHorizontalFlip(),
                ToTensor(),
                Normalize(mean=mean_pix, std=std_pix),
            ]),
            target_transform=Categorical(num_classes=options.train_way),
            meta_train=True,
            download=True)
        dataset_train = ClassSplitter(dataset_train,
                                      shuffle=True,
                                      num_train_per_class=options.train_shot,
                                      num_test_per_class=options.train_query)
        if options.coarse_weights:
            dataloader_train = BatchMetaDataLoaderWithLabels(
                dataset_train,
                batch_size=options.episodes_per_batch,
                num_workers=options.num_workers)
        else:
            dataloader_train = BatchMetaDataLoader(
                dataset_train,
                batch_size=options.episodes_per_batch,
                num_workers=options.num_workers)
        dataset_val = CIFARFS(
            "data",
            num_classes_per_task=options.val_way,
            transform=Compose([
                ToTensor(),
                Normalize(mean=mean_pix, std=std_pix),
            ]),
            target_transform=Categorical(num_classes=options.val_way),
            meta_val=True,
            download=False)
        dataset_val = ClassSplitter(dataset_val,
                                    shuffle=True,
                                    num_train_per_class=options.val_shot,
                                    num_test_per_class=options.val_query)
        dataloader_val = BatchMetaDataLoader(dataset_val,
                                             batch_size=1,
                                             num_workers=options.num_workers)
    elif options.dataset == 'FC100':
        from torchmeta.datasets import FC100
        mean_pix = [
            x / 255 for x in [129.37731888, 124.10583864, 112.47758569]
        ]
        std_pix = [x / 255 for x in [68.20947949, 65.43124043, 70.45866994]]
        if options.coarse:
            dataset_train = FC100("data",
                                  num_classes_per_task=1,
                                  meta_train=True,
                                  download=True)
            dataset_train = ClassSplitter(dataset_train,
                                          shuffle=False,
                                          num_train_per_class=1,
                                          num_test_per_class=1)
            li = {}
            for i in range(len(dataset_train)):
                li[i] = dataset_train[(i, )]['train'].__getitem__(0)[1][0][0]
            sli = list(li.values())
            dli = {x: ix for ix, x in enumerate(set(sli))}
            if options.super_coarse:
                dli['aquatic_mammals'] = 21
                dli['fish'] = 21
                dli['flowers'] = 22
                dli['fruit_and_vegetables'] = 22
                dli['food_containers'] = 23
                dli['household_electrical_devices'] = 23
                dli['household_furniture'] = 23
                dli['insects'] = 24
                dli['non-insect_invertebrates'] = 24
                dli['large_carnivores'] = 25
                dli['reptiles'] = 25
                dli['large_natural_outdoor_scenes'] = 26
                dli['trees'] = 26
                dli['large_omnivores_and_herbivores'] = 27
                dli['medium_mammals'] = 27
                dli['people'] = 27
                dli['vehicles_1'] = 28
                dli['vehicles_2'] = 28
            nli = rankdata([dli[item] for item in sli], 'dense')

            def new__iter__(self):
                num_coarse = max(nli) + 1
                for ix in range(1, num_coarse):
                    for index in combinations(
                        [n for n in range(len(li)) if nli[n] == ix],
                            self.num_classes_per_task):
                        yield self[index]

            def newsample_task(self):
                num = self.np_random.randint(1, max(nli) + 1)
                sample = [n for n in range(len(li)) if nli[n] == num]
                index = self.np_random.choice(sample,
                                              size=self.num_classes_per_task,
                                              replace=False)
                return self[tuple(index)]

            def new__len__(self):
                total_length = 0
                num_coarse = max(nli) + 1
                for jx in range(1, num_coarse):
                    num_classes, length = len(
                        [n for n in range(len(li)) if nli[n] == jx]), 1
                    for ix in range(1, self.num_classes_per_task + 1):
                        length *= (num_classes - ix + 1) / ix
                    total_length += length
                return int(total_length)

            FC100.__iter__ = new__iter__
            FC100.sample_task = newsample_task
            FC100.__len__ = new__len__
        dataset_train = FC100(
            "data",
            num_classes_per_task=options.train_way,
            transform=Compose([
                RandomCrop(32, padding=4),
                ColorJitter(brightness=.4, contrast=.4, saturation=.4),
                RandomHorizontalFlip(),
                ToTensor(),
                Normalize(mean=mean_pix, std=std_pix),
            ]),
            target_transform=Categorical(num_classes=options.train_way),
            meta_train=True,
            download=True)
        dataset_train = ClassSplitter(dataset_train,
                                      shuffle=True,
                                      num_train_per_class=options.train_shot,
                                      num_test_per_class=options.train_query)
        if options.coarse_weights:
            dataloader_train = BatchMetaDataLoaderWithLabels(
                dataset_train,
                batch_size=options.episodes_per_batch,
                num_workers=options.num_workers)
        else:
            dataloader_train = BatchMetaDataLoader(
                dataset_train,
                batch_size=options.episodes_per_batch,
                num_workers=options.num_workers)
        dataset_val = FC100(
            "data",
            num_classes_per_task=options.val_way,
            transform=Compose([
                ToTensor(),
                Normalize(mean=mean_pix, std=std_pix),
            ]),
            target_transform=Categorical(num_classes=options.val_way),
            meta_val=True,
            download=False)
        dataset_val = ClassSplitter(dataset_val,
                                    shuffle=True,
                                    num_train_per_class=options.val_shot,
                                    num_test_per_class=options.val_query)
        dataloader_val = BatchMetaDataLoader(dataset_val,
                                             batch_size=1,
                                             num_workers=options.num_workers)
    else:
        print("Cannot recognize the dataset type")
        assert (False)

    return (dataloader_train, dataloader_val)
Esempio n. 30
0
def main(args):
    print(args)
    logging.info(args)

    if args.alg == 'MAML':
        model = MAML(args)
    elif args.alg == 'Reptile':
        model = Reptile(args)
    elif args.alg == 'FOMAML':
        model = FOMAML(args)
    elif args.alg == 'Neumann':
        model = Neumann(args)
    elif args.alg == 'CAVIA':
        model = CAVIA(args)
    elif args.alg == 'iMAML':
        model = iMAML(args)
    else:
        raise ValueError('Not implemented Meta-Learning Algorithm')

    if args.load:
        model.load()

    train_dataset = setData(
        root=args.data_path,
        num_classes_per_task=args.num_way,
        meta_split='train',
        transform=transforms.Compose([
            transforms.RandomCrop(80, padding=8),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]),
        target_transform=Categorical(num_classes=args.num_way),
        download=True)
    train_dataset = ClassSplitter(train_dataset,
                                  shuffle=True,
                                  num_train_per_class=args.num_shot,
                                  num_test_per_class=args.num_query)
    train_loader = BatchMetaDataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=args.num_workers)

    valid_dataset = setData(
        root=args.data_path,
        num_classes_per_task=args.num_way,
        meta_split='val',
        transform=transforms.Compose(
            [transforms.CenterCrop(80),
             transforms.ToTensor()]),
        target_transform=Categorical(num_classes=args.num_way))
    valid_dataset = ClassSplitter(valid_dataset,
                                  shuffle=True,
                                  num_train_per_class=args.num_shot,
                                  num_test_per_class=args.num_query)
    valid_loader = BatchMetaDataLoader(valid_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=args.num_workers)

    test_dataset = setData(
        root=args.data_path,
        num_classes_per_task=args.num_way,
        meta_split='test',
        transform=transforms.Compose(
            [transforms.CenterCrop(80),
             transforms.ToTensor()]),
        target_transform=Categorical(num_classes=args.num_way))
    test_dataset = ClassSplitter(test_dataset,
                                 shuffle=True,
                                 num_train_per_class=args.num_shot,
                                 num_test_per_class=args.num_query)
    test_loader = BatchMetaDataLoader(test_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      pin_memory=True,
                                      num_workers=args.num_workers)
    global writer
    global train_iter
    global test_iter
    writer = SummaryWriter(comment=args.exp)
    train_iter = 0
    test_iter = 0
    counter = 0
    patience = 15

    for epoch in range(args.num_epoch):

        res, is_best = run_epoch(epoch, args, model, train_loader,
                                 valid_loader, test_loader)
        dict2tsv(
            res,
            os.path.join(args.result_path, args.alg,
                         str(args.num_shot) + '_' + str(args.num_way),
                         args.log_path))

        if is_best:
            logging.info('- Found new best accuracy')
            counter = 0  # reset
            model.save()
        else:
            counter += 1

        # disable early stopping
        # if counter > patience:
        #     logging.info('- No improvement in a while, stopping training...')
        #     break

        if args.lr_sched:
            model.lr_sched(res['train_loss'])

    return None