class MeshRegNet(nn.Module): def __init__( self, fc_dropout=0, resnet_version=18, criterion2d="l2", mano_neurons=[512, 512], mano_comps=15, mano_use_shape=False, mano_lambda_pose_reg=0, mano_use_pca=True, mano_center_idx=9, mano_root="assets/mano", mano_lambda_joints3d=None, mano_lambda_recov_joints3d=None, mano_lambda_recov_verts3d=None, mano_lambda_joints2d=None, mano_lambda_verts3d=None, mano_lambda_verts2d=None, mano_lambda_shape=None, mano_pose_coeff: int = 1, mano_fhb_hand: bool = False, obj_lambda_verts3d=None, obj_lambda_verts2d=None, obj_lambda_recov_verts3d=None, obj_trans_factor=1, obj_scale_factor=1, inp_res=256, uncertainty_pnp=True, domain_norm=False, ): """ Args: mano_fhb_hand: Use pre-computed mapping from MANO joints to First Person Hand Action Benchmark hand skeleton mano_root (path): dir containing mano pickle files mano_neurons: number of neurons in each layer of base mano decoder mano_use_pca: predict pca parameters directly instead of rotation angles mano_comps (int): number of principal components to use if mano_use_pca mano_lambda_pca: weight to supervise hand pose in PCA space mano_lambda_pose_reg: weight to supervise hand pose in axis-angle space mano_lambda_verts: weight to supervise vertex distances mano_lambda_joints3d: weight to supervise distances adapt_atlas_decoder: add layer between encoder and decoder, usefull when finetuning from separately pretrained encoder and decoder """ super().__init__() self.inp_res = inp_res self.uncertainty_pnp = uncertainty_pnp self.domain_norm = domain_norm self.compute_pnp = False if int(resnet_version) == 18: img_feature_size = 512 base_net = resnet.resnet18(pretrained=True, return_cuda_inter=True, domain_norm=self.domain_norm) elif int(resnet_version) == 50: img_feature_size = 2048 base_net = resnet.resnet50(pretrained=True, return_cuda_inter=True, domain_norm=self.domain_norm) else: raise NotImplementedError( "Resnet {} not supported".format(resnet_version)) self.criterion2d = criterion2d mano_base_neurons = [img_feature_size] + mano_neurons self.mano_fhb_hand = mano_fhb_hand self.base_net = base_net # Predict translation and scaling for hand self.scaletrans_branch = AbsoluteBranch( base_neurons=[img_feature_size, int(img_feature_size / 2)], out_dim=3) # Initialize object branch self.obj_branch = PVNetDecoder(ver_dim=18, seg_dim=2, uncertainty_pnp=uncertainty_pnp, resnet_inplanes=base_net.inplanes, domain_norm=self.domain_norm) self.obj_scale_factor = obj_scale_factor self.obj_trans_factor = obj_trans_factor self.obj_vote_crit = nn.functional.smooth_l1_loss self.obj_seg_crit = nn.functional.cross_entropy # Initialize mano branch self.mano_branch = ManoBranch( ncomps=mano_comps, base_neurons=mano_base_neurons, dropout=fc_dropout, mano_pose_coeff=mano_pose_coeff, mano_root=mano_root, center_idx=mano_center_idx, use_shape=mano_use_shape, use_pca=mano_use_pca, ) self.mano_center_idx = mano_center_idx if self.mano_fhb_hand: load_fhb_path = f"assets/mano/fhb_skel_centeridx{mano_center_idx}.pkl" with open(load_fhb_path, "rb") as p_f: exp_data = pickle.load(p_f) self.register_buffer("fhb_shape", torch.Tensor(exp_data["shape"])) self.adaptor = ManoAdaptor(self.mano_branch.mano_layer_right, load_fhb_path) rec_freeze(self.adaptor) else: self.adaptor = None if (mano_lambda_verts2d or mano_lambda_verts3d or mano_lambda_joints3d or mano_lambda_joints2d or mano_lambda_recov_joints3d or mano_lambda_recov_verts3d): self.mano_lambdas = True else: self.mano_lambdas = False if obj_lambda_verts2d or obj_lambda_verts3d or obj_lambda_recov_verts3d: self.obj_lambdas = True else: self.obj_lambdas = False self.mano_loss = ManoLoss( lambda_verts3d=mano_lambda_verts3d, lambda_joints3d=mano_lambda_joints3d, lambda_shape=mano_lambda_shape, lambda_pose_reg=mano_lambda_pose_reg, ) # Store loss weights self.mano_lambda_joints2d = mano_lambda_joints2d self.mano_lambda_recov_joints3d = mano_lambda_recov_joints3d self.mano_lambda_recov_verts3d = mano_lambda_recov_verts3d self.mano_lambda_verts2d = mano_lambda_verts2d self.obj_lambda_verts2d = obj_lambda_verts2d self.obj_lambda_verts3d = obj_lambda_verts3d self.obj_lambda_recov_verts3d = obj_lambda_recov_verts3d self.obj_lambda_seg = 0.01 self.obj_lambda_vote = 0.1 def recover_mano( self, sample, encoder_output=None, pose=None, shape=None, no_loss=False, total_loss=None, scale=None, trans=None, ): # Get hand projection, centered mano_results = self.mano_branch(encoder_output, sides=sample[BaseQueries.SIDE], pose=pose, shape=shape) if self.adaptor: adapt_joints, _ = self.adaptor(mano_results["verts3d"]) adapt_joints = adapt_joints.transpose(1, 2) mano_results[ "joints3d"] = adapt_joints - adapt_joints[:, self. mano_center_idx].unsqueeze( 1) mano_results["verts3d"] = mano_results[ "verts3d"] - adapt_joints[:, self.mano_center_idx].unsqueeze(1) if not no_loss: mano_total_loss, mano_losses = self.mano_loss.compute_loss( mano_results, sample) if total_loss is None: total_loss = mano_total_loss else: total_loss += mano_total_loss mano_losses["mano_total_loss"] = mano_total_loss.clone() # Recover hand position in camera coordinates if (self.mano_lambda_joints2d or self.mano_lambda_verts2d or self.mano_lambda_recov_joints3d or self.mano_lambda_recov_verts3d): if scale is None and trans is None: scaletrans = self.scaletrans_branch(encoder_output) if trans is None: trans = scaletrans[:, 1:] if scale is None: scale = scaletrans[:, :1] final_trans = trans.unsqueeze(1) * self.obj_trans_factor final_scale = scale.view(scale.shape[0], 1, 1) * self.obj_scale_factor height, width = tuple(sample[TransQueries.IMAGE].shape[2:]) camintr = sample[TransQueries.CAMINTR].cuda() recov_joints3d, center3d = project.recover_3d_proj( mano_results["joints3d"], camintr, final_scale, final_trans, input_res=(width, height)) recov_hand_verts3d = mano_results["verts3d"] + center3d proj_joints2d = camproject.batch_proj2d(recov_joints3d, camintr) proj_verts2d = camproject.batch_proj2d(recov_hand_verts3d, camintr) mano_results["joints2d"] = proj_joints2d mano_results["recov_joints3d"] = recov_joints3d mano_results["recov_handverts3d"] = recov_hand_verts3d mano_results["hand_center3d"] = center3d mano_results["verts2d"] = proj_verts2d mano_results["hand_pretrans"] = trans mano_results["hand_prescale"] = scale mano_results["hand_trans"] = final_trans mano_results["hand_scale"] = final_scale if not no_loss: # Compute hand losses in pixel space and camera coordinates if self.mano_lambda_joints2d is not None and TransQueries.JOINTS2D in sample: gt_joints2d = sample[TransQueries.JOINTS2D].cuda().float() if self.criterion2d == "l2": # Normalize predictions in pixel space so that results are roughly centered # and have magnitude ~1 norm_joints2d_pred = normalize_pixel_out(proj_joints2d) norm_joints2d_gt = normalize_pixel_out(gt_joints2d) joints2d_loss = torch_f.mse_loss( norm_joints2d_pred, norm_joints2d_gt) elif self.criterion2d == "l1": joints2d_loss = torch_f.l1_loss( proj_joints2d, gt_joints2d) elif self.criterion2d == "smoothl1": joints2d_loss = torch_f.smooth_l1_loss( proj_joints2d, gt_joints2d) total_loss += self.mano_lambda_joints2d * joints2d_loss mano_losses["joints2d"] = joints2d_loss if self.mano_lambda_verts2d is not None and TransQueries.HANDVERTS2D in sample: gt_verts2d = sample[ TransQueries.HANDVERTS2D].cuda().float() verts2d_loss = torch_f.mse_loss( normalize_pixel_out(proj_verts2d, self.inp_res), normalize_pixel_out(gt_verts2d, self.inp_res), ) total_loss += self.mano_lambda_verts2d * verts2d_loss mano_losses["verts2d"] = verts2d_loss if self.mano_lambda_recov_joints3d is not None and BaseQueries.JOINTS3D in sample: joints3d_gt = sample[BaseQueries.JOINTS3D].cuda() recov_loss = torch_f.mse_loss(recov_joints3d, joints3d_gt) total_loss += self.mano_lambda_recov_joints3d * recov_loss mano_losses["recov_joint3d"] = recov_loss if self.mano_lambda_recov_verts3d is not None and BaseQueries.HANDVERTS3D in sample: hand_verts3d_gt = sample[BaseQueries.HANDVERTS3D].cuda() recov_loss = torch_f.mse_loss(recov_hand_verts3d, hand_verts3d_gt) total_loss += self.mano_lambda_recov_verts3d * recov_loss return mano_results, total_loss, mano_losses def recover_object(self, sample, input, encoder_output, encoder_features, no_loss=False, total_loss=None, scale=None, trans=None, rotaxisang=None): """ Compute object vertex and corner positions in camera coordinates by predicting object translation and scaling, and recovering 3D positions given known object model """ obj_results = self.obj_branch(input, encoder_output, encoder_features, compute_pnp=self.compute_pnp) # Compute losses obj_losses = {} if not no_loss: if BaseQueries.OBJFPSVECFIELD in sample and BaseQueries.OBJMASK in sample: weight = sample[BaseQueries.OBJMASK][:, None].float().cuda() vec_field = sample[BaseQueries.OBJFPSVECFIELD].cuda() vote_loss = self.obj_vote_crit(obj_results['vertex'] * weight, vec_field * weight, reduction='sum') vote_loss = vote_loss / weight.sum() / vec_field.size(1) obj_losses.update({'obj_vote_loss': vote_loss}) if total_loss is None: total_loss = self.obj_lambda_vote * vote_loss else: total_loss += self.obj_lambda_vote * vote_loss if BaseQueries.OBJMASK in sample: mask = sample[BaseQueries.OBJMASK].long().cuda() seg_loss = self.obj_seg_crit(obj_results['seg'], mask) obj_losses.update({'obj_seg_loss': seg_loss}) if total_loss is None: total_loss = self.obj_lambda_seg * seg_loss else: total_loss += self.obj_lambda_seg * seg_loss # Compute obj pose via RANSAC and PnP if self.compute_pnp: if BaseQueries.OBJFPS3D in sample and BaseQueries.OBJCORNERS3D in sample and BaseQueries.CAMINTR in sample and BaseQueries.OBJCANVERTS in sample: kpt_3d = sample[BaseQueries.OBJFPS3D].cpu().numpy() cam_intr = sample[BaseQueries.CAMINTR] verts = sample[BaseQueries.OBJCANVERTS] pred_kpt_2d = obj_results['kpt_2d'].cpu().numpy() var = None if self.uncertainty_pnp: var = obj_results['var'].detach().cpu().numpy() poses = batched_pnp(kpt_3d, pred_kpt_2d, cam_intr.cpu().numpy(), var=var) poses = torch.Tensor(poses).cuda() obj_results['obj_pose'] = poses verts_3d_hom = torch.cat( [verts, torch.ones(verts.shape[:-1] + (1, ))], axis=2) pred_verts3d = poses.bmm(verts_3d_hom.transpose(1, 2).cuda()) obj_results['recov_objverts3d'] = pred_verts3d.transpose(1, 2) pred_verts2d = cam_intr.cuda().bmm(pred_verts3d) pred_verts2d = pred_verts2d[:, :2] / pred_verts2d[:, 2:] obj_results['obj_verts2d'] = pred_verts2d.transpose(1, 2) return obj_results, total_loss, obj_losses def forward(self, sample, no_loss=False, step=0, preparams=None): total_loss = torch.Tensor([0]).cuda() results = {} losses = {} # Get input image = sample[TransQueries.IMAGE].cuda() # Feed input into shared encoder encoder_output, encoder_features = self.base_net(image) has_mano_super = one_query_in( sample.keys(), [ TransQueries.JOINTS3D, TransQueries.JOINTS2D, TransQueries.HANDVERTS2D, TransQueries.HANDVERTS3D, ], ) if True or (has_mano_super and self.mano_lambdas): if preparams is not None: hand_scale = preparams["hand_prescale"] hand_pose = preparams["pose"] hand_shape = preparams["shape"] hand_trans = preparams["hand_pretrans"] else: hand_scale = None hand_pose = None hand_shape = None hand_trans = None # Hand branch mano_results, total_loss, mano_losses = self.recover_mano( sample, encoder_output=encoder_output, no_loss=no_loss, total_loss=total_loss, trans=hand_trans, scale=hand_scale, pose=hand_pose, shape=hand_shape, ) losses.update(mano_losses) results.update(mano_results) has_obj_super = one_query_in( sample.keys(), [TransQueries.OBJVERTS2D, TransQueries.OBJVERTS3D]) if has_obj_super and self.obj_lambdas: if preparams is not None: obj_scale = preparams["obj_prescale"] obj_rot = preparams["obj_prerot"] obj_trans = preparams["obj_pretrans"] else: obj_scale = None obj_rot = None obj_trans = None # Object branch obj_results, total_loss, obj_losses = self.recover_object( sample, image, encoder_output, encoder_features, no_loss=no_loss, total_loss=total_loss, scale=obj_scale, trans=obj_trans, rotaxisang=obj_rot) losses.update(obj_losses) results.update(obj_results) if total_loss is not None: losses["total_loss"] = total_loss else: losses["total_loss"] = None return total_loss, results, losses
def __init__( self, fc_dropout=0, resnet_version=18, criterion2d="l2", mano_neurons=[512, 512], mano_comps=15, mano_use_shape=False, mano_lambda_pose_reg=0, mano_use_pca=True, mano_center_idx=9, mano_root="assets/mano", mano_lambda_joints3d=None, mano_lambda_recov_joints3d=None, mano_lambda_recov_verts3d=None, mano_lambda_joints2d=None, mano_lambda_verts3d=None, mano_lambda_verts2d=None, mano_lambda_shape=None, mano_pose_coeff: int = 1, mano_fhb_hand: bool = False, obj_lambda_verts3d=None, obj_lambda_verts2d=None, obj_lambda_recov_verts3d=None, obj_trans_factor=1, obj_scale_factor=1, inp_res=256, uncertainty_pnp=True, domain_norm=False, ): """ Args: mano_fhb_hand: Use pre-computed mapping from MANO joints to First Person Hand Action Benchmark hand skeleton mano_root (path): dir containing mano pickle files mano_neurons: number of neurons in each layer of base mano decoder mano_use_pca: predict pca parameters directly instead of rotation angles mano_comps (int): number of principal components to use if mano_use_pca mano_lambda_pca: weight to supervise hand pose in PCA space mano_lambda_pose_reg: weight to supervise hand pose in axis-angle space mano_lambda_verts: weight to supervise vertex distances mano_lambda_joints3d: weight to supervise distances adapt_atlas_decoder: add layer between encoder and decoder, usefull when finetuning from separately pretrained encoder and decoder """ super().__init__() self.inp_res = inp_res self.uncertainty_pnp = uncertainty_pnp self.domain_norm = domain_norm self.compute_pnp = False if int(resnet_version) == 18: img_feature_size = 512 base_net = resnet.resnet18(pretrained=True, return_cuda_inter=True, domain_norm=self.domain_norm) elif int(resnet_version) == 50: img_feature_size = 2048 base_net = resnet.resnet50(pretrained=True, return_cuda_inter=True, domain_norm=self.domain_norm) else: raise NotImplementedError( "Resnet {} not supported".format(resnet_version)) self.criterion2d = criterion2d mano_base_neurons = [img_feature_size] + mano_neurons self.mano_fhb_hand = mano_fhb_hand self.base_net = base_net # Predict translation and scaling for hand self.scaletrans_branch = AbsoluteBranch( base_neurons=[img_feature_size, int(img_feature_size / 2)], out_dim=3) # Initialize object branch self.obj_branch = PVNetDecoder(ver_dim=18, seg_dim=2, uncertainty_pnp=uncertainty_pnp, resnet_inplanes=base_net.inplanes, domain_norm=self.domain_norm) self.obj_scale_factor = obj_scale_factor self.obj_trans_factor = obj_trans_factor self.obj_vote_crit = nn.functional.smooth_l1_loss self.obj_seg_crit = nn.functional.cross_entropy # Initialize mano branch self.mano_branch = ManoBranch( ncomps=mano_comps, base_neurons=mano_base_neurons, dropout=fc_dropout, mano_pose_coeff=mano_pose_coeff, mano_root=mano_root, center_idx=mano_center_idx, use_shape=mano_use_shape, use_pca=mano_use_pca, ) self.mano_center_idx = mano_center_idx if self.mano_fhb_hand: load_fhb_path = f"assets/mano/fhb_skel_centeridx{mano_center_idx}.pkl" with open(load_fhb_path, "rb") as p_f: exp_data = pickle.load(p_f) self.register_buffer("fhb_shape", torch.Tensor(exp_data["shape"])) self.adaptor = ManoAdaptor(self.mano_branch.mano_layer_right, load_fhb_path) rec_freeze(self.adaptor) else: self.adaptor = None if (mano_lambda_verts2d or mano_lambda_verts3d or mano_lambda_joints3d or mano_lambda_joints2d or mano_lambda_recov_joints3d or mano_lambda_recov_verts3d): self.mano_lambdas = True else: self.mano_lambdas = False if obj_lambda_verts2d or obj_lambda_verts3d or obj_lambda_recov_verts3d: self.obj_lambdas = True else: self.obj_lambdas = False self.mano_loss = ManoLoss( lambda_verts3d=mano_lambda_verts3d, lambda_joints3d=mano_lambda_joints3d, lambda_shape=mano_lambda_shape, lambda_pose_reg=mano_lambda_pose_reg, ) # Store loss weights self.mano_lambda_joints2d = mano_lambda_joints2d self.mano_lambda_recov_joints3d = mano_lambda_recov_joints3d self.mano_lambda_recov_verts3d = mano_lambda_recov_verts3d self.mano_lambda_verts2d = mano_lambda_verts2d self.obj_lambda_verts2d = obj_lambda_verts2d self.obj_lambda_verts3d = obj_lambda_verts3d self.obj_lambda_recov_verts3d = obj_lambda_recov_verts3d self.obj_lambda_seg = 0.01 self.obj_lambda_vote = 0.1
class MeshRegNet(nn.Module): def __init__( self, fc_dropout=0, resnet_version=18, criterion2d="l2", mano_neurons=[512, 512], mano_comps=15, mano_use_shape=False, mano_lambda_pose_reg=0, mano_use_pca=True, mano_center_idx=9, mano_root="assets/mano", mano_lambda_joints3d=None, mano_lambda_recov_joints3d=None, mano_lambda_recov_verts3d=None, mano_lambda_joints2d=None, mano_lambda_verts3d=None, mano_lambda_verts2d=None, mano_lambda_shape=None, mano_pose_coeff: int = 1, mano_fhb_hand: bool = False, obj_lambda_verts3d=None, obj_lambda_verts2d=None, obj_lambda_recov_verts3d=None, obj_trans_factor=1, obj_scale_factor=1, inp_res=256, ): """ Args: mano_fhb_hand: Use pre-computed mapping from MANO joints to First Person Hand Action Benchmark hand skeleton mano_root (path): dir containing mano pickle files mano_neurons: number of neurons in each layer of base mano decoder mano_use_pca: predict pca parameters directly instead of rotation angles mano_comps (int): number of principal components to use if mano_use_pca mano_lambda_pca: weight to supervise hand pose in PCA space mano_lambda_pose_reg: weight to supervise hand pose in axis-angle space mano_lambda_verts: weight to supervise vertex distances mano_lambda_joints3d: weight to supervise distances adapt_atlas_decoder: add layer between encoder and decoder, usefull when finetuning from separately pretrained encoder and decoder """ super().__init__() self.inp_res = inp_res if int(resnet_version) == 18: img_feature_size = 512 base_net = resnet.resnet18(pretrained=True) elif int(resnet_version) == 50: img_feature_size = 2048 base_net = resnet.resnet50(pretrained=True) else: raise NotImplementedError( "Resnet {} not supported".format(resnet_version)) self.criterion2d = criterion2d mano_base_neurons = [img_feature_size] + mano_neurons self.mano_fhb_hand = mano_fhb_hand self.base_net = base_net # Predict translation and scaling for hand self.scaletrans_branch = AbsoluteBranch( base_neurons=[img_feature_size, int(img_feature_size / 2)], out_dim=3) # Predict translation, scaling and rotation for object self.scaletrans_branch_obj = AbsoluteBranch( base_neurons=[img_feature_size, int(img_feature_size / 2)], out_dim=6) # Initialize object branch self.obj_branch = ObjBranch(trans_factor=obj_trans_factor, scale_factor=obj_scale_factor) self.obj_scale_factor = obj_scale_factor self.obj_trans_factor = obj_trans_factor # Initialize mano branch self.mano_branch = ManoBranch( ncomps=mano_comps, base_neurons=mano_base_neurons, dropout=fc_dropout, mano_pose_coeff=mano_pose_coeff, mano_root=mano_root, center_idx=mano_center_idx, use_shape=mano_use_shape, use_pca=mano_use_pca, ) self.mano_center_idx = mano_center_idx if self.mano_fhb_hand: load_fhb_path = f"assets/mano/fhb_skel_centeridx{mano_center_idx}.pkl" with open(load_fhb_path, "rb") as p_f: exp_data = pickle.load(p_f) self.register_buffer("fhb_shape", torch.Tensor(exp_data["shape"])) self.adaptor = ManoAdaptor(self.mano_branch.mano_layer_right, load_fhb_path) rec_freeze(self.adaptor) else: self.adaptor = None if (mano_lambda_verts2d or mano_lambda_verts3d or mano_lambda_joints3d or mano_lambda_joints2d or mano_lambda_recov_joints3d or mano_lambda_recov_verts3d): self.mano_lambdas = True else: self.mano_lambdas = False if obj_lambda_verts2d or obj_lambda_verts3d or obj_lambda_recov_verts3d: self.obj_lambdas = True else: self.obj_lambdas = False self.mano_loss = ManoLoss( lambda_verts3d=mano_lambda_verts3d, lambda_joints3d=mano_lambda_joints3d, lambda_shape=mano_lambda_shape, lambda_pose_reg=mano_lambda_pose_reg, ) self.mano_lambda_joints2d = mano_lambda_joints2d self.mano_lambda_recov_joints3d = mano_lambda_recov_joints3d self.mano_lambda_recov_verts3d = mano_lambda_recov_verts3d self.mano_lambda_verts2d = mano_lambda_verts2d self.obj_lambda_verts2d = obj_lambda_verts2d self.obj_lambda_verts3d = obj_lambda_verts3d self.obj_lambda_recov_verts3d = obj_lambda_recov_verts3d def recover_mano( self, sample, features=None, pose=None, shape=None, no_loss=False, total_loss=None, scale=None, trans=None, ): # Get hand projection, centered mano_results = self.mano_branch(features, sides=sample[BaseQueries.SIDE], pose=pose, shape=shape) if self.adaptor: adapt_joints, _ = self.adaptor(mano_results["verts3d"]) adapt_joints = adapt_joints.transpose(1, 2) mano_results[ "joints3d"] = adapt_joints - adapt_joints[:, self. mano_center_idx].unsqueeze( 1) mano_results["verts3d"] = mano_results[ "verts3d"] - adapt_joints[:, self.mano_center_idx].unsqueeze(1) mano_results[ 'mano_adapt_trans'] = adapt_joints[:, self. mano_center_idx] # Save translation # print('Using adaptor', adapt_joints.shape) if not no_loss: mano_total_loss, mano_losses = self.mano_loss.compute_loss( mano_results, sample) if total_loss is None: total_loss = mano_total_loss else: total_loss += mano_total_loss mano_losses["mano_total_loss"] = mano_total_loss.clone() # Recover hand position in camera coordinates if (self.mano_lambda_joints2d or self.mano_lambda_verts2d or self.mano_lambda_recov_joints3d or self.mano_lambda_recov_verts3d): if scale is None and trans is None: scaletrans = self.scaletrans_branch(features) if trans is None: trans = scaletrans[:, 1:] if scale is None: scale = scaletrans[:, :1] final_trans = trans.unsqueeze(1) * self.obj_trans_factor final_scale = scale.view(scale.shape[0], 1, 1) * self.obj_scale_factor height, width = tuple(sample[TransQueries.IMAGE].shape[2:]) camintr = sample[TransQueries.CAMINTR].cuda() recov_joints3d, center3d = project.recover_3d_proj( mano_results["joints3d"], camintr, final_scale, final_trans, input_res=(width, height)) recov_hand_verts3d = mano_results["verts3d"] + center3d proj_joints2d = camproject.batch_proj2d(recov_joints3d, camintr) proj_verts2d = camproject.batch_proj2d( mano_results["verts3d"] + center3d, camintr) mano_results["joints2d"] = proj_joints2d mano_results["recov_joints3d"] = recov_joints3d mano_results["recov_handverts3d"] = recov_hand_verts3d mano_results["mano_center_trans"] = center3d mano_results["verts2d"] = proj_verts2d mano_results["hand_pretrans"] = trans mano_results["hand_prescale"] = scale mano_results["hand_trans"] = final_trans mano_results["hand_scale"] = final_scale if not no_loss: # Compute hand losses in pixel space and camera coordinates if self.mano_lambda_joints2d is not None and TransQueries.JOINTS2D in sample: gt_joints2d = sample[TransQueries.JOINTS2D].cuda().float() if self.criterion2d == "l2": # Normalize predictions in pixel space so that results are roughly centered # and have magnitude ~1 norm_joints2d_pred = normalize_pixel_out(proj_joints2d) norm_joints2d_gt = normalize_pixel_out(gt_joints2d) joints2d_loss = torch_f.mse_loss( norm_joints2d_pred, norm_joints2d_gt) elif self.criterion2d == "l1": joints2d_loss = torch_f.l1_loss( proj_joints2d, gt_joints2d) elif self.criterion2d == "smoothl1": joints2d_loss = torch_f.smooth_l1_loss( proj_joints2d, gt_joints2d) total_loss += self.mano_lambda_joints2d * joints2d_loss mano_losses["joints2d"] = joints2d_loss if self.mano_lambda_verts2d is not None and TransQueries.HANDVERTS2D in sample: gt_verts2d = sample[ TransQueries.HANDVERTS2D].cuda().float() verts2d_loss = torch_f.mse_loss( normalize_pixel_out(proj_verts2d, self.inp_res), normalize_pixel_out(gt_verts2d, self.inp_res), ) total_loss += self.mano_lambda_verts2d * verts2d_loss mano_losses["verts2d"] = verts2d_loss if self.mano_lambda_recov_joints3d is not None and BaseQueries.JOINTS3D in sample: joints3d_gt = sample[BaseQueries.JOINTS3D].cuda() recov_loss = torch_f.mse_loss(recov_joints3d, joints3d_gt) total_loss += self.mano_lambda_recov_joints3d * recov_loss mano_losses["recov_joint3d"] = recov_loss if self.mano_lambda_recov_verts3d is not None and BaseQueries.HANDVERTS3D in sample: hand_verts3d_gt = sample[BaseQueries.HANDVERTS3D].cuda() recov_loss = torch_f.mse_loss(recov_hand_verts3d, hand_verts3d_gt) total_loss += self.mano_lambda_recov_verts3d * recov_loss return mano_results, total_loss, mano_losses def recover_object(self, sample, features=None, total_loss=None, scale=None, trans=None, rotaxisang=None): """ Compute object vertex and corner positions in camera coordinates by predicting object translation and scaling, and recovering 3D positions given known object model """ if features is None: scaletrans_obj = None else: scaletrans_obj = self.scaletrans_branch_obj(features) obj_results = self.obj_branch(sample, scaletrans_obj, scale=scale, trans=trans, rotaxisang=rotaxisang) obj_losses = {} if self.criterion2d == "l2" and TransQueries.OBJVERTS2D in sample: obj2d_loss = torch_f.mse_loss( normalize_pixel_out(obj_results["obj_verts2d"], self.inp_res), normalize_pixel_out(sample[TransQueries.OBJVERTS2D].cuda(), self.inp_res), ) obj_losses["objverts2d"] = obj2d_loss total_loss += self.obj_lambda_verts2d * obj2d_loss elif self.criterion2d == "l1" and TransQueries.OBJVERTS2D in sample: obj2d_loss = torch_f.l1_loss( normalize_pixel_out(obj_results["obj_verts2d"], self.inp_res), normalize_pixel_out(sample[TransQueries.OBJVERTS2D].cuda(), self.inp_res), ) obj_losses["objverts2d"] = obj2d_loss total_loss += self.obj_lambda_verts2d * obj2d_loss elif self.criterion2d == "smoothl1" and TransQueries.OBJVERTS2D in sample: obj2d_loss = torch_f.smooth_l1_loss( normalize_pixel_out(obj_results["obj_verts2d"], self.inp_res), normalize_pixel_out(sample[TransQueries.OBJVERTS2D].cuda(), self.inp_res), ) obj_losses["objverts2d"] = obj2d_loss total_loss += self.obj_lambda_verts2d * obj2d_loss if TransQueries.OBJCANROTVERTS in sample: obj3d_loss = torch_f.smooth_l1_loss( obj_results["obj_verts3d"], sample[TransQueries.OBJCANROTVERTS].float().cuda()) obj_losses["objverts3d"] = obj3d_loss total_loss += self.obj_lambda_verts3d * obj3d_loss if self.obj_lambda_recov_verts3d is not None and BaseQueries.OBJVERTS3D in sample: objverts3d_gt = sample[BaseQueries.OBJVERTS3D].cuda() recov_verts3d = obj_results["recov_objverts3d"] obj_recov_loss = torch_f.mse_loss(recov_verts3d, objverts3d_gt) if total_loss is None: total_loss = self.obj_lambda_recov_verts3d * obj_recov_loss else: total_loss += self.obj_lambda_recov_verts3d * obj_recov_loss obj_losses["recov_objverts3d"] = obj_recov_loss return obj_results, total_loss, obj_losses def forward(self, sample, no_loss=False, step=0, preparams=None): total_loss = torch.Tensor([0]).cuda() results = {} losses = {} image = sample[TransQueries.IMAGE].cuda() features, _ = self.base_net(image) has_mano_super = one_query_in( sample.keys(), [ TransQueries.JOINTS3D, TransQueries.JOINTS2D, TransQueries.HANDVERTS2D, TransQueries.HANDVERTS3D, ], ) if has_mano_super and self.mano_lambdas: if preparams is not None: hand_scale = preparams["hand_prescale"] hand_pose = preparams["pose"] hand_shape = preparams["shape"] hand_trans = preparams["hand_pretrans"] else: hand_scale = None hand_pose = None hand_shape = None hand_trans = None mano_results, total_loss, mano_losses = self.recover_mano( sample, features=features, no_loss=no_loss, total_loss=total_loss, trans=hand_trans, scale=hand_scale, pose=hand_pose, shape=hand_shape, ) losses.update(mano_losses) results.update(mano_results) has_obj_super = one_query_in( sample.keys(), [TransQueries.OBJVERTS2D, TransQueries.OBJVERTS3D]) if has_obj_super and self.obj_lambdas: if preparams is not None: obj_scale = preparams["obj_prescale"] obj_rot = preparams["obj_prerot"] obj_trans = preparams["obj_pretrans"] else: obj_scale = None obj_rot = None obj_trans = None obj_results, total_loss, obj_losses = self.recover_object( sample, features, total_loss=total_loss, scale=obj_scale, trans=obj_trans, rotaxisang=obj_rot) losses.update(obj_losses) results.update(obj_results) if total_loss is not None: losses["total_loss"] = total_loss else: losses["total_loss"] = None return total_loss, results, losses