def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) test_result = {} best_acc = 0.0 maml = Meta(args, Param.config).to(Param.device) maml = torch.nn.DataParallel(maml) opt = optim.Adam(maml.parameters(), lr=args.meta_lr) #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True) testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True) train_data = inf_get(trainloader) test_data = inf_get(testloader) for epoch in range(args.epoch): support_x, support_y, meta_x, meta_y = train_data.__next__() support_x, support_y, meta_x, meta_y = support_x.to(Param.device), support_y.to(Param.device), meta_x.to(Param.device), meta_y.to(Param.device) meta_loss = maml(support_x, support_y, meta_x, meta_y).mean() opt.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value = 10.0) opt.step() plot.plot('meta_loss', meta_loss.item()) if(epoch % 2000 == 999): ans = None maml_clone = deepcopy(maml) for _ in range(600): support_x, support_y, qx, qy = test_data.__next__() support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device) temp = maml_clone(support_x, support_y, qx, qy, meta_train = False) if(ans is None): ans = temp else: ans = torch.cat([ans, temp], dim = 0) ans = ans.mean(dim = 0).tolist() test_result[epoch] = ans if (ans[-1] > best_acc): best_acc = ans[-1] torch.save(maml.state_dict(), Param.out_path + 'net_'+ str(epoch) + '_' + str(best_acc) + '.pkl') del maml_clone print(str(epoch) + ': '+str(ans)) with open(Param.out_path+'test.json','w') as f: json.dump(test_result,f) if (epoch < 5) or (epoch % 100 == 99): plot.flush() plot.tick()
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) # df_spt = pd.read_csv('support_set.csv') # df_qry = pd.read_csv('query_set.csv') print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] cuda = 'cuda:' + str(args.gpu_index) device = torch.device(cuda) with open('model.pkl', 'rb') as f: maml = cloudpickle.load(f).to(device) # maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) # print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number # mini = MiniImagenet('./flower/', mode='train', n_way=args.n_way, k_shot=args.k_spt, # k_query=args.k_qry, # batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('./flower/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('step:', step, '\ttest acc:', accs)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) test_result = {} best_acc = 0.0 maml = Meta(args, Param.config).to(Param.device) if n_gpus>1: maml = torch.nn.DataParallel(maml) state_dict = torch.load(Param.out_path+args.ckpt) print(state_dict.keys()) pretrained_dict = OrderedDict() for k in state_dict.keys(): if n_gpus==1: pretrained_dict[k[7:]] = deepcopy(state_dict[k]) else: pretrained_dict[k[0:]] = deepcopy(state_dict[k]) maml.load_state_dict(pretrained_dict) print("Load from ckpt:", Param.out_path+args.ckpt) #opt = optim.Adam(maml.parameters(), lr=args.meta_lr) #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) #trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) #valset = MiniImagenet(Param.root, mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) #trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True) #valloader = DataLoader(valset, batch_size=4, shuffle=True, num_workers=4, drop_last=True) testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True) #train_data = inf_get(trainloader) #val_data = inf_get(valloader) test_data = inf_get(testloader) """Test for 600 epochs (each has 4 tasks)""" ans = None maml_clone = deepcopy(maml) for itr in range(600): # 600x4 test tasks support_x, support_y, qx, qy = test_data.__next__() support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device) temp = maml_clone(support_x, support_y, qx, qy, meta_train = False) if(ans is None): ans = temp else: ans = torch.cat([ans, temp], dim = 0) if itr%100==0: print(itr,ans.mean(dim = 0).tolist()) ans = ans.mean(dim = 0).tolist() print('Acc: '+str(ans)) with open(Param.out_path+'test.json','w') as f: json.dump(ans,f)
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='batch size', default=4) argparser.add_argument('-l', help='learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) lr = float(args.l) k_query = 1 imgsz = 84 threhold = 0.699 if k_shot==5 else 0.584 # threshold for when to test full version of episode mdl_file = 'ckpt/maml%d%d.mdl'%(n_way, k_shot) print('mini-imagnet: %d-way %d-shot lr:%f, threshold:%f' % (n_way, k_shot, lr, threhold)) device = torch.device('cuda') net = MAML(n_way, k_shot, k_query, meta_batchsz=meta_batchsz, K=5, device=device) print(net) if os.path.exists(mdl_file): print('load from checkpoint ...', mdl_file) net.load_state_dict(torch.load(mdl_file)) else: print('training from scratch.') # whole parameters number model_parameters = filter(lambda p: p.requires_grad, net.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('Total params:', params) for epoch in range(1000): # batchsz here means total episode number mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) # fetch meta_batchsz num of episode each time db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=8, pin_memory=True) for step, batch in enumerate(db): # 2. train support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training = True) if step % 10 == 0: print(accs)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) ckpt_dir = "./checkpoint_miniimage.pth" print("Load trained model") ckpt = torch.load(ckpt_dir) maml.load_state_dict(ckpt['model']) mini_test = MiniImagenet("F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\", mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=1, resize=args.imgsz) db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] #count = 0 #print("Test_loader",db_test) for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True cudnn.enabled = True saver = Saver(args) # set log log_format = '%(asctime)s %(message)s' logging.basicConfig(level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p', filename=os.path.join(saver.experiment_dir, 'log.txt'), filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) logging.getLogger().addHandler(console) if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) saver.create_exp_dir(scripts_to_save=glob.glob('*.py') + glob.glob('*.sh') + glob.glob('*.yml')) saver.save_experiment_config() summary = TensorboardSummary(saver.experiment_dir) writer = summary.create_summary() best_pred = 0 logging.info(args) device = torch.device('cuda') criterion = nn.CrossEntropyLoss() criterion = criterion.to(device) maml = Meta(args, criterion).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) logging.info(maml) logging.info('Total trainable tensors: {}'.format(num)) # batch_size here means total episode number mini = MiniImagenet(args.data_path, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batch_size=args.batch_size, resize=args.img_size, split=[0, args.train_portion]) mini_valid = MiniImagenet(args.data_path, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batch_size=args.batch_size, resize=args.img_size, split=[args.train_portion, 1]) mini_test = MiniImagenet(args.data_path, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batch_size=args.test_batch_size, resize=args.img_size, split=[args.train_portion, 1]) train_queue = DataLoader(mini, args.meta_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) valid_queue = DataLoader(mini_valid, args.meta_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) test_queue = DataLoader(mini_test, args.meta_test_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) architect = Architect(maml.model, args) for epoch in range(args.epoch): # fetch batch_size num of episode each time logging.info('--------- Epoch: {} ----------'.format(epoch)) train_accs = meta_train(train_queue, valid_queue, maml, architect, device, criterion, epoch, writer) logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs)) valid_accs = meta_test(test_queue, maml, device, epoch, writer) logging.info('[Epoch: {}]\t Test acc: {}'.format(epoch, valid_accs)) genotype = maml.model.genotype() logging.info('genotype = %s', genotype) # logging.info(F.softmax(maml.model.alphas_normal, dim=-1)) logging.info(F.softmax(maml.model.alphas_reduce, dim=-1)) # Save the best meta model. new_pred = valid_accs[-1] if new_pred > best_pred: is_best = True best_pred = new_pred else: is_best = False saver.save_checkpoint( { 'epoch': epoch, 'state_dict': maml.module.state_dict() if isinstance(maml, nn.DataParallel) else maml.state_dict(), 'best_pred': best_pred, }, is_best)
def main(): start_time = time.time() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) print(args) print(argv) os.makedirs(args.modelfile.split('/')[0], exist_ok=True) config = [ ('conv2d', [32, 3, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) mini = MiniImagenet('./dataset/mini-imagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) if args.domain == 'mini': mini_test = MiniImagenet('./dataset/mini-imagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) mini_val = MiniImagenet('./dataset/mini-imagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) elif args.domain == 'cub': print("CUB dataset") mini_test = MiniImagenet('./dataset/CUB_200_2011/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) elif args.domain == 'traffic': print("Traffic dataset") mini_test = MiniImagenet('./dataset/GTSRB/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) elif args.domain == 'flower': print("flower dataset") mini_test = MiniImagenet('./dataset/102flowers/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.test_iter, resize=args.imgsz) else: print("Dataset Error") return if args.mode == 'test': count = 0 db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=6, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: print(count) count += 1 x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'test', args.modelfile, pertub_scale=args.pertub_scale, num_ensemble=args.num_ensemble, fgsm_epsilon=args.fgsm_epsilon) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) np.set_printoptions(linewidth=1000) print("Running Time:", time.time()-start_time) print(accs) return for epoch in range(args.epoch//10000): db = DataLoader(mini, args.task_num, shuffle=True, num_workers=4, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('epoch:', epoch, 'step:', step, '\ttraining acc:', accs) if step % 200 == 0: print("Save model", args.modelfile) torch.save(maml, args.modelfile) db_test = DataLoader(mini_val, 1, shuffle=True, num_workers=4, pin_memory=True) accs_all_val = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'train_test') accs_all_val.append(accs) db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=4, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) accs_val = np.array(accs_all_val).mean(axis=0).astype(np.float16) save_modelfile = "{}_{}_{}_{:0.4f}_{:0.4f}.pth".format(args.modelfile, epoch, step, accs_val[-1], accs[-1]) print(save_modelfile) torch.save(maml, save_modelfile) print("Val:", accs_val) print("Test:", accs)
print('load pretrained mdl ...', pretrain_mdl_file) meta.load_state_dict(torch.load(pretrain_mdl_file)) model_parameters = filter(lambda p: p.requires_grad, meta.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('total params:', params) optimizer = optim.Adam(meta.parameters(), lr=1e-5, weight_decay=1e-6) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True) best_train_acc = 0 for epoch in range(1000): mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=500, resize=resize) db = DataLoader(mini, batchsz, shuffle=True, num_workers=6) for step, batch in enumerate(db): # batch : ([10,10,3,84,84], [10,10], [10,75,3,84,84], [10,75]) support_x = Variable(batch[0]).cuda() support_y = Variable(batch[1]).cuda() query_x = Variable(batch[2]).cuda() query_y = Variable(batch[3]).cuda() meta.train() cls_loss, train_acc = meta.pretrain(support_x, support_y, query_x, query_y)
def main(): meta_batchsz = 32 * 3 n_way = 5 k_shot = 5 k_query = k_shot meta_lr = 1e-3 num_updates = 5 dataset = "mini-imagenet" if dataset == "omniglot": imgsz = 28 db = OmniglotNShot( "dataset", batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz, ) elif dataset == "mini-imagenet": imgsz = 84 # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use # get_batch(train or test) to get different batch. # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. mini = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="train", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz, ) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) mini_test = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="test", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz, ) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) else: raise NotImplementedError # do NOT call .cuda() implicitly net = CSML() net.deploy() tb = SummaryWriter("runs") # main loop for episode_num in range(200000): # 1. train if dataset == "omniglot": support_x, support_y, query_x, query_y = db.get_batch("test") support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == "mini-imagenet": try: batch_train = iter(db).next() except StopIteration as err: mini = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="train", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz, ) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) batch_train = iter(db).next() support_x = Variable(batch_train[0]) support_y = Variable(batch_train[1]) query_x = Variable(batch_train[2]) query_y = Variable(batch_train[3]) print(support_x.size(), support_y.size()) # backprop has been embeded in forward func. accs = net.train(support_x, support_y, query_x, query_y) train_acc = np.array(accs).mean() # 2. test if episode_num % 30 == 220: test_accs = [] for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. if dataset == "omniglot": support_x, support_y, query_x, query_y = db.get_batch( "test") support_x = Variable( torch.from_numpy(support_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable( torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == "mini-imagenet": try: batch_test = iter(db_test).next() except StopIteration as err: mini_test = MiniImagenet( "../../hdd1/meta/mini-imagenet/", mode="test", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz, ) db_test = DataLoader( mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True, ) batch_test = iter(db).next() support_x = Variable(batch_test[0]) support_y = Variable(batch_test[1]) query_x = Variable(batch_test[2]) query_y = Variable(batch_test[3]) # get accuracy # test_acc = net.train(support_x, support_y, query_x, query_y, train=False) test_accs.append(test_acc) test_acc = np.array(test_accs).mean() print( "episode:", episode_num, "\tfinetune acc:%.6f" % train_acc, "\t\ttest acc:%.6f" % test_acc, ) tb.add_scalar("test-acc", test_acc) tb.add_scalar("finetune-acc", train_acc)
def evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file): """ obey the expriment setting of MAML and Learning2Compare, we randomly sample 600 episodes and 15 query images per query set. :param net: :param batchsz: :return: """ k_query = 15 mini_val = MiniImagenet( "../mini-imagenet/", mode="test", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=600, resize=imgsz, ) db_val = DataLoader(mini_val, batchsz, shuffle=True, num_workers=2, pin_memory=True) accs = [] episode_num = 0 # record tested num of episodes for batch_test in db_val: # [60, setsz, c_, h, w] # setsz = (5 + 15) * 5 support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() # we will split query set into 15 splits. # query_x : [batch, 15*way, c_, h, w] # query_x_b : tuple, 15 * [b, way, c_, h, w] query_x_b = torch.chunk(query_x, k_query, dim=1) # query_y : [batch, 15*way] # query_y_b: 15* [b, way] query_y_b = torch.chunk(query_y, k_query, dim=1) preds = [] net.eval() # we don't need the total acc on 600 episodes, but we need the acc per sets of 15*nway setsz. total_correct = 0 total_num = 0 total_loss = 0 for query_x_mini, query_y_mini in zip(query_x_b, query_y_b): # print('query_x_mini', query_x_mini.size(), 'query_y_mini', query_y_mini.size()) loss, pred, correct = net( support_x, support_y, query_x_mini.contiguous(), query_y_mini, False ) correct = correct.sum() # multi-gpu # pred: [b, nway] preds.append(pred) total_correct += correct.data[0] total_num += query_y_mini.size(0) * query_y_mini.size(1) total_loss += loss.data[0] # # 15 * [b, nway] => [b, 15*nway] # preds = torch.cat(preds, dim= 1) acc = total_correct / total_num print("%.5f," % acc, end=" ") sys.stdout.flush() accs.append(acc) # update tested episode number episode_num += query_y.size(0) if episode_num > episodesz: # test current tested episodes acc. acc = np.array(accs).mean() if acc >= threhold: # if current acc is very high, we conduct all 600 episodes testing. continue else: # current acc is low, just conduct `episodesz` num of episodes. break # compute the distribution of 600/episodesz episodes acc. global best_accuracy accs = np.array(accs) accuracy, sem = mean_confidence_interval(accs) print("\naccuracy:", accuracy, "sem:", sem) print("<<<<<<<<< accuracy:", accuracy, "best accuracy:", best_accuracy, ">>>>>>>>") if accuracy > best_accuracy: best_accuracy = accuracy torch.save(net.state_dict(), mdl_file) print("Saved to checkpoint:", mdl_file) # we only take the last one batch as avg_loss total_loss = total_loss / n_way / k_query global global_test_loss_buff, global_test_acc_buff global_test_loss_buff = total_loss global_test_acc_buff = accuracy write2file(n_way, k_shot) return accuracy, sem
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) ## Task Learner Setup task_config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] last_epoch = 0 suffix = "_v0" save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_model) + '_stepsz' + str( args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt' while os.path.isfile(save_path): valid_epoch = last_epoch last_epoch += 500 save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_model) + '_stepsz' + str( args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt' save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_model) + '_stepsz' + str( args.update_lr) + '_epoch' + str(valid_epoch) + suffix + '.pt' device = torch.device('cuda') task_mod = Meta(args, task_config).to(device) task_mod.load_state_dict(torch.load(save_path)) task_mod.eval() ## AL Learner Setup print(args) al_config = [('linear', [1, 32 * 5 * 5])] device = torch.device('cuda') maml = AL_Learner(args, al_config, task_mod).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('/home/tesca/data/miniimagenet/', mode='train', n_way=args.n_way, k_shot=1, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('/home/tesca/data/miniimagenet/', mode='test', n_way=args.n_way, k_shot=1, k_query=args.k_qry, batchsz=100, resize=args.imgsz) save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_model) + '_stepsz' + str(args.update_lr) + '_epoch' for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) al_accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\tAL acc:', al_accs) if step % 500 == 0: # evaluation torch.save(maml.state_dict(), save_path + str(step) + "_al_net.pt") '''db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [ ('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) ckpt_dir = "./model/" for epoch in range(args.epoch//10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) # save checkpoints os.makedirs(ckpt_dir, exist_ok=True) print('Saving the model as a checkpoint...') torch.save({'epoch': epoch, 'Steps': step, 'model': maml.state_dict()}, os.path.join(ckpt_dir, 'checkpoint.pth'))
def main(): torch.manual_seed(222) # 为cpu设置种子,为了使结果是确定的 torch.cuda.manual_seed_all(222) # 为GPU设置种子,为了使结果是确定的 np.random.seed(222) print(args) config = [ ('conv2d', [32, 1, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 7040]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) if os.path.exists( "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")): path = "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl") maml.load_state_dict(path) print("load model success") tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000) mini_test = MiniImagenet("./miniimagenet", mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100) test_accuracy = [] for epoch in range(10): # fetch meta_batchsz num of episode each time db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs, loss_t = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) test_accuracy.append(accs[-1]) average_accuracy = sum(test_accuracy) / len(test_accuracy) print("average accuracy:{}".format(average_accuracy))
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='meta batch size', default=4) argparser.add_argument('-l', help='meta learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) meta_lr = float(args.l) train_lr = 1e-2 k_query = 15 imgsz = 84 K = 5 # update steps mdl_file = 'ckpt/miniimagenet%d%d.mdl' % (n_way, k_shot) print('mini-imagenet: %d-way %d-shot meta-lr:%f, train-lr:%f K-steps:%d' % (n_way, k_shot, meta_lr, train_lr, K)) device = torch.device('cuda:2') net = MAML(n_way, k_shot, k_query, meta_batchsz, K, meta_lr, train_lr).to(device) print(net) for epoch in range(1000): # batchsz here means total episode number mini = MiniImagenet('/data/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) # fetch meta_batchsz num of episode each time db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=meta_batchsz, pin_memory=True) for step, batch in enumerate(db): support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training=True) if step % 50 == 0: print("epoch: {}, step: {}, {}accuracy: {}".format( epoch, step, '\t', accs)) if step % 500 == 0 and step != 0: # evaluation # test for 600 episodes mini_test = MiniImagenet('/data/miniimagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=600, resize=imgsz) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=meta_batchsz, pin_memory=True) accs_all_test = [] for batch in db_test: support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training=False) accs_all_test.append(accs) # [600, K+1] accs_all_test = np.array(accs_all_test) # [600, K+1] => [K+1] means = accs_all_test.mean(axis=0) # compute variance for last step K m, h = mean_confidence_interval(accs_all_test[:, K]) print('>>Test:\t', means, 'variance[K]: %.4f' % h, '<<')
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) # df_spt = pd.read_csv('support_set.csv') # df_qry = pd.read_csv('query_set.csv') print(args) config = [ ('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5]) ] cuda = 'cuda:' + str(args.gpu_index) device = torch.device(cuda) maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('./flower/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet('./flower/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz) accs_list_tr = [] accs_list_ts = [] for epoch in range(args.epoch//10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) accs_list_tr.append(accs) if step % 500 == 0 or (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]): # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('step:', step, '\ttest acc:', accs) accs_list_ts.append(accs) if (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]): with open('data/result_natural(acc).txt', mode='a') as f: f.write(str(accs) + str(args.task_num) + '\n') with open('model.pkl', 'wb') as f: cloudpickle.dump(maml, f)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_val = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=600, resize=args.imgsz) mini_test = MiniImagenet( '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=600, resize=args.imgsz) best_acc = 0.0 if not os.path.exists('ckpt/{}'.format(args.exp)): os.mkdir('ckpt/{}'.format(args.exp)) for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 500 == 0: print('step:', step, '\ttraining acc:', accs) if step % 1000 == 0: # evaluation db_val = DataLoader(mini_val, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_val = [] for x_spt, y_spt, x_qry, y_qry in db_val: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_val.append(accs) mean, std, ci95 = cal_conf(np.array(accs_all_val)) print('Val acc:{}, std:{}. ci95:{}'.format( mean[-1], std[-1], ci95[-1])) if mean[-1] > best_acc or step % 5000 == 0: best_acc = mean[-1] torch.save( maml.state_dict(), 'ckpt/{}/model_e{}s{}_{:.4f}.pkl'.format( args.exp, epoch, step, best_acc)) with open('ckpt/' + args.exp + '/val.txt', 'a') as f: print( 'val epoch {}, step {}: acc_val:{:.4f}, ci95:{:.4f}' .format(epoch, step, best_acc, ci95[-1]), file=f) ## Test db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) mean, std, ci95 = cal_conf(np.array(accs_all_test)) print('Test acc:{}, std:{}, ci95:{}'.format( mean[-1], std[-1], ci95[-1])) with open('ckpt/' + args.exp + '/test.txt', 'a') as f: print( 'test epoch {}, step {}: acc_test:{:.4f}, ci95:{:.4f}' .format(epoch, step, mean[-1], ci95[-1]), file=f)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) #print(args) if args.train_al == 1: last_layer = ('linear', [args.n_way + 1, 32 * 5 * 5]) else: last_layer = ('linear', [args.n_way, 32 * 5 * 5]) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), last_layer] last_epoch = 0 if args.train_al == 1: suffix = "_al" #suffix = "-trained_al" else: suffix = "_og" save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_model) + '_stepsz' + str( args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt' while os.path.isfile(save_path): valid_epoch = last_epoch last_epoch += 500 save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_model) + '_stepsz' + str( args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt' save_path = os.getcwd() + '/data/model_batchsz' + str( args.k_model) + '_stepsz' + str( args.update_lr) + '_epoch' + str(valid_epoch) + suffix + '.pt' device = torch.device('cuda') mod = Meta(args, config).to(device) mod.load_state_dict(torch.load(save_path)) mod.eval() #tmp = filter(lambda x: x.requires_grad, maml.parameters()) #num = sum(map(lambda x: np.prod(x.shape), tmp)) #print(maml) #print('Total trainable tensors:', num) # batchsz here means total episode number if args.merge_spt_qry == 1: mini_test = MiniImagenet('/home/tesca/data/miniimagenet/', mode='test', n_way=args.n_way, k_shot=1, k_query=args.k_qry, batchsz=10, resize=args.imgsz) else: mini_test = MiniImagenet('/home/tesca/data/miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10, resize=args.imgsz) db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] total_accs = [] best_accs = [] it = 0 for x_spt, y_spt, x_qry, y_qry in db_test: if args.merge_spt_qry == 1: x_spt = x_qry y_spt = y_qry sys.stdout.write("\rTest %i" % it) sys.stdout.flush() it += 1 x_spt, y_spt = x_spt.squeeze(0), y_spt.squeeze(0) x_qry_pt, y_qry_pt = x_qry.squeeze(0).to(device), y_qry.squeeze(0).to( device) x_spt_pt, y_spt_pt = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to( device) #x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ # x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) qs = QuerySelector() finished = False all_queries, all_orderings = [], [] while not finished: maml = deepcopy(mod) queries = None #[[], []] avail_cands = list(range(x_spt.shape[0])) test_accs = [] q_idx = [] w = None for i in range(x_spt.shape[0]): q = qs.query(args, maml.net, w=w, cands=x_spt_pt, avail_cands=avail_cands, targets=x_qry_pt, k=1, method=args.al_method, classification=True) q_idx.append(q) avail_cands.remove(q) if queries is None: queries = [ np.array(x_spt[q].unsqueeze(0)), np.array(y_spt[q].unsqueeze(0)) ] else: queries = [ np.concatenate( [queries[0], np.array(x_spt[q].unsqueeze(0))]), np.concatenate( [queries[1], np.array(y_spt[q].unsqueeze(0))]) ] xs, ys = torch.from_numpy( queries[0]).to(device), torch.from_numpy( queries[1]).to(device) #accs,w = maml.finetunning(x_spt_pt, y_spt_pt, x_qry_pt, y_qry_pt) accs, w = maml.finetunning(xs, ys, x_qry_pt, y_qry_pt) accs_all_test.append(accs) if len(test_accs) == 0: test_accs.append(accs[0]) test_accs.append(accs[-1]) all_orderings.append(test_accs) all_queries.append(q_idx) del maml finished = not qs.next_order() total_accs.append(np.mean(np.array(all_orderings), axis=0)) best_accs.append(all_queries[np.argmax( np.sum(np.array(all_orderings), axis=1))]) oq = [ all_queries[i] for i in np.argsort(np.sum(np.array(all_orderings), axis=1)) ] #pdb.set_trace() oq = all_orderings[np.argsort(np.sum(np.array(all_orderings), axis=1))[-1]] #total_accs.append(oq) # [b, update_step+1] accs = np.array(total_accs).mean(axis=0).astype(np.float16) #accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): meta_batchsz = 32 n_way = 20 k_shot = 1 k_query = k_shot meta_lr = 1e-3 # meta_lr = 1 num_updates = 5 dataset = 'omniglot' if dataset == 'omniglot': imgsz = 28 db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) elif dataset == 'mini-imagenet': imgsz = 84 # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use # get_batch(train or test) to get different batch. # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) else: raise NotImplementedError meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr, num_updates=num_updates).cuda() tb = SummaryWriter('runs') # main loop for episode_num in range(1500): # 1. train if dataset == 'omniglot': support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == 'mini-imagenet': try: batch_test = iter(db).next() except StopIteration as err: mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() # backprop has been embeded in forward func. if episode_num % 100 = 0: meta.prv_angle = 0 accs = meta(support_x, support_y, query_x, query_y) train_acc = np.array(accs).mean() # 2. test if episode_num % 30 == 0: test_accs = [] for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. if dataset == 'omniglot': support_x, support_y, query_x, query_y = db.get_batch('test') support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() support_y = Variable(torch.from_numpy(support_y).long()).cuda() query_y = Variable(torch.from_numpy(query_y).long()).cuda() elif dataset == 'mini-imagenet': try: batch_test = iter(db_test).next() except StopIteration as err: mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=imgsz) db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() # get accuracy test_acc = meta.pred(support_x, support_y, query_x, query_y) test_accs.append(test_acc) test_acc = np.array(test_accs).mean() print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc) tb.add_scalar('test-acc', test_acc) tb.add_scalar('finetune-acc', train_acc)
def main(): torch.manual_seed(222) # 为cpu设置种子,为了使结果是确定的 torch.cuda.manual_seed_all(222) # 为GPU设置种子,为了使结果是确定的 np.random.seed(222) print(args) config = [ ('conv2d', [32, 1, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 7040]) ] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000) mini_test = MiniImagenet("./miniimagenet", mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100) last_accuracy = 0 plt_train_loss = [] plt_train_acc = [] plt_test_loss = [] plt_test_acc =[] for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: d = loss_q.cpu() dd = d.detach().numpy() plt_train_loss.append(dd) plt_train_acc.append(accs[-1]) print('step:', step, '\ttraining acc:', accs) if step % 50 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] loss_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs, loss_test= maml.finetunning(x_spt, y_spt, x_qry, y_qry) loss_all_test.append(loss_test) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) plt_test_acc.append(accs[-1]) avg_loss = np.mean(np.array(loss_all_test)) plt_test_loss.append(avg_loss) print('Test acc:', accs) test_accuracy = np.mean(np.array(accs)) if test_accuracy > last_accuracy: # save networks torch.save(maml.state_dict(), str( "./models/miniimagenet_maml" + str(args.n_way) + "way_" + str( args.k_spt) + "shot.pkl")) last_accuracy = test_accuracy plt.figure() plt.title("testing info") plt.xlabel("episode") plt.ylabel("Acc/loss") plt.plot(plt_test_loss, label='Loss') plt.plot(plt_test_acc, label='Acc') plt.legend(loc='upper right') plt.savefig('./drawing/test.png') plt.show() plt.figure() plt.title("training info") plt.xlabel("episode") plt.ylabel("Acc/loss") plt.plot(plt_train_loss, label='Loss') plt.plot(plt_train_acc, label='Acc') plt.legend(loc='upper right') plt.savefig('./drawing/train.png') plt.show()
net.load_state_dict(torch.load(mdl_file)) model_parameters = filter(lambda p: p.requires_grad, net.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('total params:', params) optimizer = optim.Adam(net.parameters(), lr=1e-3) tb = SummaryWriter('runs', str(datetime.now())) best_accuracy = 0 for epoch in range(1000): mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=224) db = DataLoader(mini, batchsz, shuffle=True, num_workers=8, pin_memory=True) mini_val = MiniImagenet('../mini-imagenet/', mode='val', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=200, resize=224)
def main(): argparser = argparse.ArgumentParser() argparser.add_argument("-n", help="n way") argparser.add_argument("-k", help="k shot") argparser.add_argument("-b", help="batch size") argparser.add_argument("-l", help="learning rate", default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) batchsz = int(args.b) lr = float(args.l) k_query = 1 imgsz = 224 threhold = ( 0.699 if k_shot == 5 else 0.584 ) # threshold for when to test full version of episode mdl_file = "ckpt/naive5_3x3%d%d.mdl" % (n_way, k_shot) print( "mini-imagnet: %d-way %d-shot lr:%f, threshold:%f" % (n_way, k_shot, lr, threhold) ) global global_buff if os.path.exists("mini%d%d.pkl" % (n_way, k_shot)): global_buff = pickle.load(open("mini%d%d.pkl" % (n_way, k_shot), "rb")) print("load pkl buff:", len(global_buff)) net = nn.DataParallel(Naive5(n_way, k_shot, imgsz), device_ids=[0, 1, 2]).cuda() print(net) if os.path.exists(mdl_file): print("load from checkpoint ...", mdl_file) net.load_state_dict(torch.load(mdl_file)) else: print("training from scratch.") # whole parameters number model_parameters = filter(lambda p: p.requires_grad, net.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("Total params:", params) # build optimizer and lr scheduler optimizer = optim.Adam(net.parameters(), lr=lr) # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, nesterov=True) scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, "max", factor=0.5, patience=25, verbose=True ) for epoch in range(1000): mini = MiniImagenet( "../mini-imagenet/", mode="train", n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz, ) db = DataLoader(mini, batchsz, shuffle=True, num_workers=8, pin_memory=True) total_train_loss = 0 total_train_correct = 0 total_train_num = 0 for step, batch in enumerate(db): # 1. test if step % 300 == 0: # evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file): accuracy, sem = evaluation( net, batchsz, n_way, k_shot, imgsz, 600, threhold, mdl_file ) scheduler.step(accuracy) # 2. train support_x = Variable(batch[0]).cuda() support_y = Variable(batch[1]).cuda() query_x = Variable(batch[2]).cuda() query_y = Variable(batch[3]).cuda() net.train() loss, pred, correct = net(support_x, support_y, query_x, query_y) loss = loss.sum() / support_x.size(0) # multi-gpu, divide by total batchsz total_train_loss += loss.data[0] total_train_correct += correct.data[0] total_train_num += support_y.size(0) * n_way # k_query = 1 optimizer.zero_grad() loss.backward() optimizer.step() # 3. print if step % 20 == 0 and step != 0: acc = total_train_correct / total_train_num total_train_correct = 0 total_train_num = 0 print( "%d-way %d-shot %d batch> epoch:%d step:%d, loss:%.4f, train acc:%.4f" % (n_way, k_shot, batchsz, epoch, step, total_train_loss, acc) ) total_train_loss = 0 global global_train_loss_buff, global_train_acc_buff global_train_loss_buff = loss.data[0] / (n_way * k_shot) global_train_acc_buff = acc write2file(n_way, k_shot)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) np.random.seed(222) print(args) config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) # batchsz here means total episode number mini = MiniImagenet('./data/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.img_sz) mini_test = MiniImagenet('./data/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.img_sz) for epoch in range(args.epoch // 10000): # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print('step:', step, '\ttraining acc:', accs) if step % 500 == 0: # evaluation db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs)
def main(): saver = Saver(args) # set log log_format = '%(asctime)s %(message)s' logging.basicConfig(level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p', filename=os.path.join(saver.experiment_dir, 'log.txt'), filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) logging.getLogger().addHandler(console) if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) # set seed np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True cudnn.enabled = True # set saver saver.create_exp_dir(scripts_to_save=glob.glob('*.py') + glob.glob('*.sh') + glob.glob('*.yml')) saver.save_experiment_config() summary = TensorboardSummary(saver.experiment_dir) writer = summary.create_summary() logging.info(args) device = torch.device('cuda') criterion = nn.CrossEntropyLoss() criterion = criterion.to(device) ''' Compute FLOPs and Params ''' maml = Meta(args, criterion) flops, params = get_model_complexity_info(maml.model, (3, 84, 84), as_strings=False, print_per_layer_stat=True, verbose=True) logging.info('FLOPs: {} MMac Params: {}'.format(flops / 10**6, params)) maml = Meta(args, criterion).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) #logging.info(maml) logging.info('Total trainable tensors: {}'.format(num)) # batch_size here means total episode number mini = MiniImagenet(args.data_path, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batch_size=args.batch_size, resize=args.img_size) mini_test = MiniImagenet(args.data_path, mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batch_size=args.test_batch_size, resize=args.img_size) train_loader = DataLoader(mini, args.meta_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) test_loader = DataLoader(mini_test, args.meta_test_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # load pretrained model and inference if args.pretrained_model: checkpoint = torch.load(args.pretrained_model) if isinstance(maml.model, torch.nn.DataParallel): maml.module.load_state_dict(checkpoint['state_dict']) else: maml.load_state_dict(checkpoint['state_dict']) if args.evaluate: test_accs = meta_test(test_loader, maml, device, checkpoint['epoch']) logging.info('[Epoch: {}]\t Test acc: {}'.format( checkpoint['epoch'], test_accs)) return # Start training for epoch in range(args.epoch): # fetch batch_size num of episode each time logging.info('--------- Epoch: {} ----------'.format(epoch)) train_accs = meta_train(train_loader, maml, device, epoch, writer, test_loader, saver) logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs))
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='batch size', default=4) argparser.add_argument('-l', help='meta learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) meta_lr = float(args.l) train_lr = 1e-2 k_query = 15 imgsz = 84 mdl_file = 'ckpt/miniimagenet%d%d.mdl' % (n_way, k_shot) print('mini-imagnet: %d-way %d-shot meta-lr:%f, train-lr:%f' % (n_way, k_shot, meta_lr, train_lr)) device = torch.device('cuda:0') net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr, device) print(net) for epoch in range(1000): # batchsz here means total episode number mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) # fetch meta_batchsz num of episode each time db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) for step, batch in enumerate(db): # 2. train support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training=True) if step % 50 == 0: print(epoch, step, '\t', accs) if step % 1000 == 0 and step != 0: # batchsz here means total episode number mini_test = MiniImagenet( '/hdd1/liangqu/datasets/miniimagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=600, resize=imgsz) # fetch meta_batchsz num of episode each time db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) accs_all_test = [] for batch in db_test: support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training=True) accs_all_test.append(accs) # [600, K+1] accs_all_test = np.array(accs_all_test) # [600, K+1] => [K+1] accs_all_test = accs_all_test.mean(axis=0) print('>>Test:\t', accs_all_test, '<<')
def main(): # Manually seed torch and numpy for reproducible results torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) # Open csv file to write for metric logging try: f = open("results.csv", "w") except FileNotFoundError: f = open("results.csv", "x") f.write("Steps,tr_loss,tr_acc,val_loss,val_acc,te_loss,te_acc\n") # Choose PyTorch device and create the model device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') model = Meta(args, defineModel(args)).to(device) # Setup Weights and Biases logger, config hyperparams and watch model wandb.init(project="Meta-SGD") name = f"N{args.n_way}K{args.k_spt}" if args.update_step > 0: name += "Meta" name += f"Ret{args.ret_channels}VVS{args.vvs_depth}KS{args.kernel_size}" wandb.run.name = name wandb.config.update(args) wandb.watch(model) print(f"RUN NAME: {name}") # Print additional information on the model if args.verbose: tmp = filter(lambda x: x.requires_grad, model.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(args) print(model) print('Total trainable tensors:', num) # Create datasets # batchsz here means total episode number print("\nGathering Datasets:") mini = MiniImagenet('miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_train_eval = MiniImagenet('miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.eval_steps, resize=args.imgsz) mini_val = MiniImagenet('miniimagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.eval_steps, resize=args.imgsz) mini_test = MiniImagenet('miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.eval_steps, resize=args.imgsz) print("\nMeta-Training:") pruning_factor = args.pruning best_tr_acc, best_val_acc, best_te_acc = 0, 0, 0 epoch_bar = tqdm(range(args.epoch // 10000), desc="Training", total=len(range(args.epoch // 10000))) for epoch in epoch_bar: # fetch meta_batchsz num of episode each time db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) task_bar = tqdm(enumerate(db), desc=f"Epoch {epoch}", total=len(db), leave=False) for step, (x_spt, y_spt, x_qry, y_qry) in task_bar: total_steps = len(db) * epoch + step + 1 # Perform training for each task x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) model(x_spt, y_spt, x_qry, y_qry) if total_steps % args.save_summary_steps == 0: # evaluation # Get evaluation metrics tr_acc, tr_loss = evaluate(model, mini_train_eval, device, "Eval Train") val_acc, val_loss = evaluate(model, mini_val, device, "Eval Val") te_acc, te_loss = evaluate(model, mini_test, device) # Update Task tqdm bar metrics = { 'tr acc': tr_acc, 'val_acc': val_acc, 'te_acc': te_acc } task_bar.set_postfix(metrics) metrics['tr_loss'] = tr_loss metrics['val_loss'] = val_loss metrics['te_loss'] = te_loss wandb.log(metrics) f.write( f"{total_steps},{tr_loss},{tr_acc},{val_loss},{val_acc},{te_loss},{te_acc}\n" ) # Update best metrics if val_acc > best_val_acc: best_val_acc = val_acc pruning_factor = args.pruning else: pruning_factor -= 1 best_te_acc = max(te_acc, best_te_acc) best_tr_acc = max(tr_acc, best_tr_acc) if pruning_factor == 0: break # Update tqdm epoch_bar.set_postfix({ 'b_tr_acc': best_tr_acc, 'b_val_acc': best_val_acc, 'b_te_acc': best_te_acc, 'prune': pruning_factor }) if pruning_factor == 0: break f.close()