def main(): # Step 1: init dataloader print("init data loader") args_data = {} args_data['x_dim'] = x_dim args_data['ratio'] = ratio args_data['seed'] = seed if dataset == 'mini': loader_train = dataset_mini(n_examples, n_episodes, 'train', args_data) loader_val = dataset_mini(n_examples, n_episodes, 'val', args_data) elif dataset == 'tiered': loader_train = dataset_tiered(n_examples, n_episodes, 'train', args_data) loader_val = dataset_tiered(n_examples, n_episodes, 'val', args_data) if not pkl: loader_train.load_data() loader_val.load_data() else: loader_train.load_data_pkl() loader_val.load_data_pkl() # Step 2: init neural networks print("init neural networks") # construct the model model = models.LabelPropagation(args) model.cuda(0) # optimizer model_optim = torch.optim.Adam(model.parameters(), lr=lr) model_scheduler = StepLR(model_optim, step_size=step_size, gamma=gamma) # load the saved model if iters > 0: model.load_state_dict( torch.load('checkpoints/%s/models/%s_%d_model.t7' % (args['exp_name'], alg, iters))) print('Loading Parameters from %s: %d' % (args['exp_name'], iters)) # Step 3: Train and validation print("Training...") best_acc = 0.0 best_loss = np.inf wait = 0 for ep in range(iters, n_epochs): loss_tr = [] ce_list = [] acc_tr = [] loss_val = [] acc_val = [] for epi in tqdm(range(n_episodes), desc='train_epoc:{}'.format(ep)): model_scheduler.step(ep * n_episodes + epi) # set train mode model.train() # sample data for next batch support, s_labels, query, q_labels, unlabel = loader_train.next_data( n_way, n_shot, n_query) support = np.reshape(support, (support.shape[0] * support.shape[1], ) + support.shape[2:]) support = torch.from_numpy(np.transpose(support, (0, 3, 1, 2))) query = np.reshape(query, (query.shape[0] * query.shape[1], ) + query.shape[2:]) query = torch.from_numpy(np.transpose(query, (0, 3, 1, 2))) s_labels = torch.from_numpy(np.reshape(s_labels, (-1, ))) q_labels = torch.from_numpy(np.reshape(q_labels, (-1, ))) s_labels = s_labels.type(torch.LongTensor) q_labels = q_labels.type(torch.LongTensor) s_onehot = torch.zeros(n_way * n_shot, n_way).scatter_(1, s_labels.view(-1, 1), 1) q_onehot = torch.zeros(n_way * n_query, n_way).scatter_(1, q_labels.view(-1, 1), 1) inputs = [ support.cuda(0), s_onehot.cuda(0), query.cuda(0), q_onehot.cuda(0) ] loss, acc = model(inputs) loss_tr.append(loss.item()) acc_tr.append(acc.item()) model.zero_grad() loss.backward() #torch.nn.utils.clip_grad_norm(model.parameters(), 4.0) model_optim.step() for epi in tqdm(range(n_episodes), desc='val epoc:{}'.format(ep)): # set eval mode model.eval() # sample data for next batch support, s_labels, query, q_labels, unlabel = loader_val.next_data( n_test_way, n_test_shot, n_test_query) support = np.reshape(support, (support.shape[0] * support.shape[1], ) + support.shape[2:]) support = torch.from_numpy(np.transpose(support, (0, 3, 1, 2))) query = np.reshape(query, (query.shape[0] * query.shape[1], ) + query.shape[2:]) query = torch.from_numpy(np.transpose(query, (0, 3, 1, 2))) s_labels = torch.from_numpy(np.reshape(s_labels, (-1, ))) q_labels = torch.from_numpy(np.reshape(q_labels, (-1, ))) s_labels = s_labels.type(torch.LongTensor) q_labels = q_labels.type(torch.LongTensor) s_onehot = torch.zeros(n_test_way * n_test_shot, n_test_way).scatter_( 1, s_labels.view(-1, 1), 1) q_onehot = torch.zeros(n_test_way * n_test_query, n_test_way).scatter_( 1, q_labels.view(-1, 1), 1) with torch.no_grad(): inputs = [ support.cuda(0), s_onehot.cuda(0), query.cuda(0), q_onehot.cuda(0) ] loss, acc = model(inputs) loss_val.append(loss.item()) acc_val.append(acc.item()) print( 'epoch:{}, loss_tr:{:.5f}, acc_tr:{:.5f}, loss_val:{:.5f}, acc_val:{:.5f}' .format(ep, np.mean(loss_tr), np.mean(acc_tr), np.mean(loss_val), np.mean(acc_val))) # Model Save and Stop Criterion cond1 = (np.mean(acc_val) > best_acc) cond2 = (np.mean(loss_val) < best_loss) if cond1 or cond2: best_acc = np.mean(acc_val) best_loss = np.mean(loss_val) print('best val loss:{:.5f}, acc:{:.5f}'.format( best_loss, best_acc)) # save model torch.save( model.state_dict(), 'checkpoints/%s/models/%s_%d_model.t7' % (args['exp_name'], alg, (ep + 1) * n_episodes)) f = open('checkpoints/' + args['exp_name'] + '/log.txt', 'a') print('{} {:.5f} {:.5f}'.format((ep + 1) * n_episodes, best_loss, best_acc), file=f) f.close() wait = 0 else: wait += 1 if ep % 100 == 0: torch.save( model.state_dict(), 'checkpoints/%s/models/%s_%d_model.t7' % (args['exp_name'], alg, (ep + 1) * n_episodes)) f = open('checkpoints/' + args['exp_name'] + '/log.txt', 'a') print('{} {:.5f} {:.5f}'.format((ep + 1) * n_episodes, np.mean(loss_val), np.mean(acc_val)), file=f) f.close() if wait > patience and ep > n_epochs: break
def main(): # init dataloader print("init data loader") args_data = {} args_data['x_dim'] = '84,84,3' args_data['ratio'] = 1.0 args_data['seed'] = seed print('seed:',seed) if dataset=='mini': loader_test = dataset_mini(n_examples, n_episodes, 'test', args_data) elif dataset=='tiered': loader_test = dataset_tiered(n_examples, n_episodes, 'test', args_data) if not pkl: loader_test.load_data() else: loader_test.load_data_pkl() # Step 2: init neural networks print("init neural networks") # construct the model model = models.LabelPropagation(args) model.cuda(0) # load the saved model if iters>0: model.load_state_dict(torch.load('checkpoints/%s/models/%s_%d_model.t7' % (args['exp_name'], alg, iters))) else: a = torch.load('checkpoints/%s/models/%s_model_best_%s.t7' %(args['exp_name'], alg, part) ) #a[Relation] model.load_state_dict(torch.load('checkpoints/%s/models/%s_model_best_%s.t7' %(args['exp_name'], alg, part) )) print('Loading Parameters from %s' %(args['exp_name'])) # Step 3: build graph print("Testing...") all_acc = [] all_std = [] all_ci95 = [] ce_list = [] for rep in range(repeat): list_acc = [] for epi in tqdm(range(n_test_episodes), desc='test:{}'.format(rep)): model.eval() # sample data for next batch support, s_labels, query, q_labels, unlabel = loader_test.next_data(n_test_way, n_test_shot, n_test_query, train=False) support = np.reshape(support, (support.shape[0]*support.shape[1],)+support.shape[2:]) support = torch.from_numpy(np.transpose(support, (0,3,1,2))) query = np.reshape(query, (query.shape[0]*query.shape[1],)+query.shape[2:]) query = torch.from_numpy(np.transpose(query, (0,3,1,2))) s_labels = torch.from_numpy(np.reshape(s_labels,(-1,))) q_labels = torch.from_numpy(np.reshape(q_labels,(-1,))) s_labels = s_labels.type(torch.LongTensor) q_labels = q_labels.type(torch.LongTensor) s_onehot = torch.zeros(n_test_way*n_test_shot, n_test_way).scatter_(1, s_labels.view(-1,1), 1) q_onehot = torch.zeros(n_test_way*n_test_query, n_test_way).scatter_(1, q_labels.view(-1,1), 1) with torch.no_grad(): inputs = [support.cuda(0), s_onehot.cuda(0), query.cuda(0), q_onehot.cuda(0)] loss, acc = model(inputs) list_acc.append(acc.item()) mean_acc = np.mean(list_acc) std_acc = np.std(list_acc) ci95 = 1.96*std_acc/np.sqrt(n_test_episodes) m,ci = mean_confidence_interval(list_acc) print('label, acc:{:.4f},std:{:.4f},ci95:{:.4f},ci:{:.4f}'.format(mean_acc, std_acc, ci95, ci)) all_acc.append(mean_acc) all_std.append(std_acc) all_ci95.append(ci95) ind = np.argmax(all_acc) print('Max acc:{:.5f}, std:{:.5f}, ci95: {:.5f}'.format(all_acc[ind], all_std[ind], all_ci95[ind])) print('Avg over {} runs: mean:{:.5f}, std:{:.5f}, ci95: {:.5f}'.format(repeat,np.mean(all_acc),np.mean(all_std),np.mean(all_ci95)))