def test_evaluation_sampler(self):
        episodes_per_epoch = 10
        n_test = 1
        k_test = 5
        q_test = 1

        evaluation_sampler = NShotTaskSampler(self.evaluation,
                                              episodes_per_epoch,
                                              n_test, k_test, q_test)

        evaluation_loader = DataLoader(
            self.evaluation,
            batch_sampler=evaluation_sampler,
            num_workers=4
        )

        prepare_batch = prepare_nshot_task(n_test, k_test, q_test)

        for batch_index, batch in enumerate(evaluation_loader):
            x, y = prepare_batch(batch)

            loss, y_pred = dummy_fit_function(
                dummy_model,
                torch.nn.NLLLoss().to(device),
                x.to(device),
                y.to(device),
                n_shot=n_test,
                k_way=k_test,
                q_queries=q_test,
                train=False,
            )
    def test_background_sampler(self):
        episodes_per_epoch = 10
        n_train = 5
        k_train = 15
        q_train = self.background_class_count.min() - n_train

        background_sampler = NShotTaskSampler(self.background,
                                              episodes_per_epoch,
                                              n_train, k_train, q_train)

        background_loader = DataLoader(
            self.background,
            batch_sampler=background_sampler,
            num_workers=4
        )

        prepare_batch = prepare_nshot_task(n_train, k_train, q_train)

        for batch_index, batch in enumerate(background_loader):
            x, y = prepare_batch(batch)

            loss, y_pred = dummy_fit_function(
                dummy_model,
                torch.nn.NLLLoss().to(device),
                x.to(device),
                y.to(device),
                n_shot=n_train,
                k_way=k_train,
                q_queries=q_train,
                train=False,
            )
Beispiel #3
0
    def __init__(self, dataset: str, num_tasks: int, n_shot: int, k_way: int,
                 q_queries: int, distance_metric: str,
                 open_world_testing: bool):
        self.dataset_name = dataset
        self.num_tasks = num_tasks
        self.n_shot = n_shot
        self.k_way = k_way
        self.q_queries = q_queries
        self.episodes_per_epoch = 10
        self.distance_metric = 'l2'
        self.open_world_testing = open_world_testing
        self.prepare_batch = prepare_nshot_task(self.n_shot, self.k_way,
                                                self.q_queries)

        self.num_different_models = 0

        if dataset == 'whoas':
            self.evaluation_dataset = Whoas('evaluation')
            #self.evaluation_dataset = Whoas('background')
        elif dataset == 'kaggle':
            self.evaluation_dataset = Kaggle('evaluation')
        elif dataset == 'miniImageNet':
            self.evaluation_dataset = MiniImageNet('evaluation')
        else:
            raise (ValueError, 'Unsupported dataset')

        self.batch_sampler = NShotCustomTaskSampler(self.evaluation_dataset,
                                                    self.episodes_per_epoch,
                                                    n_shot, k_way, q_queries,
                                                    num_tasks, None,
                                                    open_world_testing)
        self.evaluation_taskloader = DataLoader(
            self.evaluation_dataset, batch_sampler=self.batch_sampler)

        assert torch.cuda.is_available()
        self.device = torch.device('cuda')
        torch.backends.cudnn.benchmark = True

        self.model = get_few_shot_encoder(
            self.evaluation_dataset.num_input_channels)
        self.model.to(self.device, dtype=torch.double)

        self.optimiser = Adam(self.model.parameters(), lr=1e-3)
        self.loss_fn = torch.nn.NLLLoss().cuda()
    # Drop lr every 2000 episodes
    if epoch % drop_lr_every == 0:
        return lr / 2
    else:
        return lr


callbacks = [
    EvaluateFewShot(
        eval_fn=proto_net_episode,
        num_tasks=evaluation_episodes,
        n_shot=args.n_test,
        k_way=args.k_test,
        q_queries=args.q_test,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
        distance=args.distance
    ),
    ModelCheckpoint(
        filepath=PATH + f'/models/proto_nets/{param_str}.pth',
        monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'
    ),
    LearningRateScheduler(schedule=lr_schedule),
    CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
]

