Пример #1
0
        PATH +
        f'/logs/maml_ens/finetune_{logs_name}_{args.train_pred_mode}_{args.epochs}.csv',
        hash=hash),
]

fit(
    meta_models,
    meta_optimisers,
    loss_fn,
    epochs=args.epochs,
    dataloader=background_taskloader,
    prepare_batch=prepare_meta_batch(args.n, args.k, args.q,
                                     args.meta_batch_size),
    callbacks=callbacks,
    metrics=['categorical_accuracy'],
    fit_function=fit_fn,
    n_models=args.n_models,
    fit_function_kwargs={
        'n_shot': args.n,
        'k_way': args.k,
        'q_queries': args.q,
        'train': True,
        'order': args.order,
        'device': device,
        'inner_train_steps': args.inner_train_steps,
        'inner_lr': args.inner_lr,
        'model_params': model_params,
        'pred_fn': train_pred_fn
    },
)
Пример #2
0
        eval_fn=proto_net_episode,
        num_tasks=evaluation_episodes,
        n_shot=args.n_test,
        k_way=args.k_test,
        q_queries=args.q_test,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
        distance=args.distance
    ),
    ModelCheckpoint(
        filepath=PATH + f'/models/proto_nets/{param_str}.pth',
        monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'
    ),
    LearningRateScheduler(schedule=lr_schedule),
    CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
]

fit(
    model,
    optimiser,
    loss_fn,
    epochs=n_epochs,
    dataloader=background_taskloader,
    prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
    callbacks=callbacks,
    metrics=['categorical_accuracy'],
    fit_function=proto_net_episode,
    fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                         'distance': args.distance},
)
Пример #3
0
        device=device,
        order=args.order,
    ),
    ModelCheckpoint(
        filepath=PATH + '/models/maml/{}.pth'.format(param_str),
        monitor='val_{}-shot_{}-way_acc'.format(args.n, args.k)
    ),
    ReduceLROnPlateau(patience=10, factor=0.5, monitor='val_loss'),
    CSVLogger(PATH + '/logs/maml/{}.csv'.format(param_str)),
]


