예제 #1
0
else:
    raise (ValueError, 'Unsupported dataset')

param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \
            f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_'\
            f'dist={args.distance}_fce={args.fce}'

#########
# Model #
#########
from few_shot.models import MatchingNetwork
model = MatchingNetwork(args.n_train,
                        args.k_train,
                        args.q_train,
                        args.fce,
                        num_input_channels,
                        lstm_layers=args.lstm_layers,
                        lstm_input_size=lstm_input_size,
                        unrolling_steps=args.unrolling_steps,
                        device=device)
model.to(device, dtype=torch.double)

###################
# Create datasets #
###################
background = dataset_class('background')
background_taskloader = DataLoader(
    background,
    batch_sampler=NShotTaskSampler(background, episodes_per_epoch,
                                   args.n_train, args.k_train, args.q_train),
    num_workers=4)
예제 #2
0
def run():
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
        lstm_input_size = 1600
    else:
        raise(ValueError('need to make other datasets module'))

    param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \
                f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}'


    #########
    # Model #
    #########
    from few_shot.models import MatchingNetwork
    model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels,
                            lstm_layers=args.lstm_layers,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=args.unrolling_steps,
                            device=device)
    model.to(device, dtype=torch.double)


    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original_sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )
    # Importance sampling
    else:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(train_dataset, model,
            episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train,
            args.num_s_candidates, args.init_temperature, args.is_diversity),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )

    ############
    # Training #
    ############
    print(f'Training Matching Network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()


    callbacks = [
        EvaluateFewShot(
            eval_fn=matching_net_episode,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=eval_dataset_taskloader,
            prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
            fce=args.fce,
            distance=args.distance
        ),
        ModelCheckpoint(
            filepath=PATH + f'/models/matching_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=matching_net_episode,
        fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                            'fce': args.fce, 'distance': args.distance}
    )
예제 #3
0
else:
    raise (ValueError, 'Unsupported dataset')

param_str = f'{globals.DATASET}_n={globals.N_TRAIN}_k={globals.K_TRAIN}_q={globals.Q_TRAIN}_' \
            f'nv={globals.N_TEST}_kv={globals.K_TEST}_qv={globals.Q_TEST}_'\
            f'dist={globals.DISTANCE}_fce={globals.FCE}'

#########
# Model #
#########
from few_shot.models import MatchingNetwork
model = MatchingNetwork(globals.N_TRAIN,
                        globals.K_TRAIN,
                        globals.Q_TRAIN,
                        globals.FCE,
                        num_input_channels,
                        lstm_layers=globals.LSTM_LAYERS,
                        lstm_input_size=lstm_input_size,
                        unrolling_steps=globals.UNROLLING_STEPS,
                        device=device)
model.to(device, dtype=torch.double)

###################
# Create datasets #
###################
background = dataset_class('background')
background_taskloader = DataLoader(background,
                                   batch_sampler=NShotTaskSampler(
                                       background, episodes_per_epoch,
                                       globals.N_TRAIN, globals.K_TRAIN,
                                       globals.Q_TRAIN),
예제 #4
0
    evaluation_taskloader = DataLoader(
        evaluation,
        batch_sampler=NShotTaskSampler(
            evaluation,
            episodes_per_epoch,
            args.n_test,
            args.k_test,
            args.q_test,
            eval_classes=eval_classes
        ),  # why is qtest needed for protonet i think its not rquired for protonet check it
        num_workers=4)
    model = MatchingNetwork(args.n_train,
                            args.k_train,
                            args.q_train,
                            True,
                            num_input_channels,
                            lstm_layers=1,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=2,
                            device=device)
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    eval_fn = matching_net_episode

    callbacks = [
        EvaluateFewShot(eval_fn=matching_net_episode,
                        num_tasks=evaluation_episodes,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path", type=str,
        default="./models/matching_nets/miniImageNet_n=5_k=5_q=10_nv=5_kv=5_qv=10_dist=cosine_fce=True_sampling_method=False_is_diversity=None_epi_candidate=20.pth", 
        help="model path")
    parser.add_argument(
        "--result_path", type=str,
        default="./results/matching_nets/5shot_training_5shot_inference_randomsampling.csv",
        help="Directory for evaluation report result (for experiments)")
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--fce', type=lambda x: x.lower()[0] == 't') 
    parser.add_argument('--distance', default='cosine')
    parser.add_argument('--n_train', default=1, type=int)
    parser.add_argument('--n_test', default=1, type=int)
    parser.add_argument('--k_train', default=5, type=int)
    parser.add_argument('--k_test', default=5, type=int)
    parser.add_argument('--q_train', default=15, type=int)
    parser.add_argument('--q_test', default=15, type=int)
    parser.add_argument('--lstm_layers', default=1, type=int)
    parser.add_argument('--unrolling_steps', default=2, type=int)
    parser.add_argument(
        "--debug", action="store_true", help="set logging level DEBUG",
    )
    args = parser.parse_args()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.DEBUG if args.debug else logging.INFO,
    )

    ###################
    # Create datasets #
    ###################
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 5
        dataset_class = MiniImageNet
        num_input_channels = 3
        lstm_input_size = 1600
    else:
        raise(ValueError('need to make other datasets module'))
    
    test_dataset = dataset_class('test')
    test_dataset_taskloader = DataLoader(
        test_dataset,
        batch_sampler=NShotTaskSampler(test_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
        num_workers=4
    )

    #########
    # Model #
    #########
    model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels,
                            lstm_layers=args.lstm_layers,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=args.unrolling_steps,
                            device=device
    ).to(device, dtype=torch.double)
    
    model.load_state_dict(torch.load(args.model_path), strict=False)
    model.eval()

    #############
    # Inference #
    #############
    logger.info("***** Epochs = %d *****", n_epochs)
    logger.info("***** Num episodes per epoch = %d *****", episodes_per_epoch)

    result_writer = ResultWriter(args.result_path)

    # just argument (function: matching_net_episode)
    prepare_batch = prepare_nshot_task(args.n_test, args.k_test, args.q_test)
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    train_iterator = trange(0, int(n_epochs), desc="Epoch",)
    for i_epoch in train_iterator:
        epoch_iterator = tqdm(test_dataset_taskloader, desc="Iteration",)
        seen = 0
        metric_name = f'test_{args.n_test}-shot_{args.k_test}-way_acc'
        metric = {metric_name: 0.0}
        for _, batch in enumerate(epoch_iterator):
            x, y = prepare_batch(batch)

            loss, y_pred = matching_net_episode(
                model,
                optimiser,
                loss_fn,
                x,
                y,
                n_shot=args.n_test,
                k_way=args.k_test,
                q_queries=args.q_test,
                train=False,
                fce=args.fce,
                distance=args.distance
            )

            seen += y_pred.shape[0]
            metric[metric_name] += categorical_accuracy(y, y_pred) * y_pred.shape[0]

        metric[metric_name] = metric[metric_name] / seen
        
        logger.info("epoch: {},     categorical_accuracy: {}".format(i_epoch, metric[metric_name]))
        result_writer.update(**metric)
예제 #6
0
    lstm_input_size = 1600
else:
    raise(ValueError, 'Unsupported dataset')

assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True
evaluation_episodes = 1000
episodes_per_epoch = 100


#####
# experiments/matching_nets.py
model = MatchingNetwork(globals.N_TRAIN, globals.K_TRAIN, globals.Q_TRAIN, globals.FCE, 
                        num_input_channels,
                        lstm_layers=globals.LSTM_LAYERS,
                        lstm_input_size=lstm_input_size,
                        unrolling_steps=globals.UNROLLING_STEPS,
                        device=device)

model_path = 'models/matching_nets/omniglot_n=1_k=5_q=15_nv=1_kv=5_qv=1_dist=l2_fce=False.pth'
loaded_model = torch.load(model_path)
model.load_state_dict(loaded_model)


model.to(device)
model.double()
# print("###########################")
# for param in model.parameters():
#     print(param.data)
# print("###########################")