def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='batch size', default=32) argparser.add_argument('-l', help='meta learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) meta_lr = float(args.l) train_lr = 0.4 k_query = 15 imgsz = 84 mdl_file = 'ckpt/omniglot%d%d.mdl' % (n_way, k_shot) print('omniglot: %d-way %d-shot meta-lr:%f, train-lr:%f' % (n_way, k_shot, meta_lr, train_lr)) device = torch.device('cuda:0') net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr, device) print(net) # batchsz here means total episode number db = OmniglotNShot('omniglot', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) for step in range(10000000): # train support_x, support_y, query_x, query_y = db.get_batch('train') support_x = torch.from_numpy(support_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1).to(device) query_x = torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1).to(device) support_y = torch.from_numpy(support_y).long().to(device) query_y = torch.from_numpy(query_y).long().to(device) accs = net(support_x, support_y, query_x, query_y, training=True) if step % 20 == 0: print(step, '\t', accs) if step % 1000 == 0: # test pass
def main(): meta_batchsz = 32 n_way = 20 k_shot = 1 k_query = k_shot meta_lr = 1e-3 # meta_lr = 1 num_updates = 5 dataset = 'omniglot' if dataset == 'omniglot': imgsz = 28 db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) elif dataset == 'mini-imagenet': imgsz = 84 # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use # get_batch(train or test) to get different batch. # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) else: raise NotImplementedError meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr, num_updates=num_updates).cuda() tb = SummaryWriter('runs') # main loop for episode_num in range(1500): # 1. train if dataset == 'omniglot': support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == 'mini-imagenet': try: batch_test = iter(db).next() except StopIteration as err: mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() # backprop has been embeded in forward func. if episode_num % 100 = 0: meta.prv_angle = 0 accs = meta(support_x, support_y, query_x, query_y) train_acc = np.array(accs).mean() # 2. test if episode_num % 30 == 0: test_accs = [] for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. if dataset == 'omniglot': support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == 'mini-imagenet': try: batch_test = iter(db_test).next() except StopIteration as err: mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() # get accuracy test_acc = meta.pred(support_x, support_y, query_x, query_y) test_accs.append(test_acc) test_acc = np.array(test_accs).mean() print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc) tb.add_scalar('test-acc', test_acc) tb.add_scalar('finetune-acc', train_acc)
def evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file): """ obey the expriment setting of MAML and Learning2Compare, we randomly sample 600 episodes and 15 query images per query set. :param net: :param batchsz: :return: """ k_query = 15 db = OmniglotNShot('dataset', batchsz=batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) accs = [] episode_num = 0 # record tested num of episodes for i in range(600 // batchsz): support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).int()).cuda() query_y = Variable(torch.from_numpy(query_y).int()).cuda() # we will split query set into 15 splits. # query_x : [batch, 15*way, c_, h, w] # query_x_b : tuple, 15 * [b, way, c_, h, w] query_x_b = torch.chunk(query_x, k_query, dim=1) # query_y : [batch, 15*way] # query_y_b: 15* [b, way] query_y_b = torch.chunk(query_y, k_query, dim=1) preds = [] net.eval() # we don't need the total acc on 600 episodes, but we need the acc per sets of 15*nway setsz. total_correct = 0 total_num = 0 for query_x_mini, query_y_mini in zip(query_x_b, query_y_b): # print('query_x_mini', query_x_mini.size(), 'query_y_mini', query_y_mini.size()) pred, correct = net(support_x, support_y, query_x_mini.contiguous(), query_y_mini, False) correct = correct.sum() # multi-gpu # pred: [b, nway] preds.append(pred) total_correct += correct.data[0] total_num += query_y_mini.size(0) * query_y_mini.size(1) # # 15 * [b, nway] => [b, 15*nway] # preds = torch.cat(preds, dim= 1) acc = total_correct / total_num print('%.5f,' % acc, end=' ') sys.stdout.flush() accs.append(acc) # update tested episode number episode_num += query_y.size(0) if episode_num > episodesz: # test current tested episodes acc. acc = np.array(accs).mean() if acc >= threhold: # if current acc is very high, we conduct all 600 episodes testing. continue else: # current acc is low, just conduct `episodesz` num of episodes. break # compute the distribution of 600/episodesz episodes acc. global best_accuracy accs = np.array(accs) accuracy, sem = mean_confidence_interval(accs) print('\naccuracy:', accuracy, 'sem:', sem) print('<<<<<<<<< accuracy:', accuracy, 'best accuracy:', best_accuracy, '>>>>>>>>') if accuracy > best_accuracy: best_accuracy = accuracy torch.save(net.state_dict(), mdl_file) print('Saved to checkpoint:', mdl_file) return accuracy, sem
def main(): meta_batchsz = 32 * 3 n_way = 5 k_shot = 5 k_query = k_shot meta_lr = 1e-3 num_updates = 5 dataset = "mini-imagenet" if dataset == "omniglot": imgsz = 28 db = OmniglotNShot( "dataset", batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz, ) elif dataset == "mini-imagenet": imgsz = 84 # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use # get_batch(train or test) to get different batch. # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. mini = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="train", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz, ) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) mini_test = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="test", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz, ) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) else: raise NotImplementedError # do NOT call .cuda() implicitly net = CSML() net.deploy() tb = SummaryWriter("runs") # main loop for episode_num in range(200000): # 1. train if dataset == "omniglot": support_x, support_y, query_x, query_y = db.get_batch("test") support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == "mini-imagenet": try: batch_train = iter(db).next() except StopIteration as err: mini = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="train", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz, ) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) batch_train = iter(db).next() support_x = Variable(batch_train[0]) support_y = Variable(batch_train[1]) query_x = Variable(batch_train[2]) query_y = Variable(batch_train[3]) print(support_x.size(), support_y.size()) # backprop has been embeded in forward func. accs = net.train(support_x, support_y, query_x, query_y) train_acc = np.array(accs).mean() # 2. test if episode_num % 30 == 220: test_accs = [] for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. if dataset == "omniglot": support_x, support_y, query_x, query_y = db.get_batch( "test") support_x = Variable( torch.from_numpy(support_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable( torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == "mini-imagenet": try: batch_test = iter(db_test).next() except StopIteration as err: mini_test = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="test", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz, ) db_test = DataLoader( mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True, ) batch_test = iter(db).next() support_x = Variable(batch_test[0]) support_y = Variable(batch_test[1]) query_x = Variable(batch_test[2]) query_y = Variable(batch_test[3]) # get accuracy # test_acc = net.train(support_x, support_y, query_x, query_y, train=False) test_accs.append(test_acc) test_acc = np.array(test_accs).mean() print( "episode:", episode_num, "\tfinetune acc:%.6f" % train_acc, "\t\ttest acc:%.6f" % test_acc, ) tb.add_scalar("test-acc", test_acc) tb.add_scalar("finetune-acc", train_acc)
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way') argparser.add_argument('-k', help='k shot') argparser.add_argument('-b', help='batch size') argparser.add_argument('-l', help='learning rate', default=1e-3) argparser.add_argument('-t', help='threshold to test all episodes', default=0.97) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) k_query = 1 batchsz = int(args.b) imgsz = 84 lr = float(args.l) threshold = float(args.t) db = OmniglotNShot('dataset', batchsz=batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) print('Omniglot: no rotate! %d-way %d-shot lr:%f' % (n_way, k_shot, lr)) net = NaiveRN(n_way, k_shot, imgsz).cuda() print(net) mdl_file = 'ckpt/omni%d%d.mdl' % (n_way, k_shot) if os.path.exists(mdl_file): print('recover from state: ', mdl_file) net.load_state_dict(torch.load(mdl_file)) else: print('training from scratch.') model_parameters = filter(lambda p: p.requires_grad, net.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('Total params:', params) input, input_y, query, query_y = db.get_batch( 'train') # (batch, n_way*k_shot, img) print('get batch:', input.shape, query.shape, input_y.shape, query_y.shape) optimizer = optim.Adam(net.parameters(), lr=lr) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=15, verbose=True) total_train_loss = 0 for step in range(100000000): # 1. test if step % 400 == 0: accuracy, _ = 0, 0 # evaluation(net, batchsz, n_way, k_shot, imgsz, 300, threshold, mdl_file) scheduler.step(accuracy) # 2. train support_x, support_y, query_x, query_y = db.get_batch('train') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).int()).cuda() query_y = Variable(torch.from_numpy(query_y).int()).cuda() loss = net(support_x, support_y, query_x, query_y) total_train_loss += loss.data[0] optimizer.zero_grad() loss.backward() optimizer.step() # 3. print if step % 20 == 0 and step != 0: print('%d-way %d-shot %d batch> step:%d, loss:%f' % (n_way, k_shot, batchsz, step, total_train_loss)) total_train_loss = 0