fit(
    meta_model,
    meta_optimiser,
    loss_fn,
    epochs=args.epochs,
    dataloader=background_taskloader,
    prepare_batch=prepare_meta_batch(args.n, args.k, args.q, args.meta_batch_size),
    callbacks=callbacks,
    stnmodel=stnmodel,
    stnoptim=stnoptim,
    args=args,
    metrics=['categorical_accuracy'],
    fit_function=meta_gradient_step,
    fit_function_kwargs={'n_shot': args.n, 'k_way': args.k, 'q_queries': args.q,
                         'train': True,
                         'order': args.order, 'device': device, 'inner_train_steps': args.inner_train_steps,
                         'inner_lr': args.inner_lr},
)
def train_proto_net(
    args,
    model,
    device,
    n_epochs,
    background_taskloader,
    evaluation_taskloader,
    path='.',
    lr=3e-3,
    drop_lr_every=100,
    evaluation_episodes=100,
    episodes_per_epoch=100,
):
    # Prepare model
    model.to(device, dtype=torch.float)
    model.train(True)

    # Prepare training etc.
    optimizer = Adam(model.parameters(), lr=lr)
    loss_fn = torch.nn.NLLLoss().cuda()
    ensure_folder(path + '/models')
    ensure_folder(path + '/logs')

    def lr_schedule(epoch, lr):
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        num_tasks=evaluation_episodes,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=evaluation_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=path + '/models/' + args.param_str + '_e{epoch:02d}.pth',
            monitor=args.checkpoint_monitor
            or f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            period=args.checkpoint_period or 100,
        ),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(path + f'/logs/{args.param_str}.csv'),
    ]

    fit(
        model,
        optimizer,
        loss_fn,
        epochs=n_epochs,
        dataloader=background_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        epoch_metrics=[f'val_{args.n_test}-shot_{args.k_test}-way_acc'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
Пример #5
0
    ModelCheckpoint(
        filepath=PATH + f'/models/matching_nets/{param_str}.pth',
        monitor=f'val_{globals.N_TEST}-shot_{globals.K_TEST}-way_acc',
        # monitor=f'val_loss',
    ),
    ReduceLROnPlateau(
        patience=20,
        factor=0.5,
        monitor=f'val_{globals.N_TEST}-shot_{globals.K_TEST}-way_acc'),
    CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
]

fit(model,
    optimiser,
    loss_fn,
    epochs=n_epochs,
    dataloader=background_taskloader,
    prepare_batch=prepare_nshot_task(globals.N_TRAIN, globals.K_TRAIN,
                                     globals.Q_TRAIN),
    callbacks=callbacks,
    metrics=['categorical_accuracy'],
    fit_function=matching_net_episode,
    fit_function_kwargs={
        'n_shot': globals.N_TRAIN,
        'k_way': globals.K_TRAIN,
        'q_queries': globals.Q_TRAIN,
        'train': True,
        'fce': globals.FCE,
        'distance': globals.DISTANCE
    })
Пример #6
0
def run():
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
        lstm_input_size = 1600
    else:
        raise(ValueError('need to make other datasets module'))

    param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \
                f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}'


    #########
    # Model #
    #########
    from few_shot.models import MatchingNetwork
    model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels,
                            lstm_layers=args.lstm_layers,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=args.unrolling_steps,
                            device=device)
    model.to(device, dtype=torch.double)


    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original_sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )
    # Importance sampling
    else:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(train_dataset, model,
            episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train,
            args.num_s_candidates, args.init_temperature, args.is_diversity),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )

    ############
    # Training #
    ############
    print(f'Training Matching Network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()


    callbacks = [
        EvaluateFewShot(
            eval_fn=matching_net_episode,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=eval_dataset_taskloader,
            prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
            fce=args.fce,
            distance=args.distance
        ),
        ModelCheckpoint(
            filepath=PATH + f'/models/matching_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=matching_net_episode,
        fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                            'fce': args.fce, 'distance': args.distance}
    )
Пример #7
0
    callbacks = [
        evalmetrics, progressbar,
        ModelCheckpoint(filepath=os.path.join(PATH, 'models',
                                              'semantic_classifier',
                                              str(param_str) + '.pth'),
                        monitor='val_' + str(args.n) + '-shot_' + str(args.k) +
                        '-way_acc'),
        ReduceLROnPlateau(patience=10, factor=0.5, monitor='val_loss'),
        CSVLogger(
            os.path.join(PATH, 'logs', 'semantic_classifier',
                         str(param_str) + '.csv'))
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=args.epochs,
        dataloader=train_dataloader,
        prepare_batch=prepare_batch(args.n, args.k),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=gradient_step_fn,
        fit_function_kwargs={
            'n_shot': args.n,
            'k_way': args.k,
            'device': device
        },
    )
    def test(self):
        k = 200
        n = 5
        epochs = 20
        size_binary_layer = 10
        stochastic = True
        n_conv_layers = 4
        lr = 0.01

        model_name = 'Omniglot__n=5_k=20_epochs=1000__lr=__size_binary_layer=10__size_continue_layer=10__stochastic__simplified_encoder'
        validation_split = .2

        setup_dirs()
        assert torch.cuda.is_available()

        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = True

        # model = SemanticBinaryClassifier(1, k, size_binary_layer=size_binary_layer, stochastic=stochastic,
        #                                  size_dense_layer_before_binary=None,
        #                                  n_conv_layers=n_conv_layers)
        #model = FewShotClassifier(1, k)
        model = SemanticBinaryEncoder(1, 10, 10, stochastic=True)
        model.load_state_dict(torch.load(os.path.join("models", "semantic_gan",
                                                     model_name+".pth")))

        evaluation = OmniglotDataset('evaluation')

        classes = np.random.choice(evaluation.df['class_id'].unique(), size=k)
        for i in classes:
            evaluation.df[evaluation.df['class_id'] == i] = evaluation.df[evaluation.df['class_id'] == i].sample(frac=1)

        train_dataloader = DataLoader(
            evaluation,
            batch_sampler=BasicSampler(evaluation, validation_split, True, classes, n=n),
            num_workers=8
        )

        eval_dataloader = DataLoader(
            evaluation,
            batch_sampler=BasicSampler(evaluation, validation_split, False, classes, n=n),
            num_workers=8
        )

        test_model = TestSemanticBinaryClassifier(k, model, size_binary_layer=size_binary_layer).to(device, dtype=torch.double)
        loss_fn = nn.CrossEntropyLoss().to(device)
        optimiser = torch.optim.Adam(test_model.parameters(), lr=lr)

        def prepare_batch(n, k):
            def prepare_batch_(batch):
                x, y = batch
                x = x.double().cuda()
                # Create dummy 0-(num_classes - 1) label
                y = create_nshot_task_label(k, n).cuda()
                return x, y
            return prepare_batch_

        evalmetrics = EvaluateMetrics(eval_dataloader)
        evalmetrics.set_params({'metrics': ['categorical_accuracy'],
                            'prepare_batch': prepare_batch(n, k),
                            'loss_fn': loss_fn})

        callbacks = [
            evalmetrics,

            ModelCheckpoint(
                filepath=os.path.join(PATH, 'models', 'semantic_classifier', model_name + 'test_other_class.pth'),
                monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc'
            ),
            ReduceLROnPlateau(patience=10, factor=0.5, monitor='val_loss'),
            CSVLogger(os.path.join(PATH, 'logs', 'semantic_classifier', model_name + 'test_other_class.csv'))
        ]

        #print(summary(model, (1, 28, 28)))
        for param in model.parameters():
            param.requires_grad = False
        fit(
            test_model,
            optimiser,
            loss_fn,
            epochs=100,
            dataloader=train_dataloader,
            prepare_batch=prepare_batch(n, k),
            callbacks=callbacks,
            metrics=['categorical_accuracy'],
            fit_function=gradient_step,
            fit_function_kwargs={'n_shot': n, 'k_way': k, 'device': device},
        )
Пример #9
0
def few_shot_training(datadir=DATA_PATH,
                      dataset='fashion',
                      num_input_channels=3,
                      drop_lr_every=20,
                      validation_episodes=200,
                      evaluation_episodes=1000,
                      episodes_per_epoch=100,
                      n_epochs=80,
                      small_dataset=False,
                      n_train=1,
                      n_test=1,
                      k_train=30,
                      k_test=5,
                      q_train=5,
                      q_test=1,
                      distance='l2',
                      pretrained=False,
                      monitor_validation=False,
                      n_val_classes=10,
                      architecture='resnet18',
                      gpu=None):
    setup_dirs()

    if dataset == 'fashion':
        dataset_class = FashionProductImagesSmall if small_dataset \
            else FashionProductImages
    else:
        raise (ValueError, 'Unsupported dataset')

    param_str = f'{dataset}_nt={n_train}_kt={k_train}_qt={q_train}_' \
                f'nv={n_test}_kv={k_test}_qv={q_test}_small={small_dataset}_' \
                f'pretrained={pretrained}_validate={monitor_validation}'

    print(param_str)

    ###################
    # Create datasets #
    ###################

    # ADAPTED: data transforms including augmentation
    resize = (80, 60) if small_dataset else (400, 300)

    background_transform = transforms.Compose([
        transforms.RandomResizedCrop(resize, scale=(0.8, 1.0)),
        # transforms.RandomGrayscale(),
        transforms.RandomPerspective(),
        transforms.RandomHorizontalFlip(),
        # transforms.Resize(resize),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                     std=[0.229, 0.224, 0.225])
    ])

    evaluation_transform = transforms.Compose([
        transforms.Resize(resize),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                     std=[0.229, 0.224, 0.225])
    ])

    if monitor_validation:
        if not n_val_classes >= k_test:
            n_val_classes = k_test
            print("Warning: `n_val_classes` < `k_test`. Take a larger number"
                  " of validation classes next time. Increased to `k_test`"
                  " classes")

        # class structure for background (training), validation (validation),
        # evaluation (test): take a random subset of background classes
        validation_classes = list(
            np.random.choice(dataset_class.background_classes, n_val_classes))
        background_classes = list(
            set(dataset_class.background_classes).difference(
                set(validation_classes)))

        # use keyword for evaluation classes
        evaluation_classes = 'evaluation'

        # Meta-validation set
        validation = dataset_class(datadir,
                                   split='all',
                                   classes=validation_classes,
                                   transform=evaluation_transform)
        # ADAPTED: in the original code, `episodes_per_epoch` was provided to
        # `NShotTaskSampler` instead of `validation_episodes`.
        validation_sampler = NShotTaskSampler(validation, validation_episodes,
                                              n_test, k_test, q_test)
        validation_taskloader = DataLoader(validation,
                                           batch_sampler=validation_sampler,
                                           num_workers=4)
    else:
        # use keyword for both background and evaluation classes
        background_classes = 'background'
        evaluation_classes = 'evaluation'

    # Meta-training set
    background = dataset_class(datadir,
                               split='all',
                               classes=background_classes,
                               transform=background_transform)
    background_sampler = NShotTaskSampler(background, episodes_per_epoch,
                                          n_train, k_train, q_train)
    background_taskloader = DataLoader(background,
                                       batch_sampler=background_sampler,
                                       num_workers=4)

    # Meta-test set
    evaluation = dataset_class(datadir,
                               split='all',
                               classes=evaluation_classes,
                               transform=evaluation_transform)
    # ADAPTED: in the original code, `episodes_per_epoch` was provided to
    # `NShotTaskSampler` instead of `evaluation_episodes`.
    evaluation_sampler = NShotTaskSampler(evaluation, evaluation_episodes,
                                          n_test, k_test, q_test)
    evaluation_taskloader = DataLoader(evaluation,
                                       batch_sampler=evaluation_sampler,
                                       num_workers=4)

    #########
    # Model #
    #########

    if torch.cuda.is_available():
        if gpu is not None:
            device = torch.device('cuda', gpu)
        else:
            device = torch.device('cuda')
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    if not pretrained:
        model = get_few_shot_encoder(num_input_channels)
        # ADAPTED
        model.to(device)
        # BEFORE
        # model.to(device, dtype=torch.double)
    else:
        assert torch.cuda.is_available()
        model = models.__dict__[architecture](pretrained=True)
        model.fc = Identity()
        if gpu is not None:
            model = model.cuda(gpu)
        else:
            model = model.cuda()
        # TODO this is too risky: I'm not sure that this can work, since in
        #  the few-shot github repo the batch axis is actually split into
        #  support and query samples
        # model = torch.nn.DataParallel(model).cuda()

    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    ############
    # Training #
    ############
    print(f'Training Prototypical network on {dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().to(device)

    callbacks = [
        # ADAPTED: this is the test monitoring now - and is only done at the
        # end of training.
        EvaluateFewShot(
            eval_fn=proto_net_episode,
            num_tasks=evaluation_episodes,  # THIS IS NOT USED
            n_shot=n_test,
            k_way=k_test,
            q_queries=q_test,
            taskloader=evaluation_taskloader,
            prepare_batch=prepare_nshot_task(n_test,
                                             k_test,
                                             q_test,
                                             device=device),
            distance=distance,
            on_epoch_end=False,
            on_train_end=True,
            prefix='test_')
    ]
    if monitor_validation:
        callbacks.append(
            # ADAPTED: this is the validation monitoring now - computed
            # after every epoch.
            EvaluateFewShot(
                eval_fn=proto_net_episode,
                num_tasks=evaluation_episodes,  # THIS IS NOT USED
                n_shot=n_test,
                k_way=k_test,
                q_queries=q_test,
                # BEFORE taskloader=evaluation_taskloader,
                taskloader=validation_taskloader,  # ADAPTED
                prepare_batch=prepare_nshot_task(n_test,
                                                 k_test,
                                                 q_test,
                                                 device=device),
                distance=distance,
                on_epoch_end=True,  # ADAPTED
                on_train_end=False,  # ADAPTED
                prefix='val_'))
    callbacks.extend([
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{n_test}-shot_{k_test}-way_acc',
            verbose=1,  # ADAPTED
            save_best_only=monitor_validation  # ADAPTED
        ),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ])

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=background_taskloader,
        prepare_batch=prepare_nshot_task(n_train,
                                         k_train,
                                         q_train,
                                         device=device),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': n_train,
            'k_way': k_train,
            'q_queries': q_train,
            'train': True,
            'distance': distance
        },
    )
            prepare_batch=prepare_classifier_task(),
        ),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}_classifier.pth',
            monitor=f'val_acc',
            save_best_only=True
        ),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}_classifier.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=background_taskloader,
        prepare_batch=prepare_classifier_task(),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
    )
else:
    optimiser = Adam(model.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss().cuda()
    test(
        model,
        optimiser,
        loss_fn,
        dataloader=test_taskloader,
        prepare_batch=prepare_classifier_task(),
    )
Пример #11
0
def run():
    episodes_per_epoch = 600
    '''
    ###### LearningRateScheduler ######
    drop_lr_every = 20
    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr
    # callbacks add: LearningRateScheduler(schedule=lr_schedule)
    '''

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
    else:
        raise (ValueError('need to make other datasets module'))


    param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_sampling_method={args.sampling_method}_is_diverisity={args.is_diversity}'

    print(param_str)

    #########
    # Model #
    #########
    model = get_few_shot_encoder(num_input_channels)
    model.to(device, dtype=torch.double)

    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch,
                                           args.n_train, args.k_train,
                                           args.q_train),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)
    # Importance sampling
    else:
        # ImportanceSampler: Latent space of model
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(
                train_dataset, model, episodes_per_epoch, n_epochs,
                args.n_train, args.k_train, args.q_train,
                args.num_s_candidates, args.init_temperature,
                args.is_diversity),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)

    ############
    # Training #
    ############
    print(f'Training Prototypical network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=eval_dataset_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(
            patience=40,
            factor=0.5,
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )