def __init__(self, mesh, num_layers, num_channels, batch_size=1, pretrained_checkpoint=None, device=None): super(CMR, self).__init__() self.graph_cnn = GraphCNN(mesh.adjmat, mesh.ref_vertices.t(), num_layers, num_channels) self.smpl_param_regressor = SMPLParamRegressor() # self.smpl = SMPL() self.smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=batch_size) self.mesh = mesh if pretrained_checkpoint is not None: checkpoint = torch.load(pretrained_checkpoint, map_location=device) try: self.graph_cnn.load_state_dict(checkpoint['graph_cnn']) except KeyError: print('Warning: graph_cnn was not found in checkpoint') try: self.smpl_param_regressor.load_state_dict( checkpoint['smpl_param_regressor']) except KeyError: print( 'Warning: smpl_param_regressor was not found in checkpoint' )
def init_fn(self): # create training dataset self.train_ds = create_dataset(self.options.dataset, self.options) # create Mesh object self.mesh = Mesh() self.faces = self.mesh.faces.to(self.device) # create GraphCNN self.graph_cnn = GraphCNN(self.mesh.adjmat, self.mesh.ref_vertices.t(), num_channels=self.options.num_channels, num_layers=self.options.num_layers).to( self.device) # SMPL Parameter regressor self.smpl_param_regressor = SMPLParamRegressor().to(self.device) # Setup a joint optimizer for the 2 models self.optimizer = torch.optim.Adam( params=list(self.graph_cnn.parameters()) + list(self.smpl_param_regressor.parameters()), lr=self.options.lr, betas=(self.options.adam_beta1, 0.999), weight_decay=self.options.wd) # SMPL model self.smpl = SMPL().to(self.device) # Create loss functions self.criterion_shape = nn.L1Loss().to(self.device) self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device) self.criterion_regr = nn.MSELoss().to(self.device) # Pack models and optimizers in a dict - necessary for checkpointing self.models_dict = { 'graph_cnn': self.graph_cnn, 'smpl_param_regressor': self.smpl_param_regressor } self.optimizers_dict = {'optimizer': self.optimizer} # Renderer for visualization self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy()) # LSP indices from full list of keypoints self.to_lsp = list(range(14)) # Optionally start training from a pretrained checkpoint # Note that this is different from resuming training # For the latter use --resume if self.options.pretrained_checkpoint is not None: self.load_pretrained( checkpoint_file=self.options.pretrained_checkpoint)
class CMR(nn.Module): def __init__(self, mesh, num_layers, num_channels, pretrained_checkpoint=None): super(CMR, self).__init__() self.graph_cnn = GraphCNN(mesh.adjmat, mesh.ref_vertices.t(), num_layers, num_channels) self.smpl_param_regressor = SMPLParamRegressor() self.smpl = SMPL() self.mesh = mesh if pretrained_checkpoint is not None: checkpoint = torch.load(pretrained_checkpoint) try: self.graph_cnn.load_state_dict(checkpoint['graph_cnn']) except KeyError: print('Warning: graph_cnn was not found in checkpoint') try: self.smpl_param_regressor.load_state_dict(checkpoint['smpl_param_regressor']) except KeyError: print('Warning: smpl_param_regressor was not found in checkpoint') def forward(self, image, train_graph_cnn=True, train_smpl_param_regressor=True, detach=True): """Fused forward pass for the 2 networks Inputs: image: size = (B, 3, 224, 224) Returns: Regressed non-parametric shape: size = (B, 6890, 3) Regressed SMPL shape: size = (B, 6890, 3) Weak-perspective camera: size = (B, 3) SMPL pose parameters (as rotation matrices): size = (B, 24, 3, 3) SMPL shape parameters: size = (B, 10) """ batch_size = image.shape[0] if not train_graph_cnn: with torch.no_grad(): pred_vertices_sub, camera = self.graph_cnn(image) else: pred_vertices_sub, camera = self.graph_cnn(image) pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1, 2)) if detach: x = pred_vertices_sub.transpose(1,2).detach() else: x = pred_vertices_sub.transpose(1,2) x = torch.cat([x, self.mesh.ref_vertices[None, :, :].expand(batch_size, -1, -1)], dim=-1) if not train_smpl_param_regressor: with torch.no_grad(): pred_rotmat, pred_betas = self.smpl_param_regressor(x) else: pred_rotmat, pred_betas = self.smpl_param_regressor(x) pred_vertices_smpl = self.smpl(pred_rotmat, pred_betas) return pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas
class Trainer(BaseTrainer): """Trainer object. Inherits from BaseTrainer that sets up logging, saving/restoring checkpoints etc. """ def init_fn(self): # create training dataset self.train_ds = create_dataset(self.options.dataset, self.options) # create Mesh object self.mesh = Mesh() self.faces = self.mesh.faces.to(self.device) # create GraphCNN self.graph_cnn = GraphCNN(self.mesh.adjmat, self.mesh.ref_vertices.t(), num_channels=self.options.num_channels, num_layers=self.options.num_layers).to( self.device) # SMPL Parameter regressor self.smpl_param_regressor = SMPLParamRegressor().to(self.device) # Setup a joint optimizer for the 2 models self.optimizer = torch.optim.Adam( params=list(self.graph_cnn.parameters()) + list(self.smpl_param_regressor.parameters()), lr=self.options.lr, betas=(self.options.adam_beta1, 0.999), weight_decay=self.options.wd) # SMPL model self.smpl = SMPL().to(self.device) # Create loss functions self.criterion_shape = nn.L1Loss().to(self.device) self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device) self.criterion_regr = nn.MSELoss().to(self.device) # Pack models and optimizers in a dict - necessary for checkpointing self.models_dict = { 'graph_cnn': self.graph_cnn, 'smpl_param_regressor': self.smpl_param_regressor } self.optimizers_dict = {'optimizer': self.optimizer} # Renderer for visualization self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy()) # LSP indices from full list of keypoints self.to_lsp = list(range(14)) # Optionally start training from a pretrained checkpoint # Note that this is different from resuming training # For the latter use --resume if self.options.pretrained_checkpoint is not None: self.load_pretrained( checkpoint_file=self.options.pretrained_checkpoint) def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d): """Compute 2D reprojection loss on the keypoints. The confidence is binary and indicates whether the keypoints exist or not. The available keypoints are different for each dataset. """ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() loss = (conf * self.criterion_keypoints( pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() return loss def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): """Compute 3D keypoint loss for the examples that 3D keypoint annotations are available. The loss is weighted by the confidence """ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] conf = conf[has_pose_3d == 1] pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] if len(gt_keypoints_3d) > 0: gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] pred_pelvis = (pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] return (conf * self.criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() else: return torch.FloatTensor(1).fill_(0.).to(self.device) def shape_loss(self, pred_vertices, gt_vertices, has_smpl): """Compute per-vertex loss on the shape for the examples that SMPL annotations are available.""" pred_vertices_with_shape = pred_vertices[has_smpl == 1] gt_vertices_with_shape = gt_vertices[has_smpl == 1] if len(gt_vertices_with_shape) > 0: return self.criterion_shape(pred_vertices_with_shape, gt_vertices_with_shape) else: return torch.FloatTensor(1).fill_(0.).to(self.device) def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl): """Compute SMPL parameter loss for the examples that SMPL annotations are available.""" pred_rotmat_valid = pred_rotmat[has_smpl == 1].view(-1, 3, 3) gt_rotmat_valid = rodrigues(gt_pose[has_smpl == 1].view(-1, 3)) pred_betas_valid = pred_betas[has_smpl == 1] gt_betas_valid = gt_betas[has_smpl == 1] if len(pred_rotmat_valid) > 0: loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid) loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid) else: loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device) loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device) return loss_regr_pose, loss_regr_betas def train_step(self, input_batch): """Training step.""" self.graph_cnn.train() self.smpl_param_regressor.train() # Grab data from the batch gt_keypoints_2d = input_batch['keypoints'] gt_keypoints_3d = input_batch['pose_3d'] gt_pose = input_batch['pose'] gt_betas = input_batch['betas'] has_smpl = input_batch['has_smpl'] has_pose_3d = input_batch['has_pose_3d'] images = input_batch['img'] # Render vertices using SMPL parameters gt_vertices = self.smpl(gt_pose, gt_betas) batch_size = gt_vertices.shape[0] # Feed image in the GraphCNN # Returns subsampled mesh and camera parameters pred_vertices_sub, pred_camera = self.graph_cnn(images) # Upsample mesh in the original size pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1, 2)) # Prepare input for SMPL Parameter regressor # The input is the predicted and template vertices subsampled by a factor of 4 # Notice that we detach the GraphCNN x = pred_vertices_sub.transpose(1, 2).detach() x = torch.cat( [x, self.mesh.ref_vertices[None, :, :].expand(batch_size, -1, -1)], dim=-1) # Estimate SMPL parameters and render vertices pred_rotmat, pred_shape = self.smpl_param_regressor(x) pred_vertices_smpl = self.smpl(pred_rotmat, pred_shape) # Get 3D and projected 2D keypoints from the regressed shape pred_keypoints_3d = self.smpl.get_joints(pred_vertices) pred_keypoints_2d = orthographic_projection(pred_keypoints_3d, pred_camera)[:, :, :2] pred_keypoints_3d_smpl = self.smpl.get_joints(pred_vertices_smpl) pred_keypoints_2d_smpl = orthographic_projection( pred_keypoints_3d_smpl, pred_camera.detach())[:, :, :2] # Compute losses # GraphCNN losses loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d) loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, has_pose_3d) loss_shape = self.shape_loss(pred_vertices, gt_vertices, has_smpl) # SMPL regressor losses loss_keypoints_smpl = self.keypoint_loss(pred_keypoints_2d_smpl, gt_keypoints_2d) loss_keypoints_3d_smpl = self.keypoint_3d_loss(pred_keypoints_3d_smpl, gt_keypoints_3d, has_pose_3d) loss_shape_smpl = self.shape_loss(pred_vertices_smpl, gt_vertices, has_smpl) loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_shape, gt_pose, gt_betas, has_smpl) # Add losses to compute the total loss loss = loss_shape_smpl + loss_keypoints_smpl + loss_keypoints_3d_smpl +\ loss_regr_pose + 0.1 * loss_regr_betas + loss_shape + loss_keypoints + loss_keypoints_3d # Do backprop self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Pack output arguments to be used for visualization in a list out_args = [ pred_vertices, pred_vertices_smpl, pred_camera, pred_keypoints_2d, pred_keypoints_2d_smpl, loss_shape, loss_shape_smpl, loss_keypoints, loss_keypoints_smpl, loss_keypoints_3d, loss_keypoints_3d_smpl, loss_regr_pose, loss_regr_betas, loss ] out_args = [arg.detach() for arg in out_args] return out_args def train_summaries(self, input_batch, pred_vertices, pred_vertices_smpl, pred_camera, pred_keypoints_2d, pred_keypoints_2d_smpl, loss_shape, loss_shape_smpl, loss_keypoints, loss_keypoints_smpl, loss_keypoints_3d, loss_keypoints_3d_smpl, loss_regr_pose, loss_regr_betas, loss): """Tensorboard logging.""" gt_keypoints_2d = input_batch['keypoints'].cpu().numpy() rend_imgs = [] rend_imgs_smpl = [] batch_size = pred_vertices.shape[0] # Do visualization for the first 4 images of the batch for i in range(min(batch_size, 4)): img = input_batch['img_orig'][i].cpu().numpy().transpose(1, 2, 0) # Get LSP keypoints from the full list of keypoints gt_keypoints_2d_ = gt_keypoints_2d[i, self.to_lsp] pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, self.to_lsp] pred_keypoints_2d_smpl_ = pred_keypoints_2d_smpl.cpu().numpy()[ i, self.to_lsp] # Get GraphCNN and SMPL vertices for the particular example vertices = pred_vertices[i].cpu().numpy() vertices_smpl = pred_vertices_smpl[i].cpu().numpy() cam = pred_camera[i].cpu().numpy() cam = pred_camera[i].cpu().numpy() # Visualize reconstruction and detected pose rend_img = visualize_reconstruction(img, self.options.img_res, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, self.renderer) rend_img_smpl = visualize_reconstruction(img, self.options.img_res, gt_keypoints_2d_, vertices_smpl, pred_keypoints_2d_smpl_, cam, self.renderer) rend_img = rend_img.transpose(2, 0, 1) rend_img_smpl = rend_img_smpl.transpose(2, 0, 1) rend_imgs.append(torch.from_numpy(rend_img)) rend_imgs_smpl.append(torch.from_numpy(rend_img_smpl)) rend_imgs = make_grid(rend_imgs, nrow=1) rend_imgs_smpl = make_grid(rend_imgs_smpl, nrow=1) # Save results in Tensorboard self.summary_writer.add_image('imgs', rend_imgs, self.step_count) self.summary_writer.add_image('imgs_smpl', rend_imgs_smpl, self.step_count) self.summary_writer.add_scalar('loss_shape', loss_shape, self.step_count) self.summary_writer.add_scalar('loss_shape_smpl', loss_shape_smpl, self.step_count) self.summary_writer.add_scalar('loss_regr_pose', loss_regr_pose, self.step_count) self.summary_writer.add_scalar('loss_regr_betas', loss_regr_betas, self.step_count) self.summary_writer.add_scalar('loss_keypoints', loss_keypoints, self.step_count) self.summary_writer.add_scalar('loss_keypoints_smpl', loss_keypoints_smpl, self.step_count) self.summary_writer.add_scalar('loss_keypoints_3d', loss_keypoints_3d, self.step_count) self.summary_writer.add_scalar('loss_keypoints_3d_smpl', loss_keypoints_3d_smpl, self.step_count) self.summary_writer.add_scalar('loss', loss, self.step_count)