def next(self, n_tasks, mode="train"):
     if mode == 'train':
             train_dataloader = BatchMetaDataLoader(self.train_dataset, batch_size=n_tasks, num_workers=0)
             dataiter = iter(train_dataloader)
             data = dataiter.next()
             x_spts, y_spts = data["train"]
             x_qrys, y_qrys = data["test"]
             data = [x_spts.to(self.device), y_spts.to(self.device), x_qrys.to(self.device), y_qrys.to(self.device)]
             return data
     if mode == 'val':
             val_dataloader = BatchMetaDataLoader(self.val_dataset, batch_size=n_tasks, num_workers=0)
             dataiter = iter(val_dataloader)
             data = dataiter.next()
             x_spts, y_spts = data["train"]
             x_qrys, y_qrys = data["test"]
             data = [x_spts.to(self.device), y_spts.to(self.device), x_qrys.to(self.device), y_qrys.to(self.device)]
             return data
     if mode == 'test':
             test_dataloader = BatchMetaDataLoader(self.test_dataset, batch_size=n_tasks, num_workers=0)
             dataiter = iter(test_dataloader)
             data = dataiter.next()
             x_spts, y_spts = data["train"]
             x_qrys, y_qrys = data["test"]
             data = [x_spts.to(self.device), y_spts.to(self.device), x_qrys.to(self.device), y_qrys.to(self.device)]
             return data
