def __init__(self, modelpath=None, device=None):
     if device is None:
         device = 'cpu'
     self.device = device
     self.modelpath = modelpath
     self.model = ModelConvMiniImagenet(10, hidden_size=64)
     if modelpath:
         self.model.load_state_dict(torch.load(self.modelpath))
     self.num_adaptation_steps = 1
     self.loss_function = F.cross_entropy
class MAML():
    def __init__(self, modelpath=None, device=None):
        if device is None:
            device = 'cpu'
        self.device = device
        self.modelpath = modelpath
        self.model = ModelConvMiniImagenet(10, hidden_size=64)
        if modelpath:
            self.model.load_state_dict(torch.load(self.modelpath))
        self.num_adaptation_steps = 1
        self.loss_function = F.cross_entropy

    @property
    def train_transforms(self):
        return [Resize(84), ToTensor()]

    @property
    def inference_transforms(self):
        return [Resize(84), ToTensor()]

    def train(self, dataloader, log_dir=None):
        self.log_dir = log_dir
        params = None
        self.model.train(True)
        for step in range(self.num_adaptation_steps):
            for imgs, labels in dataloader:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                logits = self.model(imgs, params=params)
                inner_loss = self.loss_function(logits, labels)
                self.model.zero_grad()
                params = maml.utils.update_parameters(self.model,
                                                      inner_loss,
                                                      params=None,
                                                      step_size=0.1,
                                                      first_order=True)
                # TODO: Multiple batches did not work (yet). Therefore training batch size needs to be big enough to cover all training samples
                break
        self.model.eval()
        return ModelWrapper(self.model, params), None
Beispiel #3
0
 def test_custom_db(self):
     pil_logger = logging.getLogger('PIL')
     pil_logger.setLevel(logging.INFO)
     input_size = 40
     num_ways = 10
     num_shots = 4
     num_shots_test = 4
     batch_size = 1
     num_workers = 0
     with tempfile.TemporaryDirectory(
     ) as folder, tempfile.NamedTemporaryFile(mode='w+t') as fp:
         tests.datatests.create_random_imagelist(folder, fp, input_size)
         dataset = data.ImagelistMetaDataset(imagelistname=fp.name,
                                             root='',
                                             transform=transforms.Compose([
                                                 transforms.Resize(84),
                                                 transforms.ToTensor()
                                             ]))
         meta_dataset = CombinationMetaDataset(
             dataset,
             num_classes_per_task=num_ways,
             target_transform=Categorical(num_ways),
             dataset_transform=ClassSplitter(
                 shuffle=True,
                 num_train_per_class=num_shots,
                 num_test_per_class=num_shots_test))
         args = ArgWrapper()
         args.output_folder = folder
         args.dataset = None
         benchmark = Benchmark(meta_train_dataset=meta_dataset,
                               meta_val_dataset=meta_dataset,
                               meta_test_dataset=meta_dataset,
                               model=ModelConvMiniImagenet(
                                   args.num_ways,
                                   hidden_size=args.hidden_size),
                               loss_function=F.cross_entropy)
         train.main(args, benchmark)
