def __init__(self, device, problem="default", task_num=16, n_way=5, imgsz=28, k_spt=1, k_qry=19):
     self.device = device
     self.task_num = task_num
     self.n_way, self.imgsz = n_way, imgsz
     self.k_spt, self.k_qry = k_spt, k_qry
     assert k_spt + k_qry <= 20, "Max 20 k_spt + k_20"
     class_augmentations = [Rotation([90, 180, 270])]
     meta_train_dataset = Omniglot("data",
                                   transform=Compose([Resize(self.imgsz), ToTensor()]),
                                   target_transform=Categorical(num_classes=self.n_way),
                                   num_classes_per_task=self.n_way,
                                   meta_train=True,
                                   class_augmentations=class_augmentations,
                                   download=True
                                 )
     meta_val_dataset = Omniglot("data",
                                   transform=Compose([Resize(self.imgsz), ToTensor()]),
                                   target_transform=Categorical(num_classes=self.n_way),
                                   num_classes_per_task=self.n_way,
                                   meta_val=True,
                                   class_augmentations=class_augmentations,
                                 )
     meta_test_dataset = Omniglot("data",
                                   transform=Compose([Resize(self.imgsz), ToTensor()]),
                                   target_transform=Categorical(num_classes=self.n_way),
                                   num_classes_per_task=self.n_way,
                                   meta_test=True,
                                   class_augmentations=class_augmentations,
                                 )
     self.train_dataset = ClassSplitter(meta_train_dataset, shuffle=True, num_train_per_class=k_spt, num_test_per_class=k_qry)
     self.val_dataset = ClassSplitter(meta_val_dataset, shuffle=True, num_train_per_class=k_spt, num_test_per_class=k_qry)
     self.test_dataset = ClassSplitter(meta_test_dataset, shuffle=True, num_train_per_class=k_spt, num_test_per_class=k_qry)