Exemple #2
0
def get_torchmeta_rand_fnn_dataloaders(args):
    # get data
    dataset_train = RandFNN(args.data_path, 'train')
    dataset_val = RandFNN(args.data_path, 'val')
    dataset_test = RandFNN(args.data_path, 'test')
    # get meta-sets
    metaset_train = ClassSplitter(dataset_train,
                                  num_train_per_class=args.k_shots,
                                  num_test_per_class=args.k_eval,
                                  shuffle=True)
    metaset_val = ClassSplitter(dataset_val,
                                num_train_per_class=args.k_shots,
                                num_test_per_class=args.k_eval,
                                shuffle=True)
    metaset_test = ClassSplitter(dataset_test,
                                 num_train_per_class=args.k_shots,
                                 num_test_per_class=args.k_eval,
                                 shuffle=True)
    # get meta-dataloader
    meta_train_dataloader = BatchMetaDataLoader(
        metaset_train,
        batch_size=args.meta_batch_size_train,
        num_workers=args.num_workers)
    meta_val_dataloader = BatchMetaDataLoader(
        metaset_val,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    meta_test_dataloader = BatchMetaDataLoader(
        metaset_test,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    return meta_train_dataloader, meta_val_dataloader, meta_test_dataloader
Exemple #3
0
def get_miniimagenet_dataloaders_torchmeta(args):
    args.trainin_with_epochs = False
    args.data_path = Path(
        '~/data/').expanduser()  # for some datasets this is enough
    args.criterion = nn.CrossEntropyLoss()
    # args.image_size = 84  # do we need this?
    from torchmeta.datasets.helpers import miniimagenet
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    data_augmentation_transforms = transforms.Compose([
        transforms.RandomResizedCrop(84),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(), normalize
    ])
    dataset_train = miniimagenet(args.data_path,
                                 transform=data_augmentation_transforms,
                                 ways=args.n_classes,
                                 shots=args.k_shots,
                                 test_shots=args.k_eval,
                                 meta_split='train',
                                 download=True)
    dataset_val = miniimagenet(args.data_path,
                               ways=args.n_classes,
                               shots=args.k_shots,
                               test_shots=args.k_eval,
                               meta_split='val',
                               download=True)
    dataset_test = miniimagenet(args.data_path,
                                ways=args.n_classes,
                                shots=args.k_shots,
                                test_shots=args.k_eval,
                                meta_split='test',
                                download=True)

    meta_train_dataloader = BatchMetaDataLoader(
        dataset_train,
        batch_size=args.meta_batch_size_train,
        num_workers=args.num_workers)
    meta_val_dataloader = BatchMetaDataLoader(
        dataset_val,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    meta_test_dataloader = BatchMetaDataLoader(
        dataset_test,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    return meta_train_dataloader, meta_val_dataloader, meta_test_dataloader
    def __init__(self,
                 dataset_name,
                 folder,
                 num_shots=5,
                 num_ways=5,
                 test_shots=15,
                 batch_size=20,
                 download=False,
                 dataloader_shuffle=True,
                 num_workers=0,
                 **kwargs):
        datasetnamedict = {
            'omniglot': omniglot,
            'miniimagenet': miniimagenet,
            'tieredimagenet': tieredimagenet
        }
        dataset = datasetnamedict[dataset_name]

        #输入[metatrain,metatest]
        test_shots = [test_shots, test_shots] if isinstance(
            test_shots, int) else test_shots
        batch_size = [batch_size, batch_size] if isinstance(
            batch_size, int) else batch_size

        #meta-train
        self.train_dataset = dataset(folder,
                                     shots=num_shots,
                                     ways=num_ways,
                                     shuffle=dataloader_shuffle,
                                     test_shots=test_shots[0],
                                     meta_train=True,
                                     download=download)
        self.train_dataloader = BatchMetaDataLoader(self.train_dataset,
                                                    batch_size=batch_size[0],
                                                    shuffle=dataloader_shuffle,
                                                    num_workers=num_workers)

        #meta-test
        self.test_dataset = dataset(folder,
                                    shots=num_shots,
                                    ways=num_ways,
                                    shuffle=dataloader_shuffle,
                                    test_shots=test_shots[1],
                                    meta_test=True,
                                    download=download)
        self.test_dataloader = BatchMetaDataLoader(self.test_dataset,
                                                   batch_size=batch_size[1],
                                                   shuffle=dataloader_shuffle,
                                                   num_workers=num_workers)
def load_meta_trainset():
    '''
    dataset = Omniglot("data",
                   # Number of ways
                   num_classes_per_task=5,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
                   target_transform=Categorical(num_classes=5),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   class_augmentations=[Rotation([90, 180, 270])],
                   meta_train=True,
                   download=True)
    '''
    
    trainset = omniglot("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_train=True, download=True)
    trainloader = BatchMetaDataLoader(trainset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
    #trainset = Pascal5i("data", num_classes_per_task=config.n, meta_train=True, download=True)
    #trainloader = BatchMetaDataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0)
    
	#trainset = CIFARFS("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_train=True, download=True)
    #trainloader = BatchMetaDataLoader(trainset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
	
    return trainset, trainloader
Exemple #6
0
def test_datasets_helpers_dataloader(name, shots, split):
    function = getattr(helpers, name)
    folder = os.getenv('TORCHMETA_DATA_FOLDER')
    download = bool(os.getenv('TORCHMETA_DOWNLOAD', False))

    dataset = function(folder,
                       ways=5,
                       shots=shots,
                       test_shots=15,
                       meta_split=split,
                       download=download)

    meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)

    batch = next(iter(meta_dataloader))
    assert isinstance(batch, dict)
    assert 'train' in batch
    assert 'test' in batch

    train_inputs, train_targets = batch['train']
    test_inputs, test_targets = batch['test']

    assert isinstance(train_inputs, torch.Tensor)
    assert isinstance(train_targets, torch.Tensor)
    assert train_inputs.ndim == 5
    assert train_inputs.shape[:2] == (4, 5 * shots)
    assert train_targets.ndim == 2
    assert train_targets.shape[:2] == (4, 5 * shots)

    assert isinstance(test_inputs, torch.Tensor)
    assert isinstance(test_targets, torch.Tensor)
    assert test_inputs.ndim == 5
    assert test_inputs.shape[:2] == (4, 5 * 15)  # test_shots
    assert test_targets.ndim == 2
    assert test_targets.shape[:2] == (4, 5 * 15)
Exemple #7
0
def train(args):
    logger.warning('This script is an example to showcase the extensions and '
                   'data-loading features of Torchmeta, and as such has been '
                   'very lightly tested.')

    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = PrototypicalNetwork(1,
                                args.embedding_size,
                                hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)
            train_embeddings = model(train_inputs)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)
            test_embeddings = model(test_inputs)

            prototypes = get_prototypes(train_embeddings, train_targets,
                dataset.num_classes_per_task)
            loss = prototypical_loss(prototypes, test_embeddings, test_targets)

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                accuracy = get_accuracy(prototypes, test_embeddings, test_targets)
                pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))

            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'protonet_omniglot_'
            '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Exemple #8
0
def get_torchmeta_sinusoid_dataloaders(args):
    # tran = transforms.Compose([torch.tensor])
    # dataset = sinusoid(shots=args.k_eval, test_shots=args.k_shots, transform=tran)
    dataset = sinusoid(shots=args.k_eval, test_shots=args.k_eval)
    meta_train_dataloader = BatchMetaDataLoader(
        dataset,
        batch_size=args.meta_batch_size_train,
        num_workers=args.num_workers)
    meta_val_dataloader = BatchMetaDataLoader(
        dataset,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    meta_test_dataloader = BatchMetaDataLoader(
        dataset,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    return meta_train_dataloader, meta_val_dataloader, meta_test_dataloader
Exemple #9
0
def prepare_dataset(args, transform):
    dataset = get_fewshotsen12msdataset(args.dataset_path, shots=args.num_shots, ways=args.num_ways, transform=transform,
                      target_transform=None,
                      meta_split="train", shuffle=True)

    dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size,
                                     shuffle=False, num_workers=args.num_workers,
                                     sampler=CombinationSubsetRandomSampler(dataset))

    valdataset = get_fewshotsen12msdataset(args.dataset_path, shots=args.num_shots, ways=args.num_ways, transform=transform,
                         target_transform=None,
                         meta_split="val", shuffle=True)

    valdataloader = BatchMetaDataLoader(valdataset, batch_size=args.batch_size,
                                        shuffle=False, num_workers=args.num_workers,
                                        sampler=CombinationSubsetRandomSampler(valdataset))

    return dataloader, valdataloader
def train(args):
    logger.warning(
        'This script is an example to showcase the MetaModule and '
        'data-loading features of Torchmeta, and as such has been '
        'very lightly tested. For a better tested implementation of '
        'Model-Agnostic Meta-Learning (MAML) using Torchmeta with '
        'more features (including multi-step adaptation and '
        'different datasets), please check `https://github.com/'
        'tristandeleu/pytorch-maml`.')

    inputs = [
        "new_models/210429/dataset/BedBathingBaxterHuman-v0217_0-v1-human-coop-robot-coop_10k"
    ]
    env_name = "BedBathingBaxterHuman-v0217_0-v1"
    env = gym.make('assistive_gym:' + env_name)

    dataset = behaviour(inputs, shots=400, test_shots=1)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = PolicyNetwork(env.observation_space_human.shape[0],
                          env.action_space_human.shape[0])
    for key, v in model.features.named_parameters():
        v.data = torch.nn.init.zeros_(v)
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    pbar = tqdm(total=args.num_batches)
    batch_idx = 0
    while batch_idx < args.num_batches:
        for batch in dataloader:
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device).float()
            train_targets = train_targets.to(device=args.device).float()

            loss = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target) in enumerate(
                    zip(train_inputs, train_targets)):
                train_output = model(train_input)
                loss += get_loss(train_output, train_target)

            model.zero_grad()
            loss.div_(len(dataloader))
            loss.backward()
            meta_optimizer.step()

            pbar.update(1)
            pbar.set_postfix(loss='{0:.4f}'.format(loss.item()))
            batch_idx += 1
Exemple #11
0
def main(
    shots=10,
    tasks_per_batch=16,
    num_tasks=160000,
    adapt_lr=0.01,
    meta_lr=0.001,
    adapt_steps=5,
    hidden_dim=32,
):
    # load the dataset
    tasksets = Sinusoid(num_samples_per_task=2 * shots, num_tasks=num_tasks)
    dataloader = BatchMetaDataLoader(tasksets, batch_size=tasks_per_batch)

    # create the model
    model = SineModel(dim=hidden_dim)
    maml = l2l.algorithms.MAML(model,
                               lr=adapt_lr,
                               first_order=False,
                               allow_unused=True)
    opt = optim.Adam(maml.parameters(), meta_lr)
    lossfn = nn.MSELoss(reduction='mean')

    # for each iteration
    for iter, batch in enumerate(dataloader):  # num_tasks/batch_size
        meta_train_loss = 0.0

        # for each task in the batch
        effective_batch_size = batch[0].shape[0]
        for i in range(effective_batch_size):
            learner = maml.clone()

            # divide the data into support and query sets
            train_inputs, train_targets = batch[0][i].float(
            ), batch[1][i].float()
            x_support, y_support = train_inputs[::2], train_targets[::2]
            x_query, y_query = train_inputs[1::2], train_targets[1::2]

            for _ in range(adapt_steps):  # adaptation_steps
                support_preds = learner(x_support)
                support_loss = lossfn(support_preds, y_support)
                learner.adapt(support_loss)

            query_preds = learner(x_query)
            query_loss = lossfn(query_preds, y_query)
            meta_train_loss += query_loss

        meta_train_loss = meta_train_loss / effective_batch_size

        if iter % 200 == 0:
            print('Iteration:', iter, 'Meta Train Loss',
                  meta_train_loss.item())

        opt.zero_grad()
        meta_train_loss.backward()
        opt.step()
Exemple #12
0
def prepare_fewshotdataloader(root, shots, ways, fold, transform, shuffle=True, num_tasks=1,
                              num_workers=0, download=False):
    dataset = get_fewshotsen12msdataset(root, shots=shots, ways=ways, transform=transform,
                      target_transform=None,
                      meta_split=fold, shuffle=shuffle, download=download)

    dataloader = BatchMetaDataLoader(dataset, batch_size=num_tasks,
                                     shuffle=False, num_workers=num_workers,
                                     sampler=CombinationSubsetRandomSampler(dataset))

    return dataloader
def load_meta_testset():
    testset = omniglot("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_test=True, download=True)
    testloader = BatchMetaDataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
    #testset = Pascal5i("data", num_classes_per_task=config.n, meta_test=True, download=True)
    #testloader = BatchMetaDataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0)
	
	#testset = CIFARFS("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_test=True, download=True)
    #testloader = BatchMetaDataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
    return testset, testloader
Exemple #14
0
def test_batch_meta_dataloader():
    dataset = Sinusoid(10, num_tasks=1000, noise_std=None)
    meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)
    assert isinstance(meta_dataloader, DataLoader)
    assert len(meta_dataloader) == 250  # 1000 / 4

    inputs, targets = next(iter(meta_dataloader))
    assert isinstance(inputs, torch.Tensor)
    assert isinstance(targets, torch.Tensor)
    assert inputs.shape == (4, 10, 1)
    assert targets.shape == (4, 10, 1)
Exemple #15
0
def create_og_data_loader(
    root,
    meta_split,
    k_way,
    n_shot,
    input_size,
    n_query,
    batch_size,
    num_workers,
    download=False,
    use_vinyals_split=False,
    seed=None,
):
    """Create a torchmeta BatchMetaDataLoader for Omniglot

    Args:
        root: Path to Omniglot data root folder (containing an 'omniglot'` subfolder with the
            preprocess json-Files or downloaded zip-files).
        meta_split: see torchmeta.datasets.Omniglot
        k_way: Number of classes per task
        n_shot: Number of samples per class
        input_size: Images are resized to this size.
        n_query: Number of test images per class
        batch_size: Meta batch size
        num_workers: Number of workers for data preprocessing
        download: Download (and dataset specific preprocessing that needs to be done on the
            downloaded files).
        use_vinyals_split: see torchmeta.datasets.Omniglot
        seed: Seed to be used in the meta-dataset

    Returns:
        A torchmeta :class:`BatchMetaDataLoader` object.
    """
    dataset = Omniglot(
        root,
        num_classes_per_task=k_way,
        transform=Compose([Resize(input_size), ToTensor()]),
        target_transform=Categorical(num_classes=k_way),
        class_augmentations=[Rotation([90, 180, 270])],
        meta_split=meta_split,
        download=download,
        use_vinyals_split=use_vinyals_split,
    )
    dataset = ClassSplitter(dataset,
                            shuffle=True,
                            num_train_per_class=n_shot,
                            num_test_per_class=n_query)
    dataset.seed = seed
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=batch_size,
                                     num_workers=num_workers,
                                     shuffle=True)
    return dataloader
