def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) db_train = OmniglotNShot( 'omniglot', batchsz=args.task_num, # meta-batch size, 32 n_way=args.n_way, # n-way, 5 k_shot=args.k_spt, # k-shot for support set, 1 k_query=args.k_qry, # k-shot for query set, 15 imgsz=args.imgsz) # image size, 28 (28x28) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda') maml = Meta(args).to(device) db_train = OmniglotNShot(root='E:/meta_learning', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \ torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) # task_batch=20 if step % 50 == 0: print('step:', step, '\t training acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \ torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] -> [update_step+1,] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) shared_config = [ ('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ] nway_config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] # reads in image discriminator_config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('flatten', []), ('linear', [1, 64]) # don't use a sigmoid at the end ] # new gen_config # starts from image and convolves it into new ones gen_config = [ ('convt2d', [1, 64, 3, 3, 1, 1]), ('leakyrelu', [.2, True]), ('bn', [64]), ('random_proj', [100, 28, 64]), ('convt2d', [128, 64, 3, 3, 1, 1]), #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding] ('relu', [.2, True]), ('bn', [64]), # ('encode', [1024, 64*28*28]), # ('decode', [64*28*28, 1024]), ('relu', [.2, True]), ('conv2d', [64, 64, 3, 3, 1, 1]), #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding] ('relu', [.2, True]), ('bn', [64]), ('conv2d', [1, 64, 3, 3, 1, 1]), ('sigmoid', [True]) ] # old gen_config # gen_config = [ # ('random_proj', [100, 512, 64, 7]), # [latent_dim, emb_size, ch_out, h_out/w_out] # # img: (64, 7, 7) # ('convt2d', [64, 32, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding] # ('bn', [32]), # ('relu', [True]), # # img: (32, 14, 14) # ('convt2d', [32, 1, 4, 4, 2, 1]), # # img: (1, 28, 28) # ('sigmoid', [True]) # ] # if args.condition_discrim: # discriminator_config = [ # ('condition', [512, 1, 6]), # [emb_dim, emb_ch_out, h_out/w_out] # ('conv2d', [128, 65, 2, 2, 1, 0]), # ('leakyrelu', [.2, True]), # ('bn', [128]), # ('conv2d', [128, 128, 2, 2, 1, 0]), # ('leakyrelu', [.2, True]), # ('bn', [128]), # ('flatten', []), # ('linear', [1, 2048]) # ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mamlGAN = MetaGAN(args, shared_config, nway_config, discriminator_config, gen_config).to(device) tmp = filter(lambda x: x.requires_grad, mamlGAN.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(mamlGAN) print('Total trainable tensors:', num) db_train = OmniglotNShot('omniglot', batchsz=args.tasks_per_batch, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, img_sz=args.img_sz) save_model = not args.no_save if save_model: now = datetime.now().replace(second=0, microsecond=0) path = "results/" + str(now) + "_omni" mkdir_p(path) file = open(path + '/architecture.txt', 'w+') file.write("shared_config = " + json.dumps(shared_config) + "\n" + "nway_config = " + json.dumps(nway_config) + "\n" + "discriminator_config = " + json.dumps(discriminator_config) + "\n" + "gen_config = " + json.dumps(gen_config) + "\n" + "learn_inner_lr = " + str(args.learn_inner_lr) + "\n" + "condition_discrim = " + str(args.condition_discrim)) file.close() for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = mamlGAN(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print("step " + str(step)) for key in accs.keys(): print(key + ": " + str(accs[key])) if save_model: save_train_accs(path, accs, int(step)) if step % 500 == 0: print("testing") accs = [] imgs = [] for _ in range(1000 // args.tasks_per_batch): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc, ims = mamlGAN.finetunning( x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) imgs.append(x_spt_one.cpu().detach().numpy()) imgs.append(ims.cpu().detach().numpy()) if args.single_fast_test: break if args.single_fast_test: break accs = np.array(accs).mean(axis=0).astype(np.float16) if save_model: save_test_accs(path, accs, int(step)) imgs = np.array(imgs) save_imgs(path, imgs, step) torch.save({'model_state_dict': mamlGAN.state_dict()}, path + "/model_step" + str(step)) # to load, do this: # checkpoint = torch.load(path + "/model_step" + str(step)) # mamlGAN.load_state_dict(checkpoint['model_state_dict']) # [b, update_steps+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way + 1, 64])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number db_train = OmniglotNShot('omniglot', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) save_path = os.getcwd() + '/data/omniglot/model_batchsz' + str( args.k_spt) + '_stepsz' + str(args.update_lr) + '_epoch' for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs, al_accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs, '\tAL acc:', al_accs) if step % 500 == 0: al_accs, accs = [], [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) al_accs.append(maml.al_test(x_qry_one, y_qry_one)) accs.append(test_acc) # [b, update_step+1] pdb.set_trace() accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs) al_accs = np.array(al_accs).mean(axis=0).astype(np.float16) print('AL acc:', al_accs) torch.save(maml.state_dict(), save_path + str(step) + "_al.pt")
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) db_train = OmniglotNShot('omniglot', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) print(args.task_num) # 32 print(args.n_way) # 5 print(args.k_spt) # 1 print(args.k_qry) # 15 print(args.imgsz) # 28 for step in range(args.epoch): #if step % 4000 == 0: x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias #print("TYPE: " + y_spt.type()) #pdb.set_trace() y_spt = torch.tensor(y_spt, dtype=torch.int64, device=device) # diff syntax ? y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device) #print("TYPE: "+y_spt.type()) #pdb.set_trace() accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) # if step % 500 == 0: if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) y_spt = torch.tensor(y_spt, dtype=torch.int64, device=device) # diff syntax ? y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): print(x_spt_one.size()) print(y_spt_one.size()) print(x_qry_one.size()) print(y_qry_one.size()) test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(args): config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") maml = Meta(args, config, device).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print('Total trainable tensors:', num) db_train = OmniglotNShot('./', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) print('trainstep:', step, '\ttraining acc:', accs) if (step + 1) % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs) ############################## for i in range(args.prune_iteration): # prune print("the {}th prune step".format(i)) x_spt, y_spt, x_qry, y_qry = db_train.getHoleTrain() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) maml.prune(x_spt, y_spt, x_qry, y_qry, args.prune_number_one_epoch, args.max_prune_number) # fine-tuning print("start finetuning....") finetune_epoch = args.finetune_epoch finetune_epoch = finetune_epoch * (2 if i == args.prune_iteration - 1 else 1) for step in range(args.finetune_epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) accs = maml(x_spt, y_spt, x_qry, y_qry, finetune=True) print('finetune step:', step, '\ttraining acc:', accs) # print the test accuracy after pruning print("start testing....") accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def test_progress(args, net, device, viz=None, global_step=0): """ to plot ani/ars/acc with respect to training epochs. :param args: :param net: :param device: :param viz: :return: """ if args.resume is None: print('No ckpt file specified! make sure you are training!') exp = args.exp if viz is None: viz = visdom.Visdom(env='test') visualh = VisualH(viz) print('Testing now...') output_dir = os.path.join(args.test_dir, args.exp) # create test_dir if not os.path.exists(args.test_dir): os.makedirs(args.test_dir) # create test_dir/exp if not os.path.exists(output_dir): os.makedirs(output_dir) # clustering, visualization and classification db_test = OmniglotNShot('db/omniglot', batchsz=1, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = 0, 0, 0, 0 acc0, acc1 = [], [] for batchidx in range(args.test_episode_num): spt_x, spt_y, qry_x, qry_y = db_test.next('test') spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \ torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device) assert spt_x.size(0) == 1 spt_x, spt_y, qry_x, qry_y = spt_x.squeeze(0), spt_y.squeeze( 0), qry_x.squeeze(0), qry_y.squeeze(0) # we can get the representation before first update, after k update # and test the representation on merged(test_spt, test_qry) set h_spt0, h_spt1, h_qry0, h_qry1, _, new_net = net.finetuning( spt_x, spt_y, qry_x, qry_y, args.finetuning_steps, None) if batchidx == 0: visualh.update(h_spt0, h_spt1, h_qry0, h_qry1, spt_y, qry_y, global_step) # we will use the acquired representation to cluster. # h_spt: [sptsz, h_dim] # h_qry: [qrysz, h_dim] h_qry0_np = h_qry0.detach().cpu().numpy() h_qry1_np = h_qry1.detach().cpu().numpy() qry_y_np = qry_y.detach().cpu().numpy() h_qry0_pred = cluster.KMeans(n_clusters=args.n_way, random_state=0).fit(h_qry0_np).labels_ h_qry1_pred = cluster.KMeans(n_clusters=args.n_way, random_state=0).fit(h_qry1_np).labels_ h_qry0_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry0_pred) h_qry0_ars += metrics.adjusted_rand_score(qry_y_np, h_qry0_pred) h_qry1_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry1_pred) h_qry1_ars += metrics.adjusted_rand_score(qry_y_np, h_qry1_pred) h_qry0_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y) h_qry1_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y) # viz.heatmap(X=h_qry0_cm, win=args.exp+' h_qry0_cm', opts=dict(title=args.exp+' h_qry0_cm:%d'%batchidx, # colormap='Electric')) # viz.heatmap(X=h_qry1_cm, win=args.exp+' h_qry1_cm', opts=dict(title=args.exp+' h_qry1_cm:%d'%batchidx, # colormap='Electric')) # return is a list of [acc_step0, acc_step1 ,...] acc0.append( net.classify_train(h_spt0, spt_y, h_qry0, qry_y, use_h=True, train_step=args.classify_steps)) acc1.append( net.classify_train(h_spt1, spt_y, h_qry1, qry_y, use_h=True, train_step=args.classify_steps)) if batchidx == 0: spt_x_hat0 = net.forward_ae(spt_x[:64]) qry_x_hat0 = net.forward_ae(qry_x[:64]) spt_x_hat1 = new_net.forward_ae(spt_x[:64]) qry_x_hat1 = new_net.forward_ae(qry_x[:64]) viz.images(qry_x[:64], nrow=8, win=exp + 'qry_x', opts=dict(title=exp + 'qry_x')) # viz.images(spt_x_hat0, nrow=8, win=exp+'spt_x_hat0', opts=dict(title=exp+'spt_x_hat0')) viz.images(qry_x_hat0, nrow=8, win=exp + 'qry_x_hat0', opts=dict(title=exp + 'qry_x_hat0')) # viz.images(spt_x_hat1, nrow=8, win=exp+'spt_x_hat1', opts=dict(title=exp+'spt_x_hat1')) viz.images(qry_x_hat1, nrow=8, win=exp + 'qry_x_hat1', opts=dict(title=exp + 'qry_x_hat1')) if batchidx > 0: break h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = h_qry0_ami / (batchidx + 1), h_qry0_ars / (batchidx + 1), \ h_qry1_ami / (batchidx + 1), h_qry1_ars / (batchidx + 1) # [[epsode1], [episode2],...] = [N, steps] => [steps] acc0, acc1 = np.array(acc0).mean(axis=0), np.array(acc1).mean(axis=0) print('ami:', h_qry0_ami, h_qry1_ami) print('ars:', h_qry0_ars, h_qry1_ars) viz.line([[h_qry0_ami, h_qry1_ami]], [global_step], win=exp + 'ami_on_qry01', update='append') viz.line([[h_qry0_ars, h_qry1_ars]], [global_step], win=exp + 'ars_on_qry01', update='append') print('acc:\n', acc0, '\n', acc1) viz.line([[acc0[-1], acc1[-1]]], [global_step], win=exp + 'acc_on_qry01', update='append')
def main(args): args = update_args(args) torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) device = torch.device('cuda') if args.is_meta: # optimizer has been embedded in model. net = MetaAE(args) # model_parameters = filter(lambda p: p.requires_grad, net.learner.parameters()) # params = sum([np.prod(p.size()) for p in model_parameters]) # print('Total params:', params) tmp = filter(lambda x: x.requires_grad, net.learner.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print('Total trainable variables:', num) else: net = AE(args, use_logits=True) optimizer = optim.Adam(list(net.encoder.parameters()) + list(net.decoder.parameters()), lr=args.meta_lr) tmp = filter( lambda x: x.requires_grad, list(net.encoder.parameters()) + list(net.decoder.parameters())) num = sum(map(lambda x: np.prod(x.shape), tmp)) print('Total trainable variables:', num) net.to(device) print(net) print('=' * 15, 'Experiment:', args.exp, '=' * 15) print(args) if args.h_dim == 2: # borrowed from https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py h_range = np.rollaxis( np.mgrid[args.h_range:-args.h_range:args.h_nrow * 1j, args.h_range:-args.h_range:args.h_nrow * 1j], 0, 3) # [b, q_h] h_manifold = torch.from_numpy(h_range.reshape([-1, 2])).to(device).float() print('h_manifold:', h_manifold.shape) else: h_manifold = None # try to resume from ckpt.mdl file epoch_start = 0 if args.resume is not None: # ckpt/normal-fc-vae_640_2018-11-20_09:58:58.mdl mdl_file = args.resume epoch_start = int(mdl_file.split('_')[-3]) net.load_state_dict(torch.load(mdl_file)) print('Resume from:', args.resume, 'epoch/batches:', epoch_start) else: print('Training/test from scratch...') if args.test: assert args.resume is not None test.test_ft_steps(args, net, device) return vis = visdom.Visdom(env=args.exp) visualh = VisualH(vis) vis.line([[0, 0, 0]], [epoch_start], win=args.exp + 'train_loss', opts=dict(title=args.exp + 'train_qloss', legend=['loss', '-lklh', 'kld'], xlabel='global_step')) # for test_progress vis.line([[0, 0]], [epoch_start], win=args.exp + 'acc_on_qry01', opts=dict(title=args.exp + 'acc_on_qry01', legend=['h_qry0', 'h_qry1'], xlabel='global_step')) vis.line([[0, 0]], [epoch_start], win=args.exp + 'ami_on_qry01', opts=dict(title=args.exp + 'ami_on_qry01', legend=['h_qry0', 'h_qry1'], xlabel='global_step')) vis.line([[0, 0]], [epoch_start], win=args.exp + 'ars_on_qry01', opts=dict(title=args.exp + 'ars_on_qry01', legend=['h_qry0', 'h_qry1'], xlabel='global_step')) # 1. train db_train = OmniglotNShot('db/omniglot', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) # epoch = batch number here. for epoch in range(epoch_start, args.train_episode_num): spt_x, spt_y, qry_x, qry_y = db_train.next() spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \ torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device) if args.is_meta: # for meta loss_optim, losses_q, likelihoods_q, klds_q = net( spt_x, spt_y, qry_x, qry_y) if epoch % 300 == 0: if args.is_vae: # print(losses_q, likelihoods_q, klds_q) vis.line([[ losses_q[-1].item(), -likelihoods_q[-1].item(), klds_q[-1].item() ]], [epoch], win=args.exp + 'train_loss', update='append') print(epoch) print( 'loss_q:', torch.stack(losses_q).detach().cpu().numpy().astype( np.float16)) print( 'lkhd_q:', torch.stack( likelihoods_q).detach().cpu().numpy().astype( np.float16)) print( 'klds_q:', torch.stack(klds_q).cpu().detach().numpy().astype( np.float16)) else: # print(losses_q, likelihoods_q, klds_q) vis.line([[losses_q[-1].item(), 0, 0]], [epoch], win=args.exp + 'train_loss', update='append') print( epoch, torch.stack(losses_q).detach().cpu().numpy().astype( np.float16)) else: # for normal vae/ae loss_optim, _, likelihood, kld = net(spt_x, spt_y, qry_x, qry_y) optimizer.zero_grad() loss_optim.backward() torch.nn.utils.clip_grad_norm_( list(net.encoder.parameters()) + list(net.decoder.parameters()), 10) optimizer.step() if epoch % 300 == 0: print(epoch, loss_optim.item()) if not args.is_vae: vis.line([[loss_optim.item(), 0, 0]], [epoch], win='train_loss', update='append') else: vis.line( [[loss_optim.item(), -likelihood.item(), kld.item()]], [epoch], win='train_loss', update='append') if epoch % 3000 == 0: # [qrysz, 1, 64, 64] => [qrysz, 1, 64, 64] x_hat = net.forward_ae(qry_x[0]) vis.images(x_hat, nrow=args.k_qry, win='train_x_hat', opts=dict(title='train_x_hat')) test.test_progress(args, net, device, vis, epoch) # save checkpoint. if epoch % 10000 == 0: date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") mdl_file = os.path.join( args.ckpt_dir, args.exp + '_%d' % epoch + '_' + date_str + '.mdl') torch.save(net.state_dict(), mdl_file) print('Saved into ckpt file:', mdl_file) # save checkpoint. date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") mdl_file = os.path.join( args.ckpt_dir, args.exp + '_%d' % args.epoch + '_' + date_str + '.mdl') torch.save(net.state_dict(), mdl_file) print('Saved Last state ckpt file:', mdl_file)
def main(args): if not os.path.exists('./logs'): os.mkdir('./logs') logfile = os.path.sep.join(('.', 'logs', f'omniglot_way[{args.n_way}]_shot[{args.k_spt}].json')) if args.write_log: log_fp = open(logfile, 'wb') torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [ ('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) path = os.sep.join((os.path.dirname(__file__), 'dataset', 'omniglot')) db_train = OmniglotNShot(path, batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): # 获取一定的 epoch 数据. 在omniglot NShot类里写的是 x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) torch.backends.cudnn.benchmark=True print(args) config = [ ("conv2d", [64, 1, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 2, 2, 1, 0]), ("relu", [True]), ("bn", [64]), ("flatten", []), ("linear", [args.n_way, 64]), ] device = torch.device("cuda") maml = Meta(args, config).to(device) db_train = OmniglotNShot( "omniglot", batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz, ) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print("Total trainable tensors:", num) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = ( torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device), ) # x_spt: shape: 32, 5, 1, 28, 28 # y_spt: shape: 32, 5 # x_qry: 32, 75, 1, 28, 28 # y_qry: 32, 75 # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print("step:", step, "\ttraining acc:", accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next("test") x_spt, y_spt, x_qry, y_qry = ( torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device), ) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print("Test acc:", accs)