Exemple #1
0
def run():
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
        lstm_input_size = 1600
    else:
        raise(ValueError('need to make other datasets module'))

    param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \
                f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}'


    #########
    # Model #
    #########
    from few_shot.models import MatchingNetwork
    model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels,
                            lstm_layers=args.lstm_layers,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=args.unrolling_steps,
                            device=device)
    model.to(device, dtype=torch.double)


    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original_sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )
    # Importance sampling
    else:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(train_dataset, model,
            episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train,
            args.num_s_candidates, args.init_temperature, args.is_diversity),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )

    ############
    # Training #
    ############
    print(f'Training Matching Network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()


    callbacks = [
        EvaluateFewShot(
            eval_fn=matching_net_episode,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=eval_dataset_taskloader,
            prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
            fce=args.fce,
            distance=args.distance
        ),
        ModelCheckpoint(
            filepath=PATH + f'/models/matching_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=matching_net_episode,
        fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                            'fce': args.fce, 'distance': args.distance}
    )
Exemple #2
0
            f'dist={args.distance}_fce={args.fce}'

#########
# Model #
#########
from few_shot.models import MatchingNetwork
model = MatchingNetwork(args.n_train,
                        args.k_train,
                        args.q_train,
                        args.fce,
                        num_input_channels,
                        lstm_layers=args.lstm_layers,
                        lstm_input_size=lstm_input_size,
                        unrolling_steps=args.unrolling_steps,
                        device=device)
model.to(device, dtype=torch.double)

###################
# Create datasets #
###################
background = dataset_class('background')
background_taskloader = DataLoader(
    background,
    batch_sampler=NShotTaskSampler(background, episodes_per_epoch,
                                   args.n_train, args.k_train, args.q_train),
    num_workers=4)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(evaluation,
                                   batch_sampler=NShotTaskSampler(
                                       evaluation, episodes_per_epoch,
                                       args.n_test, args.k_test, args.q_test),
Exemple #3
0
#####
# experiments/matching_nets.py
model = MatchingNetwork(globals.N_TRAIN, globals.K_TRAIN, globals.Q_TRAIN, globals.FCE, 
                        num_input_channels,
                        lstm_layers=globals.LSTM_LAYERS,
                        lstm_input_size=lstm_input_size,
                        unrolling_steps=globals.UNROLLING_STEPS,
                        device=device)

model_path = 'models/matching_nets/omniglot_n=1_k=5_q=15_nv=1_kv=5_qv=1_dist=l2_fce=False.pth'
loaded_model = torch.load(model_path)
model.load_state_dict(loaded_model)


model.to(device)
model.double()
# print("###########################")
# for param in model.parameters():
#     print(param.data)
# print("###########################")


evaluation = dataset_class('evaluation')
dataloader = DataLoader(
    evaluation,
    batch_sampler=NShotTaskSampler(evaluation, episodes_per_epoch, n=globals.N_TEST, k=globals.K_TEST, 
                        q=globals.Q_TEST),
                        num_workers=4
)
prepare_batch = prepare_nshot_task(globals.N_TEST, globals.K_TEST, globals.Q_TEST)