def main(args):
    test_dataset = dataset_f(args, 'test')
    if args.support_samples > 0:
        test_dloader = BatchMetaDataLoader(test_dataset, args.batch_size)
    else:
        test_dloader = DataLoader(test_dataset, args.batch_size)
    model_checkpoint = load_best_model_checkpoint(args.run_path, args.device)

    zero_shot = args.support_samples == 0
    if zero_shot:
        model = PrototypicalNetworkZeroShot(
            args.distance,
            num_classes=args.num_classes,
            meta_features=args.train_config['metadata_features'],
            img_features=args.train_config['image_features'])
    else:
        input_channels = 1 if args.dataset == 'omniglot' else 3
        model = PrototypicalNetwork(args.distance,
                                    args.num_classes,
                                    input_channels=input_channels)
    model.load_state_dict(model_checkpoint)
    model = model.float().to(args.device)

    evaluator = Trainer(model,
                        None,
                        test_dloader,
                        args.distance,
                        args.run_path,
                        train_epochs=-1,
                        opt=None,
                        device=args.device,
                        eval_steps=args.steps,
                        zero_shot=zero_shot)
    test_results = evaluator.eval()['val']
    print('evaluation done')
    eval_folder: Path = args.run_path / 'evaluation'
    eval_folder.mkdir_p()
    results = dict(num_classes=args.num_classes,
                   support_samples=args.support_samples,
                   query_samples=args.query_samples,
                   batch_size=args.batch_size,
                   steps=args.steps,
                   eval_duration_seconds=test_results.eval_duration_seconds,
                   avg_accuracy=test_results.metrics['accuracy'],
                   avg_loss=test_results.metrics['avg_loss'])
    dst_file = eval_folder / \
               f'num_classes={args.num_classes}_' \
               f'support={args.support_samples}_' \
               f'query={args.query_samples}.json'
    with open(dst_file, 'w') as f:
        json.dump(results, f, indent=2)
    print('saved results to', dst_file)
