Пример #1
0
def get_miniimagenet_dataloaders_torchmeta(args):
    args.trainin_with_epochs = False
    args.data_path = Path(
        '~/data/').expanduser()  # for some datasets this is enough
    args.criterion = nn.CrossEntropyLoss()
    # args.image_size = 84  # do we need this?
    from torchmeta.datasets.helpers import miniimagenet
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    data_augmentation_transforms = transforms.Compose([
        transforms.RandomResizedCrop(84),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(), normalize
    ])
    dataset_train = miniimagenet(args.data_path,
                                 transform=data_augmentation_transforms,
                                 ways=args.n_classes,
                                 shots=args.k_shots,
                                 test_shots=args.k_eval,
                                 meta_split='train',
                                 download=True)
    dataset_val = miniimagenet(args.data_path,
                               ways=args.n_classes,
                               shots=args.k_shots,
                               test_shots=args.k_eval,
                               meta_split='val',
                               download=True)
    dataset_test = miniimagenet(args.data_path,
                                ways=args.n_classes,
                                shots=args.k_shots,
                                test_shots=args.k_eval,
                                meta_split='test',
                                download=True)

    meta_train_dataloader = BatchMetaDataLoader(
        dataset_train,
        batch_size=args.meta_batch_size_train,
        num_workers=args.num_workers)
    meta_val_dataloader = BatchMetaDataLoader(
        dataset_val,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    meta_test_dataloader = BatchMetaDataLoader(
        dataset_test,
        batch_size=args.meta_batch_size_eval,
        num_workers=args.num_workers)
    return meta_train_dataloader, meta_val_dataloader, meta_test_dataloader
Пример #2
0
def create_miniimagenet_data_loader(
    root,
    meta_split,
    k_way,
    n_shot,
    n_query,
    batch_size,
    num_workers,
    download=False,
    seed=None,
):
    """Create a torchmeta BatchMetaDataLoader for MiniImagenet

    Args:
        root: Path to mini imagenet root folder (containing an 'miniimagenet'` subfolder with the
            preprocess json-Files or downloaded tar.gz-file).
        meta_split: see torchmeta.datasets.MiniImagenet
        k_way: Number of classes per task
        n_shot: Number of samples per class
        n_query: Number of test images per class
        batch_size: Meta batch size
        num_workers: Number of workers for data preprocessing
        download: Download (and dataset specific preprocessing that needs to be done on the
            downloaded files).
        seed: Seed to be used in the meta-dataset

    Returns:
        A torchmeta :class:`BatchMetaDataLoader` object.
    """
    dataset = miniimagenet(
        root,
        n_shot,
        k_way,
        meta_split=meta_split,
        test_shots=n_query,
        download=download,
        seed=seed,
    )
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=batch_size,
                                     num_workers=num_workers,
                                     shuffle=True)
    return dataloader
Пример #3
0
def dataset_f(args, meta_split: Literal['train', 'val', 'test'] = None):
    if meta_split is None:
        meta_split = 'train'
    meta_train = meta_split == 'train'
    meta_val = meta_split == 'val'
    meta_test = meta_split == 'test'
    dataset = args.dataset
    if dataset == 'miniimagenet' and meta_val and args.num_classes > 16:
        args.num_classes = 16
        print(
            'set num classes of mini_imagenet val to 16 because is the maximum'
        )
    dataset_kwargs = dict(
        folder=DATAFOLDER,
        shots=args.support_samples,
        ways=args.num_classes,
        shuffle=True,
        test_shots=args.query_samples,
        seed=args.seed,
        target_transform=Categorical(num_classes=args.num_classes),
        download=True,
        meta_train=meta_train,
        meta_val=meta_val,
        meta_test=meta_test)
    if dataset == 'omniglot':
        return omniglot(
            **dataset_kwargs,
            class_augmentations=[Rotation([90, 180, 270])],
        )
    elif dataset == 'miniimagenet':
        tg.set_dim('NUM_FEATURES', 1600)
        return miniimagenet(**dataset_kwargs)
    elif dataset.upper() == 'CUB':
        if args.support_samples == 0:
            from cub_dataset import CubDatasetEmbeddingsZeroShot
            print('Instantiating CubDatasetEmbeddingsZeroShot')
            return CubDatasetEmbeddingsZeroShot(DATAFOLDER, meta_split,
                                                args.query_samples,
                                                args.num_classes)
        else:
            return cub(**dataset_kwargs)
