def load_meta_trainset():
    '''
    dataset = Omniglot("data",
                   # Number of ways
                   num_classes_per_task=5,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
                   target_transform=Categorical(num_classes=5),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   class_augmentations=[Rotation([90, 180, 270])],
                   meta_train=True,
                   download=True)
    '''
    
    trainset = omniglot("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_train=True, download=True)
    trainloader = BatchMetaDataLoader(trainset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
    #trainset = Pascal5i("data", num_classes_per_task=config.n, meta_train=True, download=True)
    #trainloader = BatchMetaDataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0)
    
	#trainset = CIFARFS("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_train=True, download=True)
    #trainloader = BatchMetaDataLoader(trainset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
	
    return trainset, trainloader
Exemplo n.º 2
0
def train(args):
    logger.warning('This script is an example to showcase the extensions and '
                   'data-loading features of Torchmeta, and as such has been '
                   'very lightly tested.')

    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = PrototypicalNetwork(1,
                                args.embedding_size,
                                hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)
            train_embeddings = model(train_inputs)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)
            test_embeddings = model(test_inputs)

            prototypes = get_prototypes(train_embeddings, train_targets,
                dataset.num_classes_per_task)
            loss = prototypical_loss(prototypes, test_embeddings, test_targets)

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                accuracy = get_accuracy(prototypes, test_embeddings, test_targets)
                pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))

            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'protonet_omniglot_'
            '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
def load_meta_testset():
    testset = omniglot("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_test=True, download=True)
    testloader = BatchMetaDataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
    #testset = Pascal5i("data", num_classes_per_task=config.n, meta_test=True, download=True)
    #testloader = BatchMetaDataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0)
	
	#testset = CIFARFS("data", ways=config.n, shots=config.k, test_shots=15, shuffle=False, meta_test=True, download=True)
    #testloader = BatchMetaDataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=0)
    
    return testset, testloader
Exemplo n.º 4
0
	def __init__(self, 
				num_ways,
				num_shots,
				meta_split,
				num_test_shots = 3):
		
		self.num_ways = num_ways
		self.num_shots = num_shots
		self.num_test_shots = num_test_shots
		self.meta_split = meta_split

		self.dataset = omniglot('data', 
								ways = self.num_ways, 
								shots = self.num_shots, 
								test_shots = self.num_test_shots, 
								meta_split = self.meta_split, 
								download = True)
Exemplo n.º 5
0
def test_overflow_length_dataloader():
    folder = os.getenv('TORCHMETA_DATA_FOLDER')
    download = bool(os.getenv('TORCHMETA_DOWNLOAD', False))

    # The number of tasks is C(4112, 20), which exceeds machine precision
    dataset = helpers.omniglot(folder,
                               ways=20,
                               shots=1,
                               test_shots=5,
                               meta_train=True,
                               download=download)

    meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)

    batch = next(iter(meta_dataloader))
    assert isinstance(batch, dict)
    assert 'train' in batch
    assert 'test' in batch
Exemplo n.º 6
0
def dataset_f(args, meta_split: Literal['train', 'val', 'test'] = None):
    if meta_split is None:
        meta_split = 'train'
    meta_train = meta_split == 'train'
    meta_val = meta_split == 'val'
    meta_test = meta_split == 'test'
    dataset = args.dataset
    if dataset == 'miniimagenet' and meta_val and args.num_classes > 16:
        args.num_classes = 16
        print(
            'set num classes of mini_imagenet val to 16 because is the maximum'
        )
    dataset_kwargs = dict(
        folder=DATAFOLDER,
        shots=args.support_samples,
        ways=args.num_classes,
        shuffle=True,
        test_shots=args.query_samples,
        seed=args.seed,
        target_transform=Categorical(num_classes=args.num_classes),
        download=True,
        meta_train=meta_train,
        meta_val=meta_val,
        meta_test=meta_test)
    if dataset == 'omniglot':
        return omniglot(
            **dataset_kwargs,
            class_augmentations=[Rotation([90, 180, 270])],
        )
    elif dataset == 'miniimagenet':
        tg.set_dim('NUM_FEATURES', 1600)
        return miniimagenet(**dataset_kwargs)
    elif dataset.upper() == 'CUB':
        if args.support_samples == 0:
            from cub_dataset import CubDatasetEmbeddingsZeroShot
            print('Instantiating CubDatasetEmbeddingsZeroShot')
            return CubDatasetEmbeddingsZeroShot(DATAFOLDER, meta_split,
                                                args.query_samples,
                                                args.num_classes)
        else:
            return cub(**dataset_kwargs)