Exemple #17
0
def main(args):
    with open(args.config, 'r') as f:
        config = json.load(f)

    if args.folder is not None:
        config['folder'] = args.folder
    if args.num_steps > 0:
        config['num_steps'] = args.num_steps
    if args.num_batches > 0:
        config['num_batches'] = args.num_batches
    config_dir = path.dirname(args.config)
    config['folder'] = path.join(config_dir, config['folder'])
    config['model_path'] = path.join(config_dir, config['model_path'])

    device = torch.device('cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    benchmark = get_benchmark_by_name(config['dataset'],
                                      config['folder'],
                                      config['num_ways'],
                                      config['num_shots'],
                                      config['num_shots_test'],
                                      config['no_max_pool'],
                                      hidden_size=config['hidden_size'])

    with open(config['model_path'], 'rb') as f:
        benchmark.model.load_state_dict(torch.load(f, map_location=device))

    meta_test_dataloader = BatchMetaDataLoader(benchmark.meta_test_dataset,
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    metalearner = ModelAgnosticMetaLearning(benchmark.model,
                                            first_order=config['first_order'],
                                            num_adaptation_steps=config['num_steps'],
                                            step_size=config['step_size'],
                                            loss_function=benchmark.loss_function,
                                            device=device)

    results = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=config['num_batches'],
                                   silent=args.silent,
                                   desc='Test')

    # Save results
    dirname = os.path.dirname(config['model_path'])
    with open(os.path.join(dirname, 'results.json'), 'w') as f:
        json.dump(results, f)
Exemple #18
0
def get_dataset_loader(dataset_id,
                       folder,
                       shot,
                       query_size,
                       batch_size,
                       shuffle,
                       train=False,
                       val=False,
                       test=False):
    dataset = get_dataset(dataset_id, folder, shot, query_size, shuffle, train,
                          val, test)
    loader = BatchMetaDataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=shuffle)

    return loader
Exemple #19
0
def get_sine_loader(batch_size, num_steps, shots=10, test_shots=15):
    dataset_transform = ClassSplitter(
        shuffle=True, num_train_per_class=shots, num_test_per_class=test_shots
    )
    transform = ToTensor1D()
    dataset = Sinusoid(
        shots + test_shots,
        num_tasks=batch_size * num_steps,
        transform=transform,
        target_transform=transform,
        dataset_transform=dataset_transform,
    )
    loader = BatchMetaDataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
    )
    return loader
Exemple #20
0
    def helper_meta_imagelist(self, batch_size, input_size, num_shots,
                              num_shots_test, num_ways, num_workers):
        with tempfile.TemporaryDirectory(
        ) as folder, tempfile.NamedTemporaryFile(mode='w+t') as fp:
            create_random_imagelist(folder, fp, input_size, create_image)
            dataset = data.ImagelistMetaDataset(
                imagelistname=fp.name,
                root='',
                transform=transforms.Compose(
                    [transforms.Resize(input_size),
                     transforms.ToTensor()]))
            meta_dataset = CombinationMetaDataset(
                dataset,
                num_classes_per_task=num_ways,
                target_transform=Categorical(num_ways),
                dataset_transform=ClassSplitter(
                    shuffle=True,
                    num_train_per_class=num_shots,
                    num_test_per_class=num_shots_test))
            meta_dataloader = BatchMetaDataLoader(meta_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers,
                                                  pin_memory=True)

            for batch in meta_dataloader:
                batch_data, batch_label = batch['train']
                for img_train, label_train, img_test, label_test in zip(
                        *batch['train'], *batch['test']):
                    classmap = {}
                    for idx in range(img_train.shape[0]):
                        npimg = img_train[idx, ...].detach().numpy()
                        npimg[npimg < 0.001] = 0
                        imagesum = int(np.sum(npimg) * 255)
                        if imagesum not in classmap:
                            classmap[imagesum] = int(label_train[idx])
                        self.assertEqual(classmap[imagesum],
                                         int(label_train[idx]))

                    for idx in range(img_test.shape[0]):
                        npimg = img_test[idx, ...].detach().numpy()
                        npimg[npimg < 0.001] = 0
                        imagesum = int(np.sum(npimg) * 255)
                        self.assertEqual(classmap[imagesum],
                                         int(label_test[idx]),
                                         "Error on {}".format(idx))
