def train(modelin=args.model, modelout=args.out,device=args.device,opt=args.opt,ft=args.ft): # define model, dataloader, 3dmm eigenvectors, optimization method calib_net = PointNet(n=1,feature_transform=ft) sfm_net = PointNet(n=199,feature_transform=ft) if modelin != "": calib_path = os.path.join('model','calib_' + modelin) sfm_path = os.path.join('model','sfm_' + modelin) pretrained1 = torch.load(calib_path) pretrained2 = torch.load(sfm_path) calib_dict = calib_net.state_dict() sfm_dict = sfm_net.state_dict() pretrained1 = {k: v for k,v in pretrained1.items() if k in calib_dict} pretrained2 = {k: v for k,v in pretrained2.items() if k in sfm_dict} calib_dict.update(pretrained1) sfm_dict.update(pretrained2) calib_net.load_state_dict(pretrained1) sfm_net.load_state_dict(pretrained2) calib_net.to(device=device) sfm_net.to(device=device) opt1 = torch.optim.Adam(calib_net.parameters(),lr=1e-3) opt2 = torch.optim.Adam(sfm_net.parameters(),lr=1e-3) # dataloader loader = dataloader.SyntheticLoader() batch_size = 100 M = loader.M N = loader.N # mean shape and eigenvectors for 3dmm mu_lm = torch.from_numpy(loader.mu_lm).float()#.to(device=device) mu_lm[:,2] = mu_lm[:,2] * -1 mu_lm = torch.stack(300 * [mu_lm.to(device=device)]) shape = mu_lm lm_eigenvec = torch.from_numpy(loader.lm_eigenvec).float().to(device=device) sigma = torch.from_numpy(loader.sigma).float().detach().to(device=device) sigma = torch.diag(sigma.squeeze()) lm_eigenvec = torch.mm(lm_eigenvec, sigma) lm_eigenvec = torch.stack(300 * [lm_eigenvec]) # main training loop best = 10000 for epoch in itertools.count(): for i in range(len(loader)): if i < 3: continue v1 = loader[i] v2 = loader[i-1] v3 = loader[i-2] batch = stackVideos([v1,v2,v3],100,68,device=device) # get the input and gt values alpha_gt = batch['alpha'] x_cam_gt = batch['x_cam_gt'] shape_gt = batch['x_w_gt'] fgt = batch['f_gt'] x = batch['x_img'] M = x.shape[0] N = x.shape[-1] # calibration f = torch.squeeze(calib_net(x) + 300) K = torch.zeros((M,3,3)).float().to(device=device) K[:,0,0] = f K[:,1,1] = f K[:,2,2] = 1 # sfm alpha = sfm_net(x) alpha = alpha.unsqueeze(-1) shape = mu_lm + torch.bmm(lm_eigenvec,alpha).squeeze().view(M,N,3) shape[0:100] = shape[0:100] - shape[0:100].mean(1).unsqueeze(1) shape[100:200] = shape[100:200] - shape[100:200].mean(1).unsqueeze(1) shape[200:300] = shape[200:300] - shape[200:300].mean(1).unsqueeze(1) opt1.zero_grad() opt2.zero_grad() f1_error = torch.mean(torch.abs(f[0:100] - fgt[0])) f2_error = torch.mean(torch.abs(f[100:200] - fgt[1])) f3_error = torch.mean(torch.abs(f[200:300] - fgt[2])) #a1_error = torch.mean(torch.abs(alpha[0:100] - alpha_gt[0])) #a2_error = torch.mean(torch.abs(alpha[100:200] - alpha_gt[1])) #a3_error = torch.mean(torch.abs(alpha[200:300] - alpha_gt[2])) s1_error = torch.mean(torch.abs(shape[0:100] - shape_gt[0].unsqueeze(0))) s2_error = torch.mean(torch.abs(shape[100:200] - shape_gt[1].unsqueeze(0))) s3_error = torch.mean(torch.abs(shape[200:300] - shape_gt[2].unsqueeze(0))) ferror = f1_error + f2_error + f3_error #aerror = a1_error + a2_error + a3_error serror = s1_error + s2_error + s3_error #f_error = torch.mean(torch.abs(f - fgt)) #error3d = torch.mean(torch.norm(shape - shape_gt,dim=2)) #error = ferror + aerror error = ferror + serror error.backward() opt1.step() opt2.step() print(f"iter: {i} | best: {best:.2f} | f_error: {ferror.item():.3f} | serror: {serror.item():.3f} ") if i == 1000: break # save model and increment weight decay torch.save(sfm_net.state_dict(), os.path.join('model','sfm_model.pt')) torch.save(calib_net.state_dict(), os.path.join('model','calib_model.pt')) ferror = test(modelin='model.pt',outfile=args.out,optimize=False) if ferror < best: best = ferror print("saving!") torch.save(sfm_net.state_dict(), os.path.join('model','sfm_'+modelout)) torch.save(calib_net.state_dict(), os.path.join('model','calib_'+modelout)) sfm_net.train() calib_net.train()
def train(args, io): file_list = [ 'preprocess1_HHESTIA_ttbar', 'preprocess2_HHESTIA_RadionToZZ', 'preprocess_HHESTIA_QCD1800To2400', 'preprocess0_HHESTIA_HH_4B', 'preprocess2_HHESTIA_ZprimeWW' ] train_grp_list = ['train_group' for i in range(len(file_list))] test_grp_list = ['test_group' for i in range(len(file_list))] data_dir = '/afs/crc.nd.edu/user/a/adas/DGCNN' pcl_train_dataset = Jet_PointCloudDataSet(file_list, train_grp_list, data_dir) pcl_test_dataset = Jet_PointCloudDataSet(file_list, test_grp_list, data_dir) train_loader = DataLoader(pcl_train_dataset, num_workers=8, batch_size=args.batch_size, shuffle=True, drop_last=True) test_loader = DataLoader(pcl_test_dataset, num_workers=8, batch_size=args.test_batch_size, shuffle=True, drop_last=False) device = torch.device("cuda" if args.cuda else "cpu") if args.model == 'pointnet': model = PointNet(args).to(device) elif args.model == 'dgcnn': model = FrameImageModel(args).to(device) elif args.model == 'jetClassifier': model = JetClassifier(args).to(device) else: raise Exception("Not implemented") #print(str(model)) model = model.double() model = nn.DataParallel(model) torch.save(model.state_dict(), 'model.pt') print("Let's use", torch.cuda.device_count(), "GPUs!") if args.use_sgd: print("Use SGD") opt = optim.SGD(model.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=1e-4) else: print("Use Adam") opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr) criterion = cal_loss best_test_acc = 0 record_file = open("note.txt", "w") for epoch in range(args.epochs): scheduler.step() #################### # Train #################### print("########## training on epoch number ------>> ", epoch) train_loss = 0.0 count = 0.0 model.train() train_pred = [] train_true = [] batch_number = 1 for data, label in train_loader: data, label = data.to(device), label.to(device).squeeze() data = data.permute(0, 2, 1).double() batch_size = data.size()[0] opt.zero_grad() #print(data.shape) logits = model(data) loss = criterion(logits, label) batch_loss = loss.detach().cpu().numpy() print("### batch number ", batch_number, " ", loss) binarized_label = label_binarize(label.cpu().numpy(), classes=[0, 1, 2, 3, 4]) batch_number = batch_number + 1 loss.backward() opt.step() preds = logits.max(dim=1)[1] batch_label = label.detach().cpu().numpy() batch_preds = preds.detach().cpu().numpy() batch_acc = metrics.accuracy_score(batch_label, batch_preds) balanced_batch_acc = metrics.balanced_accuracy_score( batch_label, batch_preds) print("### batch accuracy scores ", batch_acc, " ", balanced_batch_acc) record_file.write( str(epoch) + " " + str(batch_number) + " " + str(batch_loss) + " " + str(batch_acc) + " " + str(balanced_batch_acc)) count += batch_size train_loss += loss.item() * batch_size train_true.append(label.cpu().numpy()) train_pred.append(preds.detach().cpu().numpy()) train_true = np.concatenate(train_true) train_pred = np.concatenate(train_pred) accuracy_score = metrics.accuracy_score(train_true, train_pred) balanced_accuracy_score = metrics.balanced_accuracy_score( train_true, train_pred) print(accuracy_score, balanced_accuracy_score) outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % ( epoch, train_loss * 1.0 / count, accuracy_score, balanced_accuracy_score) io.cprint(outstr) #################### # Test #################### test_loss = 0.0 count = 0.0 model.eval() test_pred = [] test_true = [] for data, label in test_loader: #print(data.shape) data, label = data.to(device), label.to(device).squeeze() data = data.permute(0, 2, 1) batch_size = data.size()[0] logits = model(data) loss = criterion(logits, label) preds = logits.max(dim=1)[1] batch_label = label.detach().cpu().numpy() batch_preds = preds.detach().cpu().numpy() batch_acc = metrics.accuracy_score(batch_label, batch_preds) balanced_batch_acc = metrics.balanced_accuracy_score( batch_label, batch_preds) print("### test batch accuracy scores ", batch_acc, " ", balanced_batch_acc) count += batch_size test_loss += loss.item() * batch_size test_true.append(label.cpu().numpy()) test_pred.append(preds.detach().cpu().numpy()) test_true = np.concatenate(test_true) test_pred = np.concatenate(test_pred) test_acc = metrics.accuracy_score(test_true, test_pred) avg_per_class_acc = metrics.balanced_accuracy_score( test_true, test_pred) outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % ( epoch, test_loss * 1.0 / count, test_acc, avg_per_class_acc) io.cprint(outstr) if test_acc >= best_test_acc: best_test_acc = test_acc torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name) record_file.close()
def train(modelin=args.model, modelout=args.out, device=args.device, opt=args.opt): # define model, dataloader, 3dmm eigenvectors, optimization method calib_net = PointNet(n=1) sfm_net = PointNet(n=199) if modelin != "": calib_path = os.path.join('model', 'calib_' + modelin) sfm_path = os.path.join('model', 'sfm_' + modelin) pretrained1 = torch.load(calib_path) pretrained2 = torch.load(sfm_path) calib_dict = calib_net.state_dict() sfm_dict = sfm_net.state_dict() pretrained1 = {k: v for k, v in pretrained1.items() if k in calib_dict} pretrained2 = {k: v for k, v in pretrained2.items() if k in sfm_dict} calib_dict.update(pretrained1) sfm_dict.update(pretrained2) calib_net.load_state_dict(pretrained1) sfm_net.load_state_dict(pretrained2) calib_net.to(device=device) sfm_net.to(device=device) opt1 = torch.optim.Adam(calib_net.parameters(), lr=1e-4) opt2 = torch.optim.Adam(sfm_net.parameters(), lr=1e-2) # dataloader data = dataloader.Data() loader = data.batchloader batch_size = data.batchsize # mean shape and eigenvectors for 3dmm mu_lm = torch.from_numpy(data.mu_lm).float() #.to(device=device) mu_lm[:, 2] = mu_lm[:, 2] * -1 mu_lm = torch.stack(batch_size * [mu_lm.to(device=device)]) shape = mu_lm lm_eigenvec = torch.from_numpy(data.lm_eigenvec).float().to(device=device) lm_eigenvec = torch.stack(batch_size * [lm_eigenvec]) M = data.M N = data.N # main training loop best = 10000 for epoch in itertools.count(): for j, batch in enumerate(loader): # get the input and gt values x_cam_gt = batch['x_cam_gt'].to(device=device) shape_gt = batch['x_w_gt'].to(device=device) fgt = batch['f_gt'].to(device=device) x_img = batch['x_img'].to(device=device) #beta_gt = batch['beta_gt'].to(device=device) #x_img_norm = batch['x_img_norm'] x_img_gt = batch['x_img_gt'].to(device=device).permute(0, 2, 1, 3) batch_size = fgt.shape[0] one = torch.ones(batch_size, M * N, 1).to(device=device) x_img_one = torch.cat([x_img, one], dim=2) x_cam_pt = x_cam_gt.permute(0, 1, 3, 2).reshape(batch_size, 6800, 3) x = x_img.permute(0, 2, 1) #x = x_img.permute(0,2,1).reshape(batch_size,2,M,N) ptsI = x_img_one.reshape(batch_size, M, N, 3).permute(0, 1, 3, 2)[:, :, :2, :] # if just optimizing if not opt: # calibration f = calib_net(x) + 300 K = torch.zeros((batch_size, 3, 3)).float().to(device=device) K[:, 0, 0] = f.squeeze() K[:, 1, 1] = f.squeeze() K[:, 2, 2] = 1 # sfm betas = sfm_net(x) betas = betas.unsqueeze(-1) shape = mu_lm + torch.bmm(lm_eigenvec, betas).squeeze().view( batch_size, N, 3) opt1.zero_grad() opt2.zero_grad() f_error = torch.mean(torch.abs(f - fgt)) #error2d = torch.mean(torch.abs(pred - x_img_gt)) error3d = torch.mean(torch.abs(shape - shape_gt)) error = f_error + error3d error.backward() opt1.step() opt2.step() print( f"{best:.2f} | f_error: {f_error.item():.3f} | error3d: {error3d.item():.3f} | f/fgt: {f[0].item():.1f}/{fgt[0].item():.1f} | f/fgt: {f[1].item():.1f}/{fgt[1].item():.1f} | f/fgt: {f[2].item():.1f}/{fgt[2].item():.1f} | f/fgt: {f[3].item():.1f}/{fgt[3].item():.1f} " ) continue # save model and increment weight decay torch.save(sfm_net.state_dict(), os.path.join('model', 'sfm_model.pt')) torch.save(calib_net.state_dict(), os.path.join('model', 'calib_model.pt')) ferror = test(modelin='model.pt', outfile=args.out, optimize=False) if ferror < best: best = ferror print("saving!") torch.save(sfm_net.state_dict(), os.path.join('model', 'sfm_' + modelout)) torch.save(calib_net.state_dict(), os.path.join('model', 'calib_' + modelout)) sfm_net.train() calib_net.train()
def train(modelin=args.model, modelout=args.out, device=args.device, opt=args.opt, ft=args.ft): # define model, dataloader, 3dmm eigenvectors, optimization method calib_net = PointNet(n=1, feature_transform=ft) sfm_net = PointNet(n=199, feature_transform=ft) if modelin != "": calib_path = os.path.join('model', 'calib_' + modelin) sfm_path = os.path.join('model', 'sfm_' + modelin) pretrained1 = torch.load(calib_path) pretrained2 = torch.load(sfm_path) calib_dict = calib_net.state_dict() sfm_dict = sfm_net.state_dict() pretrained1 = {k: v for k, v in pretrained1.items() if k in calib_dict} pretrained2 = {k: v for k, v in pretrained2.items() if k in sfm_dict} calib_dict.update(pretrained1) sfm_dict.update(pretrained2) calib_net.load_state_dict(pretrained1) sfm_net.load_state_dict(pretrained2) calib_net.to(device=device) sfm_net.to(device=device) opt1 = torch.optim.Adam(calib_net.parameters(), lr=1e-3) opt2 = torch.optim.Adam(sfm_net.parameters(), lr=1e-3) # dataloader loader = dataloader.SyntheticLoader() batch_size = 100 M = loader.M N = loader.N # mean shape and eigenvectors for 3dmm mu_lm = torch.from_numpy(loader.mu_lm).float() #.to(device=device) mu_lm[:, 2] = mu_lm[:, 2] * -1 mu_lm = torch.stack(batch_size * [mu_lm.to(device=device)]) shape = mu_lm lm_eigenvec = torch.from_numpy( loader.lm_eigenvec).float().to(device=device) sigma = torch.from_numpy(loader.sigma).float().detach().to(device=device) sigma = torch.diag(sigma.squeeze()) lm_eigenvec = torch.mm(lm_eigenvec, sigma) lm_eigenvec = torch.stack(M * [lm_eigenvec]) # main training loop best = 10000 for epoch in itertools.count(): for j, batch in enumerate(loader): # get the input and gt values x_cam_gt = batch['x_cam_gt'].to(device=device) shape_gt = batch['x_w_gt'].to(device=device) fgt = batch['f_gt'].to(device=device) x_img = batch['x_img'].to(device=device) #beta_gt = batch['beta_gt'].to(device=device) #x_img_norm = batch['x_img_norm'] #x_img_gt = batch['x_img_gt'].to(device=device).permute(0,2,1,3) x = x_img.reshape(M, N, 2).permute(0, 2, 1) batch_size = fgt.shape[0] #x_cam_pt = x_cam_gt.permute(0,1,3,2).reshape(batch_size,6800,3) #x = x_img.permute(0,2,1).reshape(batch_size,2,M,N) #ptsI = x_img_one.reshape(batch_size,M,N,3).permute(0,1,3,2)[:,:,:2,:] # calibration f = torch.squeeze(calib_net(x) + 300) K = torch.zeros((M, 3, 3)).float().to(device=device) K[:, 0, 0] = f K[:, 1, 1] = f K[:, 2, 2] = 1 # sfm betas = sfm_net(x) betas = betas.unsqueeze(-1) shape = mu_lm + torch.bmm(lm_eigenvec, betas).squeeze().view( M, N, 3) shape = shape - shape.mean(1).unsqueeze(1) opt1.zero_grad() opt2.zero_grad() f_error = torch.mean(torch.abs(f - fgt)) #error2d = torch.mean(torch.abs(pred - x_img_gt)) error3d = torch.mean(torch.norm(shape - shape_gt, dim=2)) error = f_error + error3d error.backward() opt1.step() opt2.step() print( f"iter: {j} | best: {best:.2f} | f_error: {f_error.item():.3f} | error3d: {error3d.item():.3f} " ) if j == 1000: break # save model and increment weight decay torch.save(sfm_net.state_dict(), os.path.join('model', 'sfm_model.pt')) torch.save(calib_net.state_dict(), os.path.join('model', 'calib_model.pt')) ferror = test(modelin='model.pt', outfile=args.out, optimize=False) if ferror < best: best = ferror print("saving!") torch.save(sfm_net.state_dict(), os.path.join('model', 'sfm_' + modelout)) torch.save(calib_net.state_dict(), os.path.join('model', 'calib_' + modelout)) sfm_net.train() calib_net.train()