def forward(self, sample, scaletrans=None, scale=None, trans=None, rotaxisang=None): """ Args: scaletrans: torch.Tensor of shape [batch_size, channels] with channels == 6 with in first position the predicted scale values and in 2,3 the predicted translation values, and global rotation encoded as axis-angles in channel positions 4,5,6 """ if scaletrans is None: batch_size = scale.shape[0] else: batch_size = scaletrans.shape[0] if scale is None: scale = scaletrans[:, :1] if trans is None: trans = scaletrans[:, 1:3] if rotaxisang is None: rotaxisang = scaletrans[:, 3:] # Get rotation matrixes from axis-angles rotmat = rodrigues_layer.batch_rodrigues(rotaxisang).view( rotaxisang.shape[0], 3, 3) canobjverts = sample[BaseQueries.OBJCANVERTS].cuda() rotobjverts = rotmat.bmm(canobjverts.float().transpose(1, 2)).transpose( 1, 2) final_trans = trans.unsqueeze(1) * self.trans_factor final_scale = scale.view(batch_size, 1, 1) * self.scale_factor height, width = tuple(sample[TransQueries.IMAGE].shape[2:]) camintr = sample[TransQueries.CAMINTR].cuda() objverts3d, center3d = project.recover_3d_proj(rotobjverts, camintr, final_scale, final_trans, input_res=(width, height)) # Recover 2D positions given camera intrinsic parameters and object vertex # coordinates in camera coordinate reference pred_objverts2d = camproject.batch_proj2d(objverts3d, camintr) if BaseQueries.OBJCORNERS3D in sample: canobjcorners = sample[BaseQueries.OBJCANCORNERS].cuda() rotobjcorners = rotmat.bmm(canobjcorners.float().transpose( 1, 2)).transpose(1, 2) recov_objcorners3d = rotobjcorners + center3d pred_objcorners2d = camproject.batch_proj2d( rotobjcorners + center3d, camintr) else: pred_objcorners2d = None recov_objcorners3d = None rotobjcorners = None return { "obj_verts2d": pred_objverts2d, "obj_verts3d": rotobjverts, "recov_objverts3d": objverts3d, "recov_objcorners3d": recov_objcorners3d, "obj_scale": final_scale, "obj_prescale": scale, "obj_prerot": rotaxisang, "obj_trans": final_trans, "obj_pretrans": trans, "obj_corners2d": pred_objcorners2d, "obj_corners3d": rotobjcorners, }
def get_opticalflow( verts_cam: List[torch.Tensor], faces: torch.Tensor, camintrs: List[torch.Tensor], neurenderer, orig_img_size=None, mask_occlusions: bool = True, detach_textures: bool = False, detach_renders: bool = True, ignore_face_idxs=None, ): """ Compute optical flow in image space given the displacement of the vertices in verts_cam between the first If detach_renders is False, gradients will be computed to update the 'shape' of the rendered flow image If detach_textures is False, gradients will be computed to update the 'colors' of the rendered flow image When detach_renders is True and detach_textures is False (the setting we use), the gradients only flow through the difference (as rendered by the neural renderer) between the positions between the pairs, (no gradients flow to update the position of the first mesh) Args: verts_cam: Pair of vertex positions as list of tensors of shape ([batch_size, point_nb, 3 (spatial coordinate)]) of len 2 faces: Faces as tensor of vertex indexes of shape [batch_size, face_nb, 3 (vertex indices)] camintrs: Pair of intrinsic camera parameters as list of tensors of shape (batch_size, 3, 3) of len 2 ignore_face_idxs: Idxs of faces for which the optical flow should not be computed detach_textures: Do not backpropagate through the optical flow offset *values* detach_renders: Do not backpropagate through the optical flow rendered positions Returns: pred_flow12: flow values renderered at the location of first vertices with flow values from 1 to 2 pred_flow12: flow values renderered at the location of second vertices with flow values from 2 to 1 """ # Project ground truth vertices on image plane gt_locs2d_1 = project.batch_proj2d(verts_cam[0], camintrs[0]) gt_locs2d_2 = project.batch_proj2d(verts_cam[1], camintrs[1]) # Get ground truth forward optical flow verts_displ2d_12 = gt_locs2d_2 - gt_locs2d_1 sample_flows = torch.cat( [verts_displ2d_12, torch.ones_like(verts_displ2d_12[:, :, :1])], -1) all_textures = textutils.batch_vertex_textures(faces, sample_flows) if detach_textures: all_textures = all_textures.detach() # Only keep locations with valid flow pixel predictions renderout = neurenderer(verts_cam[0], faces, all_textures, K=camintrs[0], detach_renders=detach_renders) mask_flow1 = (renderout["alpha"].unsqueeze(1) > 0.99999).float() if ignore_face_idxs is not None: ignore_mask = (renderout["face_index_map"].unsqueeze(-1) - renderout["face_index_map"].new(ignore_face_idxs) ).abs().min(-1)[0] != 0 ignore_mask = ignore_mask[:, list(reversed(range(ignore_mask.shape[1])))] ignore_mask = ignore_mask.float().unsqueeze(1) mask_flow1 = mask_flow1 * ignore_mask pred_flow12 = renderout["rgb"] * mask_flow1 # Get ground truth backward optical flow verts_displ2d_21 = gt_locs2d_1 - gt_locs2d_2 sample_flows = torch.cat( [verts_displ2d_21, torch.ones_like(verts_displ2d_21[:, :, :1])], -1) all_textures = textutils.batch_vertex_textures(faces, sample_flows) # Only keep locations with valid flow pixel predictions renderout = neurenderer(verts_cam[1], faces, all_textures, K=camintrs[1], detach_renders=detach_renders) mask_flow2 = (renderout["alpha"].unsqueeze(1) > 0.99999).float() if ignore_face_idxs is not None: ignore_mask = (renderout["face_index_map"].unsqueeze(-1) - renderout["face_index_map"].new(ignore_face_idxs) ).abs().min(-1)[0] != 0 ignore_mask = ignore_mask[:, list(reversed(range(ignore_mask.shape[1])))] ignore_mask = ignore_mask.float().unsqueeze(1) mask_flow2 = mask_flow2 * ignore_mask pred_flow21 = renderout["rgb"] * mask_flow2 if mask_occlusions: with torch.no_grad(): mask_flow2 = renderout["alpha"].unsqueeze(1) # Compute pixels which are visible in both frames by # performing a forward-backward consistency check occl_mask1, occl_mask2 = imgflowarp.get_occlusion_mask( mask_flow1, mask_flow2, pred_flow12, pred_flow21) mask_flow1 = mask_flow1 * occl_mask1.unsqueeze(1) mask_flow2 = mask_flow2 * occl_mask2.unsqueeze(1) pred_flow12 = pred_flow12 * mask_flow1 pred_flow21 = pred_flow21 * mask_flow2 pred_flow12 = pred_flow12.permute(0, 2, 3, 1)[:, :, :, :2] pred_flow21 = pred_flow21.permute(0, 2, 3, 1)[:, :, :, :2] if orig_img_size is not None: pred_flow12 = pred_flow12[:, :orig_img_size[1], :orig_img_size[0]] pred_flow21 = pred_flow21[:, :orig_img_size[1], :orig_img_size[0]] pred_flows = [pred_flow12, pred_flow21] return pred_flows
def recover_mano( self, sample, encoder_output=None, pose=None, shape=None, no_loss=False, total_loss=None, scale=None, trans=None, ): # Get hand projection, centered mano_results = self.mano_branch(encoder_output, sides=sample[BaseQueries.SIDE], pose=pose, shape=shape) if self.adaptor: adapt_joints, _ = self.adaptor(mano_results["verts3d"]) adapt_joints = adapt_joints.transpose(1, 2) mano_results[ "joints3d"] = adapt_joints - adapt_joints[:, self. mano_center_idx].unsqueeze( 1) mano_results["verts3d"] = mano_results[ "verts3d"] - adapt_joints[:, self.mano_center_idx].unsqueeze(1) if not no_loss: mano_total_loss, mano_losses = self.mano_loss.compute_loss( mano_results, sample) if total_loss is None: total_loss = mano_total_loss else: total_loss += mano_total_loss mano_losses["mano_total_loss"] = mano_total_loss.clone() # Recover hand position in camera coordinates if (self.mano_lambda_joints2d or self.mano_lambda_verts2d or self.mano_lambda_recov_joints3d or self.mano_lambda_recov_verts3d): if scale is None and trans is None: scaletrans = self.scaletrans_branch(encoder_output) if trans is None: trans = scaletrans[:, 1:] if scale is None: scale = scaletrans[:, :1] final_trans = trans.unsqueeze(1) * self.obj_trans_factor final_scale = scale.view(scale.shape[0], 1, 1) * self.obj_scale_factor height, width = tuple(sample[TransQueries.IMAGE].shape[2:]) camintr = sample[TransQueries.CAMINTR].cuda() recov_joints3d, center3d = project.recover_3d_proj( mano_results["joints3d"], camintr, final_scale, final_trans, input_res=(width, height)) recov_hand_verts3d = mano_results["verts3d"] + center3d proj_joints2d = camproject.batch_proj2d(recov_joints3d, camintr) proj_verts2d = camproject.batch_proj2d(recov_hand_verts3d, camintr) mano_results["joints2d"] = proj_joints2d mano_results["recov_joints3d"] = recov_joints3d mano_results["recov_handverts3d"] = recov_hand_verts3d mano_results["hand_center3d"] = center3d mano_results["verts2d"] = proj_verts2d mano_results["hand_pretrans"] = trans mano_results["hand_prescale"] = scale mano_results["hand_trans"] = final_trans mano_results["hand_scale"] = final_scale if not no_loss: # Compute hand losses in pixel space and camera coordinates if self.mano_lambda_joints2d is not None and TransQueries.JOINTS2D in sample: gt_joints2d = sample[TransQueries.JOINTS2D].cuda().float() if self.criterion2d == "l2": # Normalize predictions in pixel space so that results are roughly centered # and have magnitude ~1 norm_joints2d_pred = normalize_pixel_out(proj_joints2d) norm_joints2d_gt = normalize_pixel_out(gt_joints2d) joints2d_loss = torch_f.mse_loss( norm_joints2d_pred, norm_joints2d_gt) elif self.criterion2d == "l1": joints2d_loss = torch_f.l1_loss( proj_joints2d, gt_joints2d) elif self.criterion2d == "smoothl1": joints2d_loss = torch_f.smooth_l1_loss( proj_joints2d, gt_joints2d) total_loss += self.mano_lambda_joints2d * joints2d_loss mano_losses["joints2d"] = joints2d_loss if self.mano_lambda_verts2d is not None and TransQueries.HANDVERTS2D in sample: gt_verts2d = sample[ TransQueries.HANDVERTS2D].cuda().float() verts2d_loss = torch_f.mse_loss( normalize_pixel_out(proj_verts2d, self.inp_res), normalize_pixel_out(gt_verts2d, self.inp_res), ) total_loss += self.mano_lambda_verts2d * verts2d_loss mano_losses["verts2d"] = verts2d_loss if self.mano_lambda_recov_joints3d is not None and BaseQueries.JOINTS3D in sample: joints3d_gt = sample[BaseQueries.JOINTS3D].cuda() recov_loss = torch_f.mse_loss(recov_joints3d, joints3d_gt) total_loss += self.mano_lambda_recov_joints3d * recov_loss mano_losses["recov_joint3d"] = recov_loss if self.mano_lambda_recov_verts3d is not None and BaseQueries.HANDVERTS3D in sample: hand_verts3d_gt = sample[BaseQueries.HANDVERTS3D].cuda() recov_loss = torch_f.mse_loss(recov_hand_verts3d, hand_verts3d_gt) total_loss += self.mano_lambda_recov_verts3d * recov_loss return mano_results, total_loss, mano_losses