def main(): n_way = 5 k_shot = 1 k_query = 15 batchsz = 5 mdfile1 = './ckpy/feature-%d-way-%d-shot.pkl' % (n_way, k_shot) mdfile2 = './ckpy/relation-%d-way-%d-shot.pkl' % (n_way, k_shot) mini = MiniImagenet('./mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=2000, resize=84) # 训练是,batchsz = 200 db = DataLoader(mini, batch_size=batchsz, num_workers=0, pin_memory=False) feature_embed = CNNEncoder().cuda() Relation_score = RelationNetWork(64, 8).cuda() # relation_dim == 8 ?? if os.path.exists(mdfile1): print("file1-feature exit...") feature_embed.load_state_dict(torch.load(mdfile1)) if os.path.exists(mdfile2): print("f2-relation exit...") Relation_score.load_state_dict(torch.load(mdfile2)) for ts in range(3): total_correct = 0 total_num = 0 accuracy = 0 accuarcies = [] for i, batch in enumerate(db): support_x = Variable( batch[0]).cuda() # [batch_size, n_way*k_shot, c , h , w] support_y = Variable(batch[1]).cuda() query_x = Variable(batch[2]).cuda() query_y = Variable(batch[3]).cuda() # [b, n_way * q ] bh, set1, c, h, w = support_x.size() set2 = query_x.size(1) support_xf = feature_embed(support_x.view(bh * set1, c, h, w)).view( bh, set1, 64, 19, 19) query_xf = feature_embed(query_x.view(bh * set2, c, h, w)).view( bh, set2, 64, 19, 19) support_xf = support_xf.unsqueeze(1).expand( bh, set2, set1, 64, 19, 19) query_xf = query_xf.unsqueeze(2).expand(bh, set2, set1, 64, 19, 19) comb = torch.cat((support_xf, query_xf), dim=3) score = Relation_score(comb.view(bh * set2 * set1, 64 * 2, 19, 19)).view(bh, set2, set1) # score_np = score.cpu().data.numpy() support_y_np = support_y.cpu().data.numpy() rn_score_np = score.cpu().data.numpy() # 转numpy cpu pred = [] # for ii,bb in enumerate(score_np): # # for jj,bset in enumerate(bb): # # sim = [] # # for way in range(n_way): # # sim.append(np.sum(bset[way*k_shot:(way+1)*k_shot])) # # idx = np.array(sim).argmax() # # pred.append(support_y_np[ii,k_shot*idx]) # # pred = Variable(torch.from_numpy(np.array(pred).reshape(bh,set2))).cuda() # # # # correct += torch.eq(pred,query_y).sum() # # total += query_y.size(0)*query_y.size(1) # # accuarcy = float(correct)/float(total) # # print("epoch",ts,"i-batch",i,"acc:",accuarcy) # # accuarcies.append(accuarcy) for ii, tb in enumerate(rn_score_np): for jj, tset in enumerate(tb): sim = [] for way in range(n_way): sim.append( np.sum(tset[way * k_shot:(way + 1) * k_shot])) idx = np.array(sim).argmax() pred.append(support_y_np[ii, idx * k_shot]) # 同一个类标签相同 ,注意还有batch维度 # ×k_shot是因为,上一个步用sum将k_shot压缩了 # 此时的pred.size = [b.set2] # print("pred.size=", np.array(pred).shape) pred = Variable(torch.from_numpy(np.array(pred).reshape( bh, set2))).cuda() correct = torch.eq(pred, query_y).sum() total_correct += correct.data[0] total_num += query_y.size(0) * query_y.size(1) accuracy = total_correct / total_num print("epoch", ts, "acc:", accuracy) accuarcies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuarcies) print("test accuracy:", test_accuracy, "h:", h)
def main(): n_way = 5 k_shot = 1 k_query = 15 batchsz = 5 best_acc = 0 mdfile1 = './ckpy/feature-%d-way-%d-shot.pkl' %(n_way,k_shot) mdfile2 = './ckpy/relation-%d-way-%d-shot.pkl' %(n_way,k_shot) feature_embed = CNNEncoder().cuda() Relation_score = RelationNetWork(64, 8).cuda() # relation_dim == 8 ?? feature_embed.apply(weight_init) Relation_score.apply(weight_init) feature_optim = torch.optim.Adam(feature_embed.parameters(), lr=0.001) relation_opim = torch.optim.Adam(Relation_score.parameters(), lr=0.001) loss_fn = torch.nn.MSELoss().cuda() if os.path.exists(mdfile1): print("load mdfile1...") feature_embed.load_state_dict(torch.load(mdfile1)) if os.path.exists(mdfile2): print("load mdfile2...") Relation_score.load_state_dict(torch.load(mdfile2)) for epoch in range(1000): mini = MiniImagenet('./mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=84) #38400 db = DataLoader(mini,batch_size=batchsz,shuffle=True,num_workers=4,pin_memory=True) # 64 , 5*(1+15) , c, h, w mini_val = MiniImagenet('./mini-imagenet/', mode='val', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=200, resize=84) #9600 db_val = DataLoader(mini_val,batch_size=batchsz,shuffle=True,num_workers=4,pin_memory=True) for step,batch in enumerate(db): support_x = Variable(batch[0]).cuda() # [batch_size, n_way*(k_shot+k_query), c , h , w] support_y = Variable(batch[1]).cuda() query_x = Variable(batch[2]).cuda() query_y = Variable(batch[3]).cuda() bh,set1,c,h,w = support_x.size() set2 = query_x.size(1) feature_embed.train() Relation_score.train() support_xf = feature_embed(support_x.view(bh*set1,c,h,w)).view(bh,set1,64,19,19) # 在 test 的 时候 重复 query_xf = feature_embed(query_x.view(bh*set2,c,h,w)).view(bh,set2,64,19,19) # print("query_f:", query_xf.size()) support_xf = support_xf.unsqueeze(1).expand(bh,set2,set1,64,19,19) query_xf = query_xf.unsqueeze(2).expand(bh,set2,set1,64,19,19) comb = torch.cat((support_xf,query_xf),dim=3) # bh,set2,set1,2c,h,w # print(comb.is_cuda) # print(comb.view(bh*set2*set1,2*64,19,19).is_cuda) score = Relation_score(comb.view(bh*set2*set1,2*64,19,19)).view(bh,set2,set1,1).squeeze(3) support_yf = support_y.unsqueeze(1).expand(bh,set2,set1) query_yf = query_y.unsqueeze(2).expand(bh,set2,set1) label = torch.eq(support_yf,query_yf).float() feature_optim.zero_grad() relation_opim.zero_grad() loss = loss_fn(score,label) loss.backward() #torch.nn.utils.clip_grad_norm(feature_embed.parameters(),0.5) # 梯度裁剪? 降低学习率? #torch.nn.utils.clip_grad_norm(Relation_score.parameters(),0.5) feature_optim.step() relation_opim.step() # if step%100==0: # print("step:",epoch+1,"train_loss: ",loss.data[0]) logger.log_value('{}-way-{}-shot loss:'.format(n_way, k_shot),loss.data[0]) if step%200==0: print("---------test--------") total_correct = 0 total_num = 0 accuracy = 0 for j,batch_test in enumerate(db_val): # if (j%100==0): # print(j,'-------------') support_x = Variable(batch_test[0]).cuda() support_y = Variable(batch_test[1]).cuda() query_x = Variable(batch_test[2]).cuda() query_y = Variable(batch_test[3]).cuda() bh,set1,c,h,w = support_x.size() set2 = query_x.size(1) feature_embed.eval() Relation_score.eval() support_xf = feature_embed(support_x.view(bh*set1,c,h,w)).view(bh,set1,64,19,19) # 在 test 的 时候 重复 query_xf = feature_embed(query_x.view(bh*set2,c,h,w)).view(bh,set2,64,19,19) support_xf = support_xf.unsqueeze(1).expand(bh,set2,set1,64,19,19) query_xf = query_xf.unsqueeze(2).expand(bh,set2,set1,64,19,19) comb = torch.cat((support_xf,query_xf),dim=3) # bh,set2,set1,2c,h,w score = Relation_score(comb.view(bh*set2*set1,2*64,19,19)).view(bh,set2,set1,1).squeeze(3) rn_score_np = score.cpu().data.numpy() # 转numpy cpu pred = [] support_y_np = support_y.cpu().data.numpy() for ii,tb in enumerate(rn_score_np): for jj,tset in enumerate(tb): sim = [] for way in range(n_way): sim.append(np.sum(tset[way*k_shot:(way+1)*k_shot])) idx = np.array(sim).argmax() pred.append(support_y_np[ii,idx*k_shot]) # 同一个类标签相同 ,注意还有batch维度 # ×k_shot是因为,上一个步用sum将k_shot压缩了 #此时的pred.size = [b.set2] #print("pred.size=", np.array(pred).shape) pred = Variable(torch.from_numpy(np.array(pred).reshape(bh,set2))).cuda() correct = torch.eq(pred,query_y).sum() total_correct += correct.data[0] total_num += query_y.size(0)*query_y.size(1) accuracy = total_correct/total_num logger.log_value('acc : ',accuracy) print("epoch:",epoch,"acc:",accuracy) if accuracy>best_acc: print("-------------------epoch",epoch,"step:",step,"acc:",accuracy,"---------------------------------------") best_acc = accuracy torch.save(feature_embed.state_dict(),mdfile1) torch.save(Relation_score.state_dict(),mdfile2) #if step% == 0 and step != 0: # print("%d-way %d-shot %d batch | epoch:%d step:%d, loss:%f" %(n_way,k_shot,batchsz,epoch,step,loss.cpu().data[0])) logger.step()
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) #np.random.seed(222) test_result = {} best_acc = 0.0 maml = Meta(args, Param.config).to(Param.device) if len(args.gpu.split(',')) > 1: maml = torch.nn.DataParallel(maml) state_dict = torch.load(Param.out_path + args.ckpt) print(state_dict.keys()) pretrained_dict = OrderedDict() for k in state_dict.keys(): if n_gpus == -1: pretrained_dict[k[7:]] = deepcopy(state_dict[k]) else: pretrained_dict[k[0:]] = deepcopy(state_dict[k]) maml.load_state_dict(pretrained_dict) print("Load from ckpt:", Param.out_path + args.ckpt) #opt = optim.Adam(maml.parameters(), lr=args.meta_lr) #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) if args.loader in [0, 1]: # default loader if args.loader == 1: #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet from MiniImagenet2 import MiniImagenet else: from MiniImagenet import MiniImagenet testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True) test_data = inf_get(testloader) elif args.loader == 2: # pkl loader args_data = {} args_data['x_dim'] = "84,84,3" args_data['ratio'] = 1.0 args_data['seed'] = 222 loader_test = dataset_mini(600, 100, 'test', args_data) loader_test.load_data_pkl() """Test for 600 epochs (each has 4 tasks)""" ans = None maml_clone = deepcopy(maml) for itr in range(600): # 600x4 test tasks if args.loader in [0, 1]: support_x, support_y, qx, qy = test_data.__next__() support_x, support_y, qx, qy = support_x.to( Param.device), support_y.to(Param.device), qx.to( Param.device), qy.to(Param.device) elif args.loader == 2: support_x, support_y, qx, qy = get_data(loader_test) support_x, support_y, qx, qy = support_x.to( Param.device), support_y.to(Param.device), qx.to( Param.device), qy.to(Param.device) temp = maml_clone(support_x, support_y, qx, qy, meta_train=False) if (ans is None): ans = temp else: ans = torch.cat([ans, temp], dim=0) if itr % 100 == 0: print(itr, ans.mean(dim=0).tolist()) meanacc = np.array(ans.mean(dim=0).tolist()) stdacc = np.array(ans.std(dim=0).tolist()) ci95 = 1.96 * stdacc / np.sqrt(600) print(f'Acc: {meanacc[-1]:.4f}, ci95: {ci95[-1]:.4f}') with open(Param.out_path + 'test.txt', 'w') as f: print(f'Acc: {meanacc[-1]:.4f}, ci95: {ci95[-1]:.4f}', file=f)
x = self.layer1(x) # print("layer 1",x.size()) x = self.layer2(x) # print("layer 2",x.size()) x = self.layer3(x) # print("layer 3",x.size()) x = self.layer4(x) return x # [bz * (way*(s+q)), 64, 19,19] from torch.autograd import Variable if __name__ == "__main__": mini = MiniImagenet(root='./mini-imagenet/', mode='train', batchsz=100, n_way=5, k_shot=5, k_query=5, resize=84, startidx=0) for i, m in enumerate(mini): support_x, support_y, query_x, query_y = m print(i, support_x.size()) support_x = Variable(support_x).cuda() net = CNNEncoder().cuda() ans = net(support_x) print(ans.size()) print("--------")
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) #np.random.seed(222) config = [('conv2d', [32, 3, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]), ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])] device = torch.device('cuda') maml = Meta(args, config).to(device) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) root = '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet' trainset = MiniImagenet(root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True) testloader = DataLoader(testset, batch_size=1, shuffle=True, num_workers=1, worker_init_fn=worker_init_fn, drop_last=True) train_data = inf_get(trainloader) test_data = inf_get(testloader) best_acc = 0.0 if not os.path.exists('ckpt/{}'.format(args.exp)): os.mkdir('ckpt/{}'.format(args.exp)) for epoch in range(args.epoch): np.random.seed() x_spt, y_spt, x_qry, y_qry = train_data.__next__() x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to( device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if epoch % 100 == 0: print('epoch:', epoch, '\ttraining acc:', accs) if epoch % 2500 == 0: # evaluation # save checkpoint torch.save(maml.state_dict(), 'ckpt/{}/model_{}.pkl'.format(args.exp, epoch)) accs_all_test = [] for _ in range(600): x_spt, y_spt, x_qry, y_qry = test_data.__next__() x_spt, y_spt, x_qry, y_qry = x_spt.squeeze().to( device), y_spt.squeeze().to(device), x_qry.squeeze().to( device), y_qry.squeeze().to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Test acc:', accs) with open('ckpt/' + args.exp + '/test.txt', 'a') as f: print('test epoch {}: acc:{:.4f}'.format(epoch, accs[-1]), file=f)
def main(): torch.manual_seed(222) torch.cuda.manual_seed_all(222) #np.random.seed(222) test_result = {} best_acc = 0.0 maml = Meta(args, Param.config).to(Param.device) if len(args.gpu.split(',')) > 1: maml = torch.nn.DataParallel(maml) opt = optim.Adam(maml.parameters(), lr=args.meta_lr) #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay) tmp = filter(lambda x: x.requires_grad, maml.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(maml) print('Total trainable tensors:', num) if args.loader in [0, 1]: # default loader if args.loader == 1: #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet from MiniImagenet2 import MiniImagenet else: from MiniImagenet import MiniImagenet trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz) trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True) testloader = DataLoader(testset, batch_size=1, shuffle=True, num_workers=1, worker_init_fn=worker_init_fn, drop_last=True) train_data = inf_get(trainloader) test_data = inf_get(testloader) elif args.loader == 2: # pkl loader args_data = {} args_data['x_dim'] = "84,84,3" args_data['ratio'] = 1.0 args_data['seed'] = 222 loader_train = dataset_mini(600, 100, 'train', args_data) #loader_val = dataset_mini(600, 100, 'val', args_data) loader_test = dataset_mini(600, 100, 'test', args_data) loader_train.load_data_pkl() #loader_val.load_data_pkl() loader_test.load_data_pkl() for epoch in range(args.epoch): np.random.seed() if args.loader in [0, 1]: support_x, support_y, meta_x, meta_y = train_data.__next__() support_x, support_y, meta_x, meta_y = support_x.to( Param.device), support_y.to(Param.device), meta_x.to( Param.device), meta_y.to(Param.device) elif args.loader == 2: support_x, support_y, meta_x, meta_y = get_data(loader_train) support_x, support_y, meta_x, meta_y = support_x.to( Param.device), support_y.to(Param.device), meta_x.to( Param.device), meta_y.to(Param.device) meta_loss = maml(support_x, support_y, meta_x, meta_y).mean() opt.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0) opt.step() plot.plot('meta_loss', meta_loss.item()) if (epoch % 2500 == 0): ans = None maml_clone = deepcopy(maml) for _ in range(600): if args.loader in [0, 1]: support_x, support_y, qx, qy = test_data.__next__() support_x, support_y, qx, qy = support_x.to( Param.device), support_y.to(Param.device), qx.to( Param.device), qy.to(Param.device) elif args.loader == 2: support_x, support_y, qx, qy = get_data(loader_test) support_x, support_y, qx, qy = support_x.to( Param.device), support_y.to(Param.device), qx.to( Param.device), qy.to(Param.device) temp = maml_clone(support_x, support_y, qx, qy, meta_train=False) if (ans is None): ans = temp else: ans = torch.cat([ans, temp], dim=0) ans = ans.mean(dim=0).tolist() test_result[epoch] = ans if (ans[-1] > best_acc): best_acc = ans[-1] torch.save( maml.state_dict(), Param.out_path + 'net_' + str(epoch) + '_' + str(best_acc) + '.pkl') del maml_clone print(str(epoch) + ': ' + str(ans)) with open(Param.out_path + 'test.json', 'w') as f: json.dump(test_result, f) if (epoch < 5) or (epoch % 100 == 0): plot.flush() plot.tick()