def omniglot(folder, shots, ways, shuffle=True, test_shots=None,
             seed=None, **kwargs):
    """Helper function to create a meta-dataset for the Omniglot dataset.

    Parameters
    ----------
    folder : string
        Root directory where the dataset folder `omniglot` exists.

    shots : int
        Number of (training) examples per class in each task. This corresponds 
        to `k` in `k-shot` classification.

    ways : int
        Number of classes per task. This corresponds to `N` in `N-way` 
        classification.

    shuffle : bool (default: `True`)
        Shuffle the examples when creating the tasks.

    test_shots : int, optional
        Number of test examples per class in each task. If `None`, then the 
        number of test examples is equal to the number of training examples per 
        class.

    seed : int, optional
        Random seed to be used in the meta-dataset.

    kwargs
        Additional arguments passed to the `Omniglot` class.

    See also
    --------
    `datasets.Omniglot` : Meta-dataset for the Omniglot dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(28), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = [Rotation([90, 180, 270])]
    if test_shots is None:
        test_shots = shots

    dataset = Omniglot(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
Beispiel #3
0
def create_og_data_loader(
    root,
    meta_split,
    k_way,
    n_shot,
    input_size,
    n_query,
    batch_size,
    num_workers,
    download=False,
    use_vinyals_split=False,
    seed=None,
):
    """Create a torchmeta BatchMetaDataLoader for Omniglot

    Args:
        root: Path to Omniglot data root folder (containing an 'omniglot'` subfolder with the
            preprocess json-Files or downloaded zip-files).
        meta_split: see torchmeta.datasets.Omniglot
        k_way: Number of classes per task
        n_shot: Number of samples per class
        input_size: Images are resized to this size.
        n_query: Number of test images per class
        batch_size: Meta batch size
        num_workers: Number of workers for data preprocessing
        download: Download (and dataset specific preprocessing that needs to be done on the
            downloaded files).
        use_vinyals_split: see torchmeta.datasets.Omniglot
        seed: Seed to be used in the meta-dataset

    Returns:
        A torchmeta :class:`BatchMetaDataLoader` object.
    """
    dataset = Omniglot(
        root,
        num_classes_per_task=k_way,
        transform=Compose([Resize(input_size), ToTensor()]),
        target_transform=Categorical(num_classes=k_way),
        class_augmentations=[Rotation([90, 180, 270])],
        meta_split=meta_split,
        download=download,
        use_vinyals_split=use_vinyals_split,
    )
    dataset = ClassSplitter(dataset,
                            shuffle=True,
                            num_train_per_class=n_shot,
                            num_test_per_class=n_query)
    dataset.seed = seed
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=batch_size,
                                     num_workers=num_workers,
                                     shuffle=True)
    return dataloader
Beispiel #4
0
def get_dataset(dataset_name, dataset_path=None, image_size=64):
    if dataset_name in ['MNIST']:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        dataset_train = datasets.MNIST(root=dataset_path,
                                       download=True,
                                       transform=transform,
                                       train=True)
        dataset_test = datasets.MNIST(root=dataset_path,
                                      download=True,
                                      transform=transform,
                                      train=False)
        num_channels = 1
        num_train_classes = len(dataset_train.classes)
        num_test_classes = len(dataset_test.classes)

    elif dataset_name in ['Omniglot']:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: (x * 2) - 1)
        ])
        dataset_train = Omniglot(root=dataset_path,
                           num_classes_per_task=1,
                           transform=transform,
                           target_transform=None,
                           meta_train=True,
                           download=True,
                           use_vinyals_split=False)
        dataset_test = Omniglot(root=dataset_path,
                           num_classes_per_task=1,
                           transform=transform,
                           target_transform=None,
                           meta_test=True,
                           download=True,
                           use_vinyals_split=False)
        num_channels = 1

        num_train_classes = dataset_train.dataset.num_classes
        num_test_classes = dataset_test.dataset.num_classes

    elif dataset_name in ['cifar10']:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        dataset_train = datasets.CIFAR10(root=dataset_path,
                                         download=True,
                                         transform=transform,
                                         train=True)
        dataset_test = datasets.CIFAR10(root=dataset_path,
                                        download=True,
                                        transform=transform,
                                        train=False)
        num_channels = 3
        num_train_classes = len(dataset_train.classes)
        num_test_classes = len(dataset_test.classes)
    
    elif dataset_name in ['celeba']:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        dataset_train = datasets.CelebA(root=dataset_path,
                                        download=True,
                                        transform=transform,
                                        split='train')
        dataset_test = datasets.CelebA(root=dataset_path,
                                       download=True,
                                       transform=transform,
                                       split='test')
        num_channels = 3

        # TODO: revisit this if it's true?
        num_train_classes = 0
        num_test_classes = 0

    elif dataset_name in ['DoubleMNIST']:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: (x * 2) - 1)
        ])
        dataset_train = DoubleMNIST(root=dataset_path,
                           num_classes_per_task=1,
                           transform=transform,
                           target_transform=None,
                           meta_train=True,
                           download=True)
        dataset_test = DoubleMNIST(root=dataset_path,
                           num_classes_per_task=1,
                           transform=transform,
                           target_transform=None,
                           meta_test=True,
                           download=True)
        num_channels = 1

        num_train_classes = dataset_train.dataset.num_classes
        num_test_classes = dataset_test.dataset.num_classes

    return dataset_train, dataset_test, num_channels, num_train_classes, \
           num_test_classes
Beispiel #5
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=args.num_shots,
                                      num_test_per_class=args.num_shots_test)
    class_augmentations = [Rotation([90, 180, 270])]
    if args.dataset == 'sinusoid':
        transform = ToTensor()

        meta_train_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif args.dataset == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(args.folder,
                                      transform=transform,
                                      target_transform=Categorical(
                                          args.num_ways),
                                      num_classes_per_task=args.num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(args.folder,
                                    transform=transform,
                                    target_transform=Categorical(
                                        args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)

        model = ModelConvOmniglot(args.num_ways, hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    elif args.dataset == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_train=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_val=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(args.num_ways,
                                      hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            args.dataset))

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(
        model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=loss_function,
        device=device)

    best_val_accuracy = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))

        if (best_val_accuracy is None) \
                or (best_val_accuracy < results['accuracies_after']):
            best_val_accuracy = results['accuracies_after']
            if args.output_folder is not None:
                with open(args.model_path, 'wb') as f:
                    torch.save(model.state_dict(), f)

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()
Beispiel #6
0
    def generate_batch(self, test):
        '''
        The data-loaders of torch meta are fully compatible with standard data
        components of PyTorch, such as Dataset and DataLoade+r.
        Augments the pool of class candidates with variants, such as rotated images
        '''
        if test == True:
            meta_train = False
            meta_test = True
            f = "metatest"
        elif test == False:
            meta_train = True
            meta_test = False
            f = "metatrain"

        if self.dataset == "miniImageNet":
            dataset = MiniImagenet(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(84), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "tieredImageNet":
            dataset = TieredImagenet(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(84), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "CIFARFS":
            dataset = CIFARFS(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(32), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "FC100":
            dataset = FC100(
                f,
                # Number of waysfrom torchmeta.datasets
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(32), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        if self.dataset == "Omniglot":
            dataset = Omniglot(
                f,
                # Number of ways
                num_classes_per_task=self.N,
                # Resize the images and converts them
                # to PyTorch tensors (from Torchvision)
                transform=Compose([Resize(28), ToTensor()]),
                # Transform the labels to integers
                target_transform=Categorical(num_classes=self.N),
                # Creates new virtual classes with rotated versions
                # of the images (from Santoro et al., 2016)
                class_augmentations=[Rotation([90, 180, 270])],
                meta_train=meta_train,
                meta_test=meta_test,
                download=True)

        dataset = ClassSplitter(dataset,
                                shuffle=True,
                                num_train_per_class=self.K,
                                num_test_per_class=self.num_test_per_class)

        dataloader = BatchMetaDataLoader(dataset,
                                         batch_size=self.batch_size,
                                         num_workers=2)
        return dataloader
Beispiel #7
0
def main(args):

    if args.alg=='MAML':
        model = MAML(args)
    elif args.alg=='Reptile':
        model = Reptile(args)
    elif args.alg=='Neumann':
        model = Neumann(args)
    elif args.alg=='CAVIA':
        model = CAVIA(args)
    elif args.alg=='iMAML':
        model = iMAML(args)
    elif args.alg=='FOMAML':
        model = FOMAML(args)
    else:
        raise ValueError('Not implemented Meta-Learning Algorithm')

    if args.load:
        model.load()
    elif args.load_encoder:
        model.load_encoder()

    train_dataset = Omniglot(args.data_path, num_classes_per_task=args.num_way,
                        meta_split='train', 
                        transform=transforms.Compose([
                        transforms.RandomCrop(80, padding=8),
                        transforms.ToTensor(),
                        ]),
                        target_transform=Categorical(num_classes=args.num_way)
                        )
                        

    train_dataset = ClassSplitter(train_dataset, shuffle=True, num_train_per_class=args.num_shot, num_test_per_class=args.num_query)
    train_loader = BatchMetaDataLoader(train_dataset, batch_size=args.batch_size,
        shuffle=True, pin_memory=True, num_workers=args.num_workers)

    valid_dataset = Omniglot(args.data_path, num_classes_per_task=args.num_way,
                        meta_split='val', 
                        transform=transforms.Compose([
                        transforms.CenterCrop(80),
                        transforms.ToTensor(),
                        ]),
                        target_transform=Categorical(num_classes=args.num_way)
                        )

    valid_dataset = ClassSplitter(valid_dataset, shuffle=True, num_train_per_class=args.num_shot, num_test_per_class=args.num_query)
    valid_loader = BatchMetaDataLoader(valid_dataset, batch_size=args.batch_size,
        shuffle=True, pin_memory=True, num_workers=args.num_workers)

    test_dataset = Omniglot(args.data_path, num_classes_per_task=args.num_way,
                        meta_split='test', 
                        transform=transforms.Compose([
                        transforms.CenterCrop(80),
                        transforms.ToTensor(),
                        ]),
                        target_transform=Categorical(num_classes=args.num_way)
                        )

    test_dataset = ClassSplitter(test_dataset, shuffle=True, num_train_per_class=args.num_shot, num_test_per_class=args.num_query)
    test_loader = BatchMetaDataLoader(test_dataset, batch_size=args.batch_size,
        shuffle=True, pin_memory=True, num_workers=args.num_workers)

    for epoch in range(args.num_epoch):

        res, is_best = run_epoch(epoch, args, model, train_loader, valid_loader, test_loader)

        filename = os.path.join(args.result_path, args.alg, 'omniglot_' '{0}shot_{1}way'.format(args.num_shot, args.num_way)+args.log_path)
        dict2tsv(res, filename)

        if is_best:
            model.save('omniglot_' '{0}shot_{1}way'.format(args.num_shot, args.num_way))
        torch.cuda.empty_cache()

        if args.lr_sched:
            model.lr_sched()

    return None
Beispiel #8
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None,
                          meta_batch_size=1,
                          ensemble_size=0
                          ):
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40], meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size, meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = batch_cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(folder,
                                          transform=transform,
                                          target_transform=Categorical(num_ways),
                                          num_classes_per_task=num_ways,
                                          meta_train=True,
                                          dataset_transform=dataset_transform,
                                          download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(folder,
                                         transform=transform,
                                         target_transform=Categorical(num_ways),
                                         num_classes_per_task=num_ways,
                                         meta_test=True,
                                         dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size, meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = batch_cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Beispiel #9
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None):

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_test=True,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'doublenmnist':
        from torchneuromorphic.doublenmnist_torchmeta.doublenmnist_dataloaders import DoubleNMNIST, Compose, ClassNMNISTDataset, CropDims, Downsample, ToCountFrame, ToTensor, ToEventSum, Repeat, toOneHot
        from torchneuromorphic.utils import plot_frames_imshow
        from matplotlib import pyplot as plt
        from torchmeta.utils.data import CombinationMetaDataset

        root = 'data/nmnist/n_mnist.hdf5'
        chunk_size = 300
        ds = 2
        dt = 1000
        transform = None
        target_transform = None

        size = [2, 32 // ds, 32 // ds]

        transform = Compose([
            CropDims(low_crop=[0, 0], high_crop=[32, 32], dims=[2, 3]),
            Downsample(factor=[dt, 1, ds, ds]),
            ToEventSum(T=chunk_size, size=size),
            ToTensor()
        ])

        if target_transform is None:
            target_transform = Compose(
                [Repeat(chunk_size), toOneHot(num_ways)])

        loss_function = F.cross_entropy

        meta_train_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_train=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                           num_train_per_class=num_shots,
                                           num_test_per_class=num_shots_test)
        meta_val_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_val=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                         num_train_per_class=num_shots,
                                         num_test_per_class=num_shots_test)
        meta_test_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_test=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                          num_train_per_class=num_shots,
                                          num_test_per_class=num_shots_test)

        model = ModelConvDoubleNMNIST(num_ways, hidden_size=hidden_size)

    elif name == 'doublenmnistsequence':
        from torchneuromorphic.doublenmnist_torchmeta.doublenmnist_dataloaders import DoubleNMNIST, Compose, ClassNMNISTDataset, CropDims, Downsample, ToCountFrame, ToTensor, ToEventSum, Repeat, toOneHot
        from torchneuromorphic.utils import plot_frames_imshow
        from matplotlib import pyplot as plt
        from torchmeta.utils.data import CombinationMetaDataset

        root = 'data/nmnist/n_mnist.hdf5'
        chunk_size = 300
        ds = 2
        dt = 1000
        transform = None
        target_transform = None

        size = [2, 32 // ds, 32 // ds]

        transform = Compose([
            CropDims(low_crop=[0, 0], high_crop=[32, 32], dims=[2, 3]),
            Downsample(factor=[dt, 1, ds, ds]),
            ToCountFrame(T=chunk_size, size=size),
            ToTensor()
        ])

        if target_transform is None:
            target_transform = Compose(
                [Repeat(chunk_size), toOneHot(num_ways)])

        loss_function = F.cross_entropy

        meta_train_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_train=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                           num_train_per_class=num_shots,
                                           num_test_per_class=num_shots_test)
        meta_val_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_val=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                         num_train_per_class=num_shots,
                                         num_test_per_class=num_shots_test)
        meta_test_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_test=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                          num_train_per_class=num_shots,
                                          num_test_per_class=num_shots_test)

        model = ModelDECOLLE(num_ways)

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Beispiel #10
0
def setData(**attributes):
    if args.dataset == 'omniglot':
        return Omniglot(**attributes)
    else:
        return MiniImagenet(**attributes)
Beispiel #11
0
def main(args):
    with open(args.config, 'r') as f:
        config = json.load(f)

    if args.folder is not None:
        config['folder'] = args.folder
    if args.num_steps > 0:
        config['num_steps'] = args.num_steps
    if args.num_batches > 0:
        config['num_batches'] = args.num_batches
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    dataset_transform = ClassSplitter(
        shuffle=True,
        num_train_per_class=config['num_shots'],
        num_test_per_class=config['num_shots_test'])
    if config['dataset'] == 'sinusoid':
        transform = ToTensor()
        meta_test_dataset = Sinusoid(config['num_shots'] +
                                     config['num_shots_test'],
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)
        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif config['dataset'] == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])
        meta_test_dataset = Omniglot(config['folder'],
                                     transform=transform,
                                     target_transform=Categorical(
                                         config['num_ways']),
                                     num_classes_per_task=config['num_ways'],
                                     meta_train=True,
                                     dataset_transform=dataset_transform,
                                     download=True)
        model = ModelConvOmniglot(config['num_ways'],
                                  hidden_size=config['hidden_size'])
        loss_function = F.cross_entropy

    elif config['dataset'] == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])
        meta_test_dataset = MiniImagenet(
            config['folder'],
            transform=transform,
            target_transform=Categorical(config['num_ways']),
            num_classes_per_task=config['num_ways'],
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        model = ModelConvMiniImagenet(config['num_ways'],
                                      hidden_size=config['hidden_size'])
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            config['dataset']))

    with open(config['model_path'], 'rb') as f:
        model.load_state_dict(torch.load(f, map_location=device))

    meta_test_dataloader = BatchMetaDataLoader(meta_test_dataset,
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    metalearner = ModelAgnosticMetaLearning(
        model,
        first_order=config['first_order'],
        num_adaptation_steps=config['num_steps'],
        step_size=config['step_size'],
        loss_function=loss_function,
        device=device)

    results = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=config['num_batches'],
                                   verbose=args.verbose,
                                   desc='Test')

    # Save results
    dirname = os.path.dirname(config['model_path'])
    with open(os.path.join(dirname, 'results.json'), 'w') as f:
        json.dump(results, f)
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))

        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    if args.dataset == 'omniglot':
        dataset_transform = ClassSplitter(
            shuffle=True,
            num_train_per_class=args.num_shots,
            num_test_per_class=args.num_shots_test)
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(args.folder,
                                      transform=transform,
                                      target_transform=Categorical(
                                          args.num_ways),
                                      num_classes_per_task=args.num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)

        meta_val_dataset = Omniglot(args.folder,
                                    transform=transform,
                                    target_transform=Categorical(
                                        args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)

        meta_test_dataset = Omniglot(args.folder,
                                     transform=transform,
                                     target_transform=Categorical(
                                         args.num_ways),
                                     num_classes_per_task=args.num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        meta_train_dataloader = BatchMetaDataLoader(
            meta_train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)

        meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True)

        #model = ModelConvOmniglot(args.num_ways, hidden_size=64)
        model = MatchingNetwork(args.num_shots,
                                args.num_ways,
                                args.num_shots_test,
                                fce=True,
                                num_input_channels=1,
                                lstm_layers=1,
                                lstm_input_size=64,
                                unrolling_steps=2,
                                device=device)
        loss_fn = F.nll_loss

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)

    best_value = None

    matching_net_trainer = MatchingNetTrainer(
        args,
        model,
        meta_optimizer,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_fn=loss_fn)
    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        matching_net_trainer.train(meta_train_dataloader,
                                   max_batches=args.num_batches,
                                   verbose=args.verbose,
                                   desc='Training',
                                   leave=False)

        results = matching_net_trainer.evaluate(meta_val_dataloader,
                                                max_batches=args.num_batches,
                                                verbose=args.verbose,
                                                desc=epoch_desc.format(epoch +
                                                                       1))
Beispiel #13
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None):
    """
    Returns a namedtuple with the train/val/test split, model, and loss function
    for the specified task.

    Parameters
    ----------
    name : str
        Name of the dataset to use

    folder : str
        Folder where dataset is stored (or will download to this path if not found)

    num_ways : int
        Number of classes for each task
    
    num_shots : int
        Number of training examples provided per class

    num_shots_test : int
        Number of test examples provided per class (during adaptation)
    """
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'nmltoy2d':
        model_hidden_sizes = [1024, 1024]
        replay_pool_size = 100
        clip_length = 100
        from_beginning = False

        # For validation and testing, we evaluate the outer loss on the entire dataset;
        # for testing, we use smaller batches for efficiency
        meta_train_dataset = NMLToy2D(replay_pool_size=replay_pool_size,
                                      clip_length=clip_length,
                                      from_beginning=from_beginning,
                                      test_strategy='sample',
                                      test_batch_size=10)
        meta_val_dataset = NMLToy2D(replay_pool_size=replay_pool_size,
                                    clip_length=clip_length,
                                    from_beginning=from_beginning,
                                    test_strategy='all')
        meta_test_dataset = NMLToy2D(replay_pool_size=replay_pool_size,
                                     clip_length=clip_length,
                                     from_beginning=from_beginning,
                                     test_strategy='all')

        model = ModelMLPToy2D(model_hidden_sizes)
        loss_function = F.cross_entropy

    elif name == 'noisyduplicates':
        model_hidden_sizes = [2048, 2048]
        locations = [
            ([-2.5, 2.5], 1, 0),  # Single visit (negative)
            ([2.5, 2.5], 10, 0),  # Many visits
            ([-2.5, -2.5], 2, 15),  # A few negatives, mostly positives
            ([2.5,
              -2.5], 8, 15)  # More negatives, but still majority positives
        ]
        noise_std = 0

        meta_train_dataset = NoisyDuplicatesProblem(locations,
                                                    noise_std=noise_std)
        meta_val_dataset = NoisyDuplicatesProblem(locations,
                                                  noise_std=noise_std)
        meta_test_dataset = NoisyDuplicatesProblem(locations,
                                                   noise_std=noise_std)

        model = ModelMLPToy2D(model_hidden_sizes)
        loss_function = F.cross_entropy

    elif name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_test=True,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Beispiel #14
0
def dataset(args, datanames):
    #MiniImagenet
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=args.num_shot,
                                      num_test_per_class=args.num_query)
    transform = Compose([Resize(84), ToTensor()])
    MiniImagenet_train_dataset = MiniImagenet(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_train=True,
        dataset_transform=dataset_transform,
        download=True)

    Imagenet_train_loader = BatchMetaDataLoader(
        MiniImagenet_train_dataset,
        batch_size=args.MiniImagenet_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    MiniImagenet_val_dataset = MiniImagenet(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_val=True,
        dataset_transform=dataset_transform)

    Imagenet_valid_loader = BatchMetaDataLoader(
        MiniImagenet_val_dataset,
        batch_size=args.valid_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    MiniImagenet_test_dataset = MiniImagenet(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_test=True,
        dataset_transform=dataset_transform)

    Imagenet_test_loader = BatchMetaDataLoader(
        MiniImagenet_test_dataset,
        batch_size=args.valid_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    #CIFARFS
    transform = Compose([Resize(84), ToTensor()])
    CIFARFS_train_dataset = CIFARFS(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_train=True,
        dataset_transform=dataset_transform,
        download=True)

    CIFARFS_train_loader = BatchMetaDataLoader(
        CIFARFS_train_dataset,
        batch_size=args.CIFARFS_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    CIFARFS_val_dataset = CIFARFS(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_val=True,
        dataset_transform=dataset_transform)

    CIFARFS_valid_loader = BatchMetaDataLoader(
        CIFARFS_val_dataset,
        batch_size=args.valid_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    CIFARFS_test_dataset = CIFARFS(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_test=True,
        dataset_transform=dataset_transform)
    CIFARFS_test_loader = BatchMetaDataLoader(CIFARFS_test_dataset,
                                              batch_size=args.valid_batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.num_workers)

    #Omniglot
    class_augmentations = [Rotation([90, 180, 270])]
    transform = Compose([Resize(84), ToTensor()])
    Omniglot_train_dataset = Omniglot(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_train=True,
        class_augmentations=class_augmentations,
        dataset_transform=dataset_transform,
        download=True)

    Omniglot_train_loader = BatchMetaDataLoader(
        Omniglot_train_dataset,
        batch_size=args.Omniglot_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    Omniglot_val_dataset = Omniglot(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_val=True,
        class_augmentations=class_augmentations,
        dataset_transform=dataset_transform)

    Omniglot_valid_loader = BatchMetaDataLoader(
        Omniglot_val_dataset,
        batch_size=args.valid_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    Omniglot_test_dataset = Omniglot(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_test=True,
        dataset_transform=dataset_transform)
    Omniglot_test_loader = BatchMetaDataLoader(
        Omniglot_test_dataset,
        batch_size=args.valid_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    #CUB dataset
    transform = None
    CUB_train_dataset = CUBdata(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_train=True,
        dataset_transform=dataset_transform,
        download=False)

    CUB_train_loader = BatchMetaDataLoader(CUB_train_dataset,
                                           batch_size=args.CUB_batch_size,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=args.num_workers)

    CUB_val_dataset = CUBdata(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_val=True,
        dataset_transform=dataset_transform)

    CUB_valid_loader = BatchMetaDataLoader(CUB_val_dataset,
                                           batch_size=args.valid_batch_size,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=args.num_workers)

    CUB_test_dataset = CUBdata(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_test=True,
        dataset_transform=dataset_transform)
    CUB_test_loader = BatchMetaDataLoader(CUB_test_dataset,
                                          batch_size=args.valid_batch_size,
                                          shuffle=True,
                                          pin_memory=True,
                                          num_workers=args.num_workers)

    #Aircraftdata
    transform = None
    Aircraft_train_dataset = Aircraftdata(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_train=True,
        dataset_transform=dataset_transform,
        download=False)

    Aircraft_train_loader = BatchMetaDataLoader(
        Aircraft_train_dataset,
        batch_size=args.Aircraft_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    Aircraft_val_dataset = Aircraftdata(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_val=True,
        dataset_transform=dataset_transform)

    Aircraft_valid_loader = BatchMetaDataLoader(
        Aircraft_val_dataset,
        batch_size=args.valid_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    Aircraft_test_dataset = Aircraftdata(
        args.data_path,
        transform=transform,
        target_transform=Categorical(num_classes=args.num_way),
        num_classes_per_task=args.num_way,
        meta_test=True,
        dataset_transform=dataset_transform)
    Aircraft_test_loader = BatchMetaDataLoader(
        Aircraft_test_dataset,
        batch_size=args.valid_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers)

    train_loader_list = []
    valid_loader_list = []
    test_loader_list = []
    for name in datanames:
        if name == 'MiniImagenet':
            train_loader_list.append({name: Imagenet_train_loader})
            valid_loader_list.append({name: Imagenet_valid_loader})
            test_loader_list.append({name: Imagenet_test_loader})
        if name == 'CIFARFS':
            train_loader_list.append({name: CIFARFS_train_loader})
            valid_loader_list.append({name: CIFARFS_valid_loader})
            test_loader_list.append({name: CIFARFS_test_loader})
        if name == 'CUB':
            train_loader_list.append({name: CUB_train_loader})
            valid_loader_list.append({name: CUB_valid_loader})
            test_loader_list.append({name: CUB_test_loader})
        if name == 'Aircraft':
            train_loader_list.append({name: Aircraft_train_loader})
            valid_loader_list.append({name: Aircraft_valid_loader})
            test_loader_list.append({name: Aircraft_test_loader})
        if name == 'Omniglot':
            train_loader_list.append({name: Omniglot_train_loader})
            valid_loader_list.append({name: Omniglot_valid_loader})
            test_loader_list.append({name: Omniglot_test_loader})

    return train_loader_list, valid_loader_list, test_loader_list