def closure(): camera_optimizer.zero_grad() betas_ext = arrange_betas(body_pose, betas) smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas_ext) model_joints = smpl_output.joints loss = temporal_camera_fitting_loss(model_joints, camera_translation, init_cam_t, camera_center, joints_2d, joints_conf, focal_length=self.focal_length) loss.backward() return loss
def __call__(self, init_pose, init_betas, init_cam_t, camera_center, keypoints_2d): """Perform body fitting. Input: init_pose: SMPL pose estimate init_betas: SMPL betas estimate init_cam_t: Camera translation estimate camera_center: Camera center location keypoints_2d: Keypoints used for the optimization Returns: vertices: Vertices of optimized shape joints: 3D joints of optimized shape pose: SMPL pose parameters of optimized shape betas: SMPL beta parameters of optimized shape camera_translation: Camera translation reprojection_loss: Final joint reprojection loss """ # Make camera translation a learnable parameter camera_translation = init_cam_t.clone() # Get joint confidence joints_2d = keypoints_2d[:, :, :2] joints_conf = keypoints_2d[:, :, -1] # Split SMPL pose to body pose and global orientation body_pose = init_pose[:, 3:].detach().clone() global_orient = init_pose[:, :3].detach().clone() betas = init_betas.detach().clone() # Step 1: Optimize camera translation and body orientation # Optimize only camera translation and body orientation body_pose.requires_grad = False betas.requires_grad = False global_orient.requires_grad = True camera_translation.requires_grad = True camera_opt_params = [global_orient, camera_translation] if self.use_lbfgs: camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.max_iter, lr=self.step_size, line_search_fn='strong_wolfe') for i in range(self.num_iters): def closure(): camera_optimizer.zero_grad() betas_ext = arrange_betas(body_pose, betas) smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas_ext) model_joints = smpl_output.joints loss = temporal_camera_fitting_loss(model_joints, camera_translation, init_cam_t, camera_center, joints_2d, joints_conf, focal_length=self.focal_length) loss.backward() return loss camera_optimizer.step(closure) else: camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) for i in range(self.num_iters): betas_ext = arrange_betas(body_pose, betas) smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas_ext) model_joints = smpl_output.joints loss = temporal_camera_fitting_loss(model_joints, camera_translation, init_cam_t, camera_center, joints_2d, joints_conf, focal_length=self.focal_length) camera_optimizer.zero_grad() loss.backward() camera_optimizer.step() # Fix camera translation after optimizing camera camera_translation.requires_grad = False # Step 2: Optimize body joints # Optimize only the body pose and global orientation of the body body_pose.requires_grad = True betas.requires_grad = True global_orient.requires_grad = True camera_translation.requires_grad = False body_opt_params = [body_pose, betas, global_orient] # For joints ignored during fitting, set the confidence to 0 joints_conf[:, self.ign_joints] = 0. if self.use_lbfgs: body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.max_iter, lr=self.step_size, line_search_fn='strong_wolfe') for i in range(self.num_iters): def closure(): body_optimizer.zero_grad() betas_ext = arrange_betas(body_pose, betas) smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas_ext) model_joints = smpl_output.joints loss = temporal_body_fitting_loss(body_pose, betas, model_joints, camera_translation, camera_center, joints_2d, joints_conf, self.pose_prior, focal_length=self.focal_length) loss.backward() return loss body_optimizer.step(closure) else: body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) for i in range(self.num_iters): betas_ext = arrange_betas(body_pose, betas) smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas_ext) model_joints = smpl_output.joints loss = temporal_body_fitting_loss(body_pose, betas, model_joints, camera_translation, camera_center, joints_2d, joints_conf, self.pose_prior, focal_length=self.focal_length) body_optimizer.zero_grad() loss.backward() body_optimizer.step() # scheduler.step(epoch=i) # Get final loss value with torch.no_grad(): betas_ext = arrange_betas(body_pose, betas) smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas_ext, return_full_pose=True) model_joints = smpl_output.joints reprojection_loss = temporal_body_fitting_loss(body_pose, betas, model_joints, camera_translation, camera_center, joints_2d, joints_conf, self.pose_prior, focal_length=self.focal_length, output='reprojection') vertices = smpl_output.vertices.detach() joints = smpl_output.joints.detach() pose = torch.cat([global_orient, body_pose], dim=-1).detach() betas = betas.detach() # Back to weak perspective camera camera_translation = torch.stack([ 2 * 5000. / (224 * camera_translation[:,2] + 1e-9), camera_translation[:,0], camera_translation[:,1] ], dim=-1) betas = betas.repeat(pose.shape[0],1) output = { 'theta': torch.cat([camera_translation, pose, betas], dim=1), 'verts': vertices, 'kp_3d': joints, } return output, reprojection_loss