def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 256])] #device = torch.device('cuda') device = torch.device('cpu') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) db_train = SvhnNShot(batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt.long(), x_qry, y_qry.long()) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one.long(), x_qry_one, y_qry_one.long()) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) ckpt_dir = "./checkpoint_miniimage.pth" print("Load trained model") ckpt = torch.load(ckpt_dir) maml.load_state_dict(ckpt['model']) mini_test = MiniImagenet("F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\", mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=1, resize=args.imgsz) db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] #count = 0 #print("Test_loader",db_test) for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): torch.manual_seed(121) torch.cuda.manual_seed_all(121) np.random.seed(121) nshot = SinwaveNShot(all_numbers_class=2000, batch_size=20, n_way=5, k_shot=5, k_query=15, root='data') maml = Meta(hid_dim=64, meta_lr=1e-3, update_lr=0.004) for step in range(10000): x_spt, y_spt, x_qry, y_qry, param_spt, param_qry = nshot.next('train') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt), torch.from_numpy( y_spt), torch.from_numpy(x_qry), torch.from_numpy(y_qry) loss = maml(x_spt, y_spt, x_qry, y_qry) if step % 20 == 0: print('step:', step, '\ttraining loss:', loss) if step % 500 == 0: loss = [] for _ in range(1000 // 20): # test x_spt, y_spt, x_qry, y_qry, param_spt, param_qry = nshot.next( 'test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt), torch.from_numpy(y_spt), \ torch.from_numpy(x_qry), torch.from_numpy(y_qry) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one, param_spt_one, param_qry_onein in \ zip(x_spt, y_spt, x_qry, y_qry, param_spt, param_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one, param_spt_one, param_qry_onein) loss.append(test_acc) # [b, update_step+1] loss = np.array(loss).mean(axis=0).astype(np.float16) print('Test loss:', loss)
def main(): print(args) device = torch.device('cuda') maml = Meta(args).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) trainset = Gen(args.task_num, args.k_spt, args.k_qry) testset = Gen(args.task_num, args.k_spt, args.k_qry * 10) for epoch in range(args.epoch): ind = [i for i in range(trainset.xs.shape[0])] np.random.shuffle(ind) xs, ys = torch.Tensor(trainset.xs[ind]).to(device), torch.Tensor( trainset.ys[ind]).to(device) xq, yq = torch.Tensor(trainset.xq[ind]).to(device), torch.Tensor( trainset.yq[ind]).to(device) maml.train() loss = maml(xs, ys, xq, yq, epoch) print('Epoch: {} Initial loss: {} Train loss: {}'.format( epoch, loss[0] / args.task_num, loss[-1] / args.task_num)) if (epoch + 1) % 50 == 0: print("Evaling the model...") torch.save(maml.state_dict(), 'save.pt') # del(maml) # maml = Meta(args).to(device) # maml.load_state_dict(torch.load('save.pt')) maml.eval() i = random.randint(0, testset.xs.shape[0] - 1) xs, ys = torch.Tensor(testset.xs[i]).to(device), torch.Tensor( testset.ys[i]).to(device) xq, yq = torch.Tensor(testset.xq[i]).to(device), torch.Tensor( testset.yq[i]).to(device) losses, losses_q, logits_q, _ = maml.finetunning(xs, ys, xq, yq) print('Epoch: {} Initial loss: {} Test loss: {}'.format( epoch, losses_q[0], losses_q[-1]))
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [ ('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) ckpt_dir = "./model/" for epoch in range(args.epoch//10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) # save checkpoints os.makedirs(ckpt_dir, exist_ok=True) print('Saving the model as a checkpoint...') torch.save({'epoch': epoch, 'Steps': step, 'model': maml.state_dict()}, os.path.join(ckpt_dir, 'checkpoint.pth'))
def main(): mem_usage = memory_usage(-1, interval=.5, timeout=1) torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) root = args.data_dir feat = np.load(root + 'features.npy', allow_pickle=True) with open(root + '/graph_dgl.pkl', 'rb') as f: dgl_graph = pickle.load(f) if args.task_setup == 'Disjoint': with open(root + 'label.pkl', 'rb') as f: info = pickle.load(f) elif args.task_setup == 'Shared': if args.task_mode == 'True': root = root + '/task' + str(args.task_n) + '/' with open(root + 'label.pkl', 'rb') as f: info = pickle.load(f) total_class = len(np.unique(np.array(list(info.values())))) print('There are {} classes '.format(total_class)) if args.task_setup == 'Disjoint': labels_num = args.n_way elif args.task_setup == 'Shared': labels_num = total_class if len(feat.shape) == 2: # single graph, to make it compatible to multiple graph retrieval. feat = [feat] config = [('GraphConv', [feat[0].shape[1], args.hidden_dim])] if args.h > 1: config = config + [('GraphConv', [args.hidden_dim, args.hidden_dim]) ] * (args.h - 1) config = config + [('Linear', [args.hidden_dim, labels_num])] if args.link_pred_mode == 'True': config.append(('LinkPred', [True])) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) max_acc = 0 model_max = copy.deepcopy(maml) db_train = Subgraphs(root, 'train', info, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.batchsz, args=args, adjs=dgl_graph, h=args.h) db_val = Subgraphs(root, 'val', info, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, args=args, adjs=dgl_graph, h=args.h) db_test = Subgraphs(root, 'test', info, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, args=args, adjs=dgl_graph, h=args.h) print('------ Start Training ------') s_start = time.time() max_memory = 0 for epoch in range(args.epoch): db = DataLoader(db_train, args.task_num, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn=collate) s_f = time.time() for step, (x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry) in enumerate(db): nodes_len = 0 if step >= 1: data_loading_time = time.time() - s_r else: data_loading_time = time.time() - s_f s = time.time() # x_spt: a list of #task_num tasks, where each task is a mini-batch of k-shot * n_way subgraphs # y_spt: a list of #task_num lists of labels. Each list is of length k-shot * n_way int. nodes_len += sum([sum([len(j) for j in i]) for i in n_spt]) accs = maml(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) max_memory = max(max_memory, float(psutil.virtual_memory().used / (1024**3))) if step % args.train_result_report_steps == 0: print('Epoch:', epoch + 1, ' Step:', step, ' training acc:', str(accs[-1])[:5], ' time elapsed:', str(time.time() - s)[:5], ' data loading takes:', str(data_loading_time)[:5], ' Memory usage:', str(float(psutil.virtual_memory().used / (1024**3)))[:5]) s_r = time.time() # validation per epoch db_v = DataLoader(db_val, 1, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn=collate) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_v: accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) accs_all_test.append(accs) accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Epoch:', epoch + 1, ' Val acc:', str(accs[-1])[:5]) if accs[-1] > max_acc: max_acc = accs[-1] model_max = copy.deepcopy(maml) db_t = DataLoader(db_test, 1, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn=collate) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t: accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) accs_all_test.append(accs) accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', str(accs[1])[:5]) for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t: accs = model_max.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) accs_all_test.append(accs) #torch.save(model_max.state_dict(), './model.pt') accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Early Stopped Test acc:', str(accs[-1])[:5]) print('Total Time:', str(time.time() - s_start)[:5]) print('Max Momory:', str(max_memory)[:5])
def main(): #print(args) TARGET_MODEL = 3 config = [ ('conv2d', [16, 1, 3, 3, 1, 1]), ('relu', [True]), ('bn', [16]), ('conv2d', [32, 16, 4, 4, 2, 1]), ('relu', [True]), ('bn', [32]), ('conv2d', [64, 32, 4, 4, 2, 1]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 4, 4, 2, 1]), ('relu', [True]), ('bn', [64]), ('convt2d', [64, 32, 3, 3, 2, 0]), ('relu', [True]), ('bn', [32]), ('convt2d', [32, 16, 4, 4, 2, 1]), ('relu', [True]), ('bn', [16]), ('convt2d', [16, 8, 4, 4, 2, 1]), ('relu', [True]), ('bn', [8]), ('convt2d', [8, 1, 3, 3, 1, 1]), ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) #print(maml) #print('Total trainable tensors:', num) # initiate different datasets minis = [] for i in range(args.task_num): path = osp.join("./zoo_cw_grad_mnist/train", MODELS[i] + "_mnist.npy") mini = mnist(path, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) db = DataLoader(mini, args.batchsize, shuffle=True, num_workers=0, pin_memory=True) minis.append(db) path_test = osp.join("./zoo_cw_grad_mnist/test", MODELS[TARGET_MODEL] + "_mnist.npy") mini_test = mnist(path_test, mode='test', n_way=1, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) mini_test = DataLoader(mini_test, 10, shuffle=True, num_workers=0, pin_memory=True) # start training step_number = len(minis[0]) test_step_number = len(mini_test) BEST_ACC = 1.0 target_model = get_target_model(TARGET_MODEL).to(device) def save_model(model,acc): model_file_path = './checkpoint/mnist' if not os.path.exists(model_file_path): os.makedirs(model_file_path) file_name = str(acc) + 'mnist_'+ MODELS[TARGET_MODEL] + '.pt' save_model_path = os.path.join(model_file_path, file_name) torch.save(model.state_dict(), save_model_path) def load_model(model,acc): model_checkpoint_path = './checkpoint/mnist/' + str(acc) + 'mnist_' + MODELS[TARGET_MODEL] + '.pt' assert os.path.exists(model_checkpoint_path) model.load_state_dict(torch.load(model_checkpoint_path)) return model for epoch in range(args.epoch//100): minis_iter = [] for i in range(len(minis)): minis_iter.append(iter(minis[i])) mini_test_iter = iter(mini_test) if args.resume: maml = load_model(maml,"0.7231071") for step in range(step_number): batch_data = [] for each in minis_iter: batch_data.append(each.next()) accs = maml(batch_data, device) if (step + 1) % step_number == 0: print('Training acc:', accs) if accs[0] < BEST_ACC: BEST_ACC = accs[0] save_model(maml, BEST_ACC) if (epoch + 1) % 15 == 0 and step ==0: # evaluation accs_all_test = [] for i in range(3): test_data = mini_test_iter.next() accs = maml.finetunning(test_data, target_model, device) accs_all_test.append(accs) accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') maml = Meta(args, STSTNet()).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('/home/lf/miniImagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('/home/lf/miniImagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in tqdm(enumerate(db)): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) torch.save(maml.net, 'maml_ststnet.pth')
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [ ('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5]) ] cuda = 'cuda:' + args.gpu_index device = torch.device(cuda) maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) # print(maml) print('Total trainable tensors:', num) if args.mode == 0: mode_val_test = 'val' train = 'train' else: mode_val_test = 'test' train = 'train_ts' # batchsz here means total episode number mini = MiniImagenet('../flower/', mode=train, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz, cross_val_idx=args.cross_val_idx) mini_test = MiniImagenet('../flower/', mode=mode_val_test, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz, cross_val_idx=args.cross_val_idx) accs_list_tr = [] accs_list_ts = [] for epoch in range(args.epoch//10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) accs_list_tr.append(accs) if step % 500 == 0 or (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]): # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('step:', step, '\ttest acc:', accs) accs_list_ts.append(accs) if (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]): with open('acc_cv2/natural2_' + mode_val_test + '(task_num_' + str(args.task_num) + ').txt', mode='a') as f: f.write(str(accs[-1]) + ',')
def main(args): if not os.path.exists('./logs'): os.mkdir('./logs') logfile = os.path.sep.join(('.', 'logs', f'omniglot_way[{args.n_way}]_shot[{args.k_spt}].json')) if args.write_log: log_fp = open(logfile, 'wb') torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [ ('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) path = os.sep.join((os.path.dirname(__file__), 'dataset', 'omniglot')) db_train = OmniglotNShot(path, batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): # 获取一定的 epoch 数据. 在omniglot NShot类里写的是 x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda') config = [] if args.arch == "Unet": for block in range(args.NUM_DOWN_CONV): out_channels = (2**block) * args.HIDDEN_DIM if (block == 0): config += [( 'conv2d', [out_channels, args.imgc, 3, 3, 1, 1] ) # out_c, in_c, k_h, k_w, stride, padding, also only conv, without bias ] else: config += [ ('conv2d', [out_channels, out_channels // 2, 3, 3, 1, 1]), # out_c, in_c, k_h, k_w, stride, padding ] config += [ ('leakyrelu', [0.2, False]), # alpha; if true then executes relu in place ('bn', [out_channels]) ] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('max_pool2d', [2, 2, 0])] # kernel_size, stride, padding for block in range(args.NUM_DOWN_CONV - 1): out_channels = (2**(args.NUM_DOWN_CONV - block - 2)) * args.HIDDEN_DIM in_channels = out_channels * 3 config += [('upsample', [2])] config += [('conv2d', [out_channels, in_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]), ('leakyrelu', [0.2, False]), ('bn', [out_channels])] config += [ ('conv2d_b', [args.outc, args.HIDDEN_DIM, 3, 3, 1, 1]) ] # all the conv2d before are without bias, and this conv_b is with bias else: raise ("architectures other than Unet hasn't been added!!") maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) # print(maml) for name, param in maml.named_parameters(): print(name, param.size()) print('Total trainable tensors:', num) SUMMARY_INTERVAL = 5 TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5 ITER_SAVE_INTERVAL = 300 EPOCH_SAVE_INTERVAL = 5 model_path = "/scratch/users/chenkaim/pytorch-models/pytorch_" + args.model_name + "_k_shot_" + str( args.k_spt) + "_task_num_" + str(args.task_num) + "_meta_lr_" + str( args.meta_lr) + "_inner_lr_" + str( args.update_lr) + "_num_inner_updates_" + str(args.update_step) if not os.path.isdir(model_path): os.mkdir(model_path) start_epoch = 0 if (args.continue_train): print("Restoring weights from ", model_path + "/epoch_" + str(args.continue_epoch) + ".pt") checkpoint = torch.load(model_path + "/epoch_" + str(args.continue_epoch) + ".pt") maml = checkpoint['model'] maml.lr_scheduler = checkpoint['lr_scheduler'] maml.meta_optim = checkpoint['optimizer'] start_epoch = args.continue_epoch db = RCWA_data_loader(batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz, data_folder=args.data_folder) for step in range(start_epoch, args.epoch): print("epoch: ", step) if step % EPOCH_SAVE_INTERVAL == 0: checkpoint = { 'epoch': step, 'model': maml, 'optimizer': maml.meta_optim, 'lr_scheduler': maml.lr_scheduler } torch.save(checkpoint, model_path + "/epoch_" + str(step) + ".pt") for itr in range( int(0.7 * db.total_data_samples / ((args.k_spt + args.k_qry) * args.task_num))): x_spt, y_spt, x_qry, y_qry = db.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry) if itr % SUMMARY_INTERVAL == 0: print_str = "Iteration %d: pre-inner-loop train accuracy: %.5f, post-iner-loop test accuracy: %.5f, train_loss: %.5f" % ( itr, accs[0], accs[-1], loss_q) print(print_str) if itr % TEST_PRINT_INTERVAL == 0: accs = [] for _ in range(10): # test x_spt, y_spt, x_qry, y_qry = db.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print( 'Meta-validation pre-inner-loop train accuracy: %.5f, meta-validation post-inner-loop test accuracy: %.5f' % (accs[0], accs[-1])) maml.lr_scheduler.step()
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_val = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=600, resize=args.imgsz) mini_test = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=600, resize=args.imgsz) best_acc = 0.0 if not os.path.exists('ckpt/{}'.format(args.exp)): os.mkdir('ckpt/{}'.format(args.exp)) for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 500 == 0: print('step:', step, '\ttraining acc:', accs) if step % 1000 == 0: # evaluation db_val = DataLoader(mini_val, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_val = [] for x_spt, y_spt, x_qry, y_qry in db_val: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_val.append(accs) mean, std, ci95 = cal_conf(np.array(accs_all_val)) print('Val acc:{}, std:{}. ci95:{}'.format( mean[-1], std[-1], ci95[-1])) if mean[-1] > best_acc or step % 5000 == 0: best_acc = mean[-1] torch.save( maml.state_dict(), 'ckpt/{}/model_e{}s{}_{:.4f}.pkl'.format( args.exp, epoch, step, best_acc)) with open('ckpt/' + args.exp + '/val.txt', 'a') as f: print( 'val epoch {}, step {}: acc_val:{:.4f}, ci95:{:.4f}' .format(epoch, step, best_acc, ci95[-1]), file=f) ## Test db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) mean, std, ci95 = cal_conf(np.array(accs_all_test)) print('Test acc:{}, std:{}, ci95:{}'.format( mean[-1], std[-1], ci95[-1])) with open('ckpt/' + args.exp + '/test.txt', 'a') as f: print( 'test epoch {}, step {}: acc_test:{:.4f}, ci95:{:.4f}' .format(epoch, step, mean[-1], ci95[-1]), file=f)
def main(args): config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") maml = Meta(args, config, device).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print('Total trainable tensors:', num) db_train = OmniglotNShot('./', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) print('trainstep:', step, '\ttraining acc:', accs) if (step + 1) % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs) ############################## for i in range(args.prune_iteration): # prune print("the {}th prune step".format(i)) x_spt, y_spt, x_qry, y_qry = db_train.getHoleTrain() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) maml.prune(x_spt, y_spt, x_qry, y_qry, args.prune_number_one_epoch, args.max_prune_number) # fine-tuning print("start finetuning....") finetune_epoch = args.finetune_epoch finetune_epoch = finetune_epoch * (2 if i == args.prune_iteration - 1 else 1) for step in range(args.finetune_epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) accs = maml(x_spt, y_spt, x_qry, y_qry, finetune=True) print('finetune step:', step, '\ttraining acc:', accs) # print the test accuracy after pruning print("start testing....") accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) #np.random.seed(222) config = [('conv2d', [32, 3, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) root = '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet' trainset = MiniImagenet(root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True) testloader = DataLoader(testset, batch_size=1, shuffle=True, num_workers=1, worker_init_fn=worker_init_fn, drop_last=True) train_data = inf_get(trainloader) test_data = inf_get(testloader) best_acc = 0.0 if not os.path.exists('ckpt/{}'.format(args.exp)): os.mkdir('ckpt/{}'.format(args.exp)) for epoch in range(args.epoch): np.random.seed() x_spt, y_spt, x_qry, y_qry = train_data.__next__() x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if epoch % 100 == 0: print('epoch:', epoch, '\ttraining acc:', accs) if epoch % 2500 == 0: # evaluation # save checkpoint torch.save(maml.state_dict(), 'ckpt/{}/model_{}.pkl'.format(args.exp, epoch)) accs_all_test = [] for _ in range(600): x_spt, y_spt, x_qry, y_qry = test_data.__next__() x_spt, y_spt, x_qry, y_qry = x_spt.squeeze().to( device), y_spt.squeeze().to(device), x_qry.squeeze().to( device), y_qry.squeeze().to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) with open('ckpt/' + args.exp + '/test.txt', 'a') as f: print('test epoch {}: acc:{:.4f}'.format(epoch, accs[-1]), file=f)
def main(): torch.manual_seed(222) # 为cpu设置种子,为了使结果是确定的 torch.cuda.manual_seed_all(222) # 为GPU设置种子,为了使结果是确定的 np.random.seed(222) print(args) config = [ ('conv2d', [32, 1, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 7040]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000) mini_test = MiniImagenet("./miniimagenet", mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100) last_accuracy = 0 plt_train_loss = [] plt_train_acc = [] plt_test_loss = [] plt_test_acc =[] for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: d = loss_q.cpu() dd = d.detach().numpy() plt_train_loss.append(dd) plt_train_acc.append(accs[-1]) print('step:', step, '\ttraining acc:', accs) if step % 50 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] loss_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs, loss_test= maml.finetunning(x_spt, y_spt, x_qry, y_qry) loss_all_test.append(loss_test) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) plt_test_acc.append(accs[-1]) avg_loss = np.mean(np.array(loss_all_test)) plt_test_loss.append(avg_loss) print('Test acc:', accs) test_accuracy = np.mean(np.array(accs)) if test_accuracy > last_accuracy: # save networks torch.save(maml.state_dict(), str( "./models/miniimagenet_maml" + str(args.n_way) + "way_" + str( args.k_spt) + "shot.pkl")) last_accuracy = test_accuracy plt.figure() plt.title("testing info") plt.xlabel("episode") plt.ylabel("Acc/loss") plt.plot(plt_test_loss, label='Loss') plt.plot(plt_test_acc, label='Acc') plt.legend(loc='upper right') plt.savefig('./drawing/test.png') plt.show() plt.figure() plt.title("training info") plt.xlabel("episode") plt.ylabel("Acc/loss") plt.plot(plt_train_loss, label='Loss') plt.plot(plt_train_acc, label='Acc') plt.legend(loc='upper right') plt.savefig('./drawing/train.png') plt.show()
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('./data/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.img_sz) mini_test = MiniImagenet('./data/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.img_sz) for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): start_time = time.time() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) print(args) print(argv) os.makedirs(args.modelfile.split('/')[0], exist_ok=True) config = [ ('conv2d', [32, 3, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) mini = MiniImagenet('./dataset/mini-imagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) if args.domain == 'mini': mini_test = MiniImagenet('./dataset/mini-imagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) mini_val = MiniImagenet('./dataset/mini-imagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) elif args.domain == 'cub': print("CUB dataset") mini_test = MiniImagenet('./dataset/CUB_200_2011/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) elif args.domain == 'traffic': print("Traffic dataset") mini_test = MiniImagenet('./dataset/GTSRB/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) elif args.domain == 'flower': print("flower dataset") mini_test = MiniImagenet('./dataset/102flowers/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) else: print("Dataset Error") return if args.mode == 'test': count = 0 db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=6, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: print(count) count += 1 x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'test', args.modelfile, pertub_scale=args.pertub_scale, num_ensemble=args.num_ensemble, fgsm_epsilon=args.fgsm_epsilon) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) np.set_printoptions(linewidth=1000) print("Running Time:", time.time()-start_time) print(accs) return for epoch in range(args.epoch//10000): db = DataLoader(mini, args.task_num, shuffle=True, num_workers=4, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('epoch:', epoch, 'step:', step, '\ttraining acc:', accs) if step % 200 == 0: print("Save model", args.modelfile) torch.save(maml, args.modelfile) db_test = DataLoader(mini_val, 1, shuffle=True, num_workers=4, pin_memory=True) accs_all_val = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'train_test') accs_all_val.append(accs) db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=4, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) accs_val = np.array(accs_all_val).mean(axis=0).astype(np.float16) save_modelfile = "{}_{}_{}_{:0.4f}_{:0.4f}.pth".format(args.modelfile, epoch, step, accs_val[-1], accs[-1]) print(save_modelfile) torch.save(maml, save_modelfile) print("Val:", accs_val) print("Test:", accs)
def main(): torch.manual_seed(222) # 为cpu设置种子,为了使结果是确定的 torch.cuda.manual_seed_all(222) # 为GPU设置种子,为了使结果是确定的 np.random.seed(222) print(args) config = [ ('conv2d', [32, 1, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 7040]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) if os.path.exists( "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")): path = "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl") maml.load_state_dict(path) print("load model success") tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000) mini_test = MiniImagenet("./miniimagenet", mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100) test_accuracy = [] for epoch in range(10): # fetch meta_batchsz num of episode each time db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs, loss_t = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) test_accuracy.append(accs[-1]) average_accuracy = sum(test_accuracy) / len(test_accuracy) print("average accuracy:{}".format(average_accuracy))
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) torch.backends.cudnn.benchmark=True print(args) config = [ ("conv2d", [64, 1, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 2, 2, 1, 0]), ("relu", [True]), ("bn", [64]), ("flatten", []), ("linear", [args.n_way, 64]), ] device = torch.device("cuda") maml = Meta(args, config).to(device) db_train = OmniglotNShot( "omniglot", batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz, ) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print("Total trainable tensors:", num) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = ( torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device), ) # x_spt: shape: 32, 5, 1, 28, 28 # y_spt: shape: 32, 5 # x_qry: 32, 75, 1, 28, 28 # y_qry: 32, 75 # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print("step:", step, "\ttraining acc:", accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next("test") x_spt, y_spt, x_qry, y_qry = ( torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device), ) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print("Test acc:", accs)