Beispiel #4
0
def main(args):
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        folder = os.path.join(args.output_folder,
                              time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(folder)
        logging.debug('Creating folder `{0}`'.format(folder))

        args.folder = os.path.abspath(args.folder)
        args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
        # Save the configuration in a config.json file
        with open(os.path.join(folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(folder, 'config.json'))))

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=args.num_shots,
                                      num_test_per_class=args.num_shots_test)
    class_augmentations = [Rotation([90, 180, 270])]
    if args.dataset == 'sinusoid':
        transform = ToTensor()

        meta_train_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(args.num_shots + args.num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif args.dataset == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(args.folder,
                                      transform=transform,
                                      target_transform=Categorical(
                                          args.num_ways),
                                      num_classes_per_task=args.num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(args.folder,
                                    transform=transform,
                                    target_transform=Categorical(
                                        args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)

        model = ModelConvOmniglot(args.num_ways, hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    elif args.dataset == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_train=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(
            args.folder,
            transform=transform,
            target_transform=Categorical(args.num_ways),
            num_classes_per_task=args.num_ways,
            meta_val=True,
            class_augmentations=class_augmentations,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(args.num_ways,
                                      hidden_size=args.hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            args.dataset))

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)
    metalearner = ModelAgnosticMetaLearning(
        model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_steps,
        step_size=args.step_size,
        loss_function=loss_function,
        device=device)

    best_val_accuracy = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    for epoch in range(args.num_epochs):
        metalearner.train(meta_train_dataloader,
                          max_batches=args.num_batches,
                          verbose=args.verbose,
                          desc='Training',
                          leave=False)
        results = metalearner.evaluate(meta_val_dataloader,
                                       max_batches=args.num_batches,
                                       verbose=args.verbose,
                                       desc=epoch_desc.format(epoch + 1))

        if (best_val_accuracy is None) \
                or (best_val_accuracy < results['accuracies_after']):
            best_val_accuracy = results['accuracies_after']
            if args.output_folder is not None:
                with open(args.model_path, 'wb') as f:
                    torch.save(model.state_dict(), f)

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()
Beispiel #5
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None,
                          meta_batch_size=1,
                          ensemble_size=0
                          ):
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40], meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size, meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = batch_cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(folder,
                                          transform=transform,
                                          target_transform=Categorical(num_ways),
                                          num_classes_per_task=num_ways,
                                          meta_train=True,
                                          dataset_transform=dataset_transform,
                                          download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(folder,
                                         transform=transform,
                                         target_transform=Categorical(num_ways),
                                         num_classes_per_task=num_ways,
                                         meta_test=True,
                                         dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size, meta_batch_size=meta_batch_size, ensemble_size=ensemble_size)
        loss_function = batch_cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Beispiel #6
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None):

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_test=True,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'doublenmnist':
        from torchneuromorphic.doublenmnist_torchmeta.doublenmnist_dataloaders import DoubleNMNIST, Compose, ClassNMNISTDataset, CropDims, Downsample, ToCountFrame, ToTensor, ToEventSum, Repeat, toOneHot
        from torchneuromorphic.utils import plot_frames_imshow
        from matplotlib import pyplot as plt
        from torchmeta.utils.data import CombinationMetaDataset

        root = 'data/nmnist/n_mnist.hdf5'
        chunk_size = 300
        ds = 2
        dt = 1000
        transform = None
        target_transform = None

        size = [2, 32 // ds, 32 // ds]

        transform = Compose([
            CropDims(low_crop=[0, 0], high_crop=[32, 32], dims=[2, 3]),
            Downsample(factor=[dt, 1, ds, ds]),
            ToEventSum(T=chunk_size, size=size),
            ToTensor()
        ])

        if target_transform is None:
            target_transform = Compose(
                [Repeat(chunk_size), toOneHot(num_ways)])

        loss_function = F.cross_entropy

        meta_train_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_train=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                           num_train_per_class=num_shots,
                                           num_test_per_class=num_shots_test)
        meta_val_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_val=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                         num_train_per_class=num_shots,
                                         num_test_per_class=num_shots_test)
        meta_test_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_test=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                          num_train_per_class=num_shots,
                                          num_test_per_class=num_shots_test)

        model = ModelConvDoubleNMNIST(num_ways, hidden_size=hidden_size)

    elif name == 'doublenmnistsequence':
        from torchneuromorphic.doublenmnist_torchmeta.doublenmnist_dataloaders import DoubleNMNIST, Compose, ClassNMNISTDataset, CropDims, Downsample, ToCountFrame, ToTensor, ToEventSum, Repeat, toOneHot
        from torchneuromorphic.utils import plot_frames_imshow
        from matplotlib import pyplot as plt
        from torchmeta.utils.data import CombinationMetaDataset

        root = 'data/nmnist/n_mnist.hdf5'
        chunk_size = 300
        ds = 2
        dt = 1000
        transform = None
        target_transform = None

        size = [2, 32 // ds, 32 // ds]

        transform = Compose([
            CropDims(low_crop=[0, 0], high_crop=[32, 32], dims=[2, 3]),
            Downsample(factor=[dt, 1, ds, ds]),
            ToCountFrame(T=chunk_size, size=size),
            ToTensor()
        ])

        if target_transform is None:
            target_transform = Compose(
                [Repeat(chunk_size), toOneHot(num_ways)])

        loss_function = F.cross_entropy

        meta_train_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_train=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                           num_train_per_class=num_shots,
                                           num_test_per_class=num_shots_test)
        meta_val_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_val=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                         num_train_per_class=num_shots,
                                         num_test_per_class=num_shots_test)
        meta_test_dataset = ClassSplitter(DoubleNMNIST(
            root=root,
            meta_test=True,
            transform=transform,
            target_transform=target_transform,
            chunk_size=chunk_size,
            num_classes_per_task=num_ways),
                                          num_train_per_class=num_shots,
                                          num_test_per_class=num_shots_test)

        model = ModelDECOLLE(num_ways)

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)
Beispiel #7
0
def main(args):
    with open(args.config, 'r') as f:
        config = json.load(f)

    if args.folder is not None:
        config['folder'] = args.folder
    if args.num_steps > 0:
        config['num_steps'] = args.num_steps
    if args.num_batches > 0:
        config['num_batches'] = args.num_batches
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    dataset_transform = ClassSplitter(
        shuffle=True,
        num_train_per_class=config['num_shots'],
        num_test_per_class=config['num_shots_test'])
    if config['dataset'] == 'sinusoid':
        transform = ToTensor()
        meta_test_dataset = Sinusoid(config['num_shots'] +
                                     config['num_shots_test'],
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)
        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif config['dataset'] == 'omniglot':
        transform = Compose([Resize(28), ToTensor()])
        meta_test_dataset = Omniglot(config['folder'],
                                     transform=transform,
                                     target_transform=Categorical(
                                         config['num_ways']),
                                     num_classes_per_task=config['num_ways'],
                                     meta_train=True,
                                     dataset_transform=dataset_transform,
                                     download=True)
        model = ModelConvOmniglot(config['num_ways'],
                                  hidden_size=config['hidden_size'])
        loss_function = F.cross_entropy

    elif config['dataset'] == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])
        meta_test_dataset = MiniImagenet(
            config['folder'],
            transform=transform,
            target_transform=Categorical(config['num_ways']),
            num_classes_per_task=config['num_ways'],
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        model = ModelConvMiniImagenet(config['num_ways'],
                                      hidden_size=config['hidden_size'])
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(
            config['dataset']))

    with open(config['model_path'], 'rb') as f:
        model.load_state_dict(torch.load(f, map_location=device))

    meta_test_dataloader = BatchMetaDataLoader(meta_test_dataset,
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    metalearner = ModelAgnosticMetaLearning(
        model,
        first_order=config['first_order'],
        num_adaptation_steps=config['num_steps'],
        step_size=config['step_size'],
        loss_function=loss_function,
        device=device)

    results = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=config['num_batches'],
                                   verbose=args.verbose,
                                   desc='Test')

    # Save results
    dirname = os.path.dirname(config['model_path'])
    with open(os.path.join(dirname, 'results.json'), 'w') as f:
        json.dump(results, f)