fit(
    model,
    optimiser,
    loss_fn,
    epochs=n_epochs,
def train_proto_net(
    args,
    model,
    device,
    n_epochs,
    background_taskloader,
    evaluation_taskloader,
    path='.',
    lr=3e-3,
    drop_lr_every=100,
    evaluation_episodes=100,
    episodes_per_epoch=100,
):
    # Prepare model
    model.to(device, dtype=torch.float)
    model.train(True)

    # Prepare training etc.
    optimizer = Adam(model.parameters(), lr=lr)
    loss_fn = torch.nn.NLLLoss().cuda()
    ensure_folder(path + '/models')
    ensure_folder(path + '/logs')

    def lr_schedule(epoch, lr):
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        num_tasks=evaluation_episodes,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=evaluation_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=path + '/models/' + args.param_str + '_e{epoch:02d}.pth',
            monitor=args.checkpoint_monitor
            or f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            period=args.checkpoint_period or 100,
        ),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(path + f'/logs/{args.param_str}.csv'),
    ]

    fit(
        model,
        optimizer,
        loss_fn,
        epochs=n_epochs,
        dataloader=background_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        epoch_metrics=[f'val_{args.n_test}-shot_{args.k_test}-way_acc'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        type=str,
        default=
        "./models/proto_nets/miniImageNet_nt=5_kt=5_qt=10_nv=5_kv=5_qv=10_dist=l2_sampling_method=True_is_diverisity=True.pth",
        help="model path")
    parser.add_argument(
        "--result_path",
        type=str,
        default="./results/proto_nets/5shot_training_5shot_diverisity.csv",
        help="Directory for evaluation report result (for experiments)")
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--distance', default='cosine')
    parser.add_argument('--n_train', default=1, type=int)
    parser.add_argument('--n_test', default=1, type=int)
    parser.add_argument('--k_train', default=5, type=int)
    parser.add_argument('--k_test', default=5, type=int)
    parser.add_argument('--q_train', default=15, type=int)
    parser.add_argument('--q_test', default=15, type=int)
    parser.add_argument(
        "--debug",
        action="store_true",
        help="set logging level DEBUG",
    )
    args = parser.parse_args()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.DEBUG if args.debug else logging.INFO,
    )

    ###################
    # Create datasets #
    ###################
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 5
        dataset_class = MiniImageNet
        num_input_channels = 3
    else:
        raise (ValueError('need to make other datasets module'))

    test_dataset = dataset_class('test')
    test_dataset_taskloader = DataLoader(
        test_dataset,
        batch_sampler=NShotTaskSampler(test_dataset, episodes_per_epoch,
                                       args.n_test, args.k_test, args.q_test),
        num_workers=4)

    #########
    # Model #
    #########
    model = get_few_shot_encoder(num_input_channels).to(device,
                                                        dtype=torch.double)

    model.load_state_dict(torch.load(args.model_path), strict=False)
    model.eval()

    #############
    # Inference #
    #############
    logger.info("***** Epochs = %d *****", n_epochs)
    logger.info("***** Num episodes per epoch = %d *****", episodes_per_epoch)

    result_writer = ResultWriter(args.result_path)

    # just argument (function: proto_net_episode)
    prepare_batch = prepare_nshot_task(args.n_test, args.k_test, args.q_test)
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    train_iterator = trange(
        0,
        int(n_epochs),
        desc="Epoch",
    )
    for i_epoch in train_iterator:
        epoch_iterator = tqdm(
            test_dataset_taskloader,
            desc="Iteration",
        )
        seen = 0
        metric_name = f'test_{args.n_test}-shot_{args.k_test}-way_acc'
        metric = {metric_name: 0.0}
        for _, batch in enumerate(epoch_iterator):
            x, y = prepare_batch(batch)

            loss, y_pred = proto_net_episode(model,
                                             optimiser,
                                             loss_fn,
                                             x,
                                             y,
                                             n_shot=args.n_test,
                                             k_way=args.k_test,
                                             q_queries=args.q_test,
                                             train=False,
                                             distance=args.distance)

            seen += y_pred.shape[0]
            metric[metric_name] += categorical_accuracy(
                y, y_pred) * y_pred.shape[0]

        metric[metric_name] = metric[metric_name] / seen

        logger.info("epoch: {},     categorical_accuracy: {}".format(
            i_epoch, metric[metric_name]))
        result_writer.update(**metric)
Beispiel #7
0
    if epoch % drop_lr_every == 0:
        return lr / 2
    else:
        return lr