Пример #4
0
def main():

    parser = argparse.ArgumentParser(description='Data HyperCleaner')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--dataset',
                        type=str,
                        default='omniglot',
                        metavar='N',
                        help='omniglot or miniimagenet')
    parser.add_argument('--hg-mode',
                        type=str,
                        default='CG',
                        metavar='N',
                        help='hypergradient approximation: CG or fixed_point')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args = parser.parse_args()

    log_interval = 100
    eval_interval = 500
    inner_log_interval = None
    inner_log_interval_test = None
    ways = 5
    batch_size = 16
    n_tasks_test = 1000  # usually 1000 tasks are used for testing
    if args.dataset == 'omniglot':
        reg_param = 2  # reg_param = 2
        T, K = 16, 5  # T, K = 16, 5
    elif args.dataset == 'miniimagenet':
        reg_param = 0.5  # reg_param = 0.5
        T, K = 10, 5  # T, K = 10, 5
    else:
        raise NotImplementedError(args.dataset, " not implemented!")

    T_test = T
    inner_lr = .1

    loc = locals()
    del loc['parser']
    del loc['args']

    print(args, '\n', loc, '\n')

    cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

    # the following are for reproducibility on GPU, see https://pytorch.org/docs/master/notes/randomness.html
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    torch.random.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dataset == 'omniglot':
        dataset = omniglot("data",
                           ways=ways,
                           shots=1,
                           test_shots=15,
                           meta_train=True,
                           download=True)
        test_dataset = omniglot("data",
                                ways=ways,
                                shots=1,
                                test_shots=15,
                                meta_test=True,
                                download=True)

        meta_model = get_cnn_omniglot(64, ways).to(device)
    elif args.dataset == 'miniimagenet':
        dataset = miniimagenet("data",
                               ways=ways,
                               shots=1,
                               test_shots=15,
                               meta_train=True,
                               download=True)
        test_dataset = miniimagenet("data",
                                    ways=ways,
                                    shots=1,
                                    test_shots=15,
                                    meta_test=True,
                                    download=True)

        meta_model = get_cnn_miniimagenet(32, ways).to(device)
    else:
        raise NotImplementedError(
            "DATASET NOT IMPLEMENTED! only omniglot and miniimagenet ")

    dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, **kwargs)
    test_dataloader = BatchMetaDataLoader(test_dataset,
                                          batch_size=batch_size,
                                          **kwargs)

    outer_opt = torch.optim.Adam(params=meta_model.parameters())
    # outer_opt = torch.optim.SGD(lr=0.1, params=meta_model.parameters())
    inner_opt_class = hg.GradientDescent
    inner_opt_kwargs = {'step_size': inner_lr}

    def get_inner_opt(train_loss):
        return inner_opt_class(train_loss, **inner_opt_kwargs)

    for k, batch in enumerate(dataloader):
        start_time = time.time()
        meta_model.train()

        tr_xs, tr_ys = batch["train"][0].to(device), batch["train"][1].to(
            device)
        tst_xs, tst_ys = batch["test"][0].to(device), batch["test"][1].to(
            device)

        outer_opt.zero_grad()

        val_loss, val_acc = 0, 0
        forward_time, backward_time = 0, 0
        for t_idx, (tr_x, tr_y, tst_x,
                    tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):
            start_time_task = time.time()

            # single task set up
            task = Task(reg_param,
                        meta_model, (tr_x, tr_y, tst_x, tst_y),
                        batch_size=tr_xs.shape[0])
            inner_opt = get_inner_opt(task.train_loss_f)

            # single task inner loop
            params = [
                p.detach().clone().requires_grad_(True)
                for p in meta_model.parameters()
            ]
            last_param = inner_loop(meta_model.parameters(),
                                    params,
                                    inner_opt,
                                    T,
                                    log_interval=inner_log_interval)[-1]
            forward_time_task = time.time() - start_time_task

            # single task hypergradient computation
            if args.hg_mode == 'CG':
                # This is the approximation used in the paper CG stands for conjugate gradient
                cg_fp_map = hg.GradientDescent(loss_f=task.train_loss_f,
                                               step_size=1.)
                hg.CG(last_param,
                      list(meta_model.parameters()),
                      K=K,
                      fp_map=cg_fp_map,
                      outer_loss=task.val_loss_f)
            elif args.hg_mode == 'fixed_point':
                hg.fixed_point(last_param,
                               list(meta_model.parameters()),
                               K=K,
                               fp_map=inner_opt,
                               outer_loss=task.val_loss_f)

            backward_time_task = time.time(
            ) - start_time_task - forward_time_task

            val_loss += task.val_loss
            val_acc += task.val_acc / task.batch_size

            forward_time += forward_time_task
            backward_time += backward_time_task

        outer_opt.step()
        step_time = time.time() - start_time

        if k % log_interval == 0:
            print(
                'MT k={} ({:.3f}s F: {:.3f}s, B: {:.3f}s) Val Loss: {:.2e}, Val Acc: {:.2f}.'
                .format(k, step_time, forward_time, backward_time, val_loss,
                        100. * val_acc))

        if k % eval_interval == 0:
            test_losses, test_accs = evaluate(
                n_tasks_test,
                test_dataloader,
                meta_model,
                T_test,
                get_inner_opt,
                reg_param,
                log_interval=inner_log_interval_test)

            print(
                "Test loss {:.2e} +- {:.2e}: Test acc: {:.2f} +- {:.2e} (mean +- std over {} tasks)."
                .format(test_losses.mean(), test_losses.std(),
                        100. * test_accs.mean(), 100. * test_accs.std(),
                        len(test_losses)))
Пример #5
0
import csv

from torchmeta.datasets.helpers import miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader
from maml import MAML
from train import adaptation, test
import pickle

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

# dataset

trainset = miniimagenet("data",
                        ways=5,
                        shots=5,
                        test_shots=15,
                        meta_train=True,
                        download=True)
trainloader = BatchMetaDataLoader(trainset,
                                  batch_size=2,
                                  num_workers=4,
                                  shuffle=True)

testset = miniimagenet("data",
                       ways=5,
                       shots=5,
                       test_shots=15,
                       meta_test=True,
                       download=True)
testloader = BatchMetaDataLoader(testset,
                                 batch_size=2,
Пример #6
0
def generate_dataloaders(args):
    # Create dataset
    if args.dataset == "omniglot":
        dataset_train = omniglot(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_train,
            seed=args.seed,
            download=True,  # Only downloads if not in data_path
        )
        dataloader_train = BatchMetaDataLoader(
            dataset_train,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_val = omniglot(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_val,
            seed=args.seed,
            download=True,
        )
        dataloader_val = BatchMetaDataLoader(
            dataset_val,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_test = omniglot(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_test,
            download=True,
        )
        dataloader_test = BatchMetaDataLoader(
            dataset_test,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )
    elif args.dataset == "miniimagenet":
        dataset_train = miniimagenet(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_train,
            seed=args.seed,
            download=True,  # Only downloads if not in data_path
        )
        dataloader_train = BatchMetaDataLoader(
            dataset_train,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_val = miniimagenet(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_val,
            seed=args.seed,
            download=True,
        )
        dataloader_val = BatchMetaDataLoader(
            dataset_val,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )

        dataset_test = miniimagenet(
            folder=args.data_path,
            shots=args.k_shot,
            ways=args.n_way,
            shuffle=True,
            test_shots=args.k_query,
            meta_split=args.base_dataset_test,
            download=True,
        )
        dataloader_test = BatchMetaDataLoader(
            dataset_test,
            batch_size=args.tasks_per_metaupdate,
            shuffle=True,
            num_workers=args.n_workers,
        )
    else:
        raise Exception("Dataset {} not implemented".format(args.dataset))

    return dataloader_train, dataloader_val, dataloader_test
Пример #7
0
from torchmeta.datasets.helpers import miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader
from meta_dataloader import get_meta_loader

data_path = 'few_data/'

dataset = miniimagenet(data_path,
                       ways=5,
                       shots=5,
                       test_shots=15,
                       meta_train=True,
                       download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=1, num_workers=4)

# for i, batch in enumerate(dataloader):
#     train_inputs, train_targets = batch["train"]
#     print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)
#     print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)
#
#     test_inputs, test_targets = batch["test"]
#     print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)
#     print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)
#
#     print(train_targets)
#     if i > 5:
#         break

metadataloader = get_meta_loader(data_path,
                                 'miniimagenet',
                                 ways=5,
                                 shots=5,
Пример #8
0
def load_dataset(args, mode):
    folder = args.folder
    ways = args.num_ways
    shots = args.num_shots
    test_shots = 15
    download = args.download
    shuffle = True

    if mode == 'meta_train':
        args.meta_train = True
        args.meta_val = False
        args.meta_test = False
    elif mode == 'meta_valid':
        args.meta_train = False
        args.meta_val = True
        args.meta_test = False
    elif mode == 'meta_test':
        args.meta_train = False
        args.meta_val = False
        args.meta_test = True

    if args.dataset == 'miniimagenet':
        dataset = miniimagenet(folder=folder,
                               shots=shots,
                               ways=ways,
                               shuffle=shuffle,
                               test_shots=test_shots,
                               meta_train=args.meta_train,
                               meta_val=args.meta_val,
                               meta_test=args.meta_test,
                               download=download)
    elif args.dataset == 'tieredimagenet':
        dataset = tieredimagenet(folder=folder,
                                 shots=shots,
                                 ways=ways,
                                 shuffle=shuffle,
                                 test_shots=test_shots,
                                 meta_train=args.meta_train,
                                 meta_val=args.meta_val,
                                 meta_test=args.meta_test,
                                 download=download)
    elif args.dataset == 'cifar_fs':
        dataset = cifar_fs(folder=folder,
                           shots=shots,
                           ways=ways,
                           shuffle=shuffle,
                           test_shots=test_shots,
                           meta_train=args.meta_train,
                           meta_val=args.meta_val,
                           meta_test=args.meta_test,
                           download=download)
    elif args.dataset == 'fc100':
        dataset = fc100(folder=folder,
                        shots=shots,
                        ways=ways,
                        shuffle=shuffle,
                        test_shots=test_shots,
                        meta_train=args.meta_train,
                        meta_val=args.meta_val,
                        meta_test=args.meta_test,
                        download=download)
    elif args.dataset == 'cub':
        dataset = cub(folder=folder,
                      shots=shots,
                      ways=ways,
                      shuffle=shuffle,
                      test_shots=test_shots,
                      meta_train=args.meta_train,
                      meta_val=args.meta_val,
                      meta_test=args.meta_test,
                      download=download)
    elif args.dataset == 'vgg_flower':
        dataset = vgg_flower(folder=folder,
                             shots=shots,
                             ways=ways,
                             shuffle=shuffle,
                             test_shots=test_shots,
                             meta_train=args.meta_train,
                             meta_val=args.meta_val,
                             meta_test=args.meta_test,
                             download=download)
    elif args.dataset == 'aircraft':
        dataset = aircraft(folder=folder,
                           shots=shots,
                           ways=ways,
                           shuffle=shuffle,
                           test_shots=test_shots,
                           meta_train=args.meta_train,
                           meta_val=args.meta_val,
                           meta_test=args.meta_test,
                           download=download)
    elif args.dataset == 'traffic_sign':
        dataset = traffic_sign(folder=folder,
                               shots=shots,
                               ways=ways,
                               shuffle=shuffle,
                               test_shots=test_shots,
                               meta_train=args.meta_train,
                               meta_val=args.meta_val,
                               meta_test=args.meta_test,
                               download=download)
    elif args.dataset == 'svhn':
        dataset = svhn(folder=folder,
                       shots=shots,
                       ways=ways,
                       shuffle=shuffle,
                       test_shots=test_shots,
                       meta_train=args.meta_train,
                       meta_val=args.meta_val,
                       meta_test=args.meta_test,
                       download=download)
    elif args.dataset == 'cars':
        dataset = cars(folder=folder,
                       shots=shots,
                       ways=ways,
                       shuffle=shuffle,
                       test_shots=test_shots,
                       meta_train=args.meta_train,
                       meta_val=args.meta_val,
                       meta_test=args.meta_test,
                       download=download)

    return dataset
Пример #9
0
def train(args):
    dataset = miniimagenet(args.folder,
                           shots=args.num_shots,
                           ways=args.num_ways,
                           shuffle=True,
                           test_shots=15,
                           meta_train=True,
                           download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(3,
                                       84,
                                       args.num_ways,
                                       hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):

                train_logit = model(train_input)

                inner_loss = F.cross_entropy(train_logit, train_target)
                # writer.add_scalar('Loss/inner_loss', np.random.random(), task_idx)
                grid = torchvision.utils.make_grid(train_input)
                writer.add_image('images', grid, 0)
                writer.add_graph(model, train_input)

                model.zero_grad()
                params = update_parameters(model,
                                           inner_loss,
                                           step_size=args.step_size,
                                           first_order=args.first_order)
                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)
                # writer.add_scalar('Loss/outer_loss', np.random.random(), n_iter)
                for name, grads in model.meta_named_parameters():
                    writer.add_histogram(name, grads, batch_idx)
                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)
                    writer.add_histogram('meta parameters', grads, batch_idx)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()
            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            writer.add_scalar('Accuracy/test', accuracy.item(), batch_idx)
            if batch_idx >= args.num_batches:
                break
            writer.close()

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(
            args.output_folder, 'maml_omniglot_'
            '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Пример #10
0
def train(args):
    dataset = miniimagenet(args.folder,
                           shots=args.num_shots,
                           ways=args.num_ways,
                           shuffle=True,
                           test_shots=15,
                           meta_train=True,
                           download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(3,
                                       args.num_ways,
                                       hidden_size=args.hidden_size,
                                       fc_in_size=32 * 5 * 5,
                                       conv_kernel=[3, 3, 3, 2])
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    max_acc = 0
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input, test_target) in \
                    enumerate(zip(train_inputs, train_targets, test_inputs, test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = update_parameters(model,
                                           inner_loss,
                                           step_size=args.step_size,
                                           first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            max_acc = max(max_acc, accuracy.item())
            if batch_idx >= args.num_batches:
                break

    print('max acc during training is: ', max_acc)
    # Save model
    if args.output_folder is not None:
        filename = os.path.join(
            args.output_folder, 'maml_omniglot_'
            '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)