class Trainer(BaseTrainer): def init_fn(self): self.train_ds = MixedDataset(self.options, ignore_3d=self.options.ignore_3d, is_train=True) self.model = hmr(config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device) # Switch the optimizer if self.options.optimizer == 'adam': print('Using adam') self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.options.lr, weight_decay=0) elif self.options.optimizer == 'sgd': print('Using sgd') self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=self.options.lr, weight_decay=0) elif self.options.optimizer == 'momentum': print('Using momentum') self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=self.options.lr, momentum=self.options.momentum, weight_decay=0) else: print(self.options.optimizer + 'Not found') raise Exception("Optimizer Wrong!") self.smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=self.options.batch_size, create_transl=False).to(self.device) # Per-vertex loss on the shape self.criterion_shape = nn.L1Loss().to(self.device) # Keypoint (2D and 3D) loss # No reduction because confidence weighting needs to be applied self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device) # Loss for SMPL parameter regression self.criterion_regr = nn.MSELoss().to(self.device) self.models_dict = {'model': self.model} self.optimizers_dict = {'optimizer': self.optimizer} self.focal_length = constants.FOCAL_LENGTH # Initialize SMPLify fitting module self.smplify = SMPLify(step_size=1e-2, batch_size=self.options.batch_size, num_iters=self.options.num_smplify_iters, focal_length=self.focal_length) if self.options.pretrained_checkpoint is not None: self.load_pretrained( checkpoint_file=self.options.pretrained_checkpoint) # Load dictionary of fits self.fits_dict = FitsDict(self.options, self.train_ds) # Create renderer self.renderer = Renderer(focal_length=self.focal_length, img_res=self.options.img_res, faces=self.smpl.faces) def finalize(self): self.fits_dict.save() def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d, openpose_weight, gt_weight): """ Compute 2D reprojection loss on the keypoints. The loss is weighted by the confidence. The available keypoints are different for each dataset. """ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() conf[:, :25] *= openpose_weight conf[:, 25:] *= gt_weight loss = (conf * self.criterion_keypoints( pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() return loss def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): """Compute 3D keypoint loss for the examples that 3D keypoint annotations are available. The loss is weighted by the confidence. """ pred_keypoints_3d = pred_keypoints_3d[:, 25:, :] conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] conf = conf[has_pose_3d == 1] pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] if len(gt_keypoints_3d) > 0: gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] pred_pelvis = (pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] return (conf * self.criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() else: return torch.FloatTensor(1).fill_(0.).to(self.device) def shape_loss(self, pred_vertices, gt_vertices, has_smpl): """Compute per-vertex loss on the shape for the examples that SMPL annotations are available.""" pred_vertices_with_shape = pred_vertices[has_smpl == 1] gt_vertices_with_shape = gt_vertices[has_smpl == 1] if len(gt_vertices_with_shape) > 0: return self.criterion_shape(pred_vertices_with_shape, gt_vertices_with_shape) else: return torch.FloatTensor(1).fill_(0.).to(self.device) def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl): pred_rotmat_valid = pred_rotmat[has_smpl == 1] gt_rotmat_valid = batch_rodrigues(gt_pose.view(-1, 3)).view( -1, 24, 3, 3)[has_smpl == 1] pred_betas_valid = pred_betas[has_smpl == 1] gt_betas_valid = gt_betas[has_smpl == 1] if len(pred_rotmat_valid) > 0: loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid) loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid) else: loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device) loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device) return loss_regr_pose, loss_regr_betas def train_step(self, input_batch): self.model.train() # Get data from the batch images = input_batch['img'] # input image gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints gt_pose = input_batch['pose'] # SMPL pose parameters gt_betas = input_batch['betas'] # SMPL beta parameters gt_joints = input_batch['pose_3d'] # 3D pose has_smpl = input_batch['has_smpl'].byte( ) # flag that indicates whether SMPL parameters are valid has_pose_3d = input_batch['has_pose_3d'].byte( ) # flag that indicates whether 3D pose is valid is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation dataset_name = input_batch[ 'dataset_name'] # name of the dataset the image comes from indices = input_batch[ 'sample_index'] # index of example inside its dataset batch_size = images.shape[0] # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current best fits from the dictionary opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = opt_pose.to(self.device) opt_betas = opt_betas.to(self.device) opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices if opt_vertices.shape != (self.options.batch_size, 6890, 3): opt_vertices = torch.zeros_like(opt_vertices, device=self.device) opt_joints = opt_output.joints # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_joint_loss = self.smplify.get_fitting_loss( opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device), gt_keypoints_2d_orig).mean(dim=-1) # Feed images in the network to predict camera and SMPL parameters pred_rotmat, pred_betas, pred_camera = self.model(images) pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices if pred_vertices.shape != (self.options.batch_size, 6890, 3): pred_vertices = torch.zeros_like(pred_vertices, device=self.device) pred_joints = pred_output.joints # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (self.options.img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) # Normalize keypoints to [-1,1] pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.) if self.options.run_smplify: # Convert predicted rotation matrices to axis-angle pred_rotmat_hom = torch.cat([ pred_rotmat.detach().view(-1, 3, 3).detach(), torch.tensor( [0, 0, 1], dtype=torch.float32, device=self.device).view( 1, 3, 1).expand(batch_size * 24, -1, -1) ], dim=-1) pred_pose = rotation_matrix_to_angle_axis( pred_rotmat_hom).contiguous().view(batch_size, -1) # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it pred_pose[torch.isnan(pred_pose)] = 0.0 # Run SMPLify optimization starting from the network prediction new_opt_vertices, new_opt_joints,\ new_opt_pose, new_opt_betas,\ new_opt_cam_t, new_opt_joint_loss = self.smplify( pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), 0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device), gt_keypoints_2d_orig) new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1) # Will update the dictionary for the examples where the new loss is less than the current one update = (new_opt_joint_loss < opt_joint_loss) opt_joint_loss[update] = new_opt_joint_loss[update] opt_vertices[update, :] = new_opt_vertices[update, :] opt_joints[update, :] = new_opt_joints[update, :] opt_pose[update, :] = new_opt_pose[update, :] opt_betas[update, :] = new_opt_betas[update, :] opt_cam_t[update, :] = new_opt_cam_t[update, :] self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu(), update.cpu())] = (opt_pose.cpu(), opt_betas.cpu()) else: update = torch.zeros(batch_size, device=self.device).byte() # Replace extreme betas with zero betas opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. # Replace the optimized parameters with the ground truth parameters, if available opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :] opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :] opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :] opt_pose[has_smpl, :] = gt_pose[has_smpl, :] opt_betas[has_smpl, :] = gt_betas[has_smpl, :] # Assert whether a fit is valid by comparing the joint loss with the threshold valid_fit = (opt_joint_loss < self.options.smplify_threshold).to( self.device) # Add the examples with GT parameters to the list of valid fits # print(valid_fit.dtype) valid_fit = valid_fit.to(torch.uint8) valid_fit = valid_fit | has_smpl opt_keypoints_2d = perspective_projection( opt_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=opt_cam_t, focal_length=self.focal_length, camera_center=camera_center) opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.) # Compute loss on SMPL parameters loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit) # Compute 2D reprojection loss for the keypoints loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d, self.options.openpose_train_weight, self.options.gt_train_weight) # Compute 3D keypoint loss loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints, has_pose_3d) # Per-vertex loss for the shape loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit) # Compute total loss # The last component is a loss that forces the network to predict positive depth values loss = self.options.shape_loss_weight * loss_shape +\ self.options.keypoint_loss_weight * loss_keypoints +\ self.options.keypoint_loss_weight * loss_keypoints_3d +\ loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\ ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() loss *= 60 # Do backprop self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Pack output arguments for tensorboard logging output = { 'pred_vertices': pred_vertices.detach(), 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t.detach(), 'opt_cam_t': opt_cam_t } losses = { 'loss': loss.detach().item(), 'loss_keypoints': loss_keypoints.detach().item(), 'loss_keypoints_3d': loss_keypoints_3d.detach().item(), 'loss_regr_pose': loss_regr_pose.detach().item(), 'loss_regr_betas': loss_regr_betas.detach().item(), 'loss_shape': loss_shape.detach().item() } return output, losses def train_summaries(self, input_batch, output, losses): # Update dictionary every time when summaries are provoked self.finalize() images = input_batch['img'] images = images * torch.tensor( [0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1) images = images + torch.tensor( [0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1) pred_vertices = output['pred_vertices'] assert pred_vertices.shape == (self.options.batch_size, 6890, 3) opt_vertices = output['opt_vertices'] pred_cam_t = output['pred_cam_t'] opt_cam_t = output['opt_cam_t'] images_pred = self.renderer.visualize_tb(pred_vertices, pred_cam_t, images) images_opt = self.renderer.visualize_tb(opt_vertices, opt_cam_t, images) self.summary_writer.add_image('pred_shape', images_pred, self.step_count) self.summary_writer.add_image('opt_shape', images_opt, self.step_count) for loss_name, val in losses.items(): self.summary_writer.add_scalar(loss_name, val, self.step_count)
class Trainer_li(BaseTrainer): def init_fn(self): #self.dataset = 'h36m' #self.train_ds = BaseDataset(self.options, self.dataset) # training dataset self.train_ds = MixedDataset(self.options, ignore_3d=self.options.ignore_3d, is_train=True) self.model = hmr(config.SMPL_MEAN_PARAMS, pretrained=True).to( self.device) # feature extraction model self.optimizer = torch.optim.Adam( params=self.model.parameters(), #lr=5e-5, lr=self.options.lr, weight_decay=0) self.smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=16, create_transl=False).to(self.device) # per vertex loss on the shape self.criterion_shape = nn.L1Loss().to(self.device) # keypoints loss including 2D and 3D self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device) # SMPL parameters loss if we have self.criterion_regr = nn.MSELoss().to(self.device) self.models_dict = {'model': self.model} self.optimizers_dict = {'optimizer': self.optimizer} self.focal_length = constants.FOCAL_LENGTH # initialize SMPLify self.smplify = SMPLify(step_size=1e-2, batch_size=16, num_iters=100, focal_length=self.focal_length) print(self.options.pretrained_checkpoint) if self.options.pretrained_checkpoint is not None: self.load_pretrained( checkpoint_file=self.options.pretrained_checkpoint) #load dictionary of fits self.fits_dict = FitsDict(self.options, self.train_ds) # create renderer self.renderer = Renderer(focal_length=self.focal_length, img_res=224, faces=self.smpl.faces) def finalize(self): self.fits_dict.save() def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d, openpose_weight, gt_weight): """Compute 2D reprojection loss on the keypoints. The loss is weighted by the confidence. """ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() conf[:, :25] *= openpose_weight conf[:, 25:] *= gt_weight loss = (conf * self.criterion_keypoints( pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() return loss def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): pred_keypoints_3d = pred_keypoints_3d[:, 25:, :] conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] conf = conf[has_pose_3d == 1] pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] if len(gt_keypoints_3d) > 0: gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] pred_pelvis = (pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] return (conf * self.criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() else: return torch.FloatTensor(1).fill_(0.).to(self.device) def shape_loss(self, pred_vertices, gt_vertices, has_smpl): """Compute per-vertex loss on the shape for the examples that SMPL annotations are available.""" pred_vertices_with_shape = pred_vertices[has_smpl == 1] gt_vertices_with_shape = gt_vertices[has_smpl == 1] if len(gt_vertices_with_shape) > 0: return self.criterion_shape(pred_vertices_with_shape, gt_vertices_with_shape) else: return torch.FloatTensor(1).fill_(0.).to(self.device) def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl): pred_rotmat_valid = pred_rotmat[has_smpl == 1] gt_rotmat_valid = batch_rodrigues(gt_pose.view(-1, 3)).view( -1, 24, 3, 3)[has_smpl == 1] #print(pred_rotmat_valid.size(),gt_rotmat_valid.size()) #input() pred_betas_valid = pred_betas[has_smpl == 1] gt_betas_valid = gt_betas[has_smpl == 1] if len(pred_rotmat_valid) > 0: loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid) loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid) else: loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device) loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device) return loss_regr_pose, loss_regr_betas def train_step(self, input_batch): self.model.train() # get data from batch has_smpl = input_batch['has_smpl'].bool() has_pose_3d = input_batch['has_pose_3d'].bool() gt_pose1 = input_batch['pose'] # SMPL pose parameters gt_betas1 = input_batch['betas'] # SMPL beta parameters dataset_name = input_batch['dataset_name'] indices = input_batch[ 'sample_index'] # index of example inside its dataset is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation #print(rot_angle) # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_betas = torch.cat((gt_betas1, gt_betas1, gt_betas1, gt_betas1), 0) gt_pose = torch.cat((gt_pose1, gt_pose1, gt_pose1, gt_pose1), 0) gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current best fits from the dictionary opt_pose1, opt_betas1 = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = torch.cat( (opt_pose1.to(self.device), opt_pose1.to(self.device), opt_pose1.to(self.device), opt_pose1.to(self.device)), 0) #print(opt_pose.device) #opt_betas = opt_betas.to(self.device) opt_betas = torch.cat( (opt_betas1.to(self.device), opt_betas1.to(self.device), opt_betas1.to(self.device), opt_betas1.to(self.device)), 0) opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices opt_joints = opt_output.joints # images images = torch.cat((input_batch['img_0'], input_batch['img_1'], input_batch['img_2'], input_batch['img_3']), 0) batch_size = input_batch['img_0'].shape[0] #input() # Output of CNN pred_rotmat, pred_betas, pred_camera = self.model(images) pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices pred_joints = pred_output.joints pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (self.options.img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size * 4, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size * 4, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.) # 2d joint points gt_keypoints_2d = torch.cat( (input_batch['keypoints_0'], input_batch['keypoints_1'], input_batch['keypoints_2'], input_batch['keypoints_3']), 0) gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) #input() opt_joint_loss = self.smplify.get_fitting_loss( opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res * torch.ones(batch_size * 4, 2, device=self.device), gt_keypoints_2d_orig).mean(dim=-1) if self.options.run_smplify: pred_rotmat_hom = torch.cat([ pred_rotmat.detach().view(-1, 3, 3).detach(), torch.tensor( [0, 0, 1], dtype=torch.float32, device=self.device).view( 1, 3, 1).expand(batch_size * 4 * 24, -1, -1) ], dim=-1) pred_pose = rotation_matrix_to_angle_axis( pred_rotmat_hom).contiguous().view(batch_size * 4, -1) pred_pose[torch.isnan(pred_pose)] = 0.0 #pred_pose_detach = pred_pose.detach() #pred_betas_detach = pred_betas.detach() #pred_cam_t_detach = pred_cam_t.detach() new_opt_vertices, new_opt_joints,\ new_opt_pose, new_opt_betas,\ new_opt_cam_t, new_opt_joint_loss = self.smplify( pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), 0.5 * self.options.img_res * torch.ones(batch_size*4, 2, device=self.device), gt_keypoints_2d_orig) new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1) # Will update the dictionary for the examples where the new loss is less than the current one update = (new_opt_joint_loss < opt_joint_loss) update1 = torch.cat((update, update, update, update), 0) opt_joint_loss[update] = new_opt_joint_loss[update] #print(opt_joints.size(),new_opt_joints.size()) #input() opt_joints[update1, :] = new_opt_joints[update1, :] #print(opt_pose.size(),new_opt_pose.size()) opt_betas[update1, :] = new_opt_betas[update1, :] opt_pose[update1, :] = new_opt_pose[update1, :] #print(i, opt_pose_mv[i]) opt_vertices[update1, :] = new_opt_vertices[update1, :] opt_cam_t[update1, :] = new_opt_cam_t[update1, :] # now we comput the loss on the four images # Replace the optimized parameters with the ground truth parameters, if available #for i in range(4): #print('Here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1') has_smpl1 = torch.cat((has_smpl, has_smpl, has_smpl, has_smpl), 0) opt_vertices[has_smpl1, :, :] = gt_vertices[has_smpl1, :, :] opt_pose[has_smpl1, :] = gt_pose[has_smpl1, :] opt_cam_t[has_smpl1, :] = gt_cam_t[has_smpl1, :] opt_joints[has_smpl1, :, :] = gt_model_joints[has_smpl1, :, :] opt_betas[has_smpl1, :] = gt_betas[has_smpl1, :] #print(opt_cam_t[0:batch_size],opt_cam_t[batch_size:2*batch_size],opt_cam_t[2*batch_size:3*batch_size],opt_cam_t[3*batch_size:4*batch_size]) # Assert whether a fit is valid by comparing the joint loss with the threshold valid_fit1 = (opt_joint_loss < self.options.smplify_threshold).to( self.device) # Add the examples with GT parameters to the list of valid fits valid_fit = torch.cat( (valid_fit1, valid_fit1, valid_fit1, valid_fit1), 0) | has_smpl1 #gt_keypoints_2d = torch.cat((input_batch['keypoints_0'],input_batch['keypoints_1'],input_batch['keypoints_2'],input_batch['keypoints_3']),0) loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d, 0, 1) #gt_joints = torch.cat((input_batch['pose_3d_0'],input_batch['pose_3d_1'],input_batch['pose_3d_2'],input_batch['pose_3d_3']),0) #loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints, torch.cat((has_pose_3d,has_pose_3d,has_pose_3d,has_pose_3d),0)) loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit) loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit) #print(loss_shape_sum,loss_keypoints_sum,loss_keypoints_3d_sum,loss_regr_pose_sum,loss_regr_betas_sum) #input() loss_all = 0 * loss_shape +\ 5. * loss_keypoints +\ 0. * loss_keypoints_3d +\ loss_regr_pose + 0.001* loss_regr_betas +\ ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() loss_all *= 60 #print(loss_all) # Do backprop self.optimizer.zero_grad() loss_all.backward() self.optimizer.step() output = { 'pred_vertices': pred_vertices, 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t, 'opt_cam_t': opt_cam_t } losses = { 'loss': loss_all.detach().item(), 'loss_keypoints': loss_keypoints.detach().item(), 'loss_keypoints_3d': loss_keypoints_3d.detach().item(), 'loss_regr_pose': loss_regr_pose.detach().item(), 'loss_regr_betas': loss_regr_betas.detach().item(), 'loss_shape': loss_shape.detach().item() } return output, losses def train_summaries(self, input_batch, output, losses): pred_vertices = output['pred_vertices'] opt_vertices = output['opt_vertices'] pred_cam_t = output['pred_cam_t'] opt_cam_t = output['opt_cam_t'] images_pred = self.renderer.visualize_tb(pred_vertices, pred_cam_t, input_batch) images_opt = self.renderer.visualize_tb(opt_vertices, opt_cam_t, input_batch) self.summary_writer.add_image('pred_shape', images_pred, self.step_count) self.summary_writer.add_image('opt_shape', images_opt, self.step_count) for loss_name, val in losses.items(): self.summary_writer.add_scalar(loss_name, val, self.step_count)
class Trainer(BaseTrainer): def init_fn(self): self.options.img_res = cfg.DANET.INIMG_SIZE self.options.heatmap_size = cfg.DANET.HEATMAP_SIZE self.train_ds = MixedDataset(self.options, ignore_3d=self.options.ignore_3d, is_train=True) self.model = DaNet(options=self.options, smpl_mean_params=path_config.SMPL_MEAN_PARAMS).to( self.device) self.smpl = self.model.iuv2smpl.smpl self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=cfg.SOLVER.BASE_LR, weight_decay=0) self.models_dict = {'model': self.model} self.optimizers_dict = {'optimizer': self.optimizer} self.focal_length = constants.FOCAL_LENGTH if self.options.pretrained_checkpoint is not None: self.load_pretrained( checkpoint_file=self.options.pretrained_checkpoint) # Load dictionary of fits of SPIN self.fits_dict = FitsDict(self.options, self.train_ds) # Create renderer try: self.renderer = Renderer(focal_length=self.focal_length, img_res=self.options.img_res, faces=self.smpl.faces) except: Warning('No renderer for visualization.') self.renderer = None self.decay_steps_ind = 1 def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d, openpose_weight, gt_weight): """ Compute 2D reprojection loss on the keypoints. The loss is weighted by the confidence. The available keypoints are different for each dataset. """ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() conf[:, :25] *= openpose_weight conf[:, 25:] *= gt_weight loss = (conf * self.criterion_keypoints( pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() return loss def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): """Compute 3D keypoint loss for the examples that 3D keypoint annotations are available. The loss is weighted by the confidence. """ pred_keypoints_3d = pred_keypoints_3d[:, 25:, :] conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] conf = conf[has_pose_3d == 1] pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] if len(gt_keypoints_3d) > 0: gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] pred_pelvis = (pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] return (conf * self.criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() else: return torch.FloatTensor(1).fill_(0.).to(self.device) def shape_loss(self, pred_vertices, gt_vertices, has_smpl): """Compute per-vertex loss on the shape for the examples that SMPL annotations are available.""" pred_vertices_with_shape = pred_vertices[has_smpl == 1] gt_vertices_with_shape = gt_vertices[has_smpl == 1] if len(gt_vertices_with_shape) > 0: return self.criterion_shape(pred_vertices_with_shape, gt_vertices_with_shape) else: return torch.FloatTensor(1).fill_(0.).to(self.device) def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl): pred_rotmat_valid = pred_rotmat[has_smpl == 1] gt_rotmat_valid = batch_rodrigues(gt_pose.view(-1, 3)).view( -1, 24, 3, 3)[has_smpl == 1] pred_betas_valid = pred_betas[has_smpl == 1] gt_betas_valid = gt_betas[has_smpl == 1] if len(pred_rotmat_valid) > 0: loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid) loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid) else: loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device) loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device) return loss_regr_pose, loss_regr_betas def train_step(self, input_batch): # Learning rate decay if self.decay_steps_ind < len(cfg.SOLVER.STEPS) and input_batch[ 'step_count'] == cfg.SOLVER.STEPS[self.decay_steps_ind]: lr = self.optimizer.param_groups[0]['lr'] lr_new = lr * cfg.SOLVER.GAMMA print('Decay the learning on step {} from {} to {}'.format( input_batch['step_count'], lr, lr_new)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr_new lr = self.optimizer.param_groups[0]['lr'] assert lr == lr_new self.decay_steps_ind += 1 self.model.train() # Get data from the batch images = input_batch['img'] # input image gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints gt_pose = input_batch['pose'] # SMPL pose parameters gt_betas = input_batch['betas'] # SMPL beta parameters gt_joints = input_batch['pose_3d'] # 3D pose has_smpl = input_batch['has_smpl'].byte( ) # flag that indicates whether SMPL parameters are valid has_pose_3d = input_batch['has_pose_3d'].byte( ) # flag that indicates whether 3D pose is valid is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation dataset_name = input_batch[ 'dataset_name'] # name of the dataset the image comes from indices = input_batch[ 'sample_index'] # index of example inside its dataset batch_size = images.shape[0] # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current pseudo labels (final fits of SPIN) from the dictionary opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = opt_pose.to(self.device) opt_betas = opt_betas.to(self.device) # Replace extreme betas with zero betas opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. # Replace the optimized parameters with the ground truth parameters, if available opt_pose[has_smpl, :] = gt_pose[has_smpl, :] opt_betas[has_smpl, :] = gt_betas[has_smpl, :] opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices opt_joints = opt_output.joints # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) if self.options.train_data in ['h36m_coco_itw']: valid_fit = self.fits_dict.get_vaild_state(dataset_name, indices.cpu()).to( self.device) valid_fit = valid_fit | has_smpl else: valid_fit = has_smpl # Feed images in the network to predict camera and SMPL parameters input_batch['opt_pose'] = opt_pose input_batch['opt_betas'] = opt_betas input_batch['valid_fit'] = valid_fit input_batch['dp_dict'] = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in input_batch['dp_dict'].items() } has_iuv = torch.tensor([dn not in ['dp_coco'] for dn in dataset_name], dtype=torch.uint8).to(self.device) has_iuv = has_iuv & valid_fit input_batch['has_iuv'] = has_iuv has_dp = input_batch['has_dp'] target_smpl_kps = torch.zeros( (batch_size, 24, 3)).to(opt_output.smpl_joints.device) target_smpl_kps[:, :, :2] = perspective_projection( opt_output.smpl_joints.detach().clone(), rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=opt_cam_t, focal_length=self.focal_length, camera_center=torch.zeros(batch_size, 2, device=self.device) + (0.5 * self.options.img_res)) target_smpl_kps[:, :, :2] = target_smpl_kps[:, :, :2] / ( 0.5 * self.options.img_res) - 1 target_smpl_kps[has_iuv == 1, :, 2] = 1 target_smpl_kps[has_dp == 1] = input_batch['smpl_2dkps'][has_dp == 1] input_batch['target_smpl_kps'] = target_smpl_kps # [B, 24, 3] input_batch['target_verts'] = opt_vertices.detach().clone( ) # [B, 6890, 3] # camera translation for neural renderer gt_cam_t_nr = opt_cam_t.detach().clone() gt_camera = torch.zeros(gt_cam_t_nr.shape).to(gt_cam_t_nr.device) gt_camera[:, 1:] = gt_cam_t_nr[:, :2] gt_camera[:, 0] = (2. * self.focal_length / self.options.img_res) / gt_cam_t_nr[:, 2] input_batch['target_cam'] = gt_camera # Do forward danet_return_dict = self.model(input_batch) loss_tatal = 0 losses_dict = {} for loss_key in danet_return_dict['losses']: loss_tatal += danet_return_dict['losses'][loss_key] losses_dict['loss_{}'.format(loss_key)] = danet_return_dict[ 'losses'][loss_key].detach().item() # Do backprop self.optimizer.zero_grad() loss_tatal.backward() self.optimizer.step() if input_batch['pretrain_mode']: pred_vertices = None pred_cam_t = None else: pred_vertices = danet_return_dict['prediction']['vertices'].detach( ) pred_cam_t = danet_return_dict['prediction']['cam_t'].detach() # Pack output arguments for tensorboard logging output = { 'pred_vertices': pred_vertices, 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t, 'opt_cam_t': opt_cam_t, 'visualization': danet_return_dict['visualization'] } losses_dict.update({'loss_tatal': loss_tatal.detach().item()}) return output, losses_dict def train_summaries(self, input_batch, output, losses): for loss_name, val in losses.items(): self.summary_writer.add_scalar(loss_name, val, self.step_count) def visualize(self, input_batch, output, losses): images = input_batch['img'] images = images * torch.tensor( [0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1) images = images + torch.tensor( [0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1) pred_vertices = output['pred_vertices'] opt_vertices = output['opt_vertices'] pred_cam_t = output['pred_cam_t'] opt_cam_t = output['opt_cam_t'] if self.renderer is not None: images_opt = self.renderer.visualize_tb(opt_vertices, opt_cam_t, images) self.summary_writer.add_image('opt_shape', images_opt, self.step_count) if pred_vertices is not None: images_pred = self.renderer.visualize_tb( pred_vertices, pred_cam_t, images) self.summary_writer.add_image('pred_shape', images_pred, self.step_count) for key_name in [ 'pred_uv', 'gt_uv', 'part_uvi_pred', 'part_uvi_gt', 'skps_hm_pred', 'skps_hm_pred_soft', 'skps_hm_gt', 'skps_hm_gt_soft' ]: if key_name in output['visualization']: vis_uv_raw = output['visualization'][key_name] if key_name in ['pred_uv', 'gt_uv']: iuv = F.interpolate(vis_uv_raw, scale_factor=4., mode='nearest') img_iuv = images.clone() img_iuv[iuv > 0] = iuv[iuv > 0] vis_uv = make_grid(img_iuv, padding=1, pad_value=1) else: vis_uv = make_grid(vis_uv_raw, padding=1, pad_value=1) self.summary_writer.add_image(key_name, vis_uv, self.step_count) if 'target_smpl_kps' in input_batch: smpl_kps = input_batch['target_smpl_kps'].detach() smpl_kps[:, :, :2] *= images.size(-1) / 2. smpl_kps[:, :, :2] += images.size(-1) / 2. img_smpl_hm = images.detach().clone() img_with_smpljoints = vis_utils.vis_batch_image_with_joints( img_smpl_hm.data, smpl_kps.cpu().numpy(), np.ones((smpl_kps.shape[0], smpl_kps.shape[1], 1))) img_with_smpljoints = np.transpose(img_with_smpljoints, (2, 0, 1)) self.summary_writer.add_image('stn_centers_gt', img_with_smpljoints, self.step_count) if 'stn_kps_pred' in output['visualization']: smpl_kps = output['visualization']['stn_kps_pred'] smpl_kps[:, :, :2] *= images.size(-1) / 2. smpl_kps[:, :, :2] += images.size(-1) / 2. img_smpl_hm = images.detach().clone() if 'skps_hm_gt' in output['visualization']: smpl_hm = output['visualization']['skps_hm_gt'].expand( -1, 3, -1, -1) smpl_hm = F.interpolate(smpl_hm, scale_factor=output.size(-1) / smpl_hm.size(-1)) img_smpl_hm[smpl_hm > 0.1] = smpl_hm[smpl_hm > 0.1] img_with_smpljoints = vis_utils.vis_batch_image_with_joints( img_smpl_hm.data, smpl_kps.cpu().numpy(), np.ones((smpl_kps.shape[0], smpl_kps.shape[1], 1))) img_with_smpljoints = np.transpose(img_with_smpljoints, (2, 0, 1)) self.summary_writer.add_image('stn_centers_pred', img_with_smpljoints, self.step_count)