class trainer(BaseTrain): def init(self): # Create training and testing dataset self.train_ds = MGNDataset(self.options, split = 'train') self.test_ds = MGNDataset(self.options, split = 'test') # test data loader is fixed and disable shuffle as it is unnecessary. self.test_data_loader = CheckpointDataLoader( self.test_ds, batch_size = self.options.batch_size, num_workers = self.options.num_workers, pin_memory = self.options.pin_memory, shuffle = False) # Create SMPL Mesh (graph) object for GCN self.mesh = Mesh(self.options, self.options.num_downsampling) self.faces = torch.cat( self.options.batch_size * [ self.mesh.faces.to(self.device)[None] ], dim = 0 ) self.faces = self.faces.type(torch.LongTensor).to(self.device) # Create SMPL blending model and edges self.smpl = SMPL(self.options.smpl_model_path, self.device) # self.smplEdge = torch.Tensor(np.load(self.options.smpl_edges_path)) \ # .long().to(self.device) # create SMPL+D blending model self.smplD = Smpl( self.options.smpl_model_path ) # read SMPL .bj file to get uv coordinates _, self.smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(self.options.smpl_objfile_path) uv_coord[:, 1] = 1 - uv_coord[:, 1] expUV = uv_coord[tri_uv_ind.flatten()] unique, index = np.unique(self.smpl_tri_ind.flatten(), return_index = True) self.smpl_verts_uvs = torch.as_tensor(expUV[index,:]).float().to(self.device) self.smpl_tri_ind = torch.as_tensor(self.smpl_tri_ind).to(self.device) # camera for projection self.perspCam = perspCamera() # mean and std of displacements self.dispPara = \ torch.Tensor(np.load(self.options.MGN_offsMeanStd_path)).to(self.device) # load average pose and shape and convert to camera coodinate; # avg pose is decided by the image id we use for training (0-11) avgPose_objCoord = np.load(self.options.MGN_avgPose_path) avgPose_objCoord[:3] = rotationMatrix_to_axisAngle( # for 0,6, front only torch.tensor([[[1, 0, 0], [0, -1, 0], [0, 0,-1]]])) self.avgPose = \ axisAngle_to_Rot6d( torch.Tensor(avgPose_objCoord[None]).reshape(-1, 3) ).reshape(1, -1).to(self.device) self.avgBeta = \ torch.Tensor( np.load(self.options.MGN_avgBeta_path)[None]).to(self.device) self.avgCam = torch.Tensor([1.2755, 0, 0])[None].to(self.device) # 1.2755 is for our settings self.model = frameVIBE( self.options.smpl_model_path, self.mesh, self.avgPose, self.avgBeta, self.avgCam, self.options.num_channels, self.options.num_layers , self.smpl_verts_uvs, self.smpl_tri_ind ).to(self.device) # Setup a optimizer for models self.optimizer = torch.optim.Adam( params=list(self.model.parameters()), lr=self.options.lr, betas=(self.options.adam_beta1, 0.999), weight_decay=self.options.wd) self.criterion = VIBELoss( e_loss_weight=50, # for kp 2d, help to estimate camera, global orientation e_3d_loss_weight=1, # for kp 3d, bvt e_pose_loss_weight=10, # for body pose parameters e_shape_loss_weight=1, # for body shape parameters e_disp_loss_weight=1, # for displacements e_tex_loss_weight=1, # for uv image d_motion_loss_weight=0 ) # Pack models and optimizers in a dict - necessary for checkpointing self.models_dict = {self.options.model: self.model} self.optimizers_dict = {'optimizer': self.optimizer} def convert_GT(self, input_batch): smplGT = torch.cat([ # GTcamera is not used, here is a placeholder torch.zeros([self.options.batch_size,3]).to(self.device), # the global rotation is incorrect as it is under the original # coordinate system. The rest and betas are fine. rot6d_to_axisAngle(input_batch['GTsmplParas']['pose']).reshape(-1, 72), # 24 * 6 = 144 input_batch['GTsmplParas']['betas']], dim = 1).float() # vertices in the original coordinates vertices = self.smpl( pose = rot6d_to_axisAngle(input_batch['GTsmplParas']['pose']).reshape(self.options.batch_size, 72), beta = input_batch['GTsmplParas']['betas'].float(), trans = input_batch['GTsmplParas']['trans'] ) # get joints in 3d and 2d in the current camera coordiante joints_3d = self.smpl.get_joints(vertices.float()) joints_2d, _, joints_3d = self.perspCam( fx = input_batch['GTcamera']['f_rot'][:,0,0], fy = input_batch['GTcamera']['f_rot'][:,0,0], cx = 112, cy = 112, rotation = rot6d_to_rotmat(input_batch['GTcamera']['f_rot'][:,0,1:]).float(), translation = input_batch['GTcamera']['t'][:,None,:].float(), points = joints_3d, visibilityOn = False, output3d = True ) joints_3d = (joints_3d - input_batch['GTcamera']['t'][:,None,:]).float() # remove shifts # visind = 1 # img = torch.zeros([224,224]) # joints_2d[visind][joints_2d[visind].round() >= 224] = 223 # joints_2d[visind][joints_2d[visind].round() < 0] = 0 # img[joints_2d[visind][:,1].round().long(), joints_2d[visind][:,0].round().long()] = 1 # plt.imshow(img) # plt.imshow(input_batch['img'][visind].cpu().permute(1,2,0)) # convert to [-1, +1] joints_2d = (torch.cat(joints_2d, dim=0).reshape([self.options.batch_size, 24, 2]) - 112)/112 joints_2d = joints_2d.float() # convert vertices to current camera coordiantes _,_,vertices = self.perspCam( fx = input_batch['GTcamera']['f_rot'][:,0,0], fy = input_batch['GTcamera']['f_rot'][:,0,0], cx = 112, cy = 112, rotation = rot6d_to_rotmat(input_batch['GTcamera']['f_rot'][:,0,1:]).float(), translation = input_batch['GTcamera']['t'][:,None,:].float(), points = vertices, visibilityOn = False, output3d = True ) vertices = (vertices - input_batch['GTcamera']['t'][:,None,:]).float() # prove the correctness of coord sys, # points = vertices (vertices before shift) # localcam = perspCamera(smpl_obj = False) # disable additional trans # img_points, _ = localcam( # fx = input_batch['GTcamera']['f_rot'][:,0,0], # fy = input_batch['GTcamera']['f_rot'][:,0,0], # cx = 112, # cy = 112, # rotation = torch.eye(3)[None].repeat_interleave(points.shape[0], dim = 0).to('cuda'), # translation = torch.zeros([points.shape[0],1,3]).to('cuda'), # points = points.float(), # visibilityOn = False, # output3d = False # ) # img = torch.zeros([224,224]) # img_points[visind][img_points[visind].round() >= 224] = 223 # img_points[visind][img_points[visind].round() < 0] = 0 # img[img_points[visind][:,1].round().long(), img_points[visind][:,0].round().long()] = 1 # plt.imshow(img) # prove the correctness of displacements # import open3d as o3d # points = vertices # ind = 0 # o3dMesh = o3d.geometry.TriangleMesh() # o3dMesh.vertices = o3d.utility.Vector3dVector(points[ind].cpu()) # o3dMesh.triangles= o3d.utility.Vector3iVector(self.faces[0].cpu()) # o3dMesh.compute_vertex_normals() # o3d.visualization.draw_geometries([o3dMesh]) # self.renderer = renderer(batch_size = 1) # images = self.renderer( # verts = vertices[0][None], # faces = self.faces[0][None], # verts_uvs = self.smpl_verts_uvs[None].repeat_interleave(2, dim=0)[0][None], # faces_uvs = self.smpl_tri_ind[None].repeat_interleave(2, dim=0)[0][None], # tex_image = tex[None], # R = torch.eye(3)[None].repeat_interleave(2, dim = 0).to('cuda')[0][None], # T = input_batch['GTcamera']['t'].float()[0][None], # f = input_batch['GTcamera']['f_rot'][:,0,0][:,None].float()[0][None], # C = torch.ones([2,2]).to('cuda')[0][None]*112, # imgres = 224 # ) GT = {'img' : input_batch['img'].float(), 'img_orig': input_batch['img_orig'].float(), 'imgname' : input_batch['imgname'], 'camera': input_batch['GTcamera'], # wcd 'theta': smplGT.float(), # joints_2d is col, row == x, y; joints_3d is x,y,z 'target_2d': joints_2d.float(), 'target_3d': joints_3d.float(), 'target_bvt': vertices.float(), # body vertices 'target_dp': input_batch['GToffsets_t']['offsets'].float(), # smpl cd t-pose 'target_uv': input_batch['GTtextureMap'].float() } return GT def train_step(self, input_batch): """Training step.""" self.model.train() # prepare data GT = self.convert_GT(input_batch) # forward pass pred = self.model(GT['img'], GT['img_orig']) # loss gen_loss, loss_dict = self.criterion( generator_outputs=pred, data_2d = GT['target_2d'], data_3d = GT, ) # print(gen_loss) out_args = loss_dict out_args['loss'] = gen_loss out_args['prediction'] = pred # save one training sample for vis if (self.step_count+1) % self.options.summary_steps == 0: with torch.no_grad(): self.save_sample(GT, pred[0], saveTo = 'train') return out_args def test(self): """"Testing process""" self.model.eval() test_loss, test_tex_loss = 0, 0 test_loss_pose, test_loss_shape = 0, 0 test_loss_kp_2d, test_loss_kp_3d = 0, 0 test_loss_dsp_3d, test_loss_bvt_3d = 0, 0 for step, batch in enumerate(tqdm(self.test_data_loader, desc='Test', total=len(self.test_ds) // self.options.batch_size, initial=self.test_data_loader.checkpoint_batch_idx), self.test_data_loader.checkpoint_batch_idx): # convert data devices batch_toDEV = {} for key, val in batch.items(): if isinstance(val, torch.Tensor): batch_toDEV[key] = val.to(self.device) else: batch_toDEV[key] = val if isinstance(val, dict): batch_toDEV[key] = {} for k, v in val.items(): if isinstance(v, torch.Tensor): batch_toDEV[key][k] = v.to(self.device) # prepare data GT = self.convert_GT(batch_toDEV) with torch.no_grad(): # disable grad # forward pass pred = self.model(GT['img'], GT['img_orig']) # loss gen_loss, loss_dict = self.criterion( generator_outputs=pred, data_2d = GT['target_2d'], data_3d = GT, ) # save for comparison if step == 0: self.save_sample(GT, pred[0]) test_loss += gen_loss test_loss_pose += loss_dict['loss_pose'] test_loss_shape += loss_dict['loss_shape'] test_loss_kp_2d += loss_dict['loss_kp_2d'] test_loss_kp_3d += loss_dict['loss_kp_3d'] test_loss_bvt_3d += loss_dict['loss_bvt_3d'] test_loss_dsp_3d += loss_dict['loss_dsp_3d'] test_tex_loss += loss_dict['loss_tex'] test_loss = test_loss/len(self.test_data_loader) test_loss_pose = test_loss_pose/len(self.test_data_loader) test_loss_shape = test_loss_shape/len(self.test_data_loader) test_loss_kp_2d = test_loss_kp_2d/len(self.test_data_loader) test_loss_kp_3d = test_loss_kp_3d/len(self.test_data_loader) test_loss_bvt_3d = test_loss_bvt_3d/len(self.test_data_loader) test_loss_dsp_3d = test_loss_dsp_3d/len(self.test_data_loader) test_tex_loss = test_tex_loss/len(self.test_data_loader) lossSummary = {'test_loss': test_loss, 'test_loss_pose' : test_loss_pose, 'test_loss_shape' : test_loss_shape, 'test_loss_kp_2d' : test_loss_kp_2d, 'test_loss_kp_3d' : test_loss_kp_3d, 'test_loss_bvt_3d': test_loss_bvt_3d, 'test_loss_dsp_3d': test_loss_dsp_3d, 'test_tex_loss': test_tex_loss } self.test_summaries(lossSummary) if test_loss < self.last_test_loss: self.last_test_loss = test_loss self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, 0, step+1, self.options.batch_size, self.test_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Better test checkpoint saved') def save_sample(self, data, prediction, ind = 0, saveTo = 'test'): """Saving a sample for visualization and comparison""" assert saveTo in ('test', 'train'), 'save to train or test folder' folder = pjn(self.options.summary_dir, '%s/'%(saveTo)) _input = pjn(folder, 'input_image.png') # batchs = self.options.batch_size # save the input image, if not saved if not isfile(_input): plt.imsave(_input, data['img'][ind].cpu().permute(1,2,0).clamp(0,1).numpy()) # overlap the prediction to the real image; as the mesh .obj has diff # ft/uv coord from MPI lib and MGN, we flip the predicted texture. save_renderer = simple_renderer(batch_size = 1) predTrans = torch.stack( [prediction['theta'][ind, 1], prediction['theta'][ind, 2], 2 * 1000. / (224. * prediction['theta'][ind, 0] + 1e-9)], dim=-1) tex = prediction['tex_image'][ind].flip(dims=(0,))[None] pred_img = save_renderer( verts = prediction['verts'][ind][None], faces = self.faces[ind][None], verts_uvs = self.smpl_verts_uvs[None], faces_uvs = self.smpl_tri_ind[None], tex_image = tex, R = torch.eye(3)[None].to('cuda'), T = predTrans, f = torch.ones([1,1]).to('cuda')*1000, C = torch.ones([1,2]).to('cuda')*112, imgres = 224) overlayimg = 0.9*pred_img[0,:,:,:3] + 0.1*data['img'][ind].permute(1,2,0) save_path = pjn(folder,'overlay_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count)) plt.imsave(save_path, (overlayimg.clamp(0, 1)*255).cpu().numpy().astype('uint8')) save_path = pjn(folder,'tex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count)) plt.imsave(save_path, (pred_img[0,:,:,:3].clamp(0, 1)*255).cpu().numpy().astype('uint8')) save_path = pjn(folder,'unwarptex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count)) plt.imsave(save_path, (prediction['unwarp_tex'][ind].clamp(0, 1)*255).cpu().numpy().astype('uint8')) save_path = pjn(folder,'predtex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count)) plt.imsave(save_path, (prediction['tex_image'][ind].clamp(0, 1)*255).cpu().numpy().astype('uint8')) # create predicted posed undressed body vetirces offPred_t = (prediction['verts_disp'][ind]*self.dispPara[1]+self.dispPara[0]).cpu()[None,:] predDressbody = create_smplD_psbody( self.smplD, offPred_t, prediction['theta'][ind][3:75][None].cpu(), prediction['theta'][ind][75:][None].cpu(), 0, rtnMesh=True)[1] # Create meshes and save savepath = pjn(folder,'%s_iters%d.obj'%\ (data['imgname'][ind].split('/')[-1][:-4], self.step_count)) predDressbody.write_obj(savepath)
def run_evaluation(model, dataset_name, dataset, mesh, batch_size=32, img_res=224, num_workers=32, shuffle=False, log_freq=50): """Run evaluation on the datasets and metrics we report in the paper. """ renderer = PartRenderer() # Create SMPL model smpl = SMPL().cuda() # Create dataloader for the dataset data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # Transfer model to the GPU device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') model.to(device) model.eval() # 2D pose metrics pose_2d_m = np.zeros(len(dataset)) pose_2d_m_smpl = np.zeros(len(dataset)) # Shape metrics # Mean per-vertex error shape_err = np.zeros(len(dataset)) shape_err_smpl = np.zeros(len(dataset)) eval_2d_pose = False eval_shape = False # Choose appropriate evaluation for each dataset if dataset_name == 'up-3d': eval_shape = True elif dataset_name == 'lsp': eval_2d_pose = True # Iterate over the entire dataset for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))): # Get ground truth annotations from the batch gt_pose = batch['pose'].to(device) gt_betas = batch['betas'].to(device) gt_vertices = smpl(gt_pose, gt_betas) images = batch['img'].to(torch.device('cuda')) curr_batch_size = images.shape[0] gt_keypoints_2d = batch['keypoints'].cpu().numpy() # Run inference with torch.no_grad(): pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas = model(images) # If mask or part evaluation, render the mask and part images if eval_2d_pose: mask, parts = renderer(pred_vertices, camera) # 2D pose evaluation (for LSP) if eval_2d_pose: for i in range(curr_batch_size): gt_kp = gt_keypoints_2d[i, list(range(14))] # Get 3D and projected 2D keypoints from the regressed shape pred_keypoints_3d = smpl.get_joints(pred_vertices) pred_keypoints_2d = orthographic_projection(pred_keypoints_3d, camera)[:, :, :2] pred_keypoints_3d_smpl = smpl.get_joints(pred_vertices_smpl) pred_keypoints_2d_smpl = orthographic_projection(pred_keypoints_3d_smpl, camera.detach())[:, :, :2] pred_kp = pred_keypoints_2d.cpu().numpy()[i, list(range(14))] pred_kp_smpl = pred_keypoints_2d_smpl.cpu().numpy()[i, list(range(14))] # Compute 2D pose losses loss_2d_pose = np.sum((gt_kp[: , :2] - pred_kp)**2) loss_2d_pose_smpl = np.sum((gt_kp[: , :2] - pred_kp_smpl)**2) #print(gt_kp) #print() #print(pred_kp_smpl) #raw_input("Press Enter to continue...") pose_2d_m[step * batch_size + i] = loss_2d_pose pose_2d_m_smpl[step * batch_size + i] = loss_2d_pose_smpl # Shape evaluation (Mean per-vertex error) if eval_shape: se = torch.sqrt(((pred_vertices - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() se_smpl = torch.sqrt(((pred_vertices_smpl - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() shape_err[step * batch_size:step * batch_size + curr_batch_size] = se shape_err_smpl[step * batch_size:step * batch_size + curr_batch_size] = se_smpl # Print intermediate results during evaluation if step % log_freq == log_freq - 1: if eval_2d_pose: print('2D keypoints (NonParam): ' + str(1000 * pose_2d_m[:step * batch_size].mean())) print('2D keypoints (Param): ' + str(1000 * pose_2d_m_smpl[:step * batch_size].mean())) print() if eval_shape: print('Shape Error (NonParam): ' + str(1000 * shape_err[:step * batch_size].mean())) print('Shape Error (Param): ' + str(1000 * shape_err_smpl[:step * batch_size].mean())) print() # Print and store final results during evaluation print('*** Final Results ***') print() if eval_2d_pose: print('2D keypoints (NonParam): ' + str(1000 * pose_2d_m.mean())) print('2D keypoints (Param): ' + str(1000 * pose_2d_m_smpl.mean())) print() # store results #np.savez("../eval-output/CMR_no_extra_lsp_model_alt.npz", imgnames = dataset.imgname, kp_2d_err_graph = pose_2d_m, kp_2d_err_smpl = pose_2d_m_smpl) if eval_shape: print('Shape Error (NonParam): ' + str(1000 * shape_err.mean())) print('Shape Error (Param): ' + str(1000 * shape_err_smpl.mean())) print()
class Trainer(BaseTrainer): """Trainer object. Inherits from BaseTrainer that sets up logging, saving/restoring checkpoints etc. """ def init_fn(self): # create training dataset self.train_ds = create_dataset(self.options.dataset, self.options) # create Mesh object self.mesh = Mesh() self.faces = self.mesh.faces.to(self.device) # create GraphCNN self.graph_cnn = GraphCNN(self.mesh.adjmat, self.mesh.ref_vertices.t(), num_channels=self.options.num_channels, num_layers=self.options.num_layers).to( self.device) # SMPL Parameter regressor self.smpl_param_regressor = SMPLParamRegressor().to(self.device) # Setup a joint optimizer for the 2 models self.optimizer = torch.optim.Adam( params=list(self.graph_cnn.parameters()) + list(self.smpl_param_regressor.parameters()), lr=self.options.lr, betas=(self.options.adam_beta1, 0.999), weight_decay=self.options.wd) # SMPL model self.smpl = SMPL().to(self.device) # Create loss functions self.criterion_shape = nn.L1Loss().to(self.device) self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device) self.criterion_regr = nn.MSELoss().to(self.device) # Pack models and optimizers in a dict - necessary for checkpointing self.models_dict = { 'graph_cnn': self.graph_cnn, 'smpl_param_regressor': self.smpl_param_regressor } self.optimizers_dict = {'optimizer': self.optimizer} # Renderer for visualization self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy()) # LSP indices from full list of keypoints self.to_lsp = list(range(14)) # Optionally start training from a pretrained checkpoint # Note that this is different from resuming training # For the latter use --resume if self.options.pretrained_checkpoint is not None: self.load_pretrained( checkpoint_file=self.options.pretrained_checkpoint) def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d): """Compute 2D reprojection loss on the keypoints. The confidence is binary and indicates whether the keypoints exist or not. The available keypoints are different for each dataset. """ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() loss = (conf * self.criterion_keypoints( pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() return loss def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): """Compute 3D keypoint loss for the examples that 3D keypoint annotations are available. The loss is weighted by the confidence """ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] conf = conf[has_pose_3d == 1] pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] if len(gt_keypoints_3d) > 0: gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] pred_pelvis = (pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] return (conf * self.criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() else: return torch.FloatTensor(1).fill_(0.).to(self.device) def shape_loss(self, pred_vertices, gt_vertices, has_smpl): """Compute per-vertex loss on the shape for the examples that SMPL annotations are available.""" pred_vertices_with_shape = pred_vertices[has_smpl == 1] gt_vertices_with_shape = gt_vertices[has_smpl == 1] if len(gt_vertices_with_shape) > 0: return self.criterion_shape(pred_vertices_with_shape, gt_vertices_with_shape) else: return torch.FloatTensor(1).fill_(0.).to(self.device) def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl): """Compute SMPL parameter loss for the examples that SMPL annotations are available.""" pred_rotmat_valid = pred_rotmat[has_smpl == 1].view(-1, 3, 3) gt_rotmat_valid = rodrigues(gt_pose[has_smpl == 1].view(-1, 3)) pred_betas_valid = pred_betas[has_smpl == 1] gt_betas_valid = gt_betas[has_smpl == 1] if len(pred_rotmat_valid) > 0: loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid) loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid) else: loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device) loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device) return loss_regr_pose, loss_regr_betas def train_step(self, input_batch): """Training step.""" self.graph_cnn.train() self.smpl_param_regressor.train() # Grab data from the batch gt_keypoints_2d = input_batch['keypoints'] gt_keypoints_3d = input_batch['pose_3d'] gt_pose = input_batch['pose'] gt_betas = input_batch['betas'] has_smpl = input_batch['has_smpl'] has_pose_3d = input_batch['has_pose_3d'] images = input_batch['img'] # Render vertices using SMPL parameters gt_vertices = self.smpl(gt_pose, gt_betas) batch_size = gt_vertices.shape[0] # Feed image in the GraphCNN # Returns subsampled mesh and camera parameters pred_vertices_sub, pred_camera = self.graph_cnn(images) # Upsample mesh in the original size pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1, 2)) # Prepare input for SMPL Parameter regressor # The input is the predicted and template vertices subsampled by a factor of 4 # Notice that we detach the GraphCNN x = pred_vertices_sub.transpose(1, 2).detach() x = torch.cat( [x, self.mesh.ref_vertices[None, :, :].expand(batch_size, -1, -1)], dim=-1) # Estimate SMPL parameters and render vertices pred_rotmat, pred_shape = self.smpl_param_regressor(x) pred_vertices_smpl = self.smpl(pred_rotmat, pred_shape) # Get 3D and projected 2D keypoints from the regressed shape pred_keypoints_3d = self.smpl.get_joints(pred_vertices) pred_keypoints_2d = orthographic_projection(pred_keypoints_3d, pred_camera)[:, :, :2] pred_keypoints_3d_smpl = self.smpl.get_joints(pred_vertices_smpl) pred_keypoints_2d_smpl = orthographic_projection( pred_keypoints_3d_smpl, pred_camera.detach())[:, :, :2] # Compute losses # GraphCNN losses loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d) loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, has_pose_3d) loss_shape = self.shape_loss(pred_vertices, gt_vertices, has_smpl) # SMPL regressor losses loss_keypoints_smpl = self.keypoint_loss(pred_keypoints_2d_smpl, gt_keypoints_2d) loss_keypoints_3d_smpl = self.keypoint_3d_loss(pred_keypoints_3d_smpl, gt_keypoints_3d, has_pose_3d) loss_shape_smpl = self.shape_loss(pred_vertices_smpl, gt_vertices, has_smpl) loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_shape, gt_pose, gt_betas, has_smpl) # Add losses to compute the total loss loss = loss_shape_smpl + loss_keypoints_smpl + loss_keypoints_3d_smpl +\ loss_regr_pose + 0.1 * loss_regr_betas + loss_shape + loss_keypoints + loss_keypoints_3d # Do backprop self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Pack output arguments to be used for visualization in a list out_args = [ pred_vertices, pred_vertices_smpl, pred_camera, pred_keypoints_2d, pred_keypoints_2d_smpl, loss_shape, loss_shape_smpl, loss_keypoints, loss_keypoints_smpl, loss_keypoints_3d, loss_keypoints_3d_smpl, loss_regr_pose, loss_regr_betas, loss ] out_args = [arg.detach() for arg in out_args] return out_args def train_summaries(self, input_batch, pred_vertices, pred_vertices_smpl, pred_camera, pred_keypoints_2d, pred_keypoints_2d_smpl, loss_shape, loss_shape_smpl, loss_keypoints, loss_keypoints_smpl, loss_keypoints_3d, loss_keypoints_3d_smpl, loss_regr_pose, loss_regr_betas, loss): """Tensorboard logging.""" gt_keypoints_2d = input_batch['keypoints'].cpu().numpy() rend_imgs = [] rend_imgs_smpl = [] batch_size = pred_vertices.shape[0] # Do visualization for the first 4 images of the batch for i in range(min(batch_size, 4)): img = input_batch['img_orig'][i].cpu().numpy().transpose(1, 2, 0) # Get LSP keypoints from the full list of keypoints gt_keypoints_2d_ = gt_keypoints_2d[i, self.to_lsp] pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, self.to_lsp] pred_keypoints_2d_smpl_ = pred_keypoints_2d_smpl.cpu().numpy()[ i, self.to_lsp] # Get GraphCNN and SMPL vertices for the particular example vertices = pred_vertices[i].cpu().numpy() vertices_smpl = pred_vertices_smpl[i].cpu().numpy() cam = pred_camera[i].cpu().numpy() cam = pred_camera[i].cpu().numpy() # Visualize reconstruction and detected pose rend_img = visualize_reconstruction(img, self.options.img_res, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, self.renderer) rend_img_smpl = visualize_reconstruction(img, self.options.img_res, gt_keypoints_2d_, vertices_smpl, pred_keypoints_2d_smpl_, cam, self.renderer) rend_img = rend_img.transpose(2, 0, 1) rend_img_smpl = rend_img_smpl.transpose(2, 0, 1) rend_imgs.append(torch.from_numpy(rend_img)) rend_imgs_smpl.append(torch.from_numpy(rend_img_smpl)) rend_imgs = make_grid(rend_imgs, nrow=1) rend_imgs_smpl = make_grid(rend_imgs_smpl, nrow=1) # Save results in Tensorboard self.summary_writer.add_image('imgs', rend_imgs, self.step_count) self.summary_writer.add_image('imgs_smpl', rend_imgs_smpl, self.step_count) self.summary_writer.add_scalar('loss_shape', loss_shape, self.step_count) self.summary_writer.add_scalar('loss_shape_smpl', loss_shape_smpl, self.step_count) self.summary_writer.add_scalar('loss_regr_pose', loss_regr_pose, self.step_count) self.summary_writer.add_scalar('loss_regr_betas', loss_regr_betas, self.step_count) self.summary_writer.add_scalar('loss_keypoints', loss_keypoints, self.step_count) self.summary_writer.add_scalar('loss_keypoints_smpl', loss_keypoints_smpl, self.step_count) self.summary_writer.add_scalar('loss_keypoints_3d', loss_keypoints_3d, self.step_count) self.summary_writer.add_scalar('loss_keypoints_3d_smpl', loss_keypoints_3d_smpl, self.step_count) self.summary_writer.add_scalar('loss', loss, self.step_count)
def run_evaluation(model, args, dataset, mesh): """Run evaluation on the datasets and metrics we report in the paper. """ # Create SMPL model smpl = SMPL().cuda() # Regressor for H36m joints J_regressor = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M)).float() # Create dataloader for the dataset data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) # Transfer model to the GPU device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') model.to(device) model.eval() # predictions all_kps = {} # Iterate over the entire dataset for step, batch in enumerate( tqdm(data_loader, desc='Eval', total=args.num_imgs)): # Get ground truth annotations from the batch images = batch['img'].to(device) curr_batch_size = images.shape[0] # Run inference with torch.no_grad(): pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas = model( images) pred_keypoints_3d_smpl = smpl.get_joints(pred_vertices_smpl) pred_keypoints_2d_smpl = orthographic_projection( pred_keypoints_3d_smpl, camera.detach())[:, :, :2].cpu().data.numpy() eval_part = np.zeros((1, 19, 2)) # we use custom keypoints for evaluation: MPII + COCO face joints # see paper / supplementary for details eval_part[0, :14, :] = pred_keypoints_2d_smpl[0][:14] eval_part[0, 14:, :] = pred_keypoints_2d_smpl[0][19:] all_kps[step] = eval_part if args.write_imgs == 'True': renderer = Renderer(faces=smpl.faces.cpu().numpy()) write_imgs(batch, pred_keypoints_2d_smpl, pred_vertices_smpl, camera, args, renderer) if args.eval_pck == 'True': gt_kp_path = os.path.join(cfg.BASE_DATA_DIR, args.dataset, args.crop_setting, 'keypoints.pkl') log_dir = os.path.join(cfg.BASE_DATA_DIR, 'cmr_pck_results.txt') with open(gt_kp_path, 'rb') as f: gt = pkl.load(f) calc = CalcPCK( all_kps, gt, num_imgs=cfg.DATASET_SIZES[args.dataset][args.crop_setting], log_dir=log_dir, dataset=args.dataset, crop_setting=args.crop_setting, pck_eval_threshold=args.pck_eval_threshold) calc.eval()