def params_rnet(self, dorig): rh_mesh = Meshes(verts=dorig['verts_rhand_f'], faces=self.rh_f).to( self.device).verts_normals_packed().view(-1, 778, 3) rh_mesh_gt = Meshes(verts=dorig['verts_rhand'], faces=self.rh_f).to( self.device).verts_normals_packed().view(-1, 778, 3) o2h_signed, h2o, _ = point2point_signed(dorig['verts_rhand_f'], dorig['verts_object'], rh_mesh) o2h_signed_gt, h2o_gt, _ = point2point_signed(dorig['verts_rhand'], dorig['verts_object'], rh_mesh_gt) h2o = h2o.abs() h2o_gt = h2o_gt.abs() return {'h2o_dist': h2o, 'h2o_gt': h2o_gt, 'o2h_gt': o2h_signed_gt}
def loss_rnet(self, dorig, drec, ds_name='train'): out_put = self.rhm_train(**drec) verts_rhand = out_put.vertices rh_mesh = Meshes(verts=verts_rhand, faces=self.rh_f).to( self.device).verts_normals_packed().view(-1, 778, 3) h2o_gt = dorig['h2o_gt'] o2h_signed, h2o, _ = point2point_signed(verts_rhand, dorig['verts_object'], rh_mesh) ######### dist loss loss_dist_h = 35 * (1. - self.cfg.kl_coef) * torch.mean( torch.einsum('ij,j->ij', torch.abs(h2o.abs() - h2o_gt.abs()), self.v_weights2)) ########## verts loss loss_mesh_rec_w = 20 * (1. - self.cfg.kl_coef) * torch.mean( torch.einsum( 'ijk,j->ijk', torch.abs( (dorig['verts_rhand'] - verts_rhand)), self.v_weights2)) ########## edge loss loss_edge = 10 * (1. - self.cfg.kl_coef) * self.LossL1( self.edges_for(verts_rhand, self.vpe), self.edges_for(dorig['verts_rhand'], self.vpe)) ########## loss_dict = { 'loss_edge_r': loss_edge, 'loss_mesh_rec_r': loss_mesh_rec_w, 'loss_dist_h_r': loss_dist_h, } loss_total = torch.stack(list(loss_dict.values())).sum() loss_dict['loss_total'] = loss_total return loss_total, loss_dict
def forward(self, h2o_dist, fpose_rhand_rotmat_f, trans_rhand_f, global_orient_rhand_rotmat_f, verts_object, **kwargs): bs = h2o_dist.shape[0] init_pose = fpose_rhand_rotmat_f[..., :2].reshape(bs, -1) init_rpose = global_orient_rhand_rotmat_f[..., :2].reshape(bs, -1) init_pose = torch.cat([init_rpose, init_pose], dim=1) init_trans = trans_rhand_f for i in range(self.n_iters): if i != 0: hand_parms = parms_decode(init_pose, init_trans) verts_rhand = self.rhm_train(**hand_parms).vertices _, h2o_dist, _ = point2point_signed(verts_rhand, verts_object) h2o_dist = self.bn1(h2o_dist) X0 = torch.cat([h2o_dist, init_pose, init_trans], dim=1) X = self.rb1(X0) X = self.dout(X) X = self.rb2(torch.cat([X, X0], dim=1)) X = self.dout(X) X = self.rb3(torch.cat([X, X0], dim=1)) X = self.dout(X) pose = self.out_p(X) trans = self.out_t(X) init_trans = init_trans + trans init_pose = init_pose + pose hand_parms = parms_decode(init_pose, init_trans) return hand_parms
def vis_results(dorig, coarse_net, refine_net, rh_model , save=False, save_dir = None): with torch.no_grad(): imw, imh = 1920, 780 cols = len(dorig['bps_object']) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') mvs = MeshViewers(window_width=imw, window_height=imh, shape=[1, cols], keepalive=True) drec_cnet = coarse_net.sample_poses(dorig['bps_object']) verts_rh_gen_cnet = rh_model(**drec_cnet).vertices _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device)) drec_cnet['trans_rhand_f'] = drec_cnet['transl'] drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(drec_cnet['global_orient']).view(-1, 3, 3) drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(-1, 15, 3, 3) drec_cnet['verts_object'] = dorig['verts_object'].to(device) drec_cnet['h2o_dist']= h2o.abs() drec_rnet = refine_net(**drec_cnet) verts_rh_gen_rnet = rh_model(**drec_rnet).vertices for cId in range(0, len(dorig['bps_object'])): try: from copy import deepcopy meshes = deepcopy(dorig['mesh_object']) obj_mesh = meshes[cId] except: obj_mesh = points_to_spheres(to_cpu(dorig['verts_object'][cId]), radius=0.002, vc=name_to_rgb['green']) hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]), f=rh_model.faces, vc=name_to_rgb['pink']) hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]), f=rh_model.faces, vc=name_to_rgb['gray']) if 'rotmat' in dorig: rotmat = dorig['rotmat'][cId].T obj_mesh = obj_mesh.rotate_vertices(rotmat) hand_mesh_gen_cnet.rotate_vertices(rotmat) hand_mesh_gen_rnet.rotate_vertices(rotmat) hand_mesh_gen_cnet.reset_face_normals() hand_mesh_gen_rnet.reset_face_normals() # mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh, blocking=True) mvs[0][cId].set_static_meshes([hand_mesh_gen_rnet,obj_mesh], blocking=True) if save: save_path = os.path.join(save_dir, str(cId)) makepath(save_path) hand_mesh_gen_rnet.write_ply(filename=save_path + '/rh_mesh_gen_%d.ply' % cId) obj_mesh[0].write_ply(filename=save_path + '/obj_mesh_%d.ply' % cId)
def get_meshes(dorig, coarse_net, refine_net, rh_model, save=False, save_dir=None): with torch.no_grad(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') drec_cnet = coarse_net.sample_poses(dorig['bps_object']) verts_rh_gen_cnet = rh_model(**drec_cnet).vertices _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device)) drec_cnet['trans_rhand_f'] = drec_cnet['transl'] drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(drec_cnet['global_orient']).view(-1, 3, 3) drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(-1, 15, 3, 3) drec_cnet['verts_object'] = dorig['verts_object'].to(device) drec_cnet['h2o_dist'] = h2o.abs() drec_rnet = refine_net(**drec_cnet) verts_rh_gen_rnet = rh_model(**drec_rnet).vertices gen_meshes = [] for cId in range(0, len(dorig['bps_object'])): try: obj_mesh = dorig['mesh_object'][cId] except: obj_mesh = points2sphere(points=to_cpu(dorig['verts_object'][cId]), radius=0.002, vc=name_to_rgb['yellow']) hand_mesh_gen_rnet = Mesh(vertices=to_cpu(verts_rh_gen_rnet[cId]), faces=rh_model.faces, vc=[245, 191, 177]) if 'rotmat' in dorig: rotmat = dorig['rotmat'][cId].T obj_mesh = obj_mesh.rotate_vertices(rotmat) hand_mesh_gen_rnet.rotate_vertices(rotmat) gen_meshes.append([obj_mesh, hand_mesh_gen_rnet]) if save: save_path = os.path.join(save_dir, str(cId)) makepath(save_path) hand_mesh_gen_rnet.export(filename=save_path + '/rh_mesh_gen_%d.ply' % cId) obj_mesh.export(filename=save_path + '/obj_mesh_%d.ply' % cId) return gen_meshes
def loss_cnet(self, dorig, drec, ds_name='train'): device = dorig['verts_rhand'].device dtype = dorig['verts_rhand'].dtype q_z = torch.distributions.normal.Normal(drec['mean'], drec['std']) out_put = self.rhm_train(**drec) verts_rhand = out_put.vertices rh_mesh = Meshes(verts=verts_rhand, faces=self.rh_f).to( self.device).verts_normals_packed().view(-1, 778, 3) rh_mesh_gt = Meshes(verts=dorig['verts_rhand'], faces=self.rh_f).to( self.device).verts_normals_packed().view(-1, 778, 3) o2h_signed, h2o, _ = point2point_signed(verts_rhand, dorig['verts_object'], rh_mesh) o2h_signed_gt, h2o_gt, o2h_idx = point2point_signed( dorig['verts_rhand'], dorig['verts_object'], rh_mesh_gt) # addaptive weight for penetration and contact verts w_dist = (o2h_signed_gt < 0.01) * (o2h_signed_gt > -0.005) w_dist_neg = o2h_signed < 0. w = self.w_dist.clone() w[~w_dist] = .1 # less weight for far away vertices w[w_dist_neg] = 1.5 # more weight for penetration ######### dist loss loss_dist_h = 35 * (1. - self.cfg.kl_coef) * torch.mean( torch.einsum('ij,j->ij', torch.abs(h2o.abs() - h2o_gt.abs()), self.v_weights2)) loss_dist_o = 30 * (1. - self.cfg.kl_coef) * torch.mean( torch.einsum('ij,ij->ij', torch.abs(o2h_signed - o2h_signed_gt), w)) ########## verts loss loss_mesh_rec_w = 35 * (1. - self.cfg.kl_coef) * torch.mean( torch.einsum( 'ijk,j->ijk', torch.abs( (dorig['verts_rhand'] - verts_rhand)), self.v_weights)) ########## edge loss loss_edge = 30 * (1. - self.cfg.kl_coef) * self.LossL1( self.edges_for(verts_rhand, self.vpe), self.edges_for(dorig['verts_rhand'], self.vpe)) ########## KL loss p_z = torch.distributions.normal.Normal( loc=torch.tensor(np.zeros([self.cfg.batch_size, self.cfg.latentD]), requires_grad=False).to(device).type(dtype), scale=torch.tensor(np.ones([self.cfg.batch_size, self.cfg.latentD]), requires_grad=False).to(device).type(dtype)) loss_kl = self.cfg.kl_coef * torch.mean( torch.sum(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1])) ########## loss_dict = { 'loss_kl': loss_kl, 'loss_edge': loss_edge, 'loss_mesh_rec': loss_mesh_rec_w, 'loss_dist_h': loss_dist_h, 'loss_dist_o': loss_dist_o, } loss_total = torch.stack(list(loss_dict.values())).sum() loss_dict['loss_total'] = loss_total return loss_total, loss_dict
def vis_results(ho, dorig, coarse_net, refine_net, rh_model, save=False, save_dir=None, rh_model_pkl=None, vis=True): # with torch.no_grad(): imw, imh = 1920, 780 cols = len(dorig['bps_object']) # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cpu') if vis: mvs = MeshViewers(window_width=imw, window_height=imh, shape=[1, cols], keepalive=True) # drec_cnet = coarse_net.sample_poses(dorig['bps_object']) # # for k in drec_cnet.keys(): # print('drec cnet', k, drec_cnet[k].shape) # verts_rh_gen_cnet = rh_model(**drec_cnet).vertices drec_cnet = {} hand_pose_in = torch.Tensor(ho.hand_pose[3:]).unsqueeze(0) mano_out_1 = rh_model_pkl(hand_pose=hand_pose_in) hand_pose_in = mano_out_1.hand_pose mTc = torch.Tensor(ho.hand_mTc) approx_global_orient = rotmat2aa(mTc[:3, :3].unsqueeze(0)) if torch.isnan(approx_global_orient).any(): # Using honnotate? approx_global_orient = torch.Tensor(ho.hand_pose[:3]).unsqueeze(0) approx_global_orient = approx_global_orient.squeeze(1).squeeze(1) approx_trans = mTc[:3, 3].unsqueeze(0) target_verts = torch.Tensor(ho.hand_verts).unsqueeze(0) pose, trans, rot = util.opt_hand(rh_model, target_verts, hand_pose_in, approx_trans, approx_global_orient) # drec_cnet['hand_pose'] = torch.einsum('bi,ij->bj', [hand_pose_in, rh_model_pkl.hand_components]) drec_cnet['transl'] = trans drec_cnet['global_orient'] = rot drec_cnet['hand_pose'] = pose verts_rh_gen_cnet = rh_model(**drec_cnet).vertices _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device)) drec_cnet['trans_rhand_f'] = drec_cnet['transl'] drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat( drec_cnet['global_orient']).view(-1, 3, 3) drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view( -1, 15, 3, 3) drec_cnet['verts_object'] = dorig['verts_object'].to(device) drec_cnet['h2o_dist'] = h2o.abs() print( 'Hand fitting err', np.linalg.norm( verts_rh_gen_cnet.squeeze().detach().numpy() - ho.hand_verts, 2, 1).mean()) orig_obj = dorig['mesh_object'][0].v # print(orig_obj.shape, orig_obj) # print('Obj fitting err', np.linalg.norm(orig_obj - ho.obj_verts, 2, 1).mean()) drec_rnet = refine_net(**drec_cnet) mano_out = rh_model(**drec_rnet) verts_rh_gen_rnet = mano_out.vertices joints_out = mano_out.joints if vis: for cId in range(0, len(dorig['bps_object'])): try: from copy import deepcopy meshes = deepcopy(dorig['mesh_object']) obj_mesh = meshes[cId] except: obj_mesh = points_to_spheres(to_cpu( dorig['verts_object'][cId]), radius=0.002, vc=name_to_rgb['green']) hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]), f=rh_model.faces, vc=name_to_rgb['pink']) hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]), f=rh_model.faces, vc=name_to_rgb['gray']) if 'rotmat' in dorig: rotmat = dorig['rotmat'][cId].T obj_mesh = obj_mesh.rotate_vertices(rotmat) hand_mesh_gen_cnet.rotate_vertices(rotmat) hand_mesh_gen_rnet.rotate_vertices(rotmat) # print('rotmat', rotmat) hand_mesh_gen_cnet.reset_face_normals() hand_mesh_gen_rnet.reset_face_normals() # mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh, blocking=True) # mvs[0][cId].set_static_meshes([hand_mesh_gen_rnet,obj_mesh], blocking=True) mvs[0][cId].set_static_meshes( [hand_mesh_gen_rnet, hand_mesh_gen_cnet, obj_mesh], blocking=True) if save: save_path = os.path.join(save_dir, str(cId)) makepath(save_path) hand_mesh_gen_rnet.write_ply(filename=save_path + '/rh_mesh_gen_%d.ply' % cId) obj_mesh[0].write_ply(filename=save_path + '/obj_mesh_%d.ply' % cId) return verts_rh_gen_rnet, joints_out
def get_meshes(dorig, coarse_net, refine_net, rh_model, save=False, save_dir=None): with torch.no_grad(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') drec_cnet = coarse_net.sample_poses(dorig['bps_object']) output = rh_model(**drec_cnet) verts_rh_gen_cnet = output.vertices _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device)) drec_cnet['trans_rhand_f'] = drec_cnet['transl'] drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat( drec_cnet['global_orient']).view(-1, 3, 3) drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat( drec_cnet['hand_pose']).view(-1, 15, 3, 3) drec_cnet['verts_object'] = dorig['verts_object'].to(device) drec_cnet['h2o_dist'] = h2o.abs() drec_rnet = refine_net(**drec_cnet) output = rh_model(**drec_rnet) print("hand shape {} should be idtenty".format(output.betas)) verts_rh_gen_rnet = output.vertices # Reorder joints to match visualization utilities (joint_mapper) (TODO) joints_rh_gen_rnet = output.joints # [:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] transforms_rh_gen_rnet = output.transforms # [:, [0, 13, 14, 15, 1, 2, 3, 4, 5, 6, 10, 11, 12, 7, 8, 9]] joints_rh_gen_rnet = to_cpu(joints_rh_gen_rnet) transforms_rh_gen_rnet = to_cpu(transforms_rh_gen_rnet) gen_meshes = [] for cId in range(0, len(dorig['bps_object'])): try: obj_mesh = dorig['mesh_object'][cId] except: obj_mesh = points2sphere(points=to_cpu( dorig['verts_object'][cId]), radius=0.002, vc=[145, 191, 219]) hand_mesh_gen_rnet = Mesh(vertices=to_cpu(verts_rh_gen_rnet[cId]), faces=rh_model.faces, vc=[145, 191, 219]) hand_joint_gen_rnet = joints_rh_gen_rnet[cId] hand_transform_gen_rnet = transforms_rh_gen_rnet[cId] if 'rotmat' in dorig: rotmat = dorig['rotmat'][cId].T obj_mesh = obj_mesh.rotate_vertices(rotmat) hand_mesh_gen_rnet.rotate_vertices(rotmat) hand_joint_gen_rnet = hand_joint_gen_rnet @ rotmat.T hand_transform_gen_rnet[:, :, :3, :3] = np.matmul( rotmat[None, ...], hand_transform_gen_rnet[:, :, :3, :3]) gen_meshes.append([obj_mesh, hand_mesh_gen_rnet]) if save: makepath(save_dir) print("saving dir {}".format(save_dir)) np.save(save_dir + '/joints_%d.npy' % cId, hand_joint_gen_rnet) np.save(save_dir + '/trans_%d.npy' % cId, hand_transform_gen_rnet) return gen_meshes