callbacks = [
    EvaluateFewShot(
        eval_fn=proto_net_episode,
        num_tasks=evaluation_episodes,
        n_shot=args.n_test,
        k_way=args.k_test,
        q_queries=args.q_test,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_nshot_task(
            args.n_test, args.k_test, args.q_test
        ),  # n shot task is a simple function that maps classes to [0-k]
        distance=args.distance),
    ModelCheckpoint(filepath=PATH + f'/models/proto_nets/{param_str}.pth',
                    monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
    LearningRateScheduler(schedule=lr_schedule),
    CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
]

fit(
    model,
    optimiser,
    loss_fn,
    epochs=n_epochs,
    dataloader=background_taskloader,
    prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
Beispiel #8
0
def run():
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
        lstm_input_size = 1600
    else:
        raise(ValueError('need to make other datasets module'))

    param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \
                f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}'


    #########
    # Model #
    #########
    from few_shot.models import MatchingNetwork
    model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels,
                            lstm_layers=args.lstm_layers,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=args.unrolling_steps,
                            device=device)
    model.to(device, dtype=torch.double)


    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original_sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )
    # Importance sampling
    else:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(train_dataset, model,
            episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train,
            args.num_s_candidates, args.init_temperature, args.is_diversity),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )

    ############
    # Training #
    ############
    print(f'Training Matching Network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()


    callbacks = [
        EvaluateFewShot(
            eval_fn=matching_net_episode,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=eval_dataset_taskloader,
            prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
            fce=args.fce,
            distance=args.distance
        ),
        ModelCheckpoint(
            filepath=PATH + f'/models/matching_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=matching_net_episode,
        fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                            'fce': args.fce, 'distance': args.distance}
    )
Beispiel #9
0
############
# Training #
############
print(f'Training Matching Network on {globals.DATASET}...')
optimiser = Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.NLLLoss().cuda()

callbacks = [
    EvaluateFewShot(eval_fn=matching_net_episode,
                    num_tasks=evaluation_episodes,
                    n_shot=globals.N_TEST,
                    k_way=globals.K_TEST,
                    q_queries=globals.Q_TEST,
                    taskloader=evaluation_taskloader,
                    prepare_batch=prepare_nshot_task(globals.N_TEST,
                                                     globals.K_TEST,
                                                     globals.Q_TEST),
                    fce=globals.FCE,
                    distance=globals.DISTANCE),
    ModelCheckpoint(
        filepath=PATH + f'/models/matching_nets/{param_str}.pth',
        monitor=f'val_{globals.N_TEST}-shot_{globals.K_TEST}-way_acc',
        # monitor=f'val_loss',
    ),
    ReduceLROnPlateau(
        patience=20,
        factor=0.5,
        monitor=f'val_{globals.N_TEST}-shot_{globals.K_TEST}-way_acc'),
    CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
]
def train_sweep():

    from torch.optim import Adam
    from torch.utils.data import DataLoader
    import argparse

    from few_shot.datasets import OmniglotDataset, MiniImageNet, ClinicDataset, SNIPSDataset, CustomDataset
    from few_shot.models import XLNetForEmbedding
    from few_shot.core import NShotTaskSampler, EvaluateFewShot, prepare_nshot_task
    from few_shot.proto import proto_net_episode
    from few_shot.train_with_prints import fit
    from few_shot.callbacks import CallbackList, Callback, DefaultCallback, ProgressBarLogger, CSVLogger, EvaluateMetrics, ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler
    from few_shot.utils import setup_dirs
    from few_shot.utils import get_gpu_info
    from config import PATH
    import wandb
    from transformers import AdamW

    import torch

    gpu_dict = get_gpu_info()
    print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
        gpu_dict['mem_total'], gpu_dict['mem_used'],
        gpu_dict['mem_used_percent']))

    setup_dirs()
    assert torch.cuda.is_available()
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True

    ##############
    # Parameters #
    ##############
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='Custom')
    parser.add_argument('--distance', default='l2')
    parser.add_argument('--n-train', default=2, type=int)
    parser.add_argument('--n-test', default=2, type=int)
    parser.add_argument('--k-train', default=2, type=int)
    parser.add_argument('--k-test', default=2, type=int)
    parser.add_argument('--q-train', default=2, type=int)
    parser.add_argument('--q-test', default=2, type=int)
    args = parser.parse_args()

    evaluation_episodes = 100
    episodes_per_epoch = 10

    if args.dataset == 'omniglot':
        n_epochs = 40
        dataset_class = OmniglotDataset
        num_input_channels = 1
        drop_lr_every = 20
    elif args.dataset == 'miniImageNet':
        n_epochs = 80
        dataset_class = MiniImageNet
        num_input_channels = 3
        drop_lr_every = 40
    elif args.dataset == 'clinic150':
        n_epochs = 5
        dataset_class = ClinicDataset
        num_input_channels = 150
        drop_lr_every = 2
    elif args.dataset == 'SNIPS':
        n_epochs = 5
        dataset_class = SNIPSDataset
        num_input_channels = 150
        drop_lr_every = 2
    elif args.dataset == 'Custom':
        n_epochs = 20
        dataset_class = CustomDataset
        num_input_channels = 150
        drop_lr_every = 5
    else:
        raise (ValueError, 'Unsupported dataset')

    param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}'

    print(param_str)

    from sklearn.model_selection import train_test_split

    ###################
    # Create datasets #
    ###################

    train_df = dataset_class('train')

    train_taskloader = DataLoader(train_df,
                                  batch_sampler=NShotTaskSampler(
                                      train_df, episodes_per_epoch,
                                      args.n_train, args.k_train,
                                      args.q_train))

    val_df = dataset_class('val')

    evaluation_taskloader = DataLoader(
        val_df,
        batch_sampler=NShotTaskSampler(val_df, episodes_per_epoch, args.n_test,
                                       args.k_test, args.q_test))

    #train_iter = iter(train_taskloader)
    #train_taskloader = next(train_iter)

    #val_iter = iter(evaluation_taskloader)
    #evaluation_taskloader = next(val_iter)

    #########
    # Wandb #
    #########

    config_defaults = {
        'lr': 0.00001,
        'optimiser': 'adam',
        'batch_size': 16,
    }

    wandb.init(config=config_defaults)

    #########
    # Model #
    #########

    torch.cuda.empty_cache()

    try:
        print('Before Model Move')
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    #from transformers import XLNetForSequenceClassification, AdamW

    #model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=150)
    #model.cuda()

    try:
        del model
    except:
        print("Cannot delete model. No model with name 'model' exists")

    model = XLNetForEmbedding(num_input_channels)
    model.to(device, dtype=torch.double)

    #param_optimizer = list(model.named_parameters())
    #no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    #optimizer_grouped_parameters = [
    #                                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    #                                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay':0.0}
    #]

    try:
        print('After Model Move')
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    wandb.watch(model)

    ############
    # Training #
    ############

    from transformers import AdamW

    print(f'Training Prototypical network on {args.dataset}...')
    if wandb.config.optimiser == 'adam':
        optimiser = Adam(model.parameters(), lr=wandb.config.lr)
    else:
        optimiser = AdamW(model.parameters(), lr=wandb.config.lr)

    #optimiser = AdamW(optimizer_grouped_parameters, lr=3e-5)
    #loss_fn = torch.nn.NLLLoss().cuda()

    #loss_fn = torch.nn.CrossEntropyLoss()

    #max_grad_norm = 1.0

    loss_fn = torch.nn.NLLLoss()

    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        num_tasks=evaluation_episodes,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=evaluation_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ]

    try:
        print('Before Fit')
        print('optimiser :', optimiser)
        print('Learning Rate: ', wandb.config.lr)
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
def run():
    episodes_per_epoch = 600
    '''
    ###### LearningRateScheduler ######
    drop_lr_every = 20
    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr
    # callbacks add: LearningRateScheduler(schedule=lr_schedule)
    '''

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
    else:
        raise (ValueError('need to make other datasets module'))


    param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_sampling_method={args.sampling_method}_is_diverisity={args.is_diversity}'

    print(param_str)

    #########
    # Model #
    #########
    model = get_few_shot_encoder(num_input_channels)
    model.to(device, dtype=torch.double)

    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch,
                                           args.n_train, args.k_train,
                                           args.q_train),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)
    # Importance sampling
    else:
        # ImportanceSampler: Latent space of model
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(
                train_dataset, model, episodes_per_epoch, n_epochs,
                args.n_train, args.k_train, args.q_train,
                args.num_s_candidates, args.init_temperature,
                args.is_diversity),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)

    ############
    # Training #
    ############
    print(f'Training Prototypical network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=eval_dataset_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(
            patience=40,
            factor=0.5,
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )