def camera_fitting_loss(model_keypoints, rotation, camera_t, focal_length, camera_center, keypoints_2d, keypoints_conf, distortion=None): # Project model keypoints projected_keypoints = perspective_projection(model_keypoints, rotation, camera_t, focal_length, camera_center, distortion) # Disable Wing Tips keypoints_conf = keypoints_conf.detach().clone() keypoints_conf[:, 5:7] = 0 # Weighted robust reprojection loss sigma = 50 reprojection_error = gmof(projected_keypoints - keypoints_2d, sigma) reprojection_loss = (keypoints_conf**2) * reprojection_error.sum(dim=-1) total_loss = reprojection_loss.sum(dim=-1) return total_loss.sum()
def kpts_fitting_loss(model_keypoints, focal_length, camera_center, keypoints_2d, keypoints_conf, body_pose, bone_length, prior_weight=1, pose_init=None, bone_init=None, sigma=100): device = body_pose.device # Project model keypoints projected_keypoints = perspective_projection(model_keypoints, None, None, focal_length, camera_center) # Weighted robust reprojection loss reprojection_error = gmof(projected_keypoints - keypoints_2d, sigma) reprojection_loss = (keypoints_conf**2) * reprojection_error.sum(dim=-1) # If provided pose/bone initialization, constraint objective from deviation from it if pose_init == None or bone_init == None: total_loss = reprojection_loss.sum(dim=-1) else: init_loss = (body_pose - pose_init).abs().sum() + ( bone_length - bone_init).abs().sum() init_loss = init_loss * prior_weight total_loss = reprojection_loss.sum(dim=-1) + init_loss.sum() return total_loss.sum()
def reproject_keypoints(mesh_keypoints, frames): cam_rot, cam_t, focal, center, distortion = mutils.get_cam(frames) kpts = mesh_keypoints.repeat([len(frames), 1, 1]) proj_kpts = perspective_projection(kpts, cam_rot, cam_t, focal, center, distortion) return proj_kpts
def body_fitting_loss(model_keypoints, rotation, camera_t, focal_length, camera_center, keypoints_2d, keypoints_conf, body_pose, bone_length, sigma=50, lim_weight=1, prior_weight=1, bone_weight=1, distortion=None, pose_init=None, bone_init=None): # Project model keypoints device = body_pose.device projected_keypoints = perspective_projection(model_keypoints, rotation, camera_t, focal_length, camera_center, distortion) # Weighted robust reprojection loss reprojection_error = gmof(projected_keypoints - keypoints_2d, sigma) reprojection_loss = (keypoints_conf**2) * reprojection_error.sum(dim=-1) # Joint angle limit loss max_lim = torch.tensor(constants.max_lim).repeat(1, 1).to(device) min_lim = torch.tensor(constants.min_lim).repeat(1, 1).to(device) lim_loss = (body_pose - max_lim).clamp( 0, float("Inf")) + (min_lim - body_pose).clamp(0, float("Inf")) lim_loss = lim_weight * lim_loss # Prior Loss if pose_init == None or bone_init == None: prior_loss = body_pose.abs() prior_loss = prior_weight * prior_loss else: prior_loss = (body_pose - pose_init).abs().sum() + ( bone_length - bone_init).abs().sum() prior_loss = prior_weight * prior_loss # Bone Length Limit Loss max_bone = torch.tensor(constants.max_bone).repeat(1, 1).to(device) min_bone = torch.tensor(constants.min_bone).repeat(1, 1).to(device) bone_loss = (bone_length - max_bone).clamp( 0, float("Inf")) + (min_bone - bone_length).clamp(0, float("Inf")) bone_loss = bone_weight * bone_loss total_loss = (reprojection_loss.sum(dim=-1) + lim_loss.sum() + prior_loss.sum() + bone_loss.sum()) return total_loss.sum()
def projection(cam, s3d, eps=1e-9): cam_t = torch.stack([ cam[:, 1], cam[:, 2], 2 * constants.FOCAL_LENGTH / (constants.IMG_RES * cam[:, 0] + eps) ], dim=-1) camera_center = torch.zeros(s3d.shape[0], 2, device=device) s2d = perspective_projection(s3d, rotation=torch.eye( 3, device=device).unsqueeze(0).expand( s3d.shape[0], -1, -1), translation=cam_t, focal_length=constants.FOCAL_LENGTH, camera_center=camera_center) s2d_norm = s2d / (constants.IMG_RES / 2.) # to [-1,1] return {'ori': s2d, 'normed': s2d_norm}
def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, joints_2d, joints_conf, focal_length=5000, depth_loss_weight=100): """ Loss function for camera optimization. """ # Project model joints batch_size = model_joints.shape[0] rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand( batch_size, -1, -1) #print(model_joints.size(),rotation.size(),camera_t.size(),camera_center.size()) projected_joints = perspective_projection(model_joints, rotation, camera_t, focal_length, camera_center) op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] op_joints_ind = [constants.JOINT_IDS[joint] for joint in op_joints] gt_joints = ['Right Hip', 'Left Hip', 'Right Shoulder', 'Left Shoulder'] gt_joints_ind = [constants.JOINT_IDS[joint] for joint in gt_joints] #print(op_joints_ind,gt_joints_ind) #print(joints_2d[:, gt_joints_ind]) #input() reprojection_error_op = (joints_2d[:, op_joints_ind] - projected_joints[:, op_joints_ind])**2 reprojection_error_gt = (joints_2d[:, gt_joints_ind] - projected_joints[:, gt_joints_ind])**2 #print('joint_2d',joints_2d[:, gt_joints_ind]) #print('projected_2d',projected_joints[:, gt_joints_ind]) # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections # OpenPose joints are more reliable for this task, so we prefer to use them if possible is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] > 0).float() reprojection_loss = (is_valid * reprojection_error_op + (1 - is_valid) * reprojection_error_gt).sum(dim=(1, 2)) # Loss that penalizes deviation from depth estimate depth_loss = (depth_loss_weight** 2) * (camera_t[:, 2] - camera_t_est[:, 2])**2 total_loss = reprojection_loss + depth_loss return total_loss.sum()
def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center, joints_2d, joints_conf, pose_prior, focal_length=5000, sigma=100, pose_prior_weight=4.78, shape_prior_weight=5, angle_prior_weight=15.2, output='sum'): """ Loss function for body fitting """ batch_size = body_pose.shape[0] rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand( batch_size, -1, -1) projected_joints = perspective_projection(model_joints, rotation, camera_t, focal_length, camera_center) # Weighted robust reprojection error reprojection_error = gmof(projected_joints - joints_2d, sigma) reprojection_loss = (joints_conf**2) * reprojection_error.sum(dim=-1) # Pose prior loss pose_prior_loss = (pose_prior_weight**2) * pose_prior(body_pose, betas) # Angle prior for knees and elbows angle_prior_loss = (angle_prior_weight** 2) * angle_prior(body_pose).sum(dim=-1) # Regularizer to prevent betas from taking large values shape_prior_loss = (shape_prior_weight**2) * (betas**2).sum(dim=-1) total_loss = reprojection_loss.sum( dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss if output == 'sum': return total_loss.sum() elif output == 'reprojection': return reprojection_loss
def body_fitting_loss_smplify_x(body_pose, betas, pose_embedding, camera_t, camera_center, model_joints, joints_conf, joints_2d, focal_length=5000, sigma=100, body_pose_weight=4.78, shape_prior_weight=5, angle_prior_weight=15.2, output='sum'): batch_size = body_pose.shape[0] rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand( batch_size, -1, -1) projected_joints = perspective_projection(model_joints, rotation, camera_t, focal_length, camera_center) # Weighted robust reprojection error reprojection_error = gmof(projected_joints - joints_2d, sigma) reprojection_loss = (joints_conf**2) * reprojection_error.sum(dim=-1) pose_prior_loss = (pose_embedding.pow(2).sum() * body_pose_weight**2) shape_prior_loss = (shape_prior_weight**2) * (betas**2).sum(dim=-1) angle_prior_loss = (angle_prior_weight** 2) * angle_prior(body_pose).sum(dim=-1) total_loss = reprojection_loss.sum( dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss if output == 'sum': return total_loss.sum() elif output == 'reprojection': return reprojection_loss
def train_step(self, input_batch): self.model.train() # Get data from the batch images = input_batch['img'] # input image gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints gt_pose = input_batch['pose'] # SMPL pose parameters gt_betas = input_batch['betas'] # SMPL beta parameters gt_joints = input_batch['pose_3d'] # 3D pose has_smpl = input_batch['has_smpl'].byte( ) # flag that indicates whether SMPL parameters are valid has_pose_3d = input_batch['has_pose_3d'].byte( ) # flag that indicates whether 3D pose is valid is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation dataset_name = input_batch[ 'dataset_name'] # name of the dataset the image comes from indices = input_batch[ 'sample_index'] # index of example inside its dataset batch_size = images.shape[0] # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current best fits from the dictionary opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = opt_pose.to(self.device) opt_betas = opt_betas.to(self.device) opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices if opt_vertices.shape != (self.options.batch_size, 6890, 3): opt_vertices = torch.zeros_like(opt_vertices, device=self.device) opt_joints = opt_output.joints # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_joint_loss = self.smplify.get_fitting_loss( opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device), gt_keypoints_2d_orig).mean(dim=-1) # Feed images in the network to predict camera and SMPL parameters pred_rotmat, pred_betas, pred_camera = self.model(images) pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices if pred_vertices.shape != (self.options.batch_size, 6890, 3): pred_vertices = torch.zeros_like(pred_vertices, device=self.device) pred_joints = pred_output.joints # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (self.options.img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) # Normalize keypoints to [-1,1] pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.) if self.options.run_smplify: # Convert predicted rotation matrices to axis-angle pred_rotmat_hom = torch.cat([ pred_rotmat.detach().view(-1, 3, 3).detach(), torch.tensor( [0, 0, 1], dtype=torch.float32, device=self.device).view( 1, 3, 1).expand(batch_size * 24, -1, -1) ], dim=-1) pred_pose = rotation_matrix_to_angle_axis( pred_rotmat_hom).contiguous().view(batch_size, -1) # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it pred_pose[torch.isnan(pred_pose)] = 0.0 # Run SMPLify optimization starting from the network prediction new_opt_vertices, new_opt_joints,\ new_opt_pose, new_opt_betas,\ new_opt_cam_t, new_opt_joint_loss = self.smplify( pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), 0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device), gt_keypoints_2d_orig) new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1) # Will update the dictionary for the examples where the new loss is less than the current one update = (new_opt_joint_loss < opt_joint_loss) opt_joint_loss[update] = new_opt_joint_loss[update] opt_vertices[update, :] = new_opt_vertices[update, :] opt_joints[update, :] = new_opt_joints[update, :] opt_pose[update, :] = new_opt_pose[update, :] opt_betas[update, :] = new_opt_betas[update, :] opt_cam_t[update, :] = new_opt_cam_t[update, :] self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu(), update.cpu())] = (opt_pose.cpu(), opt_betas.cpu()) else: update = torch.zeros(batch_size, device=self.device).byte() # Replace extreme betas with zero betas opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. # Replace the optimized parameters with the ground truth parameters, if available opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :] opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :] opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :] opt_pose[has_smpl, :] = gt_pose[has_smpl, :] opt_betas[has_smpl, :] = gt_betas[has_smpl, :] # Assert whether a fit is valid by comparing the joint loss with the threshold valid_fit = (opt_joint_loss < self.options.smplify_threshold).to( self.device) # Add the examples with GT parameters to the list of valid fits # print(valid_fit.dtype) valid_fit = valid_fit.to(torch.uint8) valid_fit = valid_fit | has_smpl opt_keypoints_2d = perspective_projection( opt_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=opt_cam_t, focal_length=self.focal_length, camera_center=camera_center) opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.) # Compute loss on SMPL parameters loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit) # Compute 2D reprojection loss for the keypoints loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d, self.options.openpose_train_weight, self.options.gt_train_weight) # Compute 3D keypoint loss loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints, has_pose_3d) # Per-vertex loss for the shape loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit) # Compute total loss # The last component is a loss that forces the network to predict positive depth values loss = self.options.shape_loss_weight * loss_shape +\ self.options.keypoint_loss_weight * loss_keypoints +\ self.options.keypoint_loss_weight * loss_keypoints_3d +\ loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\ ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() loss *= 60 # Do backprop self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Pack output arguments for tensorboard logging output = { 'pred_vertices': pred_vertices.detach(), 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t.detach(), 'opt_cam_t': opt_cam_t } losses = { 'loss': loss.detach().item(), 'loss_keypoints': loss_keypoints.detach().item(), 'loss_keypoints_3d': loss_keypoints_3d.detach().item(), 'loss_regr_pose': loss_regr_pose.detach().item(), 'loss_regr_betas': loss_regr_betas.detach().item(), 'loss_shape': loss_shape.detach().item() } return output, losses
def train_step(self, input_batch): self.model.train() # Get data from the batch images = input_batch['img'] # input image gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints gt_pose = input_batch['pose'] # SMPL pose parameters gt_betas = input_batch['betas'] # SMPL beta parameters gt_joints = input_batch['pose_3d'] # 3D pose has_smpl = input_batch['has_smpl'].to( torch.bool ) # flag that indicates whether SMPL parameters are valid has_pose_3d = input_batch['has_pose_3d'].to( torch.bool) # flag that indicates whether 3D pose is valid is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation dataset_name = input_batch[ 'dataset_name'] # name of the dataset the image comes from indices = input_batch[ 'sample_index'] # index of example inside its dataset batch_size = images.shape[0] # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current best fits from the dictionary opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = opt_pose.to(self.device) opt_betas = opt_betas.to(self.device) # Replace extreme betas with zero betas opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. # Replace the optimized parameters with the ground truth parameters, if available opt_pose[has_smpl, :] = gt_pose[has_smpl, :] opt_betas[has_smpl, :] = gt_betas[has_smpl, :] opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices opt_joints = opt_output.joints input_batch['verts'] = opt_vertices # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) # get fitted smpl parameters as pseudo ground truth valid_fit = self.fits_dict.get_vaild_state( dataset_name, indices.cpu()).to(torch.bool).to(self.device) try: valid_fit = valid_fit | has_smpl except RuntimeError: valid_fit = (valid_fit.byte() | has_smpl.byte()).to(torch.bool) # Render Dense Correspondences if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON: gt_cam_t_nr = opt_cam_t.detach().clone() gt_camera = torch.zeros(gt_cam_t_nr.shape).to(gt_cam_t_nr.device) gt_camera[:, 1:] = gt_cam_t_nr[:, :2] gt_camera[:, 0] = (2. * self.focal_length / self.options.img_res) / gt_cam_t_nr[:, 2] iuv_image_gt = torch.zeros( (batch_size, 3, cfg.MODEL.PyMAF.DP_HEATMAP_SIZE, cfg.MODEL.PyMAF.DP_HEATMAP_SIZE)).to(self.device) if torch.sum(valid_fit.float()) > 0: iuv_image_gt[valid_fit] = self.iuv_maker.verts2iuvimg( opt_vertices[valid_fit], cam=gt_camera[valid_fit]) # [B, 3, 56, 56] input_batch['iuv_image_gt'] = iuv_image_gt uvia_list = iuv_img2map(iuv_image_gt) # Feed images in the network to predict camera and SMPL parameters if self.options.regressor == 'hmr': pred_rotmat, pred_betas, pred_camera = self.model(images) # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3]) elif self.options.regressor == 'pymaf_net': preds_dict, _ = self.model(images) output = preds_dict loss_dict = {} if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON: dp_out = preds_dict['dp_out'] for i in range(len(dp_out)): r_i = i - len(dp_out) u_pred, v_pred, index_pred, ann_pred = dp_out[r_i][ 'predict_u'], dp_out[r_i]['predict_v'], dp_out[r_i][ 'predict_uv_index'], dp_out[r_i]['predict_ann_index'] if index_pred.shape[-1] == iuv_image_gt.shape[-1]: uvia_list_i = uvia_list else: iuv_image_gt_i = F.interpolate(iuv_image_gt, u_pred.shape[-1], mode='nearest') uvia_list_i = iuv_img2map(iuv_image_gt_i) loss_U, loss_V, loss_IndexUV, loss_segAnn = self.body_uv_losses( u_pred, v_pred, index_pred, ann_pred, uvia_list_i, valid_fit) loss_dict[f'loss_U{r_i}'] = loss_U loss_dict[f'loss_V{r_i}'] = loss_V loss_dict[f'loss_IndexUV{r_i}'] = loss_IndexUV loss_dict[f'loss_segAnn{r_i}'] = loss_segAnn len_loop = len(preds_dict['smpl_out'] ) if self.options.regressor == 'pymaf_net' else 1 for l_i in range(len_loop): if self.options.regressor == 'pymaf_net': if l_i == 0: # initial parameters (mean poses) continue pred_rotmat = preds_dict['smpl_out'][l_i]['rotmat'] pred_betas = preds_dict['smpl_out'][l_i]['theta'][:, 3:13] pred_camera = preds_dict['smpl_out'][l_i]['theta'][:, :3] pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices pred_joints = pred_output.joints # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (self.options.img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) # Normalize keypoints to [-1,1] pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.) # Compute loss on SMPL parameters loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit) loss_regr_pose *= cfg.LOSS.POSE_W loss_regr_betas *= cfg.LOSS.SHAPE_W loss_dict['loss_regr_pose_{}'.format(l_i)] = loss_regr_pose loss_dict['loss_regr_betas_{}'.format(l_i)] = loss_regr_betas # Compute 2D reprojection loss for the keypoints if cfg.LOSS.KP_2D_W > 0: loss_keypoints = self.keypoint_loss( pred_keypoints_2d, gt_keypoints_2d, self.options.openpose_train_weight, self.options.gt_train_weight) * cfg.LOSS.KP_2D_W loss_dict['loss_keypoints_{}'.format(l_i)] = loss_keypoints # Compute 3D keypoint loss loss_keypoints_3d = self.keypoint_3d_loss( pred_joints, gt_joints, has_pose_3d) * cfg.LOSS.KP_3D_W loss_dict['loss_keypoints_3d_{}'.format(l_i)] = loss_keypoints_3d # Per-vertex loss for the shape if cfg.LOSS.VERT_W > 0: loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit) * cfg.LOSS.VERT_W loss_dict['loss_shape_{}'.format(l_i)] = loss_shape # Camera # force the network to predict positive depth values loss_cam = ((torch.exp(-pred_camera[:, 0] * 10))**2).mean() loss_dict['loss_cam_{}'.format(l_i)] = loss_cam for key in loss_dict: if len(loss_dict[key].shape) > 0: loss_dict[key] = loss_dict[key][0] # Compute total loss loss = torch.stack(list(loss_dict.values())).sum() # Do backprop self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Pack output arguments for tensorboard logging output.update({ 'pred_vertices': pred_vertices.detach(), 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t.detach(), 'opt_cam_t': opt_cam_t }) loss_dict['loss'] = loss.detach().item() if self.step_count % 100 == 0: if self.options.multiprocessing_distributed: for loss_name, val in loss_dict.items(): val = val / self.options.world_size if not torch.is_tensor(val): val = torch.Tensor([val]).to(self.device) dist.all_reduce(val) loss_dict[loss_name] = val if self.options.rank == 0: for loss_name, val in loss_dict.items(): self.summary_writer.add_scalar( 'losses/{}'.format(loss_name), val, self.step_count) return {'preds': output, 'losses': loss_dict}
def _forward(self, in_dict): iuv_map = in_dict['iuv_map'] part_iuv_map = in_dict[ 'part_iuv_map'] if 'part_iuv_map' in in_dict else None infer_mode = in_dict['infer_mode'] if 'infer_mode' in in_dict else False if cfg.DANET.INPUT_MODE in ['feat', 'iuv_feat', 'iuv_gt_feat']: device_id = iuv_map['feat'].get_device() elif cfg.DANET.INPUT_MODE == 'seg': device_id = iuv_map['index'].get_device() else: device_id = iuv_map.get_device() return_dict = {} return_dict['losses'] = {} return_dict['metrics'] = {} return_dict['visualization'] = {} return_dict['prediction'] = {} if cfg.DANET.DECOMPOSED: smpl_out_dict = self.smpl_para_Outs(iuv_map, part_iuv_map) else: smpl_out_dict = self.smpl_para_Outs(iuv_map) if infer_mode: return smpl_out_dict para = smpl_out_dict['para'] for k, v in smpl_out_dict['visualization'].items(): return_dict['visualization'][k] = v return_dict['prediction']['cam'] = para[:, :3] return_dict['prediction']['shape'] = para[:, 3:13] return_dict['prediction']['pose'] = para[:, 13:].reshape(-1, 24, 3, 3).contiguous() # losses for Training if self.training: batch_size = len(para) target = in_dict['target'] target_kps = in_dict['target_kps'] target_kps3d = in_dict['target_kps3d'] target_vertices = in_dict['target_verts'] has_kp3d = in_dict['has_kp3d'] has_smpl = in_dict['has_smpl'] if cfg.DANET.ORTHOGONAL_WEIGHTS > 0: loss_orth = self.orthogonal_loss(para) loss_orth *= cfg.DANET.ORTHOGONAL_WEIGHTS return_dict['losses']['Rs_orth'] = loss_orth return_dict['metrics']['orth'] = loss_orth.detach() if len(smpl_out_dict['joint_rotation']) > 0: for stack_i in range(len(smpl_out_dict['joint_rotation'])): if torch.sum(has_smpl) > 0: loss_rot = self.criterion_regr( smpl_out_dict['joint_rotation'][stack_i][ has_smpl == 1], target[:, 13:][has_smpl == 1]) loss_rot *= cfg.DANET.SMPL_POSE_WEIGHTS else: loss_rot = torch.zeros(1).to(pred.device) return_dict['losses']['joint_rotation' + str(stack_i)] = loss_rot if cfg.DANET.DECOMPOSED and ( 'joint_position' in smpl_out_dict) and cfg.DANET.JOINT_POSITION_WEIGHTS > 0: gt_beta = target[:, 3:13].contiguous().detach() gt_Rs = target[:, 13:].contiguous().view(-1, 24, 3, 3).detach() smpl_pts = self.smpl(betas=gt_beta, body_pose=gt_Rs[:, 1:], global_orient=gt_Rs[:, 0].unsqueeze(1), pose2rot=False) gt_smpl_coord = smpl_pts.smpl_joints for stack_i in range(len(smpl_out_dict['joint_position'])): loss_pos = self.l1_losses( smpl_out_dict['joint_position'][stack_i], gt_smpl_coord, has_smpl) loss_pos *= cfg.DANET.JOINT_POSITION_WEIGHTS return_dict['losses']['joint_position' + str(stack_i)] = loss_pos pred_camera = para[:, :3] pred_betas = para[:, 3:13] pred_rotmat = para[:, 13:].reshape(-1, 24, 3, 3).contiguous() gt_camera = target[:, :3] gt_betas = target[:, 3:13] gt_rotmat = target[:, 13:] pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices pred_joints = pred_output.joints # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (cfg.DANET.INIMG_SIZE * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) # Normalize keypoints to [-1,1] pred_keypoints_2d = pred_keypoints_2d / (cfg.DANET.INIMG_SIZE / 2.) # Compute loss on predicted camera loss_cam = self.l1_losses(pred_camera, gt_camera, has_smpl) # Compute loss on SMPL parameters loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, gt_rotmat, gt_betas, has_smpl) # Compute 2D reprojection loss for the keypoints loss_keypoints = self.keypoint_loss( pred_keypoints_2d, target_kps, self.options.openpose_train_weight, self.options.gt_train_weight) # Compute 3D keypoint loss loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, target_kps3d, has_kp3d) # Per-vertex loss for the shape loss_verts = self.shape_loss(pred_vertices, target_vertices, has_smpl) # The last component is a loss that forces the network to predict positive depth values return_dict['losses'].update({ 'keypoints_2d': loss_keypoints * cfg.DANET.PROJ_KPS_WEIGHTS, 'keypoints_3d': loss_keypoints_3d * cfg.DANET.KPS3D_WEIGHTS, 'smpl_pose': loss_regr_pose * cfg.DANET.SMPL_POSE_WEIGHTS, 'smpl_betas': loss_regr_betas * cfg.DANET.SMPL_BETAS_WEIGHTS, 'smpl_verts': loss_verts * cfg.DANET.VERTS_WEIGHTS, 'cam': ((torch.exp(-pred_camera[:, 0] * 10))**2).mean() }) return_dict['prediction']['vertices'] = pred_vertices return_dict['prediction']['cam_t'] = pred_cam_t # handle bug on gathering scalar(0-dim) tensors for k, v in return_dict['losses'].items(): if len(v.shape) == 0: return_dict['losses'][k] = v.unsqueeze(0) for k, v in return_dict['metrics'].items(): if len(v.shape) == 0: return_dict['metrics'][k] = v.unsqueeze(0) return return_dict
def run_evaluation(model, dataset_name, dataset, result_file, batch_size=32, img_res=224, num_workers=32, shuffle=False, log_freq=50, options=None): """Run evaluation on the datasets and metrics we report in the paper. """ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # Transfer model to the GPU model.to(device) # Load SMPL model smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR, create_transl=False).to(device) smpl_male = SMPL(path_config.SMPL_MODEL_DIR, gender='male', create_transl=False).to(device) smpl_female = SMPL(path_config.SMPL_MODEL_DIR, gender='female', create_transl=False).to(device) renderer = PartRenderer() # Regressor for H36m joints J_regressor = torch.from_numpy(np.load(path_config.JOINT_REGRESSOR_H36M)).float() save_results = result_file is not None # Disable shuffling if you want to save the results if save_results: shuffle = False # Create dataloader for the dataset data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) fits_dict = None # Pose metrics # MPJPE and Reconstruction error for the non-parametric and parametric shapes mpjpe = np.zeros(len(dataset)) recon_err = np.zeros(len(dataset)) mpjpe_smpl = np.zeros(len(dataset)) recon_err_smpl = np.zeros(len(dataset)) # Store SMPL parameters smpl_pose = np.zeros((len(dataset), 72)) smpl_betas = np.zeros((len(dataset), 10)) smpl_camera = np.zeros((len(dataset), 3)) pred_joints = np.zeros((len(dataset), 17, 3)) # joint_mapper_coco = constants.H36M_TO_JCOCO joint_mapper_gt = constants.J24_TO_JCOCO focal_length = 5000 num_joints = 17 num_samples = len(dataset) print('dataset length: {}'.format(num_samples)) all_preds = np.zeros( (num_samples, num_joints, 3), dtype=np.float32 ) all_boxes = np.zeros((num_samples, 6)) image_path = [] filenames = [] imgnums = [] idx = 0 with torch.no_grad(): for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))): if len(options.vis_imname) > 0: imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']] name_hit = False for i_n in imgnames: if options.vis_imname in i_n: name_hit = True print('vis: ' + i_n) if not name_hit: continue images = batch['img'].to(device) scale = batch['scale'].numpy() center = batch['center'].numpy() num_images = images.size(0) gt_keypoints_2d = batch['keypoints'] # 2D keypoints # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * img_res * (gt_keypoints_2d_orig[:, :, :-1] + 1) if options.regressor == 'hmr': pred_rotmat, pred_betas, pred_camera = model(images) # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3]) elif options.regressor == 'pymaf_net': preds_dict, _ = model(images) pred_rotmat = preds_dict['smpl_out'][-1]['rotmat'].contiguous().view(-1, 24, 3, 3) pred_betas = preds_dict['smpl_out'][-1]['theta'][:, 3:13].contiguous() pred_camera = preds_dict['smpl_out'][-1]['theta'][:, :3].contiguous() pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) # pred_vertices = pred_output.vertices pred_J24 = pred_output.joints[:, -24:] pred_JCOCO = pred_J24[:, constants.J24_TO_JCOCO] # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([pred_camera[:,1], pred_camera[:,2], 2*constants.FOCAL_LENGTH/(img_res * pred_camera[:, 0] +1e-9)],dim=-1) camera_center = torch.zeros(len(pred_JCOCO), 2, device=pred_camera.device) pred_keypoints_2d = perspective_projection(pred_JCOCO, rotation=torch.eye(3, device=pred_camera.device).unsqueeze(0).expand(len(pred_JCOCO), -1, -1), translation=pred_cam_t, focal_length=constants.FOCAL_LENGTH, camera_center=camera_center) coords = pred_keypoints_2d + (img_res / 2.) coords = coords.cpu().numpy() gt_keypoints_coco = gt_keypoints_2d_orig[:, -24:][:, constants.J24_TO_JCOCO] vert_errors_batch = [] for i, (gt2d, pred2d) in enumerate(zip(gt_keypoints_coco.cpu().numpy(), coords.copy())): vert_error = np.sqrt(np.sum((gt2d[:, :2] - pred2d[:, :2]) ** 2, axis=1)) vert_error *= gt2d[:, 2] vert_mean_error = np.sum(vert_error) / np.sum(gt2d[:, 2] > 0) vert_errors_batch.append(10 * vert_mean_error) if options.vis_demo: imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']] if options.regressor == 'hmr': iuv_pred = None images_vis = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1) images_vis = images_vis + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1) vis_smpl_iuv(images_vis.cpu().numpy(), pred_camera.cpu().numpy(), pred_output.vertices.cpu().numpy(), smpl_neutral.faces, iuv_pred, vert_errors_batch, imgnames, os.path.join('./notebooks/output/demo_results', dataset_name, options.checkpoint.split('/')[-3]), options) preds = coords.copy() scale_ = np.array([scale, scale]).transpose() # Transform back for i in range(coords.shape[0]): preds[i] = transform_preds( coords[i], center[i], scale_[i], [img_res, img_res] ) all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2] all_preds[idx:idx + num_images, :, 2:3] = 1. all_boxes[idx:idx + num_images, 5] = 1. image_path.extend(batch['imgname']) idx += num_images if len(options.vis_imname) > 0: exit() if args.checkpoint is None or 'model_checkpoint.pt' in args.checkpoint: ckp_name = 'spin_model' else: ckp_name = args.checkpoint.split('/') ckp_name = ckp_name[2].split('_')[1] + '_' + ckp_name[-1].split('.')[0] name_values, perf_indicator = dataset.evaluate( cfg, all_preds, options.output_dir, all_boxes, image_path, ckp_name, filenames, imgnums ) model_name = options.regressor if isinstance(name_values, list): for name_value in name_values: _print_name_value(name_value, model_name) else: _print_name_value(name_values, model_name) # Save reconstructions to a file for further processing if save_results: np.savez(result_file, pred_joints=pred_joints, pose=smpl_pose, betas=smpl_betas, camera=smpl_camera)
def train_step(self, input_batch): # Learning rate decay if self.decay_steps_ind < len(cfg.SOLVER.STEPS) and input_batch[ 'step_count'] == cfg.SOLVER.STEPS[self.decay_steps_ind]: lr = self.optimizer.param_groups[0]['lr'] lr_new = lr * cfg.SOLVER.GAMMA print('Decay the learning on step {} from {} to {}'.format( input_batch['step_count'], lr, lr_new)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr_new lr = self.optimizer.param_groups[0]['lr'] assert lr == lr_new self.decay_steps_ind += 1 self.model.train() # Get data from the batch images = input_batch['img'] # input image gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints gt_pose = input_batch['pose'] # SMPL pose parameters gt_betas = input_batch['betas'] # SMPL beta parameters gt_joints = input_batch['pose_3d'] # 3D pose has_smpl = input_batch['has_smpl'].byte( ) # flag that indicates whether SMPL parameters are valid has_pose_3d = input_batch['has_pose_3d'].byte( ) # flag that indicates whether 3D pose is valid is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation dataset_name = input_batch[ 'dataset_name'] # name of the dataset the image comes from indices = input_batch[ 'sample_index'] # index of example inside its dataset batch_size = images.shape[0] # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current pseudo labels (final fits of SPIN) from the dictionary opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = opt_pose.to(self.device) opt_betas = opt_betas.to(self.device) # Replace extreme betas with zero betas opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. # Replace the optimized parameters with the ground truth parameters, if available opt_pose[has_smpl, :] = gt_pose[has_smpl, :] opt_betas[has_smpl, :] = gt_betas[has_smpl, :] opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices opt_joints = opt_output.joints # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) if self.options.train_data in ['h36m_coco_itw']: valid_fit = self.fits_dict.get_vaild_state(dataset_name, indices.cpu()).to( self.device) valid_fit = valid_fit | has_smpl else: valid_fit = has_smpl # Feed images in the network to predict camera and SMPL parameters input_batch['opt_pose'] = opt_pose input_batch['opt_betas'] = opt_betas input_batch['valid_fit'] = valid_fit input_batch['dp_dict'] = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in input_batch['dp_dict'].items() } has_iuv = torch.tensor([dn not in ['dp_coco'] for dn in dataset_name], dtype=torch.uint8).to(self.device) has_iuv = has_iuv & valid_fit input_batch['has_iuv'] = has_iuv has_dp = input_batch['has_dp'] target_smpl_kps = torch.zeros( (batch_size, 24, 3)).to(opt_output.smpl_joints.device) target_smpl_kps[:, :, :2] = perspective_projection( opt_output.smpl_joints.detach().clone(), rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=opt_cam_t, focal_length=self.focal_length, camera_center=torch.zeros(batch_size, 2, device=self.device) + (0.5 * self.options.img_res)) target_smpl_kps[:, :, :2] = target_smpl_kps[:, :, :2] / ( 0.5 * self.options.img_res) - 1 target_smpl_kps[has_iuv == 1, :, 2] = 1 target_smpl_kps[has_dp == 1] = input_batch['smpl_2dkps'][has_dp == 1] input_batch['target_smpl_kps'] = target_smpl_kps # [B, 24, 3] input_batch['target_verts'] = opt_vertices.detach().clone( ) # [B, 6890, 3] # camera translation for neural renderer gt_cam_t_nr = opt_cam_t.detach().clone() gt_camera = torch.zeros(gt_cam_t_nr.shape).to(gt_cam_t_nr.device) gt_camera[:, 1:] = gt_cam_t_nr[:, :2] gt_camera[:, 0] = (2. * self.focal_length / self.options.img_res) / gt_cam_t_nr[:, 2] input_batch['target_cam'] = gt_camera # Do forward danet_return_dict = self.model(input_batch) loss_tatal = 0 losses_dict = {} for loss_key in danet_return_dict['losses']: loss_tatal += danet_return_dict['losses'][loss_key] losses_dict['loss_{}'.format(loss_key)] = danet_return_dict[ 'losses'][loss_key].detach().item() # Do backprop self.optimizer.zero_grad() loss_tatal.backward() self.optimizer.step() if input_batch['pretrain_mode']: pred_vertices = None pred_cam_t = None else: pred_vertices = danet_return_dict['prediction']['vertices'].detach( ) pred_cam_t = danet_return_dict['prediction']['cam_t'].detach() # Pack output arguments for tensorboard logging output = { 'pred_vertices': pred_vertices, 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t, 'opt_cam_t': opt_cam_t, 'visualization': danet_return_dict['visualization'] } losses_dict.update({'loss_tatal': loss_tatal.detach().item()}) return output, losses_dict
def train_step(self, input_batch): self.model.train() # get data from batch has_smpl = input_batch['has_smpl'].bool() has_pose_3d = input_batch['has_pose_3d'].bool() gt_pose1 = input_batch['pose'] # SMPL pose parameters gt_betas1 = input_batch['betas'] # SMPL beta parameters dataset_name = input_batch['dataset_name'] indices = input_batch[ 'sample_index'] # index of example inside its dataset is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation #print(rot_angle) # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_betas = torch.cat((gt_betas1, gt_betas1, gt_betas1, gt_betas1), 0) gt_pose = torch.cat((gt_pose1, gt_pose1, gt_pose1, gt_pose1), 0) gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current best fits from the dictionary opt_pose1, opt_betas1 = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = torch.cat( (opt_pose1.to(self.device), opt_pose1.to(self.device), opt_pose1.to(self.device), opt_pose1.to(self.device)), 0) #print(opt_pose.device) #opt_betas = opt_betas.to(self.device) opt_betas = torch.cat( (opt_betas1.to(self.device), opt_betas1.to(self.device), opt_betas1.to(self.device), opt_betas1.to(self.device)), 0) opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices opt_joints = opt_output.joints # images images = torch.cat((input_batch['img_0'], input_batch['img_1'], input_batch['img_2'], input_batch['img_3']), 0) batch_size = input_batch['img_0'].shape[0] #input() # Output of CNN pred_rotmat, pred_betas, pred_camera = self.model(images) pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices pred_joints = pred_output.joints pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (self.options.img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size * 4, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size * 4, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.) # 2d joint points gt_keypoints_2d = torch.cat( (input_batch['keypoints_0'], input_batch['keypoints_1'], input_batch['keypoints_2'], input_batch['keypoints_3']), 0) gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) #input() opt_joint_loss = self.smplify.get_fitting_loss( opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res * torch.ones(batch_size * 4, 2, device=self.device), gt_keypoints_2d_orig).mean(dim=-1) if self.options.run_smplify: pred_rotmat_hom = torch.cat([ pred_rotmat.detach().view(-1, 3, 3).detach(), torch.tensor( [0, 0, 1], dtype=torch.float32, device=self.device).view( 1, 3, 1).expand(batch_size * 4 * 24, -1, -1) ], dim=-1) pred_pose = rotation_matrix_to_angle_axis( pred_rotmat_hom).contiguous().view(batch_size * 4, -1) pred_pose[torch.isnan(pred_pose)] = 0.0 #pred_pose_detach = pred_pose.detach() #pred_betas_detach = pred_betas.detach() #pred_cam_t_detach = pred_cam_t.detach() new_opt_vertices, new_opt_joints,\ new_opt_pose, new_opt_betas,\ new_opt_cam_t, new_opt_joint_loss = self.smplify( pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), 0.5 * self.options.img_res * torch.ones(batch_size*4, 2, device=self.device), gt_keypoints_2d_orig) new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1) # Will update the dictionary for the examples where the new loss is less than the current one update = (new_opt_joint_loss < opt_joint_loss) update1 = torch.cat((update, update, update, update), 0) opt_joint_loss[update] = new_opt_joint_loss[update] #print(opt_joints.size(),new_opt_joints.size()) #input() opt_joints[update1, :] = new_opt_joints[update1, :] #print(opt_pose.size(),new_opt_pose.size()) opt_betas[update1, :] = new_opt_betas[update1, :] opt_pose[update1, :] = new_opt_pose[update1, :] #print(i, opt_pose_mv[i]) opt_vertices[update1, :] = new_opt_vertices[update1, :] opt_cam_t[update1, :] = new_opt_cam_t[update1, :] # now we comput the loss on the four images # Replace the optimized parameters with the ground truth parameters, if available #for i in range(4): #print('Here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1') has_smpl1 = torch.cat((has_smpl, has_smpl, has_smpl, has_smpl), 0) opt_vertices[has_smpl1, :, :] = gt_vertices[has_smpl1, :, :] opt_pose[has_smpl1, :] = gt_pose[has_smpl1, :] opt_cam_t[has_smpl1, :] = gt_cam_t[has_smpl1, :] opt_joints[has_smpl1, :, :] = gt_model_joints[has_smpl1, :, :] opt_betas[has_smpl1, :] = gt_betas[has_smpl1, :] #print(opt_cam_t[0:batch_size],opt_cam_t[batch_size:2*batch_size],opt_cam_t[2*batch_size:3*batch_size],opt_cam_t[3*batch_size:4*batch_size]) # Assert whether a fit is valid by comparing the joint loss with the threshold valid_fit1 = (opt_joint_loss < self.options.smplify_threshold).to( self.device) # Add the examples with GT parameters to the list of valid fits valid_fit = torch.cat( (valid_fit1, valid_fit1, valid_fit1, valid_fit1), 0) | has_smpl1 #gt_keypoints_2d = torch.cat((input_batch['keypoints_0'],input_batch['keypoints_1'],input_batch['keypoints_2'],input_batch['keypoints_3']),0) loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d, 0, 1) #gt_joints = torch.cat((input_batch['pose_3d_0'],input_batch['pose_3d_1'],input_batch['pose_3d_2'],input_batch['pose_3d_3']),0) #loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints, torch.cat((has_pose_3d,has_pose_3d,has_pose_3d,has_pose_3d),0)) loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit) loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit) #print(loss_shape_sum,loss_keypoints_sum,loss_keypoints_3d_sum,loss_regr_pose_sum,loss_regr_betas_sum) #input() loss_all = 0 * loss_shape +\ 5. * loss_keypoints +\ 0. * loss_keypoints_3d +\ loss_regr_pose + 0.001* loss_regr_betas +\ ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() loss_all *= 60 #print(loss_all) # Do backprop self.optimizer.zero_grad() loss_all.backward() self.optimizer.step() output = { 'pred_vertices': pred_vertices, 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t, 'opt_cam_t': opt_cam_t } losses = { 'loss': loss_all.detach().item(), 'loss_keypoints': loss_keypoints.detach().item(), 'loss_keypoints_3d': loss_keypoints_3d.detach().item(), 'loss_regr_pose': loss_regr_pose.detach().item(), 'loss_regr_betas': loss_regr_betas.detach().item(), 'loss_shape': loss_shape.detach().item() } return output, losses
def run_evaluation(model, dataset, result_file, batch_size=32, img_res=224, num_workers=32, shuffle=False, options=None): """Run evaluation on the datasets and metrics we report in the paper. """ device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') # Transfer model to the GPU model.to(device) # Load SMPL model smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR, create_transl=False).to(device) save_results = result_file is not None # Disable shuffling if you want to save the results if save_results: shuffle = False # Create dataloader for the dataset data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # Store SMPL parameters smpl_pose = np.zeros((len(dataset), 72)) smpl_betas = np.zeros((len(dataset), 10)) smpl_camera = np.zeros((len(dataset), 3)) pred_joints = np.zeros((len(dataset), 17, 3)) num_joints = 17 num_samples = len(dataset) print('dataset length: {}'.format(num_samples)) all_preds = np.zeros((num_samples, num_joints, 3), dtype=np.float32) all_boxes = np.zeros((num_samples, 6)) image_path = [] filenames = [] imgnums = [] idx = 0 with torch.no_grad(): end = time.time() for step, batch in enumerate( tqdm(data_loader, desc='Eval', total=len(data_loader))): images = batch['img'].to(device) scale = batch['scale'].numpy() center = batch['center'].numpy() num_images = images.size(0) gt_keypoints_2d = batch['keypoints'] # 2D keypoints # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) if options.regressor == 'hmr': pred_rotmat, pred_betas, pred_camera = model(images) elif options.regressor == 'danet': danet_pred_dict = model.infer_net(images) para_pred = danet_pred_dict['para'] pred_camera = para_pred[:, 0:3].contiguous() pred_betas = para_pred[:, 3:13].contiguous() pred_rotmat = para_pred[:, 13:].contiguous().view(-1, 24, 3, 3) pred_output = smpl_neutral( betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) # pred_vertices = pred_output.vertices pred_J24 = pred_output.joints[:, -24:] pred_JCOCO = pred_J24[:, constants.J24_TO_JCOCO] # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * constants.FOCAL_LENGTH / (img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(len(pred_JCOCO), 2, device=pred_camera.device) pred_keypoints_2d = perspective_projection( pred_JCOCO, rotation=torch.eye( 3, device=pred_camera.device).unsqueeze(0).expand( len(pred_JCOCO), -1, -1), translation=pred_cam_t, focal_length=constants.FOCAL_LENGTH, camera_center=camera_center) coords = pred_keypoints_2d + (img_res / 2.) coords = coords.cpu().numpy() # Normalize keypoints to [-1,1] # pred_keypoints_2d = pred_keypoints_2d / (img_res / 2.) gt_keypoints_coco = gt_keypoints_2d_orig[:, -24:][:, constants. J24_TO_JCOCO] preds = coords.copy() scale_ = np.array([scale, scale]).transpose() # Transform back for i in range(coords.shape[0]): preds[i] = transform_preds(coords[i], center[i], scale_[i], [img_res, img_res]) all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2] all_preds[idx:idx + num_images, :, 2:3] = 1. # double check this all_boxes parts all_boxes[idx:idx + num_images, 0:2] = center[:, 0:2] all_boxes[idx:idx + num_images, 2:4] = scale_[:, 0:2] all_boxes[idx:idx + num_images, 4] = np.prod(scale_ * 200, 1) all_boxes[idx:idx + num_images, 5] = 1. image_path.extend(batch['imgname']) idx += num_images ckp_name = options.regressor name_values, perf_indicator = dataset.evaluate(all_preds, options.output_dir, all_boxes, image_path, ckp_name, filenames, imgnums) model_name = options.regressor if isinstance(name_values, list): for name_value in name_values: _print_name_value(name_value, model_name) else: _print_name_value(name_values, model_name) # Save reconstructions to a file for further processing if save_results: np.savez(result_file, pred_joints=pred_joints, pose=smpl_pose, betas=smpl_betas, camera=smpl_camera)
def train_step(self, input_batch): self.model.train() images_hr = input_batch['img_hr'] images_lr_list = input_batch['img_lr'] images_list = [images_hr] + images_lr_list scale_names = ['224', '224_128', '128_64', '64_40', '40_24'] scale_names = scale_names[:len(images_list)] feat_names = ['layer4'] # Get data from the batch gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints gt_pose = input_batch['pose'] # SMPL pose parameters gt_betas = input_batch['betas'] # SMPL beta parameters gt_joints = input_batch['pose_3d'] # 3D pose has_smpl = input_batch['has_smpl'].byte( ) # flag that indicates whether SMPL parameters are valid has_pose_3d = input_batch['has_pose_3d'].byte( ) # flag that indicates whether 3D pose is valid dataset_name = input_batch[ 'dataset_name'] # name of the dataset the image comes from indices = input_batch['sample_index'].numpy( ) # index of example inside mixed dataset batch_size = images_hr.shape[0] # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) loss_shape = 0 loss_keypoints = 0 loss_keypoints_3d = 0 loss_regr_pose = 0 loss_regr_betas = 0 loss_regr_cam_t = 0 smpl_outputs = [] for i, (images, scale_name) in enumerate( zip(images_list, scale_names[:len(images_list)])): images = images.to(self.device) # Feed images in the network to predict camera and SMPL parameters pred_rotmat, pred_betas, pred_camera, feat_list = self.model( images, scale=i) pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices pred_joints = pred_output.joints # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (self.options.img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) # Normalize keypoints to [-1,1] pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.) # Compute loss on SMPL parameters loss_pose, loss_betas, loss_cam_t = self.smpl_losses( pred_rotmat, pred_betas, pred_cam_t, gt_pose, gt_betas, gt_cam_t, has_smpl) loss_regr_pose = loss_regr_pose + (i + 1) * loss_pose loss_regr_betas = loss_regr_betas + (i + 1) * loss_betas loss_regr_cam_t = loss_regr_cam_t + (i + 1) * loss_cam_t # Compute 2D reprojection loss for the keypoints loss_keypoints = loss_keypoints + (i + 1) * self.keypoint_loss( pred_keypoints_2d, gt_keypoints_2d, self.options. openpose_train_weight, self.options.gt_train_weight) # Compute 3D keypoint loss loss_keypoints_3d = loss_keypoints_3d + ( i + 1) * self.keypoint_3d_loss(pred_joints, gt_joints, has_pose_3d) # Per-vertex loss for the shape loss_shape = loss_shape + (i + 1) * self.shape_loss( pred_vertices, gt_vertices, has_smpl) # save pred_rotmat, pred_betas, pred_cam_t for later, from large images to smaller images smpl_outputs.append( [pred_rotmat, pred_betas, pred_cam_t, feat_list]) # update queue size self.feat_queue.update_queue_size(batch_size) # update the queue self.feat_queue.update_all([feat.detach() for feat in feat_list], [name for name in feat_names]) # update dataset name and index for each scale self.feat_queue.update('dataset_names', np.array(dataset_name)) self.feat_queue.update('dataset_indices', indices) # Compute total loss except the consistency loss loss = self.options.shape_loss_weight * loss_shape +\ self.options.keypoint_loss_weight * loss_keypoints + \ self.options.keypoint_loss_weight * loss_keypoints_3d +\ self.options.pose_loss_weight * loss_regr_pose + \ self.options.beta_loss_weight * loss_regr_betas + \ self.options.cam_loss_weight * loss_regr_cam_t loss = loss / len(images_list) # compute the consistency loss loss_consistency = 0 for i in range(len(smpl_outputs)): gt_rotmat, gt_betas, gt_cam_t, gt_feat_list = smpl_outputs[i] gt_rotmat = gt_rotmat.detach() gt_betas = gt_betas.detach() gt_cam_t = gt_cam_t.detach() gt_feat_list = [feat.detach() for feat in gt_feat_list] # sample negative index indices_list = self.feat_queue.select_indices( dataset_name, indices, self.options.sample_size) neg_feat_list = self.feat_queue.batch_sample_all(indices_list, names=feat_names) for j in range(i + 1, len(smpl_outputs)): # compute the consistency loss from high to low: 1:2, 1:3, 2:3 and weighted by 1/(j-i) pred_rotmat, pred_betas, pred_cam_t, pred_feat_list = smpl_outputs[ j] loss_consistency_total, loss_consistency_smpl, loss_consistency_feat = self.consistency_losses( pred_rotmat, pred_betas, pred_cam_t, pred_feat_list, gt_rotmat, gt_betas, gt_cam_t, gt_feat_list, neg_feat_list) loss_consistency = loss_consistency + ( (j - i) / len(smpl_outputs)) * loss_consistency_total loss_consistency = loss_consistency * self.consistency_loss_ramp * self.options.consistency_loss_weight loss += loss_consistency loss *= 60 # Do backprop self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Pack output arguments output = { 'pred_vertices': pred_vertices.detach(), 'pred_cam_t': pred_cam_t.detach() } losses = { 'lr': self.optimizer.param_groups[0]['lr'], 'loss_ramp': self.consistency_loss_ramp, 'loss': loss.detach().item(), 'loss_consistency': loss_consistency.detach().item(), 'loss_consistency_smpl': loss_consistency_smpl.detach().item(), 'loss_consistency_feat': loss_consistency_feat.detach().item(), 'loss_keypoints': loss_keypoints.detach().item(), 'loss_keypoints_3d': loss_keypoints_3d.detach().item(), 'loss_regr_pose': loss_regr_pose.detach().item(), 'loss_regr_betas': loss_regr_betas.detach().item(), 'loss_shape': loss_shape.detach().item() } return output, losses
def evaluate_singleview(root, annfile, device, use_mask=False): """ Function to evaluation singleview reconstruction """ # Models and optimizer bird = bird_model() predictor = load_detector().to(device) regressor = load_regressor().to(device) if args.use_mask: if device == 'cpu': print('Warning: using mask during optimization without GPU acceleration is very slow!') silhouette_renderer = base_renderer(size=256, focal=2167, device=device) optimizer = OptimizeSV(num_iters=100, prior_weight=1, mask_weight=1, use_mask=True, renderer=silhouette_renderer, device=device) print('Using mask for single view optimization') else: optimizer = OptimizeSV(num_iters=100, prior_weight=1, mask_weight=1, use_mask=False, device=device) # Dataset to run on normalize = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) ]) dataset = Cowbird_Dataset(root=root, annfile=annfile, scale_factor=0.25, transform=normalize) loader = torch.utils.data.DataLoader(dataset, batch_size=30) Pose_, Tran_, Bone_ = [], [], [] GT_kpts, GT_masks, Sizes = [], [], [] # Run reconstruction for i, (imgs, gt_kpts, gt_masks, meta) in enumerate(loader): print('Running on batch:', i+1) with torch.no_grad(): # Prediction output = predictor(imgs.to(device)) pred_kpts, pred_mask = postprocess(output) # Regression kpts_in = pred_kpts.reshape(pred_kpts.shape[0], -1) mask_in = pred_mask p_est, b_est = regressor(kpts_in, mask_in) pose, tran, bone = regressor.postprocess(p_est, b_est) # Optimization ignored = pred_kpts[:, :, 2] < 0.3 opt_kpts = pred_kpts.clone() opt_kpts[ignored] = 0 pose_op, bone_op, tran_op, model_mesh = optimizer(pose, bone, tran, focal_length=2167, camera_center=128, keypoints=opt_kpts, masks=mask_in.squeeze(1)) Pose_.append(pose_op) Tran_.append(tran_op) Bone_.append(bone_op) GT_kpts.append(gt_kpts) GT_masks.append(gt_masks) Sizes.append(meta['size']) Pose_ = torch.cat(Pose_) Tran_ = torch.cat(Tran_) Bone_ = torch.cat(Bone_) GT_kpts = torch.cat(GT_kpts) GT_masks = torch.cat(GT_masks) Sizes = torch.cat(Sizes) # Render reprojected kpts and masks kpts_3d, vertices = pose_bird(bird, Pose_[:,:3], Pose_[:,3:], Bone_, Tran_, pose2rot=True) kpts_2d = perspective_projection(kpts_3d, None, None, focal_length=2167, camera_center=128) faces = torch.tensor(bird.dd['F']) masks = [] mask_renderer = Silhouette_Renderer(focal_length=2167, center=(128,128), img_w=256, img_h=256) for i in range(len(vertices)): m = mask_renderer(vertices[i], faces) masks.append(m) masks = torch.tensor(np.stack(masks)).long() # Evaluation PCK05, PCK10 = evaluate_pck(kpts_2d[:,:12,:], GT_kpts, size=Sizes) IOU = evaluate_iou(masks, GT_masks) avg_PCK05 = torch.mean(torch.stack(PCK05)) avg_PCK10 = torch.mean(torch.stack(PCK10)) avg_IOU = torch.mean(torch.stack(IOU)) print('Average PCK05:', avg_PCK05) print('Average PCK10:', avg_PCK10) print('Average IOU:', avg_IOU)