def run(): episodes_per_epoch = 600 if args.dataset == 'miniImageNet': n_epochs = 500 dataset_class = MiniImageNet num_input_channels = 3 lstm_input_size = 1600 else: raise(ValueError('need to make other datasets module')) param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \ f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \ f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \ f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}' ######### # Model # ######### from few_shot.models import MatchingNetwork model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels, lstm_layers=args.lstm_layers, lstm_input_size=lstm_input_size, unrolling_steps=args.unrolling_steps, device=device) model.to(device, dtype=torch.double) ################### # Create datasets # ################### train_dataset = dataset_class('train') eval_dataset = dataset_class('eval') # Original_sampling if not args.sampling_method: train_dataset_taskloader = DataLoader( train_dataset, batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train), num_workers=4 ) eval_dataset_taskloader = DataLoader( eval_dataset, batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test), num_workers=4 ) # Importance sampling else: train_dataset_taskloader = DataLoader( train_dataset, batch_sampler=ImportanceSampler(train_dataset, model, episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train, args.num_s_candidates, args.init_temperature, args.is_diversity), num_workers=4 ) eval_dataset_taskloader = DataLoader( eval_dataset, batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test), num_workers=4 ) ############ # Training # ############ print(f'Training Matching Network on {args.dataset}...') optimiser = Adam(model.parameters(), lr=1e-3) loss_fn = torch.nn.NLLLoss().cuda() callbacks = [ EvaluateFewShot( eval_fn=matching_net_episode, n_shot=args.n_test, k_way=args.k_test, q_queries=args.q_test, taskloader=eval_dataset_taskloader, prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test), fce=args.fce, distance=args.distance ), ModelCheckpoint( filepath=PATH + f'/models/matching_nets/{param_str}.pth', monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc', save_best_only=True, ), ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'), CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'), ] fit( model, optimiser, loss_fn, epochs=n_epochs, dataloader=train_dataset_taskloader, prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train), callbacks=callbacks, metrics=['categorical_accuracy'], fit_function=matching_net_episode, fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True, 'fce': args.fce, 'distance': args.distance} )
f'dist={args.distance}_fce={args.fce}' ######### # Model # ######### from few_shot.models import MatchingNetwork model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels, lstm_layers=args.lstm_layers, lstm_input_size=lstm_input_size, unrolling_steps=args.unrolling_steps, device=device) model.to(device, dtype=torch.double) ################### # Create datasets # ################### background = dataset_class('background') background_taskloader = DataLoader( background, batch_sampler=NShotTaskSampler(background, episodes_per_epoch, args.n_train, args.k_train, args.q_train), num_workers=4) evaluation = dataset_class('evaluation') evaluation_taskloader = DataLoader(evaluation, batch_sampler=NShotTaskSampler( evaluation, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
##### # experiments/matching_nets.py model = MatchingNetwork(globals.N_TRAIN, globals.K_TRAIN, globals.Q_TRAIN, globals.FCE, num_input_channels, lstm_layers=globals.LSTM_LAYERS, lstm_input_size=lstm_input_size, unrolling_steps=globals.UNROLLING_STEPS, device=device) model_path = 'models/matching_nets/omniglot_n=1_k=5_q=15_nv=1_kv=5_qv=1_dist=l2_fce=False.pth' loaded_model = torch.load(model_path) model.load_state_dict(loaded_model) model.to(device) model.double() # print("###########################") # for param in model.parameters(): # print(param.data) # print("###########################") evaluation = dataset_class('evaluation') dataloader = DataLoader( evaluation, batch_sampler=NShotTaskSampler(evaluation, episodes_per_epoch, n=globals.N_TEST, k=globals.K_TEST, q=globals.Q_TEST), num_workers=4 ) prepare_batch = prepare_nshot_task(globals.N_TEST, globals.K_TEST, globals.Q_TEST)