def forward(self, sample, scaletrans=None, scale=None, trans=None, rotaxisang=None): """ Args: scaletrans: torch.Tensor of shape [batch_size, channels] with channels == 6 with in first position the predicted scale values and in 2,3 the predicted translation values, and global rotation encoded as axis-angles in channel positions 4,5,6 """ if scaletrans is None: batch_size = scale.shape[0] else: batch_size = scaletrans.shape[0] if scale is None: scale = scaletrans[:, :1] if trans is None: trans = scaletrans[:, 1:3] if rotaxisang is None: rotaxisang = scaletrans[:, 3:] # Get rotation matrixes from axis-angles rotmat = rodrigues_layer.batch_rodrigues(rotaxisang).view( rotaxisang.shape[0], 3, 3) canobjverts = sample[BaseQueries.OBJCANVERTS].cuda() rotobjverts = rotmat.bmm(canobjverts.float().transpose(1, 2)).transpose( 1, 2) final_trans = trans.unsqueeze(1) * self.trans_factor final_scale = scale.view(batch_size, 1, 1) * self.scale_factor height, width = tuple(sample[TransQueries.IMAGE].shape[2:]) camintr = sample[TransQueries.CAMINTR].cuda() objverts3d, center3d = project.recover_3d_proj(rotobjverts, camintr, final_scale, final_trans, input_res=(width, height)) # Recover 2D positions given camera intrinsic parameters and object vertex # coordinates in camera coordinate reference pred_objverts2d = camproject.batch_proj2d(objverts3d, camintr) if BaseQueries.OBJCORNERS3D in sample: canobjcorners = sample[BaseQueries.OBJCANCORNERS].cuda() rotobjcorners = rotmat.bmm(canobjcorners.float().transpose( 1, 2)).transpose(1, 2) recov_objcorners3d = rotobjcorners + center3d pred_objcorners2d = camproject.batch_proj2d( rotobjcorners + center3d, camintr) else: pred_objcorners2d = None recov_objcorners3d = None rotobjcorners = None return { "obj_verts2d": pred_objverts2d, "obj_verts3d": rotobjverts, "recov_objverts3d": objverts3d, "recov_objcorners3d": recov_objcorners3d, "obj_scale": final_scale, "obj_prescale": scale, "obj_prerot": rotaxisang, "obj_trans": final_trans, "obj_pretrans": trans, "obj_corners2d": pred_objcorners2d, "obj_corners3d": rotobjcorners, }
def recover_mano( self, sample, encoder_output=None, pose=None, shape=None, no_loss=False, total_loss=None, scale=None, trans=None, ): # Get hand projection, centered mano_results = self.mano_branch(encoder_output, sides=sample[BaseQueries.SIDE], pose=pose, shape=shape) if self.adaptor: adapt_joints, _ = self.adaptor(mano_results["verts3d"]) adapt_joints = adapt_joints.transpose(1, 2) mano_results[ "joints3d"] = adapt_joints - adapt_joints[:, self. mano_center_idx].unsqueeze( 1) mano_results["verts3d"] = mano_results[ "verts3d"] - adapt_joints[:, self.mano_center_idx].unsqueeze(1) if not no_loss: mano_total_loss, mano_losses = self.mano_loss.compute_loss( mano_results, sample) if total_loss is None: total_loss = mano_total_loss else: total_loss += mano_total_loss mano_losses["mano_total_loss"] = mano_total_loss.clone() # Recover hand position in camera coordinates if (self.mano_lambda_joints2d or self.mano_lambda_verts2d or self.mano_lambda_recov_joints3d or self.mano_lambda_recov_verts3d): if scale is None and trans is None: scaletrans = self.scaletrans_branch(encoder_output) if trans is None: trans = scaletrans[:, 1:] if scale is None: scale = scaletrans[:, :1] final_trans = trans.unsqueeze(1) * self.obj_trans_factor final_scale = scale.view(scale.shape[0], 1, 1) * self.obj_scale_factor height, width = tuple(sample[TransQueries.IMAGE].shape[2:]) camintr = sample[TransQueries.CAMINTR].cuda() recov_joints3d, center3d = project.recover_3d_proj( mano_results["joints3d"], camintr, final_scale, final_trans, input_res=(width, height)) recov_hand_verts3d = mano_results["verts3d"] + center3d proj_joints2d = camproject.batch_proj2d(recov_joints3d, camintr) proj_verts2d = camproject.batch_proj2d(recov_hand_verts3d, camintr) mano_results["joints2d"] = proj_joints2d mano_results["recov_joints3d"] = recov_joints3d mano_results["recov_handverts3d"] = recov_hand_verts3d mano_results["hand_center3d"] = center3d mano_results["verts2d"] = proj_verts2d mano_results["hand_pretrans"] = trans mano_results["hand_prescale"] = scale mano_results["hand_trans"] = final_trans mano_results["hand_scale"] = final_scale if not no_loss: # Compute hand losses in pixel space and camera coordinates if self.mano_lambda_joints2d is not None and TransQueries.JOINTS2D in sample: gt_joints2d = sample[TransQueries.JOINTS2D].cuda().float() if self.criterion2d == "l2": # Normalize predictions in pixel space so that results are roughly centered # and have magnitude ~1 norm_joints2d_pred = normalize_pixel_out(proj_joints2d) norm_joints2d_gt = normalize_pixel_out(gt_joints2d) joints2d_loss = torch_f.mse_loss( norm_joints2d_pred, norm_joints2d_gt) elif self.criterion2d == "l1": joints2d_loss = torch_f.l1_loss( proj_joints2d, gt_joints2d) elif self.criterion2d == "smoothl1": joints2d_loss = torch_f.smooth_l1_loss( proj_joints2d, gt_joints2d) total_loss += self.mano_lambda_joints2d * joints2d_loss mano_losses["joints2d"] = joints2d_loss if self.mano_lambda_verts2d is not None and TransQueries.HANDVERTS2D in sample: gt_verts2d = sample[ TransQueries.HANDVERTS2D].cuda().float() verts2d_loss = torch_f.mse_loss( normalize_pixel_out(proj_verts2d, self.inp_res), normalize_pixel_out(gt_verts2d, self.inp_res), ) total_loss += self.mano_lambda_verts2d * verts2d_loss mano_losses["verts2d"] = verts2d_loss if self.mano_lambda_recov_joints3d is not None and BaseQueries.JOINTS3D in sample: joints3d_gt = sample[BaseQueries.JOINTS3D].cuda() recov_loss = torch_f.mse_loss(recov_joints3d, joints3d_gt) total_loss += self.mano_lambda_recov_joints3d * recov_loss mano_losses["recov_joint3d"] = recov_loss if self.mano_lambda_recov_verts3d is not None and BaseQueries.HANDVERTS3D in sample: hand_verts3d_gt = sample[BaseQueries.HANDVERTS3D].cuda() recov_loss = torch_f.mse_loss(recov_hand_verts3d, hand_verts3d_gt) total_loss += self.mano_lambda_recov_verts3d * recov_loss return mano_results, total_loss, mano_losses