else: raise (ValueError, 'Unsupported dataset') param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \ f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_'\ f'dist={args.distance}_fce={args.fce}' ######### # Model # ######### from few_shot.models import MatchingNetwork model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels, lstm_layers=args.lstm_layers, lstm_input_size=lstm_input_size, unrolling_steps=args.unrolling_steps, device=device) model.to(device, dtype=torch.double) ################### # Create datasets # ################### background = dataset_class('background') background_taskloader = DataLoader( background, batch_sampler=NShotTaskSampler(background, episodes_per_epoch, args.n_train, args.k_train, args.q_train), num_workers=4)
def run(): episodes_per_epoch = 600 if args.dataset == 'miniImageNet': n_epochs = 500 dataset_class = MiniImageNet num_input_channels = 3 lstm_input_size = 1600 else: raise(ValueError('need to make other datasets module')) param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \ f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \ f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \ f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}' ######### # Model # ######### from few_shot.models import MatchingNetwork model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels, lstm_layers=args.lstm_layers, lstm_input_size=lstm_input_size, unrolling_steps=args.unrolling_steps, device=device) model.to(device, dtype=torch.double) ################### # Create datasets # ################### train_dataset = dataset_class('train') eval_dataset = dataset_class('eval') # Original_sampling if not args.sampling_method: train_dataset_taskloader = DataLoader( train_dataset, batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train), num_workers=4 ) eval_dataset_taskloader = DataLoader( eval_dataset, batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test), num_workers=4 ) # Importance sampling else: train_dataset_taskloader = DataLoader( train_dataset, batch_sampler=ImportanceSampler(train_dataset, model, episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train, args.num_s_candidates, args.init_temperature, args.is_diversity), num_workers=4 ) eval_dataset_taskloader = DataLoader( eval_dataset, batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test), num_workers=4 ) ############ # Training # ############ print(f'Training Matching Network on {args.dataset}...') optimiser = Adam(model.parameters(), lr=1e-3) loss_fn = torch.nn.NLLLoss().cuda() callbacks = [ EvaluateFewShot( eval_fn=matching_net_episode, n_shot=args.n_test, k_way=args.k_test, q_queries=args.q_test, taskloader=eval_dataset_taskloader, prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test), fce=args.fce, distance=args.distance ), ModelCheckpoint( filepath=PATH + f'/models/matching_nets/{param_str}.pth', monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc', save_best_only=True, ), ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'), CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'), ] fit( model, optimiser, loss_fn, epochs=n_epochs, dataloader=train_dataset_taskloader, prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train), callbacks=callbacks, metrics=['categorical_accuracy'], fit_function=matching_net_episode, fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True, 'fce': args.fce, 'distance': args.distance} )
else: raise (ValueError, 'Unsupported dataset') param_str = f'{globals.DATASET}_n={globals.N_TRAIN}_k={globals.K_TRAIN}_q={globals.Q_TRAIN}_' \ f'nv={globals.N_TEST}_kv={globals.K_TEST}_qv={globals.Q_TEST}_'\ f'dist={globals.DISTANCE}_fce={globals.FCE}' ######### # Model # ######### from few_shot.models import MatchingNetwork model = MatchingNetwork(globals.N_TRAIN, globals.K_TRAIN, globals.Q_TRAIN, globals.FCE, num_input_channels, lstm_layers=globals.LSTM_LAYERS, lstm_input_size=lstm_input_size, unrolling_steps=globals.UNROLLING_STEPS, device=device) model.to(device, dtype=torch.double) ################### # Create datasets # ################### background = dataset_class('background') background_taskloader = DataLoader(background, batch_sampler=NShotTaskSampler( background, episodes_per_epoch, globals.N_TRAIN, globals.K_TRAIN, globals.Q_TRAIN),
evaluation_taskloader = DataLoader( evaluation, batch_sampler=NShotTaskSampler( evaluation, episodes_per_epoch, args.n_test, args.k_test, args.q_test, eval_classes=eval_classes ), # why is qtest needed for protonet i think its not rquired for protonet check it num_workers=4) model = MatchingNetwork(args.n_train, args.k_train, args.q_train, True, num_input_channels, lstm_layers=1, lstm_input_size=lstm_input_size, unrolling_steps=2, device=device) optimiser = Adam(model.parameters(), lr=1e-3) loss_fn = torch.nn.NLLLoss().cuda() eval_fn = matching_net_episode callbacks = [ EvaluateFewShot(eval_fn=matching_net_episode, num_tasks=evaluation_episodes, n_shot=args.n_test, k_way=args.k_test, q_queries=args.q_test,
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model_path", type=str, default="./models/matching_nets/miniImageNet_n=5_k=5_q=10_nv=5_kv=5_qv=10_dist=cosine_fce=True_sampling_method=False_is_diversity=None_epi_candidate=20.pth", help="model path") parser.add_argument( "--result_path", type=str, default="./results/matching_nets/5shot_training_5shot_inference_randomsampling.csv", help="Directory for evaluation report result (for experiments)") parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--fce', type=lambda x: x.lower()[0] == 't') parser.add_argument('--distance', default='cosine') parser.add_argument('--n_train', default=1, type=int) parser.add_argument('--n_test', default=1, type=int) parser.add_argument('--k_train', default=5, type=int) parser.add_argument('--k_test', default=5, type=int) parser.add_argument('--q_train', default=15, type=int) parser.add_argument('--q_test', default=15, type=int) parser.add_argument('--lstm_layers', default=1, type=int) parser.add_argument('--unrolling_steps', default=2, type=int) parser.add_argument( "--debug", action="store_true", help="set logging level DEBUG", ) args = parser.parse_args() # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG if args.debug else logging.INFO, ) ################### # Create datasets # ################### episodes_per_epoch = 600 if args.dataset == 'miniImageNet': n_epochs = 5 dataset_class = MiniImageNet num_input_channels = 3 lstm_input_size = 1600 else: raise(ValueError('need to make other datasets module')) test_dataset = dataset_class('test') test_dataset_taskloader = DataLoader( test_dataset, batch_sampler=NShotTaskSampler(test_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test), num_workers=4 ) ######### # Model # ######### model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels, lstm_layers=args.lstm_layers, lstm_input_size=lstm_input_size, unrolling_steps=args.unrolling_steps, device=device ).to(device, dtype=torch.double) model.load_state_dict(torch.load(args.model_path), strict=False) model.eval() ############# # Inference # ############# logger.info("***** Epochs = %d *****", n_epochs) logger.info("***** Num episodes per epoch = %d *****", episodes_per_epoch) result_writer = ResultWriter(args.result_path) # just argument (function: matching_net_episode) prepare_batch = prepare_nshot_task(args.n_test, args.k_test, args.q_test) optimiser = Adam(model.parameters(), lr=1e-3) loss_fn = torch.nn.NLLLoss().cuda() train_iterator = trange(0, int(n_epochs), desc="Epoch",) for i_epoch in train_iterator: epoch_iterator = tqdm(test_dataset_taskloader, desc="Iteration",) seen = 0 metric_name = f'test_{args.n_test}-shot_{args.k_test}-way_acc' metric = {metric_name: 0.0} for _, batch in enumerate(epoch_iterator): x, y = prepare_batch(batch) loss, y_pred = matching_net_episode( model, optimiser, loss_fn, x, y, n_shot=args.n_test, k_way=args.k_test, q_queries=args.q_test, train=False, fce=args.fce, distance=args.distance ) seen += y_pred.shape[0] metric[metric_name] += categorical_accuracy(y, y_pred) * y_pred.shape[0] metric[metric_name] = metric[metric_name] / seen logger.info("epoch: {}, categorical_accuracy: {}".format(i_epoch, metric[metric_name])) result_writer.update(**metric)
lstm_input_size = 1600 else: raise(ValueError, 'Unsupported dataset') assert torch.cuda.is_available() device = torch.device('cuda') torch.backends.cudnn.benchmark = True evaluation_episodes = 1000 episodes_per_epoch = 100 ##### # experiments/matching_nets.py model = MatchingNetwork(globals.N_TRAIN, globals.K_TRAIN, globals.Q_TRAIN, globals.FCE, num_input_channels, lstm_layers=globals.LSTM_LAYERS, lstm_input_size=lstm_input_size, unrolling_steps=globals.UNROLLING_STEPS, device=device) model_path = 'models/matching_nets/omniglot_n=1_k=5_q=15_nv=1_kv=5_qv=1_dist=l2_fce=False.pth' loaded_model = torch.load(model_path) model.load_state_dict(loaded_model) model.to(device) model.double() # print("###########################") # for param in model.parameters(): # print(param.data) # print("###########################")