Exemplo n.º 7
0
                             use_vinyals_split=False,
                             download=True,
                             transform=trans)
    # meta_test_ds = Omniglot(root_path, num_classes_per_task=5, meta_test=True,
    #                          use_vinyals_split=False, download=True, transform=trans)
    return (meta_train_ds, meta_test_ds)


if __name__ == '__main__':
    # (meta_train_ds, meta_test_ds) = get_omniglot_for_fl()
    from torchmeta.datasets.helpers import omniglot
    from torchmeta.utils.data import BatchMetaDataLoader

    dataset = omniglot(root_path,
                       ways=5,
                       shots=5,
                       test_shots=15,
                       meta_train=True,
                       download=True)
    dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

    for batch in dataloader:
        train_inputs, train_targets = batch["train"]
        print('Train inputs shape: {0}'.format(
            train_inputs.shape))  # (16, 25, 1, 28, 28)
        print('Train targets shape: {0}'.format(
            train_targets.shape))  # (16, 25)

        test_inputs, test_targets = batch["test"]
        print('Test inputs shape: {0}'.format(
            test_inputs.shape))  # (16, 75, 1, 28, 28)
        print('Test targets shape: {0}'.format(test_targets.shape))  # (16, 75)
Exemplo n.º 8
0
def train(args):
    logger.warning(
        'This script is an example to showcase the MetaModule and '
        'data-loading features of Torchmeta, and as such has been '
        'very lightly tested. For a better tested implementation of '
        'Model-Agnostic Meta-Learning (MAML) using Torchmeta with '
        'more features (including multi-step adaptation and '
        'different datasets), please check `https://github.com/'
        'tristandeleu/pytorch-maml`.')

    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(1,
                                       args.num_ways,
                                       hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(
                    model,
                    inner_loss,
                    step_size=args.step_size,
                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(
            args.output_folder, 'maml_omniglot_'
            '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Exemplo n.º 9
0
def main():

    parser = argparse.ArgumentParser(description='Data HyperCleaner')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--dataset',
                        type=str,
                        default='omniglot',
                        metavar='N',
                        help='omniglot or miniimagenet')
    parser.add_argument('--hg-mode',
                        type=str,
                        default='CG',
                        metavar='N',
                        help='hypergradient approximation: CG or fixed_point')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args = parser.parse_args()

    log_interval = 100
    eval_interval = 500
    inner_log_interval = None
    inner_log_interval_test = None
    ways = 5
    batch_size = 16
    n_tasks_test = 1000  # usually 1000 tasks are used for testing
    if args.dataset == 'omniglot':
        reg_param = 2  # reg_param = 2
        T, K = 16, 5  # T, K = 16, 5
    elif args.dataset == 'miniimagenet':
        reg_param = 0.5  # reg_param = 0.5
        T, K = 10, 5  # T, K = 10, 5
    else:
        raise NotImplementedError(args.dataset, " not implemented!")

    T_test = T
    inner_lr = .1

    loc = locals()
    del loc['parser']
    del loc['args']

    print(args, '\n', loc, '\n')

    cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

    # the following are for reproducibility on GPU, see https://pytorch.org/docs/master/notes/randomness.html
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    torch.random.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dataset == 'omniglot':
        dataset = omniglot("data",
                           ways=ways,
                           shots=1,
                           test_shots=15,
                           meta_train=True,
                           download=True)
        test_dataset = omniglot("data",
                                ways=ways,
                                shots=1,
                                test_shots=15,
                                meta_test=True,
                                download=True)

        meta_model = get_cnn_omniglot(64, ways).to(device)
    elif args.dataset == 'miniimagenet':
        dataset = miniimagenet("data",
                               ways=ways,
                               shots=1,
                               test_shots=15,
                               meta_train=True,
                               download=True)
        test_dataset = miniimagenet("data",
                                    ways=ways,
                                    shots=1,
                                    test_shots=15,
                                    meta_test=True,
                                    download=True)

        meta_model = get_cnn_miniimagenet(32, ways).to(device)
    else:
        raise NotImplementedError(
            "DATASET NOT IMPLEMENTED! only omniglot and miniimagenet ")

    dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, **kwargs)
    test_dataloader = BatchMetaDataLoader(test_dataset,
                                          batch_size=batch_size,
                                          **kwargs)

    outer_opt = torch.optim.Adam(params=meta_model.parameters())
    # outer_opt = torch.optim.SGD(lr=0.1, params=meta_model.parameters())
    inner_opt_class = hg.GradientDescent
    inner_opt_kwargs = {'step_size': inner_lr}

    def get_inner_opt(train_loss):
        return inner_opt_class(train_loss, **inner_opt_kwargs)

    for k, batch in enumerate(dataloader):
        start_time = time.time()
        meta_model.train()

        tr_xs, tr_ys = batch["train"][0].to(device), batch["train"][1].to(
            device)
        tst_xs, tst_ys = batch["test"][0].to(device), batch["test"][1].to(
            device)

        outer_opt.zero_grad()

        val_loss, val_acc = 0, 0
        forward_time, backward_time = 0, 0
        for t_idx, (tr_x, tr_y, tst_x,
                    tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):
            start_time_task = time.time()

            # single task set up
            task = Task(reg_param,
                        meta_model, (tr_x, tr_y, tst_x, tst_y),
                        batch_size=tr_xs.shape[0])
            inner_opt = get_inner_opt(task.train_loss_f)

            # single task inner loop
            params = [
                p.detach().clone().requires_grad_(True)
                for p in meta_model.parameters()
            ]
            last_param = inner_loop(meta_model.parameters(),
                                    params,
                                    inner_opt,
                                    T,
                                    log_interval=inner_log_interval)[-1]
            forward_time_task = time.time() - start_time_task

            # single task hypergradient computation
            if args.hg_mode == 'CG':
                # This is the approximation used in the paper CG stands for conjugate gradient
                cg_fp_map = hg.GradientDescent(loss_f=task.train_loss_f,
                                               step_size=1.)
                hg.CG(last_param,
                      list(meta_model.parameters()),
                      K=K,
                      fp_map=cg_fp_map,
                      outer_loss=task.val_loss_f)
            elif args.hg_mode == 'fixed_point':
                hg.fixed_point(last_param,
                               list(meta_model.parameters()),
                               K=K,
                               fp_map=inner_opt,
                               outer_loss=task.val_loss_f)

            backward_time_task = time.time(
            ) - start_time_task - forward_time_task

            val_loss += task.val_loss
            val_acc += task.val_acc / task.batch_size

            forward_time += forward_time_task
            backward_time += backward_time_task

        outer_opt.step()
        step_time = time.time() - start_time

        if k % log_interval == 0:
            print(
                'MT k={} ({:.3f}s F: {:.3f}s, B: {:.3f}s) Val Loss: {:.2e}, Val Acc: {:.2f}.'
                .format(k, step_time, forward_time, backward_time, val_loss,
                        100. * val_acc))

        if k % eval_interval == 0:
            test_losses, test_accs = evaluate(
                n_tasks_test,
                test_dataloader,
                meta_model,
                T_test,
                get_inner_opt,
                reg_param,
                log_interval=inner_log_interval_test)

            print(
                "Test loss {:.2e} +- {:.2e}: Test acc: {:.2f} +- {:.2e} (mean +- std over {} tasks)."
                .format(test_losses.mean(), test_losses.std(),
                        100. * test_accs.mean(), 100. * test_accs.std(),
                        len(test_losses)))
Exemplo n.º 10
0
def generate_dataloaders(args):
    # Create dataset
    if args.dataset == "omniglot":
        dataset_train = omniglot(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_train,
            seed=args.seed,
            download=True,  # Only downloads if not in data_path
        )
        dataloader_train = BatchMetaDataLoader(
            dataset_train,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_val = omniglot(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_val,
            seed=args.seed,
            download=True,
        )
        dataloader_val = BatchMetaDataLoader(
            dataset_val,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_test = omniglot(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_test,
            download=True,
        )
        dataloader_test = BatchMetaDataLoader(
            dataset_test,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )
    elif args.dataset == "miniimagenet":
        dataset_train = miniimagenet(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_train,
            seed=args.seed,
            download=True,  # Only downloads if not in data_path
        )
        dataloader_train = BatchMetaDataLoader(
            dataset_train,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_val = miniimagenet(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_val,
            seed=args.seed,
            download=True,
        )
        dataloader_val = BatchMetaDataLoader(
            dataset_val,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_test = miniimagenet(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_test,
            download=True,
        )
        dataloader_test = BatchMetaDataLoader(
            dataset_test,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )
    else:
        raise Exception("Dataset {} not implemented".format(args.dataset))

    return dataloader_train, dataloader_val, dataloader_test
Exemplo n.º 11
0
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

dataset = omniglot("data",
                   ways=3,
                   shots=5,
                   test_shots=15,
                   meta_train=True,
                   download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

for batch in dataloader:
    train_inputs, train_targets = batch["train"]
    print('Train inputs shape: {0}'.format(
        train_inputs.shape))  # (16, 15, 1, 28, 28)
    print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 15)

    test_inputs, test_targets = batch["test"]
    print('Test inputs shape: {0}'.format(
        test_inputs.shape))  # (16, 45, 1, 28, 28)
    print('Test targets shape: {0}'.format(test_targets.shape))  # (16, 45)
Exemplo n.º 12
0
def train(args):
    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(1,
                                       args.num_ways,
                                       hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = update_parameters(model,
                                           inner_loss,
                                           step_size=args.step_size,
                                           first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(
            args.output_folder, 'maml_omniglot_'
            '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Exemplo n.º 13
0
def train(args):
    logger.warning('This script is an example to showcase the data-loading '
                   'features of Torchmeta in conjunction with using higher to '
                   'make models "unrollable" and optimizers differentiable, '
                   'and as such has been  very lightly tested.')

    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(1,
                                       args.num_ways,
                                       hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    inner_optimiser = torch.optim.SGD(model.parameters(), lr=args.step_size)
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)

            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):
                with higher.innerloop_ctx(
                        model, inner_optimiser,
                        copy_initial_weights=False) as (fmodel, diffopt):
                    train_logit = fmodel(train_input)
                    inner_loss = F.cross_entropy(train_logit, train_target)

                    diffopt.step(inner_loss)

                    test_logit = fmodel(test_input)
                    outer_loss += F.cross_entropy(test_logit, test_target)

                    with torch.no_grad():
                        accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(
            args.output_folder, 'maml_omniglot_'
            '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Exemplo n.º 14
0
    lr=0.001,
    no_cuda=False,
    save_interval=30000,
    seed=22,
    test_N_shots=1,
    test_N_way=5,
    train_N_shots=1,
    train_N_way=5,
    unlabeled_extra=0,
)

train_dataset, test_dataset = omniglot(
    folder=Path(args.folder),
    shots=args.train_N_shots,
    ways=args.train_N_way,
    shuffle=False,
    test_shots=args.test_N_shots,
    meta_split="train",
    download=True,
), omniglot(
    folder=Path(args.folder),
    shots=args.train_N_shots,
    ways=args.train_N_way,
    shuffle=False,
    test_shots=args.test_N_shots,
    meta_split="test",
    download=True,
)


def load_batch_inputs(self, batch, device):