Exemple #21
0
def test_overflow_length_dataloader():
    folder = os.getenv('TORCHMETA_DATA_FOLDER')
    download = bool(os.getenv('TORCHMETA_DOWNLOAD', False))

    # The number of tasks is C(4112, 20), which exceeds machine precision
    dataset = helpers.omniglot(folder,
                               ways=20,
                               shots=1,
                               test_shots=5,
                               meta_train=True,
                               download=download)

    meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)

    batch = next(iter(meta_dataloader))
    assert isinstance(batch, dict)
    assert 'train' in batch
    assert 'test' in batch
Exemple #22
0
def create_miniimagenet_data_loader(
    root,
    meta_split,
    k_way,
    n_shot,
    n_query,
    batch_size,
    num_workers,
    download=False,
    seed=None,
):
    """Create a torchmeta BatchMetaDataLoader for MiniImagenet

    Args:
        root: Path to mini imagenet root folder (containing an 'miniimagenet'` subfolder with the
            preprocess json-Files or downloaded tar.gz-file).
        meta_split: see torchmeta.datasets.MiniImagenet
        k_way: Number of classes per task
        n_shot: Number of samples per class
        n_query: Number of test images per class
        batch_size: Meta batch size
        num_workers: Number of workers for data preprocessing
        download: Download (and dataset specific preprocessing that needs to be done on the
            downloaded files).
        seed: Seed to be used in the meta-dataset

    Returns:
        A torchmeta :class:`BatchMetaDataLoader` object.
    """
    dataset = miniimagenet(
        root,
        n_shot,
        k_way,
        meta_split=meta_split,
        test_shots=n_query,
        download=download,
        seed=seed,
    )
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=batch_size,
                                     num_workers=num_workers,
                                     shuffle=True)
    return dataloader
Exemple #23
0
def test_batch_meta_dataloader_splitter():
    dataset = Sinusoid(20, num_tasks=1000, noise_std=None)
    dataset = ClassSplitter(dataset, num_train_per_class=5,
        num_test_per_class=15)
    meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)

    batch = next(iter(meta_dataloader))
    assert isinstance(batch, dict)
    assert 'train' in batch
    assert 'test' in batch

    train_inputs, train_targets = batch['train']
    test_inputs, test_targets = batch['test']
    assert isinstance(train_inputs, torch.Tensor)
    assert isinstance(train_targets, torch.Tensor)
    assert train_inputs.shape == (4, 5, 1)
    assert train_targets.shape == (4, 5, 1)
    assert isinstance(test_inputs, torch.Tensor)
    assert isinstance(test_targets, torch.Tensor)
    assert test_inputs.shape == (4, 15, 1)
    assert test_targets.shape == (4, 15, 1)
    # meta_test_ds = Omniglot(root_path, num_classes_per_task=5, meta_test=True,
    #                          use_vinyals_split=False, download=True, transform=trans)
    return (meta_train_ds, meta_test_ds)


if __name__ == '__main__':
    # (meta_train_ds, meta_test_ds) = get_omniglot_for_fl()
    from torchmeta.datasets.helpers import omniglot
    from torchmeta.utils.data import BatchMetaDataLoader

    dataset = omniglot(root_path,
                       ways=5,
                       shots=5,
                       test_shots=15,
                       meta_train=True,
                       download=True)
    dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

    for batch in dataloader:
        train_inputs, train_targets = batch["train"]
        print('Train inputs shape: {0}'.format(
            train_inputs.shape))  # (16, 25, 1, 28, 28)
        print('Train targets shape: {0}'.format(
            train_targets.shape))  # (16, 25)

        test_inputs, test_targets = batch["test"]
        print('Test inputs shape: {0}'.format(
            test_inputs.shape))  # (16, 75, 1, 28, 28)
        print('Test targets shape: {0}'.format(test_targets.shape))  # (16, 75)
    g = 5
