def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) db_train = OmniglotNShot( 'omniglot', batchsz=args.task_num, # meta-batch size, 32 n_way=args.n_way, # n-way, 5 k_shot=args.k_spt, # k-shot for support set, 1 k_query=args.k_qry, # k-shot for query set, 15 imgsz=args.imgsz) # image size, 28 (28x28) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='batch size', default=32) argparser.add_argument('-l', help='meta learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) meta_lr = float(args.l) train_lr = 0.4 k_query = 15 imgsz = 84 mdl_file = 'ckpt/omniglot%d%d.mdl' % (n_way, k_shot) print('omniglot: %d-way %d-shot meta-lr:%f, train-lr:%f' % (n_way, k_shot, meta_lr, train_lr)) device = torch.device('cuda:0') net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr, device) print(net) # batchsz here means total episode number db = OmniglotNShot('omniglot', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) for step in range(10000000): # train support_x, support_y, query_x, query_y = db.get_batch('train') support_x = torch.from_numpy(support_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1).to(device) query_x = torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1).to(device) support_y = torch.from_numpy(support_y).long().to(device) query_y = torch.from_numpy(query_y).long().to(device) accs = net(support_x, support_y, query_x, query_y, training=True) if step % 20 == 0: print(step, '\t', accs) if step % 1000 == 0: # test pass
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) device = torch.device('cuda') maml = Meta(args).to(device) db_train = OmniglotNShot(root='E:/meta_learning', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \ torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) # task_batch=20 if step % 50 == 0: print('step:', step, '\t training acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \ torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] -> [update_step+1,] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): meta_batchsz = 32 n_way = 20 k_shot = 1 k_query = k_shot meta_lr = 1e-3 # meta_lr = 1 num_updates = 5 dataset = 'omniglot' if dataset == 'omniglot': imgsz = 28 db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) elif dataset == 'mini-imagenet': imgsz = 84 # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use # get_batch(train or test) to get different batch. # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) else: raise NotImplementedError meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr, num_updates=num_updates).cuda() tb = SummaryWriter('runs') # main loop for episode_num in range(1500): # 1. train if dataset == 'omniglot': support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == 'mini-imagenet': try: batch_test = iter(db).next() except StopIteration as err: mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() # backprop has been embeded in forward func. if episode_num % 100 = 0: meta.prv_angle = 0 accs = meta(support_x, support_y, query_x, query_y) train_acc = np.array(accs).mean() # 2. test if episode_num % 30 == 0: test_accs = [] for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. if dataset == 'omniglot': support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == 'mini-imagenet': try: batch_test = iter(db_test).next() except StopIteration as err: mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() # get accuracy test_acc = meta.pred(support_x, support_y, query_x, query_y) test_accs.append(test_acc) test_acc = np.array(test_accs).mean() print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc) tb.add_scalar('test-acc', test_acc) tb.add_scalar('finetune-acc', train_acc)
def main(): meta_batchsz = 32 * 3 n_way = 5 k_shot = 5 k_query = k_shot meta_lr = 1e-3 num_updates = 5 dataset = "mini-imagenet" if dataset == "omniglot": imgsz = 28 db = OmniglotNShot( "dataset", batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz, ) elif dataset == "mini-imagenet": imgsz = 84 # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use # get_batch(train or test) to get different batch. # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. mini = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="train", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz, ) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) mini_test = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="test", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz, ) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) else: raise NotImplementedError # do NOT call .cuda() implicitly net = CSML() net.deploy() tb = SummaryWriter("runs") # main loop for episode_num in range(200000): # 1. train if dataset == "omniglot": support_x, support_y, query_x, query_y = db.get_batch("test") support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == "mini-imagenet": try: batch_train = iter(db).next() except StopIteration as err: mini = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="train", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz, ) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) batch_train = iter(db).next() support_x = Variable(batch_train[0]) support_y = Variable(batch_train[1]) query_x = Variable(batch_train[2]) query_y = Variable(batch_train[3]) print(support_x.size(), support_y.size()) # backprop has been embeded in forward func. accs = net.train(support_x, support_y, query_x, query_y) train_acc = np.array(accs).mean() # 2. test if episode_num % 30 == 220: test_accs = [] for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. if dataset == "omniglot": support_x, support_y, query_x, query_y = db.get_batch( "test") support_x = Variable( torch.from_numpy(support_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable( torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == "mini-imagenet": try: batch_test = iter(db_test).next() except StopIteration as err: mini_test = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="test", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz, ) db_test = DataLoader( mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True, ) batch_test = iter(db).next() support_x = Variable(batch_test[0]) support_y = Variable(batch_test[1]) query_x = Variable(batch_test[2]) query_y = Variable(batch_test[3]) # get accuracy # test_acc = net.train(support_x, support_y, query_x, query_y, train=False) test_accs.append(test_acc) test_acc = np.array(test_accs).mean() print( "episode:", episode_num, "\tfinetune acc:%.6f" % train_acc, "\t\ttest acc:%.6f" % test_acc, ) tb.add_scalar("test-acc", test_acc) tb.add_scalar("finetune-acc", train_acc)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) shared_config = [ ('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ] nway_config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] # reads in image discriminator_config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('leakyrelu', [.2, True]), ('bn', [64]), ('flatten', []), ('linear', [1, 64]) # don't use a sigmoid at the end ] # new gen_config # starts from image and convolves it into new ones gen_config = [ ('convt2d', [1, 64, 3, 3, 1, 1]), ('leakyrelu', [.2, True]), ('bn', [64]), ('random_proj', [100, 28, 64]), ('convt2d', [128, 64, 3, 3, 1, 1]), #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding] ('relu', [.2, True]), ('bn', [64]), # ('encode', [1024, 64*28*28]), # ('decode', [64*28*28, 1024]), ('relu', [.2, True]), ('conv2d', [64, 64, 3, 3, 1, 1]), #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding] ('relu', [.2, True]), ('bn', [64]), ('conv2d', [1, 64, 3, 3, 1, 1]), ('sigmoid', [True]) ] # old gen_config # gen_config = [ # ('random_proj', [100, 512, 64, 7]), # [latent_dim, emb_size, ch_out, h_out/w_out] # # img: (64, 7, 7) # ('convt2d', [64, 32, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding] # ('bn', [32]), # ('relu', [True]), # # img: (32, 14, 14) # ('convt2d', [32, 1, 4, 4, 2, 1]), # # img: (1, 28, 28) # ('sigmoid', [True]) # ] # if args.condition_discrim: # discriminator_config = [ # ('condition', [512, 1, 6]), # [emb_dim, emb_ch_out, h_out/w_out] # ('conv2d', [128, 65, 2, 2, 1, 0]), # ('leakyrelu', [.2, True]), # ('bn', [128]), # ('conv2d', [128, 128, 2, 2, 1, 0]), # ('leakyrelu', [.2, True]), # ('bn', [128]), # ('flatten', []), # ('linear', [1, 2048]) # ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mamlGAN = MetaGAN(args, shared_config, nway_config, discriminator_config, gen_config).to(device) tmp = filter(lambda x: x.requires_grad, mamlGAN.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(mamlGAN) print('Total trainable tensors:', num) db_train = OmniglotNShot('omniglot', batchsz=args.tasks_per_batch, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, img_sz=args.img_sz) save_model = not args.no_save if save_model: now = datetime.now().replace(second=0, microsecond=0) path = "results/" + str(now) + "_omni" mkdir_p(path) file = open(path + '/architecture.txt', 'w+') file.write("shared_config = " + json.dumps(shared_config) + "\n" + "nway_config = " + json.dumps(nway_config) + "\n" + "discriminator_config = " + json.dumps(discriminator_config) + "\n" + "gen_config = " + json.dumps(gen_config) + "\n" + "learn_inner_lr = " + str(args.learn_inner_lr) + "\n" + "condition_discrim = " + str(args.condition_discrim)) file.close() for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = mamlGAN(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print("step " + str(step)) for key in accs.keys(): print(key + ": " + str(accs[key])) if save_model: save_train_accs(path, accs, int(step)) if step % 500 == 0: print("testing") accs = [] imgs = [] for _ in range(1000 // args.tasks_per_batch): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc, ims = mamlGAN.finetunning( x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) imgs.append(x_spt_one.cpu().detach().numpy()) imgs.append(ims.cpu().detach().numpy()) if args.single_fast_test: break if args.single_fast_test: break accs = np.array(accs).mean(axis=0).astype(np.float16) if save_model: save_test_accs(path, accs, int(step)) imgs = np.array(imgs) save_imgs(path, imgs, step) torch.save({'model_state_dict': mamlGAN.state_dict()}, path + "/model_step" + str(step)) # to load, do this: # checkpoint = torch.load(path + "/model_step" + str(step)) # mamlGAN.load_state_dict(checkpoint['model_state_dict']) # [b, update_steps+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file): """ obey the expriment setting of MAML and Learning2Compare, we randomly sample 600 episodes and 15 query images per query set. :param net: :param batchsz: :return: """ k_query = 15 db = OmniglotNShot('dataset', batchsz=batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) accs = [] episode_num = 0 # record tested num of episodes for i in range(600 // batchsz): support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).int()).cuda() query_y = Variable(torch.from_numpy(query_y).int()).cuda() # we will split query set into 15 splits. # query_x : [batch, 15*way, c_, h, w] # query_x_b : tuple, 15 * [b, way, c_, h, w] query_x_b = torch.chunk(query_x, k_query, dim=1) # query_y : [batch, 15*way] # query_y_b: 15* [b, way] query_y_b = torch.chunk(query_y, k_query, dim=1) preds = [] net.eval() # we don't need the total acc on 600 episodes, but we need the acc per sets of 15*nway setsz. total_correct = 0 total_num = 0 for query_x_mini, query_y_mini in zip(query_x_b, query_y_b): # print('query_x_mini', query_x_mini.size(), 'query_y_mini', query_y_mini.size()) pred, correct = net(support_x, support_y, query_x_mini.contiguous(), query_y_mini, False) correct = correct.sum() # multi-gpu # pred: [b, nway] preds.append(pred) total_correct += correct.data[0] total_num += query_y_mini.size(0) * query_y_mini.size(1) # # 15 * [b, nway] => [b, 15*nway] # preds = torch.cat(preds, dim= 1) acc = total_correct / total_num print('%.5f,' % acc, end=' ') sys.stdout.flush() accs.append(acc) # update tested episode number episode_num += query_y.size(0) if episode_num > episodesz: # test current tested episodes acc. acc = np.array(accs).mean() if acc >= threhold: # if current acc is very high, we conduct all 600 episodes testing. continue else: # current acc is low, just conduct `episodesz` num of episodes. break # compute the distribution of 600/episodesz episodes acc. global best_accuracy accs = np.array(accs) accuracy, sem = mean_confidence_interval(accs) print('\naccuracy:', accuracy, 'sem:', sem) print('<<<<<<<<< accuracy:', accuracy, 'best accuracy:', best_accuracy, '>>>>>>>>') if accuracy > best_accuracy: best_accuracy = accuracy torch.save(net.state_dict(), mdl_file) print('Saved to checkpoint:', mdl_file) return accuracy, sem
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way') argparser.add_argument('-k', help='k shot') argparser.add_argument('-b', help='batch size') argparser.add_argument('-l', help='learning rate', default=1e-3) argparser.add_argument('-t', help='threshold to test all episodes', default=0.97) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) k_query = 1 batchsz = int(args.b) imgsz = 84 lr = float(args.l) threshold = float(args.t) db = OmniglotNShot('dataset', batchsz=batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) print('Omniglot: no rotate! %d-way %d-shot lr:%f' % (n_way, k_shot, lr)) net = NaiveRN(n_way, k_shot, imgsz).cuda() print(net) mdl_file = 'ckpt/omni%d%d.mdl' % (n_way, k_shot) if os.path.exists(mdl_file): print('recover from state: ', mdl_file) net.load_state_dict(torch.load(mdl_file)) else: print('training from scratch.') model_parameters = filter(lambda p: p.requires_grad, net.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('Total params:', params) input, input_y, query, query_y = db.get_batch( 'train') # (batch, n_way*k_shot, img) print('get batch:', input.shape, query.shape, input_y.shape, query_y.shape) optimizer = optim.Adam(net.parameters(), lr=lr) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=15, verbose=True) total_train_loss = 0 for step in range(100000000): # 1. test if step % 400 == 0: accuracy, _ = 0, 0 # evaluation(net, batchsz, n_way, k_shot, imgsz, 300, threshold, mdl_file) scheduler.step(accuracy) # 2. train support_x, support_y, query_x, query_y = db.get_batch('train') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).int()).cuda() query_y = Variable(torch.from_numpy(query_y).int()).cuda() loss = net(support_x, support_y, query_x, query_y) total_train_loss += loss.data[0] optimizer.zero_grad() loss.backward() optimizer.step() # 3. print if step % 20 == 0 and step != 0: print('%d-way %d-shot %d batch> step:%d, loss:%f' % (n_way, k_shot, batchsz, step, total_train_loss)) total_train_loss = 0
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way + 1, 64])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number db_train = OmniglotNShot('omniglot', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) save_path = os.getcwd() + '/data/omniglot/model_batchsz' + str( args.k_spt) + '_stepsz' + str(args.update_lr) + '_epoch' for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs, al_accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs, '\tAL acc:', al_accs) if step % 500 == 0: al_accs, accs = [], [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) al_accs.append(maml.al_test(x_qry_one, y_qry_one)) accs.append(test_acc) # [b, update_step+1] pdb.set_trace() accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs) al_accs = np.array(al_accs).mean(axis=0).astype(np.float16) print('AL acc:', al_accs) torch.save(maml.state_dict(), save_path + str(step) + "_al.pt")
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) db_train = OmniglotNShot('omniglot', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) print(args.task_num) # 32 print(args.n_way) # 5 print(args.k_spt) # 1 print(args.k_qry) # 15 print(args.imgsz) # 28 for step in range(args.epoch): #if step % 4000 == 0: x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias #print("TYPE: " + y_spt.type()) #pdb.set_trace() y_spt = torch.tensor(y_spt, dtype=torch.int64, device=device) # diff syntax ? y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device) #print("TYPE: "+y_spt.type()) #pdb.set_trace() accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) # if step % 500 == 0: if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) y_spt = torch.tensor(y_spt, dtype=torch.int64, device=device) # diff syntax ? y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): print(x_spt_one.size()) print(y_spt_one.size()) print(x_qry_one.size()) print(y_qry_one.size()) test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(args): config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64])] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") maml = Meta(args, config, device).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print('Total trainable tensors:', num) db_train = OmniglotNShot('./', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) print('trainstep:', step, '\ttraining acc:', accs) if (step + 1) % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs) ############################## for i in range(args.prune_iteration): # prune print("the {}th prune step".format(i)) x_spt, y_spt, x_qry, y_qry = db_train.getHoleTrain() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) maml.prune(x_spt, y_spt, x_qry, y_qry, args.prune_number_one_epoch, args.max_prune_number) # fine-tuning print("start finetuning....") finetune_epoch = args.finetune_epoch finetune_epoch = finetune_epoch * (2 if i == args.prune_iteration - 1 else 1) for step in range(args.finetune_epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) accs = maml(x_spt, y_spt, x_qry, y_qry, finetune=True) print('finetune step:', step, '\ttraining acc:', accs) # print the test accuracy after pruning print("start testing....") accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip( x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def test_progress(args, net, device, viz=None, global_step=0): """ to plot ani/ars/acc with respect to training epochs. :param args: :param net: :param device: :param viz: :return: """ if args.resume is None: print('No ckpt file specified! make sure you are training!') exp = args.exp if viz is None: viz = visdom.Visdom(env='test') visualh = VisualH(viz) print('Testing now...') output_dir = os.path.join(args.test_dir, args.exp) # create test_dir if not os.path.exists(args.test_dir): os.makedirs(args.test_dir) # create test_dir/exp if not os.path.exists(output_dir): os.makedirs(output_dir) # clustering, visualization and classification db_test = OmniglotNShot('db/omniglot', batchsz=1, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = 0, 0, 0, 0 acc0, acc1 = [], [] for batchidx in range(args.test_episode_num): spt_x, spt_y, qry_x, qry_y = db_test.next('test') spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \ torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device) assert spt_x.size(0) == 1 spt_x, spt_y, qry_x, qry_y = spt_x.squeeze(0), spt_y.squeeze( 0), qry_x.squeeze(0), qry_y.squeeze(0) # we can get the representation before first update, after k update # and test the representation on merged(test_spt, test_qry) set h_spt0, h_spt1, h_qry0, h_qry1, _, new_net = net.finetuning( spt_x, spt_y, qry_x, qry_y, args.finetuning_steps, None) if batchidx == 0: visualh.update(h_spt0, h_spt1, h_qry0, h_qry1, spt_y, qry_y, global_step) # we will use the acquired representation to cluster. # h_spt: [sptsz, h_dim] # h_qry: [qrysz, h_dim] h_qry0_np = h_qry0.detach().cpu().numpy() h_qry1_np = h_qry1.detach().cpu().numpy() qry_y_np = qry_y.detach().cpu().numpy() h_qry0_pred = cluster.KMeans(n_clusters=args.n_way, random_state=0).fit(h_qry0_np).labels_ h_qry1_pred = cluster.KMeans(n_clusters=args.n_way, random_state=0).fit(h_qry1_np).labels_ h_qry0_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry0_pred) h_qry0_ars += metrics.adjusted_rand_score(qry_y_np, h_qry0_pred) h_qry1_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry1_pred) h_qry1_ars += metrics.adjusted_rand_score(qry_y_np, h_qry1_pred) h_qry0_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y) h_qry1_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y) # viz.heatmap(X=h_qry0_cm, win=args.exp+' h_qry0_cm', opts=dict(title=args.exp+' h_qry0_cm:%d'%batchidx, # colormap='Electric')) # viz.heatmap(X=h_qry1_cm, win=args.exp+' h_qry1_cm', opts=dict(title=args.exp+' h_qry1_cm:%d'%batchidx, # colormap='Electric')) # return is a list of [acc_step0, acc_step1 ,...] acc0.append( net.classify_train(h_spt0, spt_y, h_qry0, qry_y, use_h=True, train_step=args.classify_steps)) acc1.append( net.classify_train(h_spt1, spt_y, h_qry1, qry_y, use_h=True, train_step=args.classify_steps)) if batchidx == 0: spt_x_hat0 = net.forward_ae(spt_x[:64]) qry_x_hat0 = net.forward_ae(qry_x[:64]) spt_x_hat1 = new_net.forward_ae(spt_x[:64]) qry_x_hat1 = new_net.forward_ae(qry_x[:64]) viz.images(qry_x[:64], nrow=8, win=exp + 'qry_x', opts=dict(title=exp + 'qry_x')) # viz.images(spt_x_hat0, nrow=8, win=exp+'spt_x_hat0', opts=dict(title=exp+'spt_x_hat0')) viz.images(qry_x_hat0, nrow=8, win=exp + 'qry_x_hat0', opts=dict(title=exp + 'qry_x_hat0')) # viz.images(spt_x_hat1, nrow=8, win=exp+'spt_x_hat1', opts=dict(title=exp+'spt_x_hat1')) viz.images(qry_x_hat1, nrow=8, win=exp + 'qry_x_hat1', opts=dict(title=exp + 'qry_x_hat1')) if batchidx > 0: break h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = h_qry0_ami / (batchidx + 1), h_qry0_ars / (batchidx + 1), \ h_qry1_ami / (batchidx + 1), h_qry1_ars / (batchidx + 1) # [[epsode1], [episode2],...] = [N, steps] => [steps] acc0, acc1 = np.array(acc0).mean(axis=0), np.array(acc1).mean(axis=0) print('ami:', h_qry0_ami, h_qry1_ami) print('ars:', h_qry0_ars, h_qry1_ars) viz.line([[h_qry0_ami, h_qry1_ami]], [global_step], win=exp + 'ami_on_qry01', update='append') viz.line([[h_qry0_ars, h_qry1_ars]], [global_step], win=exp + 'ars_on_qry01', update='append') print('acc:\n', acc0, '\n', acc1) viz.line([[acc0[-1], acc1[-1]]], [global_step], win=exp + 'acc_on_qry01', update='append')
def main(args): args = update_args(args) torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) device = torch.device('cuda') if args.is_meta: # optimizer has been embedded in model. net = MetaAE(args) # model_parameters = filter(lambda p: p.requires_grad, net.learner.parameters()) # params = sum([np.prod(p.size()) for p in model_parameters]) # print('Total params:', params) tmp = filter(lambda x: x.requires_grad, net.learner.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print('Total trainable variables:', num) else: net = AE(args, use_logits=True) optimizer = optim.Adam(list(net.encoder.parameters()) + list(net.decoder.parameters()), lr=args.meta_lr) tmp = filter( lambda x: x.requires_grad, list(net.encoder.parameters()) + list(net.decoder.parameters())) num = sum(map(lambda x: np.prod(x.shape), tmp)) print('Total trainable variables:', num) net.to(device) print(net) print('=' * 15, 'Experiment:', args.exp, '=' * 15) print(args) if args.h_dim == 2: # borrowed from https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py h_range = np.rollaxis( np.mgrid[args.h_range:-args.h_range:args.h_nrow * 1j, args.h_range:-args.h_range:args.h_nrow * 1j], 0, 3) # [b, q_h] h_manifold = torch.from_numpy(h_range.reshape([-1, 2])).to(device).float() print('h_manifold:', h_manifold.shape) else: h_manifold = None # try to resume from ckpt.mdl file epoch_start = 0 if args.resume is not None: # ckpt/normal-fc-vae_640_2018-11-20_09:58:58.mdl mdl_file = args.resume epoch_start = int(mdl_file.split('_')[-3]) net.load_state_dict(torch.load(mdl_file)) print('Resume from:', args.resume, 'epoch/batches:', epoch_start) else: print('Training/test from scratch...') if args.test: assert args.resume is not None test.test_ft_steps(args, net, device) return vis = visdom.Visdom(env=args.exp) visualh = VisualH(vis) vis.line([[0, 0, 0]], [epoch_start], win=args.exp + 'train_loss', opts=dict(title=args.exp + 'train_qloss', legend=['loss', '-lklh', 'kld'], xlabel='global_step')) # for test_progress vis.line([[0, 0]], [epoch_start], win=args.exp + 'acc_on_qry01', opts=dict(title=args.exp + 'acc_on_qry01', legend=['h_qry0', 'h_qry1'], xlabel='global_step')) vis.line([[0, 0]], [epoch_start], win=args.exp + 'ami_on_qry01', opts=dict(title=args.exp + 'ami_on_qry01', legend=['h_qry0', 'h_qry1'], xlabel='global_step')) vis.line([[0, 0]], [epoch_start], win=args.exp + 'ars_on_qry01', opts=dict(title=args.exp + 'ars_on_qry01', legend=['h_qry0', 'h_qry1'], xlabel='global_step')) # 1. train db_train = OmniglotNShot('db/omniglot', batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) # epoch = batch number here. for epoch in range(epoch_start, args.train_episode_num): spt_x, spt_y, qry_x, qry_y = db_train.next() spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \ torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device) if args.is_meta: # for meta loss_optim, losses_q, likelihoods_q, klds_q = net( spt_x, spt_y, qry_x, qry_y) if epoch % 300 == 0: if args.is_vae: # print(losses_q, likelihoods_q, klds_q) vis.line([[ losses_q[-1].item(), -likelihoods_q[-1].item(), klds_q[-1].item() ]], [epoch], win=args.exp + 'train_loss', update='append') print(epoch) print( 'loss_q:', torch.stack(losses_q).detach().cpu().numpy().astype( np.float16)) print( 'lkhd_q:', torch.stack( likelihoods_q).detach().cpu().numpy().astype( np.float16)) print( 'klds_q:', torch.stack(klds_q).cpu().detach().numpy().astype( np.float16)) else: # print(losses_q, likelihoods_q, klds_q) vis.line([[losses_q[-1].item(), 0, 0]], [epoch], win=args.exp + 'train_loss', update='append') print( epoch, torch.stack(losses_q).detach().cpu().numpy().astype( np.float16)) else: # for normal vae/ae loss_optim, _, likelihood, kld = net(spt_x, spt_y, qry_x, qry_y) optimizer.zero_grad() loss_optim.backward() torch.nn.utils.clip_grad_norm_( list(net.encoder.parameters()) + list(net.decoder.parameters()), 10) optimizer.step() if epoch % 300 == 0: print(epoch, loss_optim.item()) if not args.is_vae: vis.line([[loss_optim.item(), 0, 0]], [epoch], win='train_loss', update='append') else: vis.line( [[loss_optim.item(), -likelihood.item(), kld.item()]], [epoch], win='train_loss', update='append') if epoch % 3000 == 0: # [qrysz, 1, 64, 64] => [qrysz, 1, 64, 64] x_hat = net.forward_ae(qry_x[0]) vis.images(x_hat, nrow=args.k_qry, win='train_x_hat', opts=dict(title='train_x_hat')) test.test_progress(args, net, device, vis, epoch) # save checkpoint. if epoch % 10000 == 0: date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") mdl_file = os.path.join( args.ckpt_dir, args.exp + '_%d' % epoch + '_' + date_str + '.mdl') torch.save(net.state_dict(), mdl_file) print('Saved into ckpt file:', mdl_file) # save checkpoint. date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") mdl_file = os.path.join( args.ckpt_dir, args.exp + '_%d' % args.epoch + '_' + date_str + '.mdl') torch.save(net.state_dict(), mdl_file) print('Saved Last state ckpt file:', mdl_file)
def main(args): if not os.path.exists('./logs'): os.mkdir('./logs') logfile = os.path.sep.join(('.', 'logs', f'omniglot_way[{args.n_way}]_shot[{args.k_spt}].json')) if args.write_log: log_fp = open(logfile, 'wb') torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [ ('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]), ('flatten', []), ('linear', [args.n_way, 64]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) path = os.sep.join((os.path.dirname(__file__), 'dataset', 'omniglot')) db_train = OmniglotNShot(path, batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz) for step in range(args.epoch): # 获取一定的 epoch 数据. 在omniglot NShot类里写的是 x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next('test') x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(args): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) torch.backends.cudnn.benchmark=True print(args) config = [ ("conv2d", [64, 1, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 3, 3, 2, 0]), ("relu", [True]), ("bn", [64]), ("conv2d", [64, 64, 2, 2, 1, 0]), ("relu", [True]), ("bn", [64]), ("flatten", []), ("linear", [args.n_way, 64]), ] device = torch.device("cuda") maml = Meta(args, config).to(device) db_train = OmniglotNShot( "omniglot", batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=args.imgsz, ) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print("Total trainable tensors:", num) for step in range(args.epoch): x_spt, y_spt, x_qry, y_qry = db_train.next() x_spt, y_spt, x_qry, y_qry = ( torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device), ) # x_spt: shape: 32, 5, 1, 28, 28 # y_spt: shape: 32, 5 # x_qry: 32, 75, 1, 28, 28 # y_qry: 32, 75 # set traning=True to update running_mean, running_variance, bn_weights, bn_bias accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 50 == 0: print("step:", step, "\ttraining acc:", accs) if step % 500 == 0: accs = [] for _ in range(1000 // args.task_num): # test x_spt, y_spt, x_qry, y_qry = db_train.next("test") x_spt, y_spt, x_qry, y_qry = ( torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device), ) # split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append(test_acc) # [b, update_step+1] accs = np.array(accs).mean(axis=0).astype(np.float16) print("Test acc:", accs)