Beispiel #8
0
def get_benchmark_by_name(name,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None):
    """
    Returns a namedtuple with the train/val/test split, model, and loss function
    for the specified task.

    Parameters
    ----------
    name : str
        Name of the dataset to use

    folder : str
        Folder where dataset is stored (or will download to this path if not found)

    num_ways : int
        Number of classes for each task
    
    num_shots : int
        Number of training examples provided per class

    num_shots_test : int
        Number of test examples provided per class (during adaptation)
    """
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
    if name == 'nmltoy2d':
        model_hidden_sizes = [1024, 1024]
        replay_pool_size = 100
        clip_length = 100
        from_beginning = False

        # For validation and testing, we evaluate the outer loss on the entire dataset;
        # for testing, we use smaller batches for efficiency
        meta_train_dataset = NMLToy2D(replay_pool_size=replay_pool_size,
                                      clip_length=clip_length,
                                      from_beginning=from_beginning,
                                      test_strategy='sample',
                                      test_batch_size=10)
        meta_val_dataset = NMLToy2D(replay_pool_size=replay_pool_size,
                                    clip_length=clip_length,
                                    from_beginning=from_beginning,
                                    test_strategy='all')
        meta_test_dataset = NMLToy2D(replay_pool_size=replay_pool_size,
                                     clip_length=clip_length,
                                     from_beginning=from_beginning,
                                     test_strategy='all')

        model = ModelMLPToy2D(model_hidden_sizes)
        loss_function = F.cross_entropy

    elif name == 'noisyduplicates':
        model_hidden_sizes = [2048, 2048]
        locations = [
            ([-2.5, 2.5], 1, 0),  # Single visit (negative)
            ([2.5, 2.5], 10, 0),  # Many visits
            ([-2.5, -2.5], 2, 15),  # A few negatives, mostly positives
            ([2.5,
              -2.5], 8, 15)  # More negatives, but still majority positives
        ]
        noise_std = 0

        meta_train_dataset = NoisyDuplicatesProblem(locations,
                                                    noise_std=noise_std)
        meta_val_dataset = NoisyDuplicatesProblem(locations,
                                                  noise_std=noise_std)
        meta_test_dataset = NoisyDuplicatesProblem(locations,
                                                   noise_std=noise_std)

        model = ModelMLPToy2D(model_hidden_sizes)
        loss_function = F.cross_entropy

    elif name == 'sinusoid':
        transform = ToTensor1D()

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_val_dataset = Sinusoid(num_shots + num_shots_test,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif name == 'omniglot':
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif name == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_train=True,
            dataset_transform=dataset_transform,
            download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(
            folder,
            transform=transform,
            target_transform=Categorical(num_ways),
            num_classes_per_task=num_ways,
            meta_test=True,
            dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)