def lr_schedule(epoch, lr):
    # Drop lr every 2000 episodes
    if epoch % drop_lr_every == 0:
        return lr / 2
    else:
        return lr


callbacks = [
    EvaluateFewShot(
        eval_fn=proto_net_episode,
        num_tasks=evaluation_episodes,
        n_shot=args.n_test,
        k_way=args.k_test,
        q_queries=args.q_test,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
        distance=args.distance
    ),
    ModelCheckpoint(
        filepath=PATH + f'/models/proto_nets/{param_str}.pth',
        monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'
    ),
    LearningRateScheduler(schedule=lr_schedule),
    CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
]

fit(
    model,
    optimiser,
Example #2
0
# Training #
############
print('Training Matching Network on {}...'.format(args.dataset))
if args.stn:
    print('Training with STN')
optimiser = Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.NLLLoss().cuda()

callbacks = [
    EvaluateFewShot(eval_fn=matching_net_episode,
                    num_tasks=evaluation_episodes,
                    n_shot=args.n_test,
                    k_way=args.k_test,
                    q_queries=args.q_test,
                    taskloader=evaluation_taskloader,
                    prepare_batch=prepare_nshot_task(args.n_test, args.k_test,
                                                     args.q_test),
                    fce=args.fce,
                    args=args,
                    stnmodel=None,
                    stnoptim=None,
                    distance=args.distance),
    ModelCheckpoint(
        filepath=PATH + '/models/matching_nets/{}.pth'.format(param_str),
        monitor='val_{}-shot_{}-way_acc'.format(args.n_test, args.k_test),
        # monitor=f'val_loss',
    ),
    ReduceLROnPlateau(patience=20,
                      factor=0.5,
                      monitor='val_{}-shot_{}-way_acc'.format(
                          args.n_test, args.k_test)),
Example #3
0
    for i in range(len(meta_optimisers))
]
ReduceLRCallback = CallbackList(ReduceLRCallback)

hash = ''.join([chr(random.randint(97, 122)) for _ in range(3)])
callbacks = [
    ModelLoader(names=names),
    EvaluateFewShot(
        eval_fn=fit_fn,
        num_tasks=args.eval_batches,
        n_shot=args.n,
        k_way=args.k,
        q_queries=args.q,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_meta_batch(args.n, args.k, args.q,
                                         args.meta_batch_size),
        loss_fn=loss_fn,
        # MAML kwargs
        inner_train_steps=args.inner_val_steps,
        inner_lr=args.inner_lr,
        device=device,
        order=args.order,
        model_params=model_params,
        pred_fn=test_pred_fn),
    EvaluateFewShot(
        eval_fn=fit_fn,
        num_tasks=args.eval_batches,
        n_shot=args.n,
        k_way=args.k,
        q_queries=args.q,
        taskloader=evaluation_taskloader,
def train_proto_net(
    args,
    model,
    device,
    n_epochs,
    background_taskloader,
    evaluation_taskloader,
    path='.',
    lr=3e-3,
    drop_lr_every=100,
    evaluation_episodes=100,
    episodes_per_epoch=100,
):
    # Prepare model
    model.to(device, dtype=torch.float)
    model.train(True)

    # Prepare training etc.
    optimizer = Adam(model.parameters(), lr=lr)
    loss_fn = torch.nn.NLLLoss().cuda()
    ensure_folder(path + '/models')
    ensure_folder(path + '/logs')

    def lr_schedule(epoch, lr):
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        num_tasks=evaluation_episodes,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=evaluation_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=path + '/models/' + args.param_str + '_e{epoch:02d}.pth',
            monitor=args.checkpoint_monitor
            or f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            period=args.checkpoint_period or 100,
        ),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(path + f'/logs/{args.param_str}.csv'),
    ]

    fit(
        model,
        optimizer,
        loss_fn,
        epochs=n_epochs,
        dataloader=background_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        epoch_metrics=[f'val_{args.n_test}-shot_{args.k_test}-way_acc'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
Example #5
0
        # Create label
        y = create_nshot_task_label(k, q).cuda().repeat(meta_batch_size)
        return x, y

    return prepare_meta_batch_


callbacks = [
    EvaluateFewShot(
        eval_fn=meta_gradient_step,
        num_tasks=args.eval_batches,
        n_shot=args.n,
        k_way=args.k,
        q_queries=args.q,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_meta_batch(args.n, args.k, args.q,
                                         args.meta_batch_size),
        # MAML kwargs
        inner_train_steps=args.inner_val_steps,
        inner_lr=args.inner_lr,
        device=device,
        order=args.order,
    ),
    ModelCheckpoint(filepath=PATH + f'/models/maml/{param_str}.pth',
                    monitor=f'val_{args.n}-shot_{args.k}-way_acc'),
    #ReduceLROnPlateau(patience=10, factor=0.5, monitor=f'val_loss'),
    CSVLogger(PATH + f'/logs/maml/{param_str}.csv'),
]

fit(
    meta_model,
Example #6
0
def run():
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
        lstm_input_size = 1600
    else:
        raise(ValueError('need to make other datasets module'))

    param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \
                f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}'


    #########
    # Model #
    #########
    from few_shot.models import MatchingNetwork
    model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels,
                            lstm_layers=args.lstm_layers,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=args.unrolling_steps,
                            device=device)
    model.to(device, dtype=torch.double)


    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original_sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )
    # Importance sampling
    else:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(train_dataset, model,
            episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train,
            args.num_s_candidates, args.init_temperature, args.is_diversity),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )

    ############
    # Training #
    ############
    print(f'Training Matching Network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()


    callbacks = [
        EvaluateFewShot(
            eval_fn=matching_net_episode,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=eval_dataset_taskloader,
            prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
            fce=args.fce,
            distance=args.distance
        ),
        ModelCheckpoint(
            filepath=PATH + f'/models/matching_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=matching_net_episode,
        fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                            'fce': args.fce, 'distance': args.distance}
    )
Example #7
0
def lr_schedule(epoch, lr):
    # Drop lr every 2000 episodes
    if epoch % drop_lr_every == 0:
        return lr / 2
    else:
        return lr


callbacks = [
    EvaluateFewShot(
        eval_fn=proto_net_episode,
        num_tasks=evaluation_episodes,
        n_shot=args.n_test,
        k_way=args.k_test,
        q_queries=args.q_test,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_nshot_task(
            args.n_test, args.k_test, args.q_test
        ),  # n shot task is a simple function that maps classes to [0-k]
        distance=args.distance),
    ModelCheckpoint(filepath=PATH + f'/models/proto_nets/{param_str}.pth',
                    monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
    LearningRateScheduler(schedule=lr_schedule),
    CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
]

fit(
    model,
    optimiser,
    loss_fn,
Example #8
0
                                   num_workers=4)

############
# Training #
############
print(f'Training Matching Network on {globals.DATASET}...')
optimiser = Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.NLLLoss().cuda()

callbacks = [
    EvaluateFewShot(eval_fn=matching_net_episode,
                    num_tasks=evaluation_episodes,
                    n_shot=globals.N_TEST,
                    k_way=globals.K_TEST,
                    q_queries=globals.Q_TEST,
                    taskloader=evaluation_taskloader,
                    prepare_batch=prepare_nshot_task(globals.N_TEST,
                                                     globals.K_TEST,
                                                     globals.Q_TEST),
                    fce=globals.FCE,
                    distance=globals.DISTANCE),
    ModelCheckpoint(
        filepath=PATH + f'/models/matching_nets/{param_str}.pth',
        monitor=f'val_{globals.N_TEST}-shot_{globals.K_TEST}-way_acc',
        # monitor=f'val_loss',
    ),
    ReduceLROnPlateau(
        patience=20,
        factor=0.5,
        monitor=f'val_{globals.N_TEST}-shot_{globals.K_TEST}-way_acc'),
    CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
Example #9
0
        batch_sampler=NShotTaskSampler(
            evaluation,
            episodes_per_epoch,
            args.n_test,
            args.k_test,
            args.q_test,
            eval_classes=None
        ),  # why is qtest needed for protonet i think its not rquired for protonet check it
        num_workers=4)
    callbacks = [
        EvaluateFewShot(
            eval_fn=eval_fn,
            num_tasks=evaluation_episodes,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=evaluation_taskloader,
            prepare_batch=prepare_nshot_task(
                args.n_test, args.k_test, args.q_test
            ),  # n shot task is a simple function that maps classes to [0-k]
            distance=args.distance),
    ]
elif args.network == 'matching':
    from few_shot.models import MatchingNetwork
    n_epochs = 200
    dataset_class = FashionDataset
    num_input_channels = 3
    lstm_input_size = 1600

    evaluation_taskloader = DataLoader(
        evaluation,
def train_sweep():

    from torch.optim import Adam
    from torch.utils.data import DataLoader
    import argparse

    from few_shot.datasets import OmniglotDataset, MiniImageNet, ClinicDataset, SNIPSDataset, CustomDataset
    from few_shot.models import XLNetForEmbedding
    from few_shot.core import NShotTaskSampler, EvaluateFewShot, prepare_nshot_task
    from few_shot.proto import proto_net_episode
    from few_shot.train_with_prints import fit
    from few_shot.callbacks import CallbackList, Callback, DefaultCallback, ProgressBarLogger, CSVLogger, EvaluateMetrics, ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler
    from few_shot.utils import setup_dirs
    from few_shot.utils import get_gpu_info
    from config import PATH
    import wandb
    from transformers import AdamW

    import torch

    gpu_dict = get_gpu_info()
    print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
        gpu_dict['mem_total'], gpu_dict['mem_used'],
        gpu_dict['mem_used_percent']))

    setup_dirs()
    assert torch.cuda.is_available()
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True

    ##############
    # Parameters #
    ##############
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='Custom')
    parser.add_argument('--distance', default='l2')
    parser.add_argument('--n-train', default=2, type=int)
    parser.add_argument('--n-test', default=2, type=int)
    parser.add_argument('--k-train', default=2, type=int)
    parser.add_argument('--k-test', default=2, type=int)
    parser.add_argument('--q-train', default=2, type=int)
    parser.add_argument('--q-test', default=2, type=int)
    args = parser.parse_args()

    evaluation_episodes = 100
    episodes_per_epoch = 10

    if args.dataset == 'omniglot':
        n_epochs = 40
        dataset_class = OmniglotDataset
        num_input_channels = 1
        drop_lr_every = 20
    elif args.dataset == 'miniImageNet':
        n_epochs = 80
        dataset_class = MiniImageNet
        num_input_channels = 3
        drop_lr_every = 40
    elif args.dataset == 'clinic150':
        n_epochs = 5
        dataset_class = ClinicDataset
        num_input_channels = 150
        drop_lr_every = 2
    elif args.dataset == 'SNIPS':
        n_epochs = 5
        dataset_class = SNIPSDataset
        num_input_channels = 150
        drop_lr_every = 2
    elif args.dataset == 'Custom':
        n_epochs = 20
        dataset_class = CustomDataset
        num_input_channels = 150
        drop_lr_every = 5
    else:
        raise (ValueError, 'Unsupported dataset')

    param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}'

    print(param_str)

    from sklearn.model_selection import train_test_split

    ###################
    # Create datasets #
    ###################

    train_df = dataset_class('train')

    train_taskloader = DataLoader(train_df,
                                  batch_sampler=NShotTaskSampler(
                                      train_df, episodes_per_epoch,
                                      args.n_train, args.k_train,
                                      args.q_train))

    val_df = dataset_class('val')

    evaluation_taskloader = DataLoader(
        val_df,
        batch_sampler=NShotTaskSampler(val_df, episodes_per_epoch, args.n_test,
                                       args.k_test, args.q_test))

    #train_iter = iter(train_taskloader)
    #train_taskloader = next(train_iter)

    #val_iter = iter(evaluation_taskloader)
    #evaluation_taskloader = next(val_iter)

    #########
    # Wandb #
    #########

    config_defaults = {
        'lr': 0.00001,
        'optimiser': 'adam',
        'batch_size': 16,
    }

    wandb.init(config=config_defaults)

    #########
    # Model #
    #########

    torch.cuda.empty_cache()

    try:
        print('Before Model Move')
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    #from transformers import XLNetForSequenceClassification, AdamW

    #model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=150)
    #model.cuda()

    try:
        del model
    except:
        print("Cannot delete model. No model with name 'model' exists")

    model = XLNetForEmbedding(num_input_channels)
    model.to(device, dtype=torch.double)

    #param_optimizer = list(model.named_parameters())
    #no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    #optimizer_grouped_parameters = [
    #                                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    #                                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay':0.0}
    #]

    try:
        print('After Model Move')
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    wandb.watch(model)

    ############
    # Training #
    ############

    from transformers import AdamW

    print(f'Training Prototypical network on {args.dataset}...')
    if wandb.config.optimiser == 'adam':
        optimiser = Adam(model.parameters(), lr=wandb.config.lr)
    else:
        optimiser = AdamW(model.parameters(), lr=wandb.config.lr)

    #optimiser = AdamW(optimizer_grouped_parameters, lr=3e-5)
    #loss_fn = torch.nn.NLLLoss().cuda()

    #loss_fn = torch.nn.CrossEntropyLoss()

    #max_grad_norm = 1.0

    loss_fn = torch.nn.NLLLoss()

    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        num_tasks=evaluation_episodes,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=evaluation_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ]

    try:
        print('Before Fit')
        print('optimiser :', optimiser)
        print('Learning Rate: ', wandb.config.lr)
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
def run():
    episodes_per_epoch = 600
    '''
    ###### LearningRateScheduler ######
    drop_lr_every = 20
    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr
    # callbacks add: LearningRateScheduler(schedule=lr_schedule)
    '''

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
    else:
        raise (ValueError('need to make other datasets module'))


    param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_sampling_method={args.sampling_method}_is_diverisity={args.is_diversity}'

    print(param_str)

    #########
    # Model #
    #########
    model = get_few_shot_encoder(num_input_channels)
    model.to(device, dtype=torch.double)

    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch,
                                           args.n_train, args.k_train,
                                           args.q_train),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)
    # Importance sampling
    else:
        # ImportanceSampler: Latent space of model
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(
                train_dataset, model, episodes_per_epoch, n_epochs,
                args.n_train, args.k_train, args.q_train,
                args.num_s_candidates, args.init_temperature,
                args.is_diversity),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)

    ############
    # Training #
    ############
    print(f'Training Prototypical network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=eval_dataset_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(
            patience=40,
            factor=0.5,
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
Example #12
0
                                       args.n_test, args.k_test, args.q_test),
                                   num_workers=4)

############
# Training #
############
print(f'Training Matching Network on {args.dataset}...')
optimiser = Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.NLLLoss().cuda()

callbacks = [
    EvaluateFewShot(eval_fn=matching_net_episode,
                    num_tasks=evaluation_episodes,
                    n_shot=args.n_test,
                    k_way=args.k_test,
                    q_queries=args.q_test,
                    taskloader=evaluation_taskloader,
                    prepare_batch=prepare_nshot_task(args.n_test, args.k_test,
                                                     args.q_test),
                    fce=args.fce,
                    distance=args.distance),
    ModelCheckpoint(filepath=PATH + f'/models/matching_nets/{param_str}.pth',
                    monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
                    save_best_only=True
                    # monitor=f'val_loss',
                    ),
    ReduceLROnPlateau(patience=20,
                      factor=0.5,
                      monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
    CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
]