def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) test_result = {} best_acc = 0.0 maml = Meta(args, Param.config).to(Param.device) maml = torch.nn.DataParallel(maml) opt = optim.Adam(maml.parameters(), lr=args.meta_lr) #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True) testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True) train_data = inf_get(trainloader) test_data = inf_get(testloader) for epoch in range(args.epoch): support_x, support_y, meta_x, meta_y = train_data.__next__() support_x, support_y, meta_x, meta_y = support_x.to(Param.device), support_y.to(Param.device), meta_x.to(Param.device), meta_y.to(Param.device) meta_loss = maml(support_x, support_y, meta_x, meta_y).mean() opt.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value = 10.0) opt.step() plot.plot('meta_loss', meta_loss.item()) if(epoch % 2000 == 999): ans = None maml_clone = deepcopy(maml) for _ in range(600): support_x, support_y, qx, qy = test_data.__next__() support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device) temp = maml_clone(support_x, support_y, qx, qy, meta_train = False) if(ans is None): ans = temp else: ans = torch.cat([ans, temp], dim = 0) ans = ans.mean(dim = 0).tolist() test_result[epoch] = ans if (ans[-1] > best_acc): best_acc = ans[-1] torch.save(maml.state_dict(), Param.out_path + 'net_'+ str(epoch) + '_' + str(best_acc) + '.pkl') del maml_clone print(str(epoch) + ': '+str(ans)) with open(Param.out_path+'test.json','w') as f: json.dump(test_result,f) if (epoch < 5) or (epoch % 100 == 99): plot.flush() plot.tick()
def main(): print(args) device = torch.device('cuda') maml = Meta(args).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) trainset = Gen(args.task_num, args.k_spt, args.k_qry) testset = Gen(args.task_num, args.k_spt, args.k_qry * 10) for epoch in range(args.epoch): ind = [i for i in range(trainset.xs.shape[0])] np.random.shuffle(ind) xs, ys = torch.Tensor(trainset.xs[ind]).to(device), torch.Tensor( trainset.ys[ind]).to(device) xq, yq = torch.Tensor(trainset.xq[ind]).to(device), torch.Tensor( trainset.yq[ind]).to(device) maml.train() loss = maml(xs, ys, xq, yq, epoch) print('Epoch: {} Initial loss: {} Train loss: {}'.format( epoch, loss[0] / args.task_num, loss[-1] / args.task_num)) if (epoch + 1) % 50 == 0: print("Evaling the model...") torch.save(maml.state_dict(), 'save.pt') # del(maml) # maml = Meta(args).to(device) # maml.load_state_dict(torch.load('save.pt')) maml.eval() i = random.randint(0, testset.xs.shape[0] - 1) xs, ys = torch.Tensor(testset.xs[i]).to(device), torch.Tensor( testset.ys[i]).to(device) xq, yq = torch.Tensor(testset.xq[i]).to(device), torch.Tensor( testset.yq[i]).to(device) losses, losses_q, logits_q, _ = maml.finetunning(xs, ys, xq, yq) print('Epoch: {} Initial loss: {} Test loss: {}'.format( epoch, losses_q[0], losses_q[-1]))
query_y) optimizer.zero_grad() euc_loss.backward() optimizer.step() if step % 100 == 0: val_acc = eval(db_val, meta) tb.add_scalar('accuracy', val_acc) print('accuracy:', val_acc, 'best accuracy:', best_val_acc) # update learning rate per epoch # scheduler.step(total_val_loss) if val_acc > best_val_acc: best_val_acc = val_acc torch.save(meta.state_dict(), mdl_file) print('saved to checkpoint:', mdl_file) if val_acc > 0.4: print('now conduct test performance...') mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=1, batchsz=200, resize=resize) db_test = DataLoader(mini_test, batchsz, shuffle=True) test_acc, _ = eval(db_test, meta) print('>>>>>>>>>>>> test accuracy:', test_acc, '<<<<<<<<<<<<<<')
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [ ('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) ckpt_dir = "./model/" for epoch in range(args.epoch//10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) # save checkpoints os.makedirs(ckpt_dir, exist_ok=True) print('Saving the model as a checkpoint...') torch.save({'epoch': epoch, 'Steps': step, 'model': maml.state_dict()}, os.path.join(ckpt_dir, 'checkpoint.pth'))
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('/home/tesca/data/miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('/home/tesca/data/miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_spt) + '_stepsz' + str(args.update_lr) + '_epoch' for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs, _ = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) torch.save(maml.state_dict(), save_path + str(step) + "_og.pt")
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_val = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=600, resize=args.imgsz) mini_test = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=600, resize=args.imgsz) best_acc = 0.0 if not os.path.exists('ckpt/{}'.format(args.exp)): os.mkdir('ckpt/{}'.format(args.exp)) for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 500 == 0: print('step:', step, '\ttraining acc:', accs) if step % 1000 == 0: # evaluation db_val = DataLoader(mini_val, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_val = [] for x_spt, y_spt, x_qry, y_qry in db_val: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_val.append(accs) mean, std, ci95 = cal_conf(np.array(accs_all_val)) print('Val acc:{}, std:{}. ci95:{}'.format( mean[-1], std[-1], ci95[-1])) if mean[-1] > best_acc or step % 5000 == 0: best_acc = mean[-1] torch.save( maml.state_dict(), 'ckpt/{}/model_e{}s{}_{:.4f}.pkl'.format( args.exp, epoch, step, best_acc)) with open('ckpt/' + args.exp + '/val.txt', 'a') as f: print( 'val epoch {}, step {}: acc_val:{:.4f}, ci95:{:.4f}' .format(epoch, step, best_acc, ci95[-1]), file=f) ## Test db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) mean, std, ci95 = cal_conf(np.array(accs_all_test)) print('Test acc:{}, std:{}, ci95:{}'.format( mean[-1], std[-1], ci95[-1])) with open('ckpt/' + args.exp + '/test.txt', 'a') as f: print( 'test epoch {}, step {}: acc_test:{:.4f}, ci95:{:.4f}' .format(epoch, step, mean[-1], ci95[-1]), file=f)
def main(args): step = args.step set_seed(args.seed) adj, features, labels = load_citation(args.dataset, args.normalization) features = sgc_precompute(features, adj, args.degree) if args.dataset == 'citeseer': node_num = 3327 class_label = [0, 1, 2, 3, 4, 5] combination = list(combinations(class_label, 2)) elif args.dataset == 'cora': node_num = 2708 class_label = [0, 1, 2, 3, 4, 5, 6] combination = list(combinations(class_label, 2)) config = [('linear', [args.hidden, features.size(1)]), ('linear', [args.n_way, args.hidden])] device = torch.device('cuda') for i in range(len(combination)): print("Cross Validation: {}".format((i + 1))) maml = Meta(args, config).to(device) test_label = list(combination[i]) train_label = [n for n in class_label if n not in test_label] print('Cross Validation {} Train_Label_List: {} '.format( i + 1, train_label)) print('Cross Validation {} Test_Label_List: {} '.format( i + 1, test_label)) for j in range(args.epoch): x_spt, y_spt, x_qry, y_qry = sgc_data_generator( features, labels, node_num, train_label, args.task_num, args.n_way, args.k_spt, args.k_qry) accs = maml.forward(x_spt, y_spt, x_qry, y_qry) print('Step:', j, '\tMeta_Training_Accuracy:', accs) if j % 100 == 0: torch.save(maml.state_dict(), 'maml.pkl') meta_test_acc = [] for k in range(step): model_meta_trained = Meta(args, config).to(device) model_meta_trained.load_state_dict(torch.load('maml.pkl')) model_meta_trained.eval() x_spt, y_spt, x_qry, y_qry = sgc_data_generator( features, labels, node_num, test_label, args.task_num, args.n_way, args.k_spt, args.k_qry) accs = model_meta_trained.forward(x_spt, y_spt, x_qry, y_qry) meta_test_acc.append(accs) if args.dataset == 'citeseer': with open('citeseer.txt', 'a') as f: f.write( 'Cross Validation:{}, Step: {}, Meta-Test_Accuracy: {}' .format( i + 1, j, np.array(meta_test_acc).mean(axis=0).astype( np.float16))) f.write('\n') elif args.dataset == 'cora': with open('cora.txt', 'a') as f: f.write( 'Cross Validation:{}, Step: {}, Meta-Test_Accuracy: {}' .format( i + 1, j, np.array(meta_test_acc).mean(axis=0).astype( np.float16))) f.write('\n')
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) test_result = {} best_acc = 0.0 maml = Meta(args, Param.config).to(Param.device) maml = torch.nn.DataParallel(maml) opt = optim.Adam(maml.parameters(), lr=args.meta_lr) #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # Load pkl dataset args_data = {} args_data['x_dim'] = "84,84,3" args_data['ratio'] = 1.0 args_data['seed'] = 222 loader_train = dataset_mini(600, 100, 'train', args_data) #loader_val = dataset_mini(600, 100, 'val', args_data) loader_test = dataset_mini(600, 100, 'test', args_data) loader_train.load_data_pkl() #loader_val.load_data_pkl() loader_test.load_data_pkl() for epoch in range(args.epoch): support_x, support_y, meta_x, meta_y = get_data(loader_train) support_x, support_y, meta_x, meta_y = support_x.to( Param.device), support_y.to(Param.device), meta_x.to( Param.device), meta_y.to(Param.device) meta_loss = maml(support_x, support_y, meta_x, meta_y).mean() opt.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0) opt.step() plot.plot('meta_loss', meta_loss.item()) if (epoch % 2000 == 999): ans = None maml_clone = deepcopy(maml) for _ in range(600): support_x, support_y, qx, qy = get_data(loader_test) support_x, support_y, qx, qy = support_x.to( Param.device), support_y.to(Param.device), qx.to( Param.device), qy.to(Param.device) temp = maml_clone(support_x, support_y, qx, qy, meta_train=False) if (ans is None): ans = temp else: ans = torch.cat([ans, temp], dim=0) ans = ans.mean(dim=0).tolist() test_result[epoch] = ans if (ans[-1] > best_acc): best_acc = ans[-1] torch.save( maml.state_dict(), Param.out_path + 'net_' + str(epoch) + '_' + str(best_acc) + '.pkl') del maml_clone print(str(epoch) + ': ' + str(ans)) with open(Param.out_path + 'test.json', 'w') as f: json.dump(test_result, f) if (epoch < 5) or (epoch % 100 == 99): plot.flush() plot.tick()
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) #np.random.seed(222) config = [('conv2d', [32, 3, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) root = '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet' trainset = MiniImagenet(root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True) testloader = DataLoader(testset, batch_size=1, shuffle=True, num_workers=1, worker_init_fn=worker_init_fn, drop_last=True) train_data = inf_get(trainloader) test_data = inf_get(testloader) best_acc = 0.0 if not os.path.exists('ckpt/{}'.format(args.exp)): os.mkdir('ckpt/{}'.format(args.exp)) for epoch in range(args.epoch): np.random.seed() x_spt, y_spt, x_qry, y_qry = train_data.__next__() x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if epoch % 100 == 0: print('epoch:', epoch, '\ttraining acc:', accs) if epoch % 2500 == 0: # evaluation # save checkpoint torch.save(maml.state_dict(), 'ckpt/{}/model_{}.pkl'.format(args.exp, epoch)) accs_all_test = [] for _ in range(600): x_spt, y_spt, x_qry, y_qry = test_data.__next__() x_spt, y_spt, x_qry, y_qry = x_spt.squeeze().to( device), y_spt.squeeze().to(device), x_qry.squeeze().to( device), y_qry.squeeze().to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) with open('ckpt/' + args.exp + '/test.txt', 'a') as f: print('test epoch {}: acc:{:.4f}'.format(epoch, accs[-1]), file=f)
def main(): saver = Saver(args) # set log log_format = '%(asctime)s %(message)s' logging.basicConfig(level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p', filename=os.path.join(saver.experiment_dir, 'log.txt'), filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) logging.getLogger().addHandler(console) if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True cudnn.enabled = True saver.create_exp_dir(scripts_to_save=glob.glob('*.py') + glob.glob('*.sh') + glob.glob('*.yml')) saver.save_experiment_config() summary = TensorboardSummary(saver.experiment_dir) writer = summary.create_summary() best_pred = 0 logging.info(args) device = torch.device('cuda') criterion = nn.CrossEntropyLoss() criterion = criterion.to(device) maml = Meta(args, criterion).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) logging.info(maml) logging.info('Total trainable tensors: {}'.format(num)) # batch_size here means total episode number mini = MiniImagenet(args.data_path, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batch_size=args.batch_size, resize=args.img_size) mini_valid = MiniImagenet(args.data_path, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batch_size=args.test_batch_size, resize=args.img_size) train_loader = DataLoader(mini, args.meta_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) valid_loader = DataLoader(mini_valid, args.meta_test_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) for epoch in range(args.epoch): # fetch batch_size num of episode each time logging.info('--------- Epoch: {} ----------'.format(epoch)) train_accs_theta, train_accs_w = meta_train(train_loader, maml, device, epoch, writer) logging.info( '[Epoch: {}]\t Train acc_theta: {}\t Train acc_w: {}'.format( epoch, train_accs_theta, train_accs_w)) test_accs_theta, test_accs_w = meta_test(valid_loader, maml, device, epoch, writer) logging.info( '[Epoch: {}]\t Test acc_theta: {}\t Test acc_w: {}'.format( epoch, test_accs_theta, test_accs_w)) genotype = maml.model.genotype() logging.info('genotype = %s', genotype) logging.info(F.softmax(maml.model.alphas_normal, dim=-1)) logging.info(F.softmax(maml.model.alphas_reduce, dim=-1)) # Save the best meta model. new_pred = test_accs_w[-1] if new_pred > best_pred: is_best = True best_pred = new_pred else: is_best = False saver.save_checkpoint( { 'epoch': epoch, 'state_dict_w': maml.module.state_dict() if isinstance(maml, nn.DataParallel) else maml.state_dict(), 'state_dict_theta': maml.model.arch_parameters(), 'best_pred': best_pred, }, is_best)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [] if args.arch == "Unet": for block in range(args.NUM_DOWN_CONV): out_channels = (2**block) * args.HIDDEN_DIM if (block == 0): config += [( 'conv2d', [out_channels, args.imgc, 3, 3, 1, 1] ) # out_c, in_c, k_h, k_w, stride, padding, also only conv, without bias ] else: config += [ ('conv2d', [out_channels, out_channels // 2, 3, 3, 1, 1]), # out_c, in_c, k_h, k_w, stride, padding ] config += [ ('leakyrelu', [0.2, False]), # alpha; if true then executes relu in place ('bn', [out_channels]) ] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('max_pool2d', [2, 2, 0])] # kernel_size, stride, padding for block in range(args.NUM_DOWN_CONV - 1): out_channels = (2**(args.NUM_DOWN_CONV - block - 2)) * args.HIDDEN_DIM in_channels = out_channels * 3 config += [('upsample', [2])] config += [('conv2d', [out_channels, in_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [ ('conv2d_b', [args.outc, args.HIDDEN_DIM, 3, 3, 1, 1]) ] # all the conv2d before are without bias, and this conv_b is with bias else: raise ("architectures other than Unet hasn't been added!!") device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) # print(maml) for name, param in maml.named_parameters(): print(name, param.size()) print('Total trainable tensors:', num) SUMMARY_INTERVAL = 5 TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5 ITER_SAVE_INTERVAL = 300 EPOCH_SAVE_INTERVAL = 5 model_path = "/scratch/users/chenkaim/pytorch-models/pytorch_" + args.model_name + "_k_shot_" + str( args.k_spt) + "_task_num_" + str(args.task_num) + "_meta_lr_" + str( args.meta_lr) + "_inner_lr_" + str( args.update_lr) + "_num_inner_updates_" + str(args.update_step) if not os.path.isdir(model_path): os.mkdir(model_path) start_epoch = 0 if (args.continue_train): print("Restoring weights from ", model_path + "/epoch_" + str(args.continue_epoch) + ".pt") checkpoint = torch.load(model_path + "/epoch_" + str(args.continue_epoch) + ".pt") print(checkpoint.keys()) print(checkpoint.items()) maml.load_state_dict(checkpoint['state_dict']) maml.lr_scheduler.load_state_dict(checkpoint['scheduler']) maml.meta_optim.load_state_dict(checkpoint['optimizer']) start_epoch = args.continue_epoch db = RCWA_data_loader(batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz, data_folder=args.data_folder) for step in range(start_epoch, args.epoch): print("epoch: ", step) if step % EPOCH_SAVE_INTERVAL == 0: torch.save(maml.state_dict(), model_path + "/epoch_" + str(step) + ".pt") for itr in range( int(0.7 * db.total_data_samples / ((args.k_spt + args.k_qry) * args.task_num))): x_spt, y_spt, x_qry, y_qry = db.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs, loss_q, ave_trans, min_trans = maml.transference( x_spt, y_spt, x_qry, y_qry) if itr % SUMMARY_INTERVAL == 0: print_str = "Iteration %d: pre-inner-loop train accuracy: %.5f, post-iner-loop test accuracy: %.5f, train_loss: %.5f, ave_trans: %.2f, min_trans: %.2f" % ( itr, accs[0], accs[-1], loss_q, ave_trans, min_trans) print(print_str) if itr % TEST_PRINT_INTERVAL == 0: accs = [] for _ in range(10): # test x_spt, y_spt, x_qry, y_qry = db.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print( 'Meta-validation pre-inner-loop train accuracy: %.5f, meta-validation post-inner-loop test accuracy: %.5f' % (accs[0], accs[-1])) maml.lr_scheduler.step()
def main(): torch.manual_seed(222) # 为cpu设置种子,为了使结果是确定的 torch.cuda.manual_seed_all(222) # 为GPU设置种子,为了使结果是确定的 np.random.seed(222) print(args) config = [ ('conv2d', [32, 1, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 7040]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000) mini_test = MiniImagenet("./miniimagenet", mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100) last_accuracy = 0 plt_train_loss = [] plt_train_acc = [] plt_test_loss = [] plt_test_acc =[] for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: d = loss_q.cpu() dd = d.detach().numpy() plt_train_loss.append(dd) plt_train_acc.append(accs[-1]) print('step:', step, '\ttraining acc:', accs) if step % 50 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] loss_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs, loss_test= maml.finetunning(x_spt, y_spt, x_qry, y_qry) loss_all_test.append(loss_test) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) plt_test_acc.append(accs[-1]) avg_loss = np.mean(np.array(loss_all_test)) plt_test_loss.append(avg_loss) print('Test acc:', accs) test_accuracy = np.mean(np.array(accs)) if test_accuracy > last_accuracy: # save networks torch.save(maml.state_dict(), str( "./models/miniimagenet_maml" + str(args.n_way) + "way_" + str( args.k_spt) + "shot.pkl")) last_accuracy = test_accuracy plt.figure() plt.title("testing info") plt.xlabel("episode") plt.ylabel("Acc/loss") plt.plot(plt_test_loss, label='Loss') plt.plot(plt_test_acc, label='Acc') plt.legend(loc='upper right') plt.savefig('./drawing/test.png') plt.show() plt.figure() plt.title("training info") plt.xlabel("episode") plt.ylabel("Acc/loss") plt.plot(plt_train_loss, label='Loss') plt.plot(plt_train_acc, label='Acc') plt.legend(loc='upper right') plt.savefig('./drawing/train.png') plt.show()
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) #np.random.seed(222) test_result = {} best_acc = 0.0 maml = Meta(args, Param.config).to(Param.device) if len(args.gpu.split(',')) > 1: maml = torch.nn.DataParallel(maml) opt = optim.Adam(maml.parameters(), lr=args.meta_lr) #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) if args.loader in [0, 1]: # default loader if args.loader == 1: #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet from MiniImagenet2 import MiniImagenet else: from MiniImagenet import MiniImagenet trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True) testloader = DataLoader(testset, batch_size=1, shuffle=True, num_workers=1, worker_init_fn=worker_init_fn, drop_last=True) train_data = inf_get(trainloader) test_data = inf_get(testloader) elif args.loader == 2: # pkl loader args_data = {} args_data['x_dim'] = "84,84,3" args_data['ratio'] = 1.0 args_data['seed'] = 222 loader_train = dataset_mini(600, 100, 'train', args_data) #loader_val = dataset_mini(600, 100, 'val', args_data) loader_test = dataset_mini(600, 100, 'test', args_data) loader_train.load_data_pkl() #loader_val.load_data_pkl() loader_test.load_data_pkl() for epoch in range(args.epoch): np.random.seed() if args.loader in [0, 1]: support_x, support_y, meta_x, meta_y = train_data.__next__() support_x, support_y, meta_x, meta_y = support_x.to( Param.device), support_y.to(Param.device), meta_x.to( Param.device), meta_y.to(Param.device) elif args.loader == 2: support_x, support_y, meta_x, meta_y = get_data(loader_train) support_x, support_y, meta_x, meta_y = support_x.to( Param.device), support_y.to(Param.device), meta_x.to( Param.device), meta_y.to(Param.device) meta_loss = maml(support_x, support_y, meta_x, meta_y).mean() opt.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0) opt.step() plot.plot('meta_loss', meta_loss.item()) if (epoch % 2500 == 0): ans = None maml_clone = deepcopy(maml) for _ in range(600): if args.loader in [0, 1]: support_x, support_y, qx, qy = test_data.__next__() support_x, support_y, qx, qy = support_x.to( Param.device), support_y.to(Param.device), qx.to( Param.device), qy.to(Param.device) elif args.loader == 2: support_x, support_y, qx, qy = get_data(loader_test) support_x, support_y, qx, qy = support_x.to( Param.device), support_y.to(Param.device), qx.to( Param.device), qy.to(Param.device) temp = maml_clone(support_x, support_y, qx, qy, meta_train=False) if (ans is None): ans = temp else: ans = torch.cat([ans, temp], dim=0) ans = ans.mean(dim=0).tolist() test_result[epoch] = ans if (ans[-1] > best_acc): best_acc = ans[-1] torch.save( maml.state_dict(), Param.out_path + 'net_' + str(epoch) + '_' + str(best_acc) + '.pkl') del maml_clone print(str(epoch) + ': ' + str(ans)) with open(Param.out_path + 'test.json', 'w') as f: json.dump(test_result, f) if (epoch < 5) or (epoch % 100 == 0): plot.flush() plot.tick()