Exemple #25
0
def main(args, mode, iteration=None):
    dataset = load_dataset(args, mode)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model.to(device=args.device)
    model.train()

    # To control outer update parameter
    # If you want to control inner update parameter, please see update_parameters function in ./maml/utils.py
    freeze_params = [
        p for name, p in model.named_parameters() if 'classifier' in name
    ]
    learnable_params = [
        p for name, p in model.named_parameters() if 'classifier' not in name
    ]
    if args.outer_fix:
        meta_optimizer = torch.optim.Adam([{
            'params': freeze_params,
            'lr': 0
        }, {
            'params': learnable_params,
            'lr': args.meta_lr
        }])
    else:
        meta_optimizer = torch.optim.Adam([{
            'params': freeze_params,
            'lr': args.meta_lr
        }, {
            'params': learnable_params,
            'lr': args.meta_lr
        }])

    if args.meta_train:
        total = args.train_batches
    elif args.meta_val:
        total = args.valid_batches
    elif args.meta_test:
        total = args.test_batches

    loss_logs, accuracy_logs = [], []

    # Training loop
    with tqdm(dataloader, total=total, leave=False) as pbar:
        for batch_idx, batch in enumerate(pbar):
            if args.centering:
                fc_weight_mean = torch.mean(model.classifier.weight.data,
                                            dim=0)
                model.classifier.weight.data -= fc_weight_mean

            model.zero_grad()

            support_inputs, support_targets = batch['train']
            support_inputs = support_inputs.to(device=args.device)
            support_targets = support_targets.to(device=args.device)

            query_inputs, query_targets = batch['test']
            query_inputs = query_inputs.to(device=args.device)
            query_targets = query_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)

            for task_idx, (support_input, support_target, query_input,
                           query_target) in enumerate(
                               zip(support_inputs, support_targets,
                                   query_inputs, query_targets)):
                support_features, support_logit = model(support_input)
                inner_loss = F.cross_entropy(support_logit, support_target)

                model.zero_grad()

                params = update_parameters(
                    model,
                    inner_loss,
                    extractor_step_size=args.extractor_step_size,
                    classifier_step_size=args.classifier_step_size,
                    first_order=args.first_order)

                query_features, query_logit = model(query_input, params=params)
                outer_loss += F.cross_entropy(query_logit, query_target)

                with torch.no_grad():
                    accuracy += get_accuracy(query_logit, query_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)
            loss_logs.append(outer_loss.item())
            accuracy_logs.append(accuracy.item())

            if args.meta_train:
                outer_loss.backward()
                meta_optimizer.step()

            postfix = {
                'mode': mode,
                'iter': iteration,
                'acc': round(accuracy.item(), 5)
            }
            pbar.set_postfix(postfix)
            if batch_idx + 1 == total:
                break

    # Save model
    if args.meta_train:
        filename = os.path.join(args.output_folder,
                                args.dataset + '_' + args.save_name, 'models',
                                'epochs_{}.pt'.format((iteration + 1) * total))
        if (iteration + 1) * total % 5000 == 0:
            with open(filename, 'wb') as f:
                state_dict = model.state_dict()
                torch.save(state_dict, f)

    return loss_logs, accuracy_logs
Exemple #26
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=args.num_shots,
                                      num_test_per_class=args.num_shots_test)
    class_augmentations = [Rotation([90, 180, 270])]
    if args.dataset == 'sinusoid':
        transform = ToTensor()

        meta_train_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif args.dataset == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(args.folder,
                                      transform=transform,
                                      target_transform=Categorical(
                                          args.num_ways),
                                      num_classes_per_task=args.num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(args.folder,
                                    transform=transform,
                                    target_transform=Categorical(
                                        args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)

        model = ModelConvOmniglot(args.num_ways, hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    elif args.dataset == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_train=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_val=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(args.num_ways,
                                      hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            args.dataset))

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(
        model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=loss_function,
        device=device)

    best_val_accuracy = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))

        if (best_val_accuracy is None) \
                or (best_val_accuracy < results['accuracies_after']):
            best_val_accuracy = results['accuracies_after']
            if args.output_folder is not None:
                with open(args.model_path, 'wb') as f:
                    torch.save(model.state_dict(), f)

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()
Exemple #27
0
def main(
    shots=10,
    tasks_per_batch=16,
    num_tasks=16000,
    num_test_tasks=32,
    adapt_lr=0.01,
    meta_lr=0.001,
    adapt_steps=5,
    hidden_dim=32,
):
    exp_name = input(
        "Enter Experiment NAME (should be unique, else overwrites): ")
    EXPERIMENT_DIR = "./experiments/MAML_Sine_exps" + exp_name
    if os.path.isdir(EXPERIMENT_DIR):
        print("Experiment folder opened ...")
    else:
        os.mkdir(EXPERIMENT_DIR)
        print("New Experiment folder started ...")

    MODEL_CHECKPOINT_DIR = EXPERIMENT_DIR + "/model"
    if os.path.isdir(MODEL_CHECKPOINT_DIR):
        print("Model Checkpoint folder opened ...")
    else:
        os.mkdir(MODEL_CHECKPOINT_DIR)
        print("New Model checkpoint folder made ...")

    PLOT_RESULTS_DIR = EXPERIMENT_DIR + "/plot_results"
    if os.path.isdir(PLOT_RESULTS_DIR):
        print("Image results folder opened ...")
    else:
        os.mkdir(PLOT_RESULTS_DIR)
        print("New Image results folder made ...")

    # load the dataset
    tasksets = Sinusoid(num_samples_per_task=2 * shots, num_tasks=num_tasks)
    dataloader = BatchMetaDataLoader(tasksets, batch_size=tasks_per_batch)

    # create the model
    model = SineModel(dim=hidden_dim, experiment_dir=EXPERIMENT_DIR)
    maml = l2l.algorithms.MAML(model,
                               lr=adapt_lr,
                               first_order=False,
                               allow_unused=True)
    opt = optim.Adam(maml.parameters(), meta_lr)
    lossfn = nn.MSELoss(reduction='mean')

    # for each iteration
    for iter, batch in enumerate(dataloader):  # num_tasks/batch_size
        meta_train_loss = 0.0

        # for each task in the batch
        effective_batch_size = batch[0].shape[0]
        for i in range(effective_batch_size):
            learner = maml.clone()

            # divide the data into support and query sets
            train_inputs, train_targets = batch[0][i].float(
            ), batch[1][i].float()
            x_support, y_support = train_inputs[::2], train_targets[::2]
            x_query, y_query = train_inputs[1::2], train_targets[1::2]

            for _ in range(adapt_steps):  # adaptation_steps
                support_preds = learner(x_support)
                support_loss = lossfn(support_preds, y_support)
                learner.adapt(support_loss)

            query_preds = learner(x_query)
            query_loss = lossfn(query_preds, y_query)
            meta_train_loss += query_loss

        meta_train_loss = meta_train_loss / effective_batch_size

        opt.zero_grad()
        meta_train_loss.backward()
        opt.step()

        if iter % 100 == 0:
            print('Iteration:', iter, 'Meta Train Loss',
                  meta_train_loss.item())
            # print(x_query.requires_grad, y_query.requires_grad,query_preds.requires_grad,meta_train_loss.item())
            plotter(x_query, y_query,
                    query_preds.detach().numpy(), iter, 'Train',
                    meta_train_loss.item(), model.plot_results)

    #save current model
    model.save_checkpoint()

    #meta-testing
    test_tasks = Sinusoid(num_samples_per_task=shots, num_tasks=num_test_tasks)
    test_dataloader = BatchMetaDataLoader(test_tasks,
                                          batch_size=tasks_per_batch)

    #load learned model
    test_model = SineModel(dim=hidden_dim, experiment_dir=EXPERIMENT_DIR)
    test_model.load_checkpoint()

    for iter, batch in enumerate(test_dataloader):
        meta_test_loss = 0.0

        # for each task in the batch
        effective_batch_size = batch[0].shape[0]
        for i in range(effective_batch_size):
            learner = maml.clone()

            # divide the data into support and query sets
            test_inputs, test_targets = batch[0][i].float(), batch[1][i].float(
            )

            test_preds = test_model(test_inputs)
            test_loss = lossfn(test_preds, test_targets)
            meta_test_loss += test_loss

        meta_test_loss = meta_test_loss / effective_batch_size

        if iter % 20 == 0:
            print('Iteration:', iter, 'Meta Test Loss', meta_test_loss.item())
            plotter(test_inputs, test_targets,
                    test_preds.detach().numpy(), iter, 'Test',
                    meta_test_loss.item(), test_model.plot_results)
Exemple #28
0
    def do_train(self, train_data, test_data, batch_size=0, accumulation_steps=1, num_epochs=1, test_strategy='all', 
                test_batch_size=0, include_query_point=True, mixup_alpha=None, verbose=True):
        """
        Performs `num_epochs` training iterations of meta-NML on a batch of data.

        Parameters
        ----------
        train_data : tuple[np.array, np.array]
            Inputs and labels to sample from for adaptation (inner loop).
            Each task will consist of one of these inputs with a proposed class label.

        test_data : tuple[np.array, np.array]
            Inputs and labels to use for the test loss (outer loop).
            The loss will be computed over ALL of these points for each task.

        batch_size : int
            Number of tasks to use per batch for meta-learning

        num_epochs : int
            Number of passes through the entire `train_data` dataset

        test_strategy : str in {'all', 'sample', 'cycle'}
            The strategy to use for evaluating the test loss across tasks in a batch.
            By default, we use all the points in `test_data`.

        test_batch_size : int, optional
            Number of points to use in a test batch. 
            Only used if test_strategy is 'sample' or 'cycle'.

        include_query_point : bool, optional
            Whether to include the downweighted query point in every batch during testing. Only used
            if the test_strategy is 'sample' or 'cycle'. Default: True

        mixup_alpha : float, optional
            Alpha parameter to use for mixup (https://arxiv.org/pdf/1710.09412.pdf)
            (only affects the test set, not the tasks themselves).

        verbose : bool
            Whether to show the MAML training progress bar for each epoch. Default: True
        """
        batch_size = batch_size or len(train_data[0])
        self.model.train()

        if not self._do_metalearning:
            # Do standard MLE training on the meta-test set
            ds = TensorDataset(torch.Tensor(test_data[0]), torch.Tensor(test_data[1]).long())
            loader = DataLoader(ds, batch_size=64, shuffle=True)
            epoch_results = []
            for _ in range(num_epochs):
                all_losses = []
                for inputs, labels in loader:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    logits = self.model(inputs)
                    loss = F.cross_entropy(logits, labels)
                    self.meta_optimizer.zero_grad()
                    loss.backward()
                    self.meta_optimizer.step()
                    all_losses.append(loss.item())
                epoch_results.append({
                    'mean_loss': np.array(all_losses).mean(),
                    'all_losses': all_losses
                })
            return epoch_results    

        epoch_results = []
        for _ in range(num_epochs):
            if self.embedding_type == 'features':
                train_features = self.model.embedding(torch.Tensor(train_data[0]).cuda()).cpu().detach()
                test_features = self.model.embedding(torch.Tensor(test_data[0]).cuda()).cpu().detach()
            elif self.embedding_type == 'vae':
                train_features = self.model_vae(torch.Tensor(train_data[0]).cuda())[1].cpu().detach()
                test_features = self.model_vae(torch.Tensor(test_data[0]).cuda())[1].cpu().detach()
            elif self.embedding_type == 'custom':
                train_features = train_data[2]
                test_features = test_data[2]
            else:
                train_features = train_data[0]
                test_features = test_data[0]
            train_data = (train_data[0], train_data[1], train_features)
            test_data = (test_data[0], test_data[1], test_features)
            dataset = NMLDataset(train_data, test_data, mixup_alpha=mixup_alpha,
                                 points_per_task=self.points_per_task, num_classes=self.num_classes,
                                 test_strategy=test_strategy,
                                 test_batch_size=test_batch_size, include_query_point=include_query_point,
                                 dist_weight_thresh=self.dist_weight_thresh, equal_pos_neg_test=self.equal_pos_neg_test,
                                 query_point_weight=self.query_point_weight, kernel=self.kernel)
            trainloader = BatchMetaDataLoader(dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=self.num_workers, pin_memory=True)
            num_batches = ceil(dataset.num_tasks / batch_size)
            results = self.metalearner.train(trainloader, accumulation_steps=accumulation_steps, max_batches=num_batches, is_classification_task=True,
                verbose=verbose, desc='Training', leave=False)
            epoch_results.append(results)

        return epoch_results
Exemple #29
0
        global_task_count = 0

    # save the args into .json file
    with open(os.path.join(args.record_folder, 'args.json'), 'w') as f:
        json.dump(vars(args), f)

    # get datasets and dataloaders
    train_dataset = get_dataset(args,
                                dataset_name=args.train_data,
                                phase='train')
    val_dataset = get_dataset(args, dataset_name=args.test_data, phase='val')
    test_dataset = get_dataset(args, dataset_name=args.test_data, phase='test')

    train_loader = BatchMetaDataLoader(train_dataset,
                                       batch_size=args.batch_tasks,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

    val_loader = BatchMetaDataLoader(val_dataset,
                                     batch_size=args.batch_tasks,
                                     shuffle=False,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

    test_loader = BatchMetaDataLoader(test_dataset,
                                      batch_size=args.batch_tasks,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
def main(args):

    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    # Create output folder
    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        output_folder = os.path.join(args.output_folder,
                                     time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(output_folder)
        logging.debug('Creating folder `{0}`'.format(output_folder))

        args.datafolder = os.path.abspath(args.datafolder)
        args.model_path = os.path.abspath(
            os.path.join(output_folder, 'model.th'))

        # Save the configuration in a config.json file
        with open(os.path.join(output_folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(output_folder, 'config.json'))))

    # Get datasets and load into meta learning format
    meta_train_dataset, meta_val_dataset, _ = get_datasets(
        args.dataset,
        args.datafolder,
        args.num_ways,
        args.num_shots,
        args.num_shots_test,
        augment=augment,
        fold=args.fold,
        download=download_data)

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)

    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    # Define model
    model = Unet(device=device, feature_scale=args.feature_scale)
    model = model.to(device)
    print(f'Using device: {device}')

    # Define optimizer
    meta_optimizer = torch.optim.Adam(model.parameters(),
                                      lr=args.meta_lr)  #, weight_decay=1e-5)
    #meta_optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, momentum = 0.99)

    # Define meta learner
    metalearner = ModelAgnosticMetaLearning(
        model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_adaption_steps,
        step_size=args.step_size,
        learn_step_size=False,
        loss_function=loss_function,
        device=device)

    best_value = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    train_losses = []
    val_losses = []
    train_ious = []
    train_accuracies = []
    val_accuracies = []
    val_ious = []

    start_time = time.time()

    for epoch in range(args.num_epochs):
        print('start epoch ', epoch + 1)
        print('start train---------------------------------------------------')
        train_loss, train_accuracy, train_iou = metalearner.train(
            meta_train_dataloader,
            max_batches=args.num_batches,
            verbose=args.verbose,
            desc='Training',
            leave=False)
        print(f'\n train accuracy: {train_accuracy}, train loss: {train_loss}')
        print('end train---------------------------------------------------')
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        train_ious.append(train_iou)

        # Evaluate in given intervals
        if epoch % args.val_step_size == 0:
            print(
                'start evaluate-------------------------------------------------'
            )
            results = metalearner.evaluate(meta_val_dataloader,
                                           max_batches=args.num_batches,
                                           verbose=args.verbose,
                                           desc=epoch_desc.format(epoch + 1),
                                           is_test=False)
            val_acc = results['accuracy']
            val_loss = results['mean_outer_loss']
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)
            val_ious.append(results['iou'])
            print(
                f'\n validation accuracy: {val_acc}, validation loss: {val_loss}'
            )
            print(
                'end evaluate-------------------------------------------------'
            )

            # Save best model
            if 'accuracies_after' in results:
                if (best_value is None) or (best_value <
                                            results['accuracies_after']):
                    best_value = results['accuracies_after']
                    save_model = True
            elif (best_value is None) or (best_value >
                                          results['mean_outer_loss']):
                best_value = results['mean_outer_loss']
                save_model = True
            else:
                save_model = False

            if save_model and (args.output_folder is not None):
                with open(args.model_path, 'wb') as f:
                    torch.save(model.state_dict(), f)

        print('end epoch ', epoch + 1)

    elapsed_time = time.time() - start_time
    print('Finished after ',
          time.strftime('%H:%M:%S', time.gmtime(elapsed_time)))

    r = {}
    r['train_losses'] = train_losses
    r['train_accuracies'] = train_accuracies
    r['train_ious'] = train_ious
    r['val_losses'] = val_losses
    r['val_accuracies'] = val_accuracies
    r['val_ious'] = val_ious
    r['time'] = time.strftime('%H:%M:%S', time.gmtime(elapsed_time))
    with open(os.path.join(output_folder, 'train_results.json'), 'w') as g:
        json.dump(r, g)
        logging.info('Saving results dict in `{0}`'.format(
            os.path.abspath(os.path.join(output_folder,
                                         'train_results.json'))))

    # Plot results
    plot_errors(args.num_epochs,
                train_losses,
                val_losses,
                val_step_size=args.val_step_size,
                output_folder=output_folder,
                save=True,
                bce_dice_focal=bce_dice_focal)
    plot_accuracy(args.num_epochs,
                  train_accuracies,
                  val_accuracies,
                  val_step_size=args.val_step_size,
                  output_folder=output_folder,
                  save=True)
    plot_iou(args.num_epochs,
             train_ious,
             val_ious,
             val_step_size=args.val_step_size,
             output_folder=output_folder,
             save=True)

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()