def get_vposer_mean_pose(vposer_dir="misc/vposer_dir"): pose_embedding = torch.zeros([1, 32]) vposer, _ = load_vposer(vposer_dir, vp_model="snapshot") vposer.eval() mean_pose = (vposer.decode(pose_embedding, output_type="aa").contiguous().view(1, -1)) return mean_pose
def __init__(self, trainconfig, lossconfig): for key, val in trainconfig.items(): setattr(self, key, val) for key, val in lossconfig.items(): setattr(self, key, val) if not os.path.exists(self.save_dir): os.mkdir(self.save_dir) if len(self.ckp_dir) > 0: self.resume_training=True ### define model if self.use_cont_rot: n_dim_body=72+3 else: n_dim_body=72 self.model_h_latentD = 256 self.model_h = HumanCVAES2(latentD_g=self.model_h_latentD, latentD_l=self.model_h_latentD, scene_model_ckpt=self.scene_model_ckpt, n_dim_body=n_dim_body, n_dim_scene=self.model_h_latentD) self.optimizer_h = optim.Adam(self.model_h.parameters(), lr=self.init_lr_h) ### body mesh model self.vposer, _ = load_vposer(self.vposer_ckpt_path, vp_model='snapshot') self.body_mesh_model = smplx.create(self.human_model_path, model_type='smplx', gender='neutral', ext='npz', num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=True, batch_size=self.batch_size ) self.smplx_face_idx = np.load(os.path.join(self.human_model_path, 'smplx/SMPLX_NEUTRAL.npz'), allow_pickle=True)['f'].reshape(-1,3) self.smplx_face_idx = torch.tensor(self.smplx_face_idx.astype(np.int64), device=self.device) print('--[INFO] device: '+str(torch.cuda.get_device_name(self.device)) )
def __init__(self, fittingconfig, lossconfig): for key, val in fittingconfig.items(): setattr(self, key, val) for key, val in lossconfig.items(): setattr(self, key, val) self.vposer, _ = load_vposer(self.vposer_ckpt_path, vp_model='snapshot') self.body_mesh_model = smplx.create(self.human_model_path, model_type='smplx', gender='neutral', ext='npz', num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=True, batch_size=self.batch_size ) self.vposer.to(self.device) self.body_mesh_model.to(self.device) self.xhr_rec = Variable(torch.randn(1,75).to(self.device), requires_grad=True) self.optimizer = optim.Adam([self.xhr_rec], lr=self.init_lr_h) ## read scene sdf with open(self.scene_sdf_path+'.json') as f: sdf_data = json.load(f) grid_min = np.array(sdf_data['min']) grid_max = np.array(sdf_data['max']) grid_dim = sdf_data['dim'] sdf = np.load(self.scene_sdf_path + '_sdf.npy').reshape(grid_dim, grid_dim, grid_dim) self.s_grid_min_batch = torch.tensor(grid_min, dtype=torch.float32, device=self.device).unsqueeze(0) self.s_grid_max_batch = torch.tensor(grid_max, dtype=torch.float32, device=self.device).unsqueeze(0) self.s_sdf_batch = torch.tensor(sdf, dtype=torch.float32, device=self.device).unsqueeze(0) ## read scene vertices scene_o3d = o3d.io.read_triangle_mesh(self.scene_verts_path) scene_verts = np.asarray(scene_o3d.vertices) self.s_verts_batch = torch.tensor(scene_verts, dtype=torch.float32, device=self.device).unsqueeze(0)
def sample_vposer(expr_dir, bm, num_samples=5, vp_model='snapshot'): from human_body_prior.tools.omni_tools import id_generator, makepath from human_body_prior.tools.model_loader import load_vposer from human_body_prior.tools.omni_tools import copy2cpu vposer_pt, ps = load_vposer(expr_dir, vp_model=vp_model) sampled_pose_body = copy2cpu(vposer_pt.sample_poses(num_poses=num_samples)) out_dir = makepath( os.path.join(ps.work_dir, 'evaluations', 'pose_generation')) out_imgpath = os.path.join(out_dir, '%s.png' % id_generator(6)) dump_vposer_samples(bm, sampled_pose_body, out_imgpath) print('Dumped samples at %s' % out_dir) return sampled_pose_body
def __init__( self, debug=True, batch_size=0, hand_pca_nb=6, head_center_idx=8949, smpl_root="assets/models", vposer_dir="assets/vposer", vposer_dim=32, parts_path="assets/models/smplx/smplx_parts_segm.pkl", mano_corresp_path="assets/models/MANO_SMPLX_vertex_ids.pkl", ): super().__init__() self.debug = debug self.hand_pca_nb = hand_pca_nb self.head_center_idx = head_center_idx self.smplx_vertex_nb = 10475 self.vposer_dim = vposer_dim # Initialize SMPL-X model self.smpl_model = vposeutils.get_smplx(model_root=smpl_root, vposer_dir=vposer_dir) self.smpl_f = self.smpl_model.faces # Get vposer self.vposer = load_vposer(vposer_dir, vp_model="snapshot")[0] self.vposer.eval() # Translate human so that head is at camera level self.set_head2cam_trans() self.armfaces = smplvis.filter_parts(self.smpl_f, parts_path) with open(mano_corresp_path, "rb") as p_f: self.mano_corresp = pickle.load(p_f) # Initialize model parameters self.batch_size = batch_size left_hand_pose = self.smpl_model.left_hand_pose.repeat(batch_size, 1) right_hand_pose = self.smpl_model.right_hand_pose.repeat(batch_size, 1) pose_embedding = ( self.get_neutral_pose_embedding().unsqueeze(0).repeat( batch_size, 1)) self.pose_embedding = torch.nn.Parameter(pose_embedding, requires_grad=True) self.left_hand_pose = torch.nn.Parameter(left_hand_pose, requires_grad=True) self.right_hand_pose = torch.nn.Parameter(right_hand_pose, requires_grad=True)
def main(**args): input_media = args.pop('input_media') config = easy_configuration.configure(**args) device = torch.device('cpu') dtype=torch.float32 body_model = config['neutral_model'] body_model.to(device= device) camera = config['camera'] vposer_ckpt = args['vposer_ckpt'] vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() viewer = MeshViewer() body_color=(1.0, 1.0, 0.9, 1.0) img_np = cv2.imread(input_media) img = get_img(input_media) # Keypoints for the first person keypoints = get_keypoints(img_np)[[0]] keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] gt_joints = gt_joints.to(device=device, dtype=dtype) def render_embedding(vposer, pose_embedding, body_model, viewer): body_pose = vposer.decode(pose_embedding, output_type='aa').view(1, -1) body_pose.to(device=device) body_model_output = body_model(body_pose=body_pose) vertices = body_model_output.vertices.detach().cpu().numpy().squeeze() viewer.update_mesh(vertices, body_model.faces) while True: pose_embedding = torch.randn([1,32], dtype=torch.float32, device=device) * 10929292929 render_embedding(vposer, pose_embedding, body_model, viewer)
def __init__(self, testconfig): for key, val in testconfig.items(): setattr(self, key, val) if not os.path.exists(self.ckpt_dir): print('--[ERROR] checkpoints do not exist') sys.exit() #define model if self.use_cont_rot: n_dim_body = 72 + 3 else: n_dim_body = 72 self.model_h_latentD = 256 self.model_h = HumanCVAES2(latentD_g=self.model_h_latentD, latentD_l=self.model_h_latentD, n_dim_body=n_dim_body, n_dim_scene=self.model_h_latentD, test=True) ### body mesh model self.vposer, _ = load_vposer(self.vposer_ckpt_path, vp_model='snapshot') self.body_mesh_model = smplx.create(self.human_model_path, model_type='smplx', gender='neutral', ext='npz', num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=True, batch_size=self.n_samples)
def load_avakhitov_fits(dp, load_betas=True, load_body_poses=True, load_expressions=False, load_fid_lst=True): result = dict() for flag, k, fn_no_ext in [ [load_betas, 'betas', 'betas'], [load_body_poses, 'body_poses', 'poses'], [load_expressions, 'expressions', 'expressions'], [load_fid_lst, 'fid_lst', 'fid_lst'] ]: if flag: load_fp = osp.join(dp, f'{fn_no_ext}.npy') try: loaded = np.load(load_fp) except: print(load_fp) raise Exception() if fn_no_ext == 'poses': #load the vposer model if loaded.shape[1] == 69: pose_body = loaded[:, 0:32] else: vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer.eval() pose_body_vp = torch.tensor(loaded[:, 0:32]) #convert from vposer to rotation matrices pose_body_mats = vposer.decode(pose_body_vp).reshape(len(loaded), -1, 3, 3).detach().cpu().numpy() pose_body = np.zeros((pose_body_mats.shape[0], 63)) for i in range(0, pose_body_mats.shape[0]): for j in range(0, pose_body_mats.shape[1]): rot_vec, jac = cv2.Rodrigues(pose_body_mats[i,j]) pose_body[i, 3*j : 3*j+3] = rot_vec.reshape(-1) result[k] = pose_body result['global_rvecs'] = loaded[:, -3:] result['global_tvecs'] = loaded[:, -6:-3] result['n'] = len(loaded) else: result[k] = loaded return result
num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=True, batch_size=batch_size ).to(device) print('[INFO] smplx model loaded.') vposer_model, _ = load_vposer(vposer_model_path, vp_model='snapshot') vposer_model = vposer_model.to(device) print('[INFO] vposer model loaded') ######## set body mesh gen dataloader ########### scene_name = 'MPH1Library' dataset = PreprocessLoader(scene_name=scene_name) dataset.load_body_params(proxd_path) dataset.load_cam_params(cam2world_path) dataset.load_scene(scene_mesh_file=scene_mesh_path) bps_dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True) # drop_last=T: smlpx model needs to predefine bs ######### get body mesh for all samples ########### body_verts_list = []
**args) model = smplx.create(**model_params) model = model.to(device=device) batch_size = args.get('batch_size', 1) use_vposer = args.get('use_vposer', True) vposer, pose_embedding = [None, ] * 2 vposer_ckpt = args.get('vposer_ckpt', '') if use_vposer: pose_embedding = torch.zeros([batch_size, 32], dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() for pkl_path in pkl_paths: with open(pkl_path, 'rb') as f: data = pickle.load(f, encoding='latin1') if use_vposer: with torch.no_grad(): pose_embedding[:] = torch.tensor( data['body_pose'], device=device, dtype=dtype) est_params = {} for key, val in data.items(): if key == 'body_pose' and use_vposer: body_pose = vposer.decode(
def optimize_visulize(): # read scene mesh, scene sdf scene, cur_scene_verts, s_grid_min_batch, s_grid_max_batch, s_sdf_batch = read_mesh_sdf( args.dataset_path, args.dataset, args.scene_name) smplx_model = smplx.create(args.smplx_model_path, model_type='smplx', gender='neutral', ext='npz', num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=True, batch_size=1).to(device) print('[INFO] smplx model loaded.') vposer_model, _ = load_vposer(args.vposer_model_path, vp_model='snapshot') vposer_model = vposer_model.to(device) print('[INFO] vposer model loaded') ##################### load optimization results ################## shift_list = np.load('{}/{}/shift_list.npy'.format( args.optimize_result_dir, args.scene_name)) rot_angle_list_1 = np.load('{}/{}/rot_angle_list_1.npy'.format( args.optimize_result_dir, args.scene_name)) if args.optimize: body_params_opt_list_s1 = np.load( '{}/{}/body_params_opt_list_s1.npy'.format( args.optimize_result_dir, args.scene_name)) body_params_opt_list_s2 = np.load( '{}/{}/body_params_opt_list_s2.npy'.format( args.optimize_result_dir, args.scene_name)) body_verts_sample_list = np.load('{}/{}/body_verts_sample_list.npy'.format( args.optimize_result_dir, args.scene_name)) n_sample = len(body_verts_sample_list) ########################## evaluation (contact/collision score) ######################### loss_non_collision_sample, loss_contact_sample = 0, 0 loss_non_collision_opt_s1, loss_contact_opt_s1 = 0, 0 loss_non_collision_opt_s2, loss_contact_opt_s2 = 0, 0 body_params_prox_list_s1, body_params_prox_list_s2 = [], [] body_verts_opt_prox_s2_list = [] for cnt in tqdm(range(0, n_sample)): body_verts_sample = body_verts_sample_list[cnt] # [10475, 3] # smplx params --> body mesh body_params_opt_s1 = torch.from_numpy( body_params_opt_list_s1[cnt]).float().unsqueeze(0).to( device) # [1,75] body_params_opt_s1 = convert_to_3D_rot( body_params_opt_s1) # tensor, [bs=1, 72] body_pose_joint = vposer_model.decode(body_params_opt_s1[:, 16:48], output_type='aa').view( 1, -1) # [1, 63] body_verts_opt_s1 = gen_body_mesh(body_params_opt_s1, body_pose_joint, smplx_model)[0] # [n_body_vert, 3] body_verts_opt_s1 = body_verts_opt_s1.detach().cpu().numpy() body_params_opt_s2 = torch.from_numpy( body_params_opt_list_s2[cnt]).float().unsqueeze(0).to(device) body_params_opt_s2 = convert_to_3D_rot( body_params_opt_s2) # tensor, [bs=1, 72] body_pose_joint = vposer_model.decode(body_params_opt_s2[:, 16:48], output_type='aa').view(1, -1) body_verts_opt_s2 = gen_body_mesh(body_params_opt_s2, body_pose_joint, smplx_model)[0] body_verts_opt_s2 = body_verts_opt_s2.detach().cpu().numpy() ####################### transfrom local body verts to prox coodinate system #################### # generated body verts from cvae, before optimization body_verts_sample_prox = np.zeros( body_verts_sample.shape) # [10475, 3] temp = body_verts_sample - shift_list[cnt] body_verts_sample_prox[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \ temp[:, 1] * math.sin(-rot_angle_list_1[cnt]) body_verts_sample_prox[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \ temp[:, 1] * math.cos(-rot_angle_list_1[cnt]) body_verts_sample_prox[:, 2] = temp[:, 2] ######### optimized body verts trans_matrix_1 = np.array([[ math.cos(-rot_angle_list_1[cnt]), -math.sin(-rot_angle_list_1[cnt]), 0, 0 ], [ math.sin(-rot_angle_list_1[cnt]), math.cos(-rot_angle_list_1[cnt]), 0, 0 ], [0, 0, 1, 0], [0, 0, 0, 1]]) trans_matrix_2 = np.array([[1, 0, 0, -shift_list[cnt][0]], [0, 1, 0, -shift_list[cnt][1]], [0, 0, 1, -shift_list[cnt][2]], [0, 0, 0, 1]]) ### stage 1: simple optimization results body_verts_opt_prox_s1 = np.zeros( body_verts_opt_s1.shape) # [10475, 3] temp = body_verts_opt_s1 - shift_list[cnt] body_verts_opt_prox_s1[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \ temp[:, 1] * math.sin(-rot_angle_list_1[cnt]) body_verts_opt_prox_s1[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \ temp[:, 1] * math.cos(-rot_angle_list_1[cnt]) body_verts_opt_prox_s1[:, 2] = temp[:, 2] # transfrom local params to prox coordinate system body_params_prox_s1 = update_globalRT_for_smplx( body_params_opt_s1[0].cpu().numpy(), smplx_model, trans_matrix_2) # [72] body_params_prox_s1 = update_globalRT_for_smplx( body_params_prox_s1, smplx_model, trans_matrix_1) # [72] body_params_prox_list_s1.append(body_params_prox_s1) ### stage 2: advanced optimiation results body_verts_opt_prox_s2 = np.zeros( body_verts_opt_s2.shape) # [10475, 3] temp = body_verts_opt_s2 - shift_list[cnt] body_verts_opt_prox_s2[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \ temp[:, 1] * math.sin(-rot_angle_list_1[cnt]) body_verts_opt_prox_s2[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \ temp[:, 1] * math.cos(-rot_angle_list_1[cnt]) body_verts_opt_prox_s2[:, 2] = temp[:, 2] # transfrom local params to prox coordinate system body_params_prox_s2 = update_globalRT_for_smplx( body_params_opt_s2[0].cpu().numpy(), smplx_model, trans_matrix_2) # [72] body_params_prox_s2 = update_globalRT_for_smplx( body_params_prox_s2, smplx_model, trans_matrix_1) # [72] body_params_prox_list_s2.append(body_params_prox_s2) body_verts_opt_prox_s2_list.append(body_verts_opt_prox_s2) ########################### visualization ########################## if args.visualize: body_mesh_sample = o3d.geometry.TriangleMesh() body_mesh_sample.vertices = o3d.utility.Vector3dVector( body_verts_sample_prox) body_mesh_sample.triangles = o3d.utility.Vector3iVector( smplx_model.faces) body_mesh_sample.compute_vertex_normals() body_mesh_opt_s1 = o3d.geometry.TriangleMesh() body_mesh_opt_s1.vertices = o3d.utility.Vector3dVector( body_verts_opt_prox_s1) body_mesh_opt_s1.triangles = o3d.utility.Vector3iVector( smplx_model.faces) body_mesh_opt_s1.compute_vertex_normals() body_mesh_opt_s2 = o3d.geometry.TriangleMesh() body_mesh_opt_s2.vertices = o3d.utility.Vector3dVector( body_verts_opt_prox_s2) body_mesh_opt_s2.triangles = o3d.utility.Vector3iVector( smplx_model.faces) body_mesh_opt_s2.compute_vertex_normals() o3d.visualization.draw_geometries( [scene, body_mesh_sample]) # generated body mesh by cvae o3d.visualization.draw_geometries([scene, body_mesh_opt_s1 ]) # simple-optimized body mesh o3d.visualization.draw_geometries([scene, body_mesh_opt_s2 ]) # adv-optimizaed body mesh ##################### compute non-collision/contact score ############## # body verts before optimization body_verts_sample_prox_tensor = torch.from_numpy( body_verts_sample_prox).float().unsqueeze(0).to( device) # [1, 10475, 3] norm_verts_batch = (body_verts_sample_prox_tensor - s_grid_min_batch ) / (s_grid_max_batch - s_grid_min_batch) * 2 - 1 body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1), norm_verts_batch[:, :, [2, 1, 0]].view( -1, 10475, 1, 1, 3), padding_mode='border') if body_sdf_batch.lt(0).sum().item( ) < 1: # if no interpenetration: negative sdf entries is less than one loss_non_collision_sample += 1.0 loss_contact_sample += 0.0 else: loss_non_collision_sample += (body_sdf_batch > 0).sum().float().item() / 10475.0 loss_contact_sample += 1.0 # stage 1: simple optimization results body_verts_opt_prox_tensor = torch.from_numpy( body_verts_opt_prox_s1).float().unsqueeze(0).to( device) # [1, 10475, 3] norm_verts_batch = (body_verts_opt_prox_tensor - s_grid_min_batch) / ( s_grid_max_batch - s_grid_min_batch) * 2 - 1 body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1), norm_verts_batch[:, :, [2, 1, 0]].view( -1, 10475, 1, 1, 3), padding_mode='border') if body_sdf_batch.lt(0).sum().item( ) < 1: # if no interpenetration: negative sdf entries is less than one loss_non_collision_opt_s1 += 1.0 loss_contact_opt_s1 += 0.0 else: loss_non_collision_opt_s1 += (body_sdf_batch > 0).sum().float().item() / 10475.0 loss_contact_opt_s1 += 1.0 # stage 2: advanced optimization results body_verts_opt_prox_tensor = torch.from_numpy( body_verts_opt_prox_s2).float().unsqueeze(0).to( device) # [1, 10475, 3] norm_verts_batch = (body_verts_opt_prox_tensor - s_grid_min_batch) / ( s_grid_max_batch - s_grid_min_batch) * 2 - 1 body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1), norm_verts_batch[:, :, [2, 1, 0]].view( -1, 10475, 1, 1, 3), padding_mode='border') if body_sdf_batch.lt(0).sum().item( ) < 1: # if no interpenetration: negative sdf entries is less than one loss_non_collision_opt_s2 += 1.0 loss_contact_opt_s2 += 0.0 else: loss_non_collision_opt_s2 += (body_sdf_batch > 0).sum().float().item() / 10475.0 loss_contact_opt_s2 += 1.0 print('scene:', args.scene_name) loss_non_collision_sample = loss_non_collision_sample / n_sample loss_contact_sample = loss_contact_sample / n_sample print('w/o optimization body: non_collision score:', loss_non_collision_sample) print('w/o optimization body: contact score:', loss_contact_sample) loss_non_collision_opt_s1 = loss_non_collision_opt_s1 / n_sample loss_contact_opt_s1 = loss_contact_opt_s1 / n_sample print('optimized body s1: non_collision score:', loss_non_collision_opt_s1) print('optimized body s1: contact score:', loss_contact_opt_s1) loss_non_collision_opt_s2 = loss_non_collision_opt_s2 / n_sample loss_contact_opt_s2 = loss_contact_opt_s2 / n_sample print('optimized body s2: non_collision score:', loss_non_collision_opt_s2) print('optimized body s2: contact score:', loss_contact_opt_s2)
def fit_single_frame(img, keypoints, init_trans, scan, scene_name, body_model, camera, joint_weights, body_pose_prior, jaw_prior, left_hand_prior, right_hand_prior, shape_prior, expr_prior, angle_prior, result_fn='out.pkl', mesh_fn='out.obj', body_scene_rendering_fn='body_scene.png', out_img_fn='overlay.png', loss_type='smplify', use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=True, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=True, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn='', focal_length_x=5000., focal_length_y=5000., side_view_thsh=25., rho=100, vposer_latent_dim=32, vposer_ckpt='', use_joints_conf=False, interactive=True, visualize=False, save_meshes=True, degrees=None, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, #################### ### PROX render_results=True, camera_mode='moving', ## Depth s2m=False, s2m_weights=None, m2s=False, m2s_weights=None, rho_s2m=1, rho_m2s=1, init_mode=None, trans_opt_stages=None, viz_mode='mv', #penetration sdf_penetration=False, sdf_penetration_weights=0.0, sdf_dir=None, cam2world_dir=None, #contact contact=False, rho_contact=1.0, contact_loss_weights=None, contact_angle=15, contact_body_parts=None, body_segments_dir=None, load_scene=False, scene_dir=None, height=None, weight=None, gender='male', weight_w=0, height_w=0, **kwargs): if kwargs['optim_type'] == 'lbfgsls': assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' batch_size = keypoints.shape[0] body_model.reset_params() body_model.transl.requires_grad = True device = torch.device('cuda') if use_cuda else torch.device('cpu') # if visualize: # pil_img.fromarray((img * 255).astype(np.uint8)).show() if degrees is None: degrees = [0, 90, 180, 270] if data_weights is None: data_weights = [1, ] * 5 if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] msg = ( 'Number of Body pose prior weights {}'.format( len(body_pose_prior_weights)) + ' does not match the number of data term weights {}'.format( len(data_weights))) assert (len(data_weights) == len(body_pose_prior_weights)), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights does not match the' + ' number of hand pose prior weights') assert (len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of hand joint distance weights') assert (len(hand_joints_weights) == len(body_pose_prior_weights)), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Shape prior weights = {}') assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format( len(shape_weights), len(body_pose_prior_weights)) if use_face: if jaw_pose_prior_weights is None: jaw_pose_prior_weights = [[x] * 3 for x in shape_weights] else: jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')), jaw_pose_prior_weights) jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights] msg = ('Number of Body pose prior weights does not match the' + ' number of jaw pose prior weights') assert (len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg if expr_weights is None: expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Expression prior weights = {}') assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format( len(body_pose_prior_weights), len(expr_weights)) if face_joints_weights is None: face_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of face joint distance weights') assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ('Number of Body pose prior weights does not match the' + ' number of collision loss weights') assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg use_vposer = kwargs.get('use_vposer', True) vposer, pose_embedding = [None, ] * 2 if use_vposer: # pose_embedding = torch.zeros([batch_size, 32], # dtype=dtype, device=device, # requires_grad=True) # Patrick: hack to set default body pose to something more sleep-y mean_body = np.array([[ 0.19463745, 1.6240447, 0.6890624, 0.19186097, 0.08003145, -0.04189298, 3.450903, -0.29570094, 0.25072002, -1.1879578, 0.33350763, 0.23568614, 0.38122794, -2.1258948, 0.2910664, 2.2407222, -0.5400814, -0.95984083, -1.2880017, 1.1122228, 0.7411389, -0.2265636, -4.8202057, -1.950323, -0.28771818, -1.9282387, 0.9928907, -0.27183488, -0.55805033, 0.04047768, -0.537362, 0.65770334]]) pose_embedding = torch.tensor(mean_body, dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype) else: # body_mean_pose = body_pose_prior.get_mean().detach().cpu() # body_mean_pose = torch.zeros([batch_size, 69], dtype=dtype) # mean_body = np.array([[-2.33263850e-01, 1.35460928e-01, 2.94471830e-01, -3.22930813e-01, # -4.73931670e-01, -2.67531037e-01, 7.12558180e-02, 7.89440796e-03, # 8.67700949e-03, 1.05982251e-01, 2.79584467e-01, -7.04243258e-02, # 3.61106455e-01, -5.87305248e-01, 1.10897996e-01, -1.68918714e-01, # -4.60174456e-02, 3.28684039e-02, 5.80525696e-01, -5.11317095e-03, # -1.57546505e-01, 5.85777402e-01, -8.94948393e-02, 2.24680841e-01, # 1.55473784e-01, 5.38146123e-04, 4.30279821e-02, -4.68525589e-02, # 7.75185153e-02, 7.82282930e-03, 6.74356073e-02, 4.09710407e-02, # -3.60425897e-02, -4.71813440e-01, 5.02379127e-02, 2.02309843e-02, # 5.29680364e-02, 1.68510173e-02, 2.25090146e-01, -4.52307612e-02, # 7.72185996e-02, -2.17333943e-01, 3.30020368e-01, 4.21866514e-02, # 7.15153441e-02, 3.05950731e-01, -3.63454908e-01, -1.28235269e+00, # 5.09610713e-01, 4.65482563e-01, 1.20263052e+00, 5.56594551e-01, # -2.24000740e+00, 3.83565158e-01, 5.31355202e-01, 2.21637583e+00, # -5.63146770e-01, -3.01193684e-01, -4.31942672e-01, 6.85038209e-01, # 3.61178756e-01, 2.76136428e-01, -2.64388829e-01, 0.00000000e+00, # 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, # 0.00000000e+00]]) mean_body = np.array(joint_limits.axang_limits_patrick / 180 * np.pi).mean(1) body_mean_pose = torch.tensor(mean_body, dtype=dtype).unsqueeze(0) betanet = None if height is not None: betanet = torch.load('models/betanet_old_pytorch.pt') betanet = betanet.to(device=device) betanet.eval() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2] # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) if use_joints_conf: joints_conf = joints_conf.to(device=device, dtype=dtype) scan_tensor = None if scan is not None: scan_tensor = scan.to(device=device) # load pre-computed signed distance field sdf = None sdf_normals = None grid_min = None grid_max = None voxel_size = None # if sdf_penetration: # with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f: # sdf_data = json.load(f) # grid_min = torch.tensor(np.array(sdf_data['min']), dtype=dtype, device=device) # grid_max = torch.tensor(np.array(sdf_data['max']), dtype=dtype, device=device) # grid_dim = sdf_data['dim'] # voxel_size = (grid_max - grid_min) / grid_dim # sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape(grid_dim, grid_dim, grid_dim) # sdf = torch.tensor(sdf, dtype=dtype, device=device) # if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')): # sdf_normals = np.load(osp.join(sdf_dir, scene_name + '_normals.npy')).reshape(grid_dim, grid_dim, grid_dim, 3) # sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device) # else: # print("Normals not found...") with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f: cam2world = np.array(json.load(f)) R = torch.tensor(cam2world[:3, :3].reshape(3, 3), dtype=dtype, device=device) t = torch.tensor(cam2world[:3, 3].reshape(1, 3), dtype=dtype, device=device) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # load vertix ids of contact parts contact_verts_ids = ftov = None if contact: contact_verts_ids = [] for part in contact_body_parts: with open(os.path.join(body_segments_dir, part + '.json'), 'r') as f: data = json.load(f) contact_verts_ids.append(list(set(data["verts_ind"]))) contact_verts_ids = np.concatenate(contact_verts_ids) vertices = body_model(return_verts=True, body_pose= torch.zeros((batch_size, 63), dtype=dtype, device=device)).vertices vertices_np = vertices.detach().cpu().numpy().squeeze() body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape(-1, 3) m = Mesh(v=vertices_np, f=body_faces_np) ftov = m.faces_by_vertex(as_sparse_matrix=True) ftov = sparse.coo_matrix(ftov) indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device) values = torch.FloatTensor(ftov.data).to(device) shape = ftov.shape ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape)) # Read the scene scan if any scene_v = scene_vn = scene_f = None if scene_name is not None: if load_scene: scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply')) scene.vn = scene.estimate_vertex_normals() scene_v = torch.tensor(scene.v[np.newaxis, :], dtype=dtype, device=device).contiguous() scene_vn = torch.tensor(scene.vn[np.newaxis, :], dtype=dtype, device=device) scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :], dtype=torch.long, device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = {'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights} if use_face: opt_weights_dict['face_weight'] = face_joints_weights opt_weights_dict['expr_prior_weight'] = expr_weights opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights if use_hands: opt_weights_dict['hand_weight'] = hand_joints_weights opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights if s2m: opt_weights_dict['s2m_weight'] = s2m_weights if m2s: opt_weights_dict['m2s_weight'] = m2s_weights if sdf_penetration: opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights if contact: opt_weights_dict['contact_loss_weight'] = contact_loss_weights keys = opt_weights_dict.keys() opt_weights = [dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None))] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # load indices of the head of smpl-x model with open( osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp: head_indx = np.array(json.load(fp)) N = body_model.get_num_verts() body_indx = np.setdiff1d(np.arange(N), head_indx) head_mask = np.in1d(np.arange(N), head_indx) body_mask = np.in1d(np.arange(N), body_indx) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) edge_indices = kwargs.get('body_tri_idxs') # which initialization mode to choose: similar traingles, mean of the scan or the average of both if init_mode == 'scan': init_t = init_trans elif init_mode == 'both': init_t = (init_trans.to(device) + fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype) ) /2.0 else: init_t = fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype) camera_loss = fitting.create_loss('camera_init', trans_estimation=init_t, init_joints_idxs=init_joints_idxs, depth_loss_weight=depth_loss_weight, camera_mode=camera_mode, dtype=dtype).to(device=device) camera_loss.trans_estimation[:] = init_t loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, s2m=s2m, m2s=m2s, rho_s2m=rho_s2m, rho_m2s=rho_m2s, head_mask=head_mask, body_mask=body_mask, sdf_penetration=sdf_penetration, voxel_size=voxel_size, grid_min=grid_min, grid_max=grid_max, sdf=sdf, sdf_normals=sdf_normals, R=R, t=t, contact=contact, contact_verts_ids=contact_verts_ids, rho_contact=rho_contact, contact_angle=contact_angle, dtype=dtype, betanet=betanet, height=height, weight=weight, gender=gender, weight_w=weight_w, height_w=height_w, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, viz_mode=viz_mode, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) _, H, W, _ = img.shape # Reset the parameters to estimate the initial translation of the # body model if camera_mode == 'moving': body_model.reset_params(body_pose=body_mean_pose) # Update the value of the translation of the camera as well as # the image center. with torch.no_grad(): camera.translation[:] = init_t.view_as(camera.translation) camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5 # Re-enable gradient calculation for the camera translation camera.translation.requires_grad = True camera_opt_params = [camera.translation, body_model.global_orient] elif camera_mode == 'fixed': # body_model.reset_params() # body_model.transl[:] = torch.tensor(init_t) # body_model.body_pose[:] = torch.tensor(body_mean_pose) body_model.reset_params(body_pose=body_mean_pose, transl=init_t) camera_opt_params = [body_model.transl, body_model.global_orient] # If the distance between the 2D shoulders is smaller than a # predefined threshold then try 2 fits, the initial one and a 180 # degree rotation shoulder_dist = torch.norm(gt_joints[:, left_shoulder_idx, :] - gt_joints[:, right_shoulder_idx, :], dim=1) try_both_orient = shoulder_dist.min() < side_view_thsh kwargs['lr'] *= 10 camera_optimizer, camera_create_graph = optim_factory.create_optimizer(camera_opt_params, **kwargs) kwargs['lr'] /= 10 # The closure passed to the optimizer fit_camera = monitor.create_fitting_closure( camera_optimizer, body_model, camera, gt_joints, camera_loss, create_graph=camera_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, return_full_pose=False, return_verts=False) # Step 1: Optimize over the torso joints the camera translation # Initialize the computational graph by feeding the initial translation # of the camera and the initial pose of the body model. camera_init_start = time.time() cam_init_loss_val = monitor.run_fitting(camera_optimizer, fit_camera, camera_opt_params, body_model, use_vposer=use_vposer, pose_embedding=pose_embedding, vposer=vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() tqdm.write('Camera initialization done after {:.4f}'.format( time.time() - camera_init_start)) tqdm.write('Camera initialization final loss {:.4f}'.format( cam_init_loss_val)) # If the 2D detections/positions of the shoulder joints are too # close the rotate the body by 180 degrees and also fit to that # orientation if try_both_orient: with torch.no_grad(): flipped_orient = torch.zeros_like(body_model.global_orient) for i in range(batch_size): body_orient = body_model.global_orient[i, :].detach().cpu().numpy() local_flip = cv2.Rodrigues(body_orient)[0].dot(cv2.Rodrigues(np.array([0., np.pi, 0]))[0]) local_flip = cv2.Rodrigues(local_flip)[0].ravel() flipped_orient[i, :] = torch.Tensor(local_flip).to(device) orientations = [body_model.global_orient, flipped_orient] else: orientations = [body_model.global_orient.detach().cpu().numpy()] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] body_transl = body_model.transl.clone().detach() # Step 2: Optimize the full model final_loss_val = 0 # for or_idx, orient in enumerate(orientations): or_idx = 0 while or_idx < len(orientations): global_vars.cur_orientation = or_idx orient = orientations[or_idx] print('Trying orientation', or_idx, 'of', len(orientations)) opt_start = time.time() or_idx += 1 new_params = defaultdict(transl=body_transl, global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) pose_embedding += torch.tensor(mean_body, dtype=dtype, device=device) for opt_idx, curr_weights in enumerate(opt_weights): global_vars.cur_opt_stage = opt_idx if opt_idx not in trans_opt_stages: body_model.transl.requires_grad = False else: body_model.transl.requires_grad = True body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) if use_hands: joint_weights[:, 25:76] = curr_weights['hand_weight'] if use_face: joint_weights[:, 76:] = curr_weights['face_weight'] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, scene_v=scene_v, scene_vn=scene_vn, scene_f=scene_f,ftov=ftov, return_verts=True, return_full_pose=True) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) # print('Final loss val', final_loss_val) # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']): # break if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write('Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']): # print('Optimization FAILURE, retrying') # orientations.append(orientations[or_idx-1] * 0.9) # continue if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write('Body fitting Orientation {} done after {:.4f} seconds'.format(or_idx, elapsed)) tqdm.write('Body final loss val = {:.5f}'.format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = {'camera_' + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters()} result['camera_focal_length_x'] = camera.focal_length_x.detach().cpu().numpy() result['camera_focal_length_y'] = camera.focal_length_y.detach().cpu().numpy() result['camera_center'] = camera.center.detach().cpu().numpy() result.update({key: val.detach().cpu().numpy() for key, val in body_model.named_parameters()}) if use_vposer: result['pose_embedding'] = pose_embedding.detach().cpu().numpy() body_pose = vposer.decode(pose_embedding, output_type='aa').view(1, -1) if use_vposer else None if "smplx.body_models.SMPL'" in str(type(body_model)): wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device) body_pose = torch.cat([body_pose, wrist_pose], dim=1) result['body_pose'] = body_pose.detach().cpu().numpy() result['final_loss_val'] = final_loss_val result['loss_dict'] = global_vars.cur_loss_dict result['betanet_weight'] = global_vars.cur_weight result['betanet_height'] = global_vars.cur_height result['gt_joints'] = gt_joints.detach().cpu().numpy() result['max_joint'] = global_vars.cur_max_joint results.append(result) for idx, res_folder in enumerate(result_fn): # Iterate over batch pkl_data = {} min_loss = np.inf all_results = [] for result in results: # Iterate over orientations sel_res = misc_utils.get_data_from_batched_dict(result, idx, len(result_fn)) all_results.append(sel_res) cost = sel_res['loss_dict']['total'] + sel_res['loss_dict']['pprior'] * 60 if cost < min_loss: min_loss = cost pkl_data.update(sel_res) pkl_data['all_results'] = all_results with open(res_folder, 'wb') as result_file: pickle.dump(pkl_data, result_file, protocol=2) img_s = img[idx, :].detach().cpu().numpy() img_s = pil_img.fromarray((img_s * 255).astype(np.uint8)) img_s.save(out_img_fn[idx])
def fit_frames(img, keypoints, body_model, camera=None, joint_weights=None, body_pose_prior=None, jaw_prior=None, left_hand_prior=None, right_hand_prior=None, shape_prior=None, expr_prior=None, angle_prior=None, loss_type="smplify", use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=False, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=False, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn="", focal_length=5000.0, side_view_thsh=25.0, rho=100, vposer_latent_dim=32, vposer_ckpt="", use_joints_conf=False, interactive=True, visualize=False, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, freeze_camera=True, **kwargs): # assert batch_size == 1, "PyTorch L-BFGS only supports batch_size == 1" device = torch.device("cuda") if use_cuda else torch.device("cpu") if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] if data_weights is None: data_weights = [1] * len(body_pose_prior_weights) msg = "Number of Body pose prior weights {}".format( len(body_pose_prior_weights) ) + " does not match the number of data term weights {}".format( len(data_weights)) assert len(data_weights) == len(body_pose_prior_weights), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1] msg = ("Number of Body pose prior weights does not match the" + " number of hand pose prior weights") assert len(hand_pose_prior_weights) == len( body_pose_prior_weights), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ("Number of Body pose prior weights does not match the" + " number of hand joint distance weights") assert len(hand_joints_weights) == len( body_pose_prior_weights), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1] msg = ("Number of Body pose prior weights = {} does not match the" + " number of Shape prior weights = {}") assert len(shape_weights) == len(body_pose_prior_weights), msg.format( len(shape_weights), len(body_pose_prior_weights)) if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ("Number of Body pose prior weights does not match the" + " number of collision loss weights") assert len(coll_loss_weights) == len(body_pose_prior_weights), msg use_vposer = kwargs.get("use_vposer", True) vposer, pose_embedding = [None] * 2 if use_vposer: pose_embedding = torch.zeros( [batch_size, vposer_latent_dim], dtype=dtype, device=device, requires_grad=True, ) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model="snapshot") vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = ( vposeutils.get_vposer_mean_pose().detach().cpu().numpy()) else: body_mean_pose = body_pose_prior.get_mean().detach().cpu() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2].reshape(keypoint_data.shape[0], -1) joints_conf = joints_conf.to(device=device, dtype=dtype) # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, "Interpenetration term can only be used with CUDA" assert torch.cuda.is_available(), ( "No CUDA Device! Interpenetration term can only be used" + " with CUDA") search_tree = BVH(max_collisions=max_collisions) pen_distance = collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside, ) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, "rb") as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding="latin1") faces_segm = face_segm_data["segm"] faces_parents = face_segm_data["parents"] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs, ).to(device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = { "data_weight": data_weights, "body_pose_weight": body_pose_prior_weights, "shape_weight": shape_weights, } if use_face: opt_weights_dict["face_weight"] = face_joints_weights opt_weights_dict["expr_prior_weight"] = expr_weights opt_weights_dict["jaw_prior_weight"] = jaw_pose_prior_weights if use_hands: opt_weights_dict["hand_weight"] = hand_joints_weights opt_weights_dict["hand_prior_weight"] = hand_pose_prior_weights if interpenetration: opt_weights_dict["coll_loss_weight"] = coll_loss_weights keys = opt_weights_dict.keys() opt_weights = [ dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None)) ] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) # Hand joints start at 25 (before body) loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, dtype=dtype, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) H, W, _ = img.shape data_weight = 1000 / H orientations = [body_model.global_orient.detach().cpu().numpy()] # # Step 2: Optimize the full model final_loss_val = 0 for or_idx, orient in enumerate(tqdm(orientations, desc="Orientation")): opt_start = time.time() new_params = defaultdict(global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) for opt_idx, curr_weights in enumerate( tqdm(opt_weights, desc="Stage")): body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) ( body_optimizer, body_create_graph, ) = optim_factory.create_optimizer(final_params, **kwargs) body_optimizer.zero_grad() curr_weights["data_weight"] = data_weight curr_weights["bending_prior_weight"] = ( 3.17 * curr_weights["body_pose_weight"]) if use_hands: # joint_weights[:, 25:67] = curr_weights['hand_weight'] pass if use_face: joint_weights[:, 67:] = curr_weights["face_weight"] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_verts=True, return_full_pose=True, ) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer, ) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write( "Stage {:03d} done after {:.4f} seconds".format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write( "Body fitting Orientation {} done after {:.4f} seconds". format(or_idx, elapsed)) tqdm.write( "Body final loss val = {:.5f}".format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = { "camera_" + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters() } result.update({ key: val.detach().cpu().numpy() for key, val in body_model.named_parameters() }) if use_vposer: result["pose_embedding"] = ( pose_embedding.detach().cpu().numpy()) body_pose = (vposer.decode( pose_embedding, output_type="aa").reshape( pose_embedding.shape[0], -1) if use_vposer else None) result["body_pose"] = body_pose.detach().cpu().numpy() model_output = body_model(return_verts=True, body_pose=body_pose) return model_output, result
def main(args): scene_name = os.path.abspath(args.gen_folder).split("/")[-1] outimg_dir = args.outimg_dir if not os.path.exists(outimg_dir): os.makedirs(outimg_dir) ### setup visualization window vis = o3d.visualization.Visualizer() vis.create_window(width=960, height=540, visible=True) render_opt = vis.get_render_option().mesh_show_back_face = True ### put the scene into the environment scene = o3d.io.read_triangle_mesh( osp.join(args.prox_dir, scene_name + '.ply')) vis.add_geometry(scene) vis.update_geometry() # put the body into the environment vposer_ckpt = osp.join(args.model_folder, 'vposer_v1_0') vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') model = smplx.create(args.model_folder, model_type='smplx', gender=args.gender, ext='npz', num_pca_comps=args.num_pca_comps, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=True) ## create a corn at the camera location # mesh_corn = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) # mesh_corn.transform(trans) # vis.add_geometry(mesh_corn) # vis.update_geometry() # print(trans) ## create a corn at the world origin # mesh_corn2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) # vis.add_geometry(mesh_corn2) # vis.update_geometry() # print(trans) cv2.namedWindow("GUI") gen_file_list = glob.glob(os.path.join(args.gen_folder, '*')) body = o3d.geometry.TriangleMesh() vis.add_geometry(body) for idx, gen_file in enumerate(gen_file_list): with open(gen_file, 'rb') as f: param = pickle.load(f) cam_ext = param['cam_ext'][0] cam_int = param['cam_int'][0] body_pose = vposer.decode(torch.tensor(param['body_pose']), output_type='aa').view(1, -1) torch_param = {} for key in param.keys(): if key in ['body_pose', 'camera_rotation', 'camera_translation']: continue else: torch_param[key] = torch.tensor(param[key]) output = model(return_verts=True, body_pose=body_pose, **torch_param) vertices = output.vertices.detach().cpu().numpy().squeeze() body.vertices = o3d.utility.Vector3dVector(vertices) body.triangles = o3d.utility.Vector3iVector(model.faces) body.vertex_normals = o3d.utility.Vector3dVector([]) body.triangle_normals = o3d.utility.Vector3dVector([]) body.compute_vertex_normals() T_mat = np.eye(4) T_mat[1, :] = np.array([0, -1, 0, 0]) T_mat[2, :] = np.array([0, 0, -1, 0]) trans = np.dot(cam_ext, T_mat) body.transform(trans) vis.update_geometry() # while True: # vis.poll_events() # vis.update_renderer() # cv2.imshow("GUI", np.random.random([10,10,3])) # # ctr = vis.get_view_control() # # cam_param = ctr.convert_to_pinhole_camera_parameters() # # print(cam_param.extrinsic) # key = cv2.waitKey(15) # if key == 27: # break ctr = vis.get_view_control() cam_param = ctr.convert_to_pinhole_camera_parameters() cam_param = update_cam(cam_param, trans) ctr.convert_from_pinhole_camera_parameters(cam_param) vis.poll_events() vis.update_renderer() capture_image(vis, outfilename=os.path.join( outimg_dir, 'img_{:06d}_cam1.png'.format(idx))) # vis.run() # capture_image(vis, outfilename=os.path.join(outimg_dir, 'img_{:06d}_cam1.png'.format(idx))) ### setup rendering cam, depth capture, segmentation capture ctr = vis.get_view_control() cam_param = ctr.convert_to_pinhole_camera_parameters() cam_param.extrinsic = trans2_dict[scene_name] ctr.convert_from_pinhole_camera_parameters(cam_param) vis.poll_events() vis.update_renderer() capture_image(vis, outfilename=os.path.join( outimg_dir, 'img_{:06d}_cam2.png'.format(idx)))
def main(args): fitting_dir = args.fitting_dir recording_name = os.path.abspath(fitting_dir).split("/")[-1] fitting_dir = osp.join(fitting_dir, 'results') data_dir = args.data_dir cam2world_dir = osp.join(data_dir, 'cam2world') scene_dir = osp.join(data_dir, 'scenes_semantics') recording_dir = osp.join(data_dir, 'recordings', recording_name) scene_name = os.path.abspath(recording_dir).split("/")[-1].split("_")[0] ## setup the output folder output_folder = os.path.join('/mnt/hdd/PROX', 'snapshot_virtualcam_TNoise0.5') if not os.path.exists(output_folder): os.mkdir(output_folder) ### setup visualization window vis = o3d.visualization.Visualizer() vis.create_window(width=480, height=270, visible=True) render_opt = vis.get_render_option().mesh_show_back_face = True ### put the scene into the visualizer scene = o3d.io.read_triangle_mesh( osp.join(scene_dir, scene_name + '_withlabels.ply')) vis.add_geometry(scene) ## get scene 3D scene bounding box scene_o = o3d.io.read_triangle_mesh( osp.join(scene_dir, scene_name + '.ply')) scene_min = scene_o.get_min_bound() #[x_min, y_min, z_min] scene_max = scene_o.get_max_bound() #[x_max, y_max, z_max] # reduce the scene region furthermore, to avoid cams behind the window shift = 0.7 scene_min = scene_min + shift scene_max = scene_max - shift ### get the real camera config trans_calib = np.eye(4) with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f: trans_calib = np.array(json.load(f)) ## put the body into the environment vposer_ckpt = osp.join(args.model_folder, 'vposer_v1_0') vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') model = smplx.create(args.model_folder, model_type='smplx', gender='neutral', ext='npz', num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=False, create_jaw_pose=False, create_leye_pose=False, create_reye_pose=False, create_transl=True) rec_count = -1 sample_rate = 15 # 0.5second for img_name in sorted(os.listdir(fitting_dir))[::sample_rate]: ## get humam body params filename = osp.join(fitting_dir, img_name, '000.pkl') print('frame: ' + filename) if not os.path.exists(filename): print('file does not exist. Continue') continue with open(osp.join(fitting_dir, img_name, '000.pkl'), 'rb') as f: body_dict = pickle.load(f) if np.sum(np.isnan(body_dict['body_pose'])) > 0: continue rec_count = rec_count + 1 ## save depth, semantics and render cam outname1 = os.path.join(output_folder, recording_name) if not os.path.exists(outname1): os.mkdir(outname1) ######################### then we obtain the virutal cam ################################ ## find world coordinate of the human body in the current frame body_params_W_list, dT = update_globalRT_for_smplx( body_dict, model, vposer, [trans_calib]) body_T_world = body_params_W_list[0]['transl'][0] + dT ## get virtual cams, and transform global_R and global_T to virtual cams new_cammat_ext_list0 = [] new_cammat_ext_list0 = get_new_cams(scene_name, s_min=scene_min, s_max=scene_max, body_T=body_T_world) random.shuffle(new_cammat_ext_list0) new_cammat_ext_list = new_cammat_ext_list0[:30] print('--obtain {:d} cams'.format(len(new_cammat_ext_list))) new_cammat_list = [invert_transform(x) for x in new_cammat_ext_list] body_params_new_list, _ = update_globalRT_for_smplx( body_params_W_list[0], model, vposer, new_cammat_list, delta_T=dT) #### capture depth and seg in new cams for idx_cam, cam_ext in enumerate(new_cammat_ext_list): ### save filename outname = os.path.join( outname1, 'rec_frame{:06d}_cam{:06d}.mat'.format(rec_count, idx_cam)) ## put the render cam to the real cam ctr = vis.get_view_control() cam_param = ctr.convert_to_pinhole_camera_parameters() cam_param = update_render_cam(cam_param, cam_ext) ctr.convert_from_pinhole_camera_parameters(cam_param) vis.poll_events() vis.update_renderer() ## get render cam parameters cam_dict = {} cam_dict['extrinsic'] = cam_param.extrinsic cam_dict['intrinsic'] = cam_param.intrinsic.intrinsic_matrix ## capture depth image depth = np.asarray(vis.capture_depth_float_buffer(do_render=True)) _h = depth.shape[0] _w = depth.shape[1] depth0 = depth depth_canvas, scaling_factor = data_preprocessing(depth, 'depth') ### skip settings when the human body is severely occluded. body_is_occluded = is_body_occluded(body_params_new_list[idx_cam], cam_dict, depth) if body_is_occluded: print( '-- body is occluded or not in the scene at current view.') continue ## capture semantics seg = np.asarray(vis.capture_screen_float_buffer(do_render=True)) verid = np.mean(seg * 255 / 5.0, axis=-1) #.astype(int) seg0 = verid # verid = cv2.resize(verid, (_w//factor, _h//factor)) seg_canvas, _ = data_preprocessing(verid, 'seg') # pdb.set_trace() ## save file to disk ot_dict = {} ot_dict['scaling_factor'] = scaling_factor ot_dict['depth'] = depth_canvas ot_dict['depth0'] = depth0 ot_dict['seg0'] = seg0 ot_dict['seg'] = seg_canvas ot_dict['cam'] = cam_dict ot_dict['body'] = body_params_new_list[idx_cam] sio.savemat(outname, ot_dict) vis.destroy_window()
if mode == "smplx": model_params = dict( model_path="assets/models", model_type="smplx", gender="female", # create_body_pose=True, dtype=torch.float32, use_face=False, ) model = smplx.create(**model_params) model.cuda() else: bm_path = "assets/models/smplx/SMPLX_FEMALE.npz" bm = BodyModel(bm_path=bm_path, batch_size=1).to("cuda") bm.cuda() vp, ps = load_vposer("assets/vposer") vp = vp.to("cuda") vp.eval() # Sample a 32 dimentional vector from a Normal distribution poZ_body_sample = torch.zeros(1, 32).cuda() pose_body = vp.decode(poZ_body_sample, output_type="aa").view(-1, 63) pose_embedding = torch.zeros( [1, 32], requires_grad=True, device=poZ_body_sample.device ) pose_embedding = torch.rand( [1, 32], requires_grad=True, device=poZ_body_sample.device )
def fit_single_frame( img, keypoints, init_trans, scan, scene_name, body_model, camera, joint_weights, body_pose_prior, jaw_prior, left_hand_prior, right_hand_prior, shape_prior, expr_prior, angle_prior, result_fn='out.pkl', mesh_fn='out.obj', body_scene_rendering_fn='body_scene.png', out_img_fn='overlay.png', loss_type='smplify', use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=True, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=True, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn='', focal_length_x=5000., focal_length_y=5000., side_view_thsh=25., rho=100, vposer_latent_dim=32, vposer_ckpt='', use_joints_conf=False, interactive=True, visualize=False, save_meshes=True, degrees=None, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, #################### ### PROX render_results=True, camera_mode='moving', ## Depth s2m=False, s2m_weights=None, m2s=False, m2s_weights=None, rho_s2m=1, rho_m2s=1, init_mode=None, trans_opt_stages=None, viz_mode='mv', #penetration sdf_penetration=False, sdf_penetration_weights=0.0, sdf_dir=None, cam2world_dir=None, #contact contact=False, rho_contact=1.0, contact_loss_weights=None, contact_angle=15, contact_body_parts=None, body_segments_dir=None, load_scene=False, scene_dir=None, **kwargs): assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' body_model.reset_params() body_model.transl.requires_grad = True device = torch.device('cuda') if use_cuda else torch.device('cpu') if visualize: pil_img.fromarray((img * 255).astype(np.uint8)).show() if degrees is None: degrees = [0, 90, 180, 270] if data_weights is None: data_weights = [ 1, ] * 5 if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] msg = ('Number of Body pose prior weights {}'.format( len(body_pose_prior_weights)) + ' does not match the number of data term weights {}'.format( len(data_weights))) assert (len(data_weights) == len(body_pose_prior_weights)), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights does not match the' + ' number of hand pose prior weights') assert ( len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of hand joint distance weights') assert ( len(hand_joints_weights) == len(body_pose_prior_weights)), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Shape prior weights = {}') assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format( len(shape_weights), len(body_pose_prior_weights)) if use_face: if jaw_pose_prior_weights is None: jaw_pose_prior_weights = [[x] * 3 for x in shape_weights] else: jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')), jaw_pose_prior_weights) jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights] msg = ('Number of Body pose prior weights does not match the' + ' number of jaw pose prior weights') assert ( len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg if expr_weights is None: expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Expression prior weights = {}') assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format( len(body_pose_prior_weights), len(expr_weights)) if face_joints_weights is None: face_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of face joint distance weights') assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ('Number of Body pose prior weights does not match the' + ' number of collision loss weights') assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg use_vposer = kwargs.get('use_vposer', True) vposer, pose_embedding = [ None, ] * 2 if use_vposer: pose_embedding = torch.zeros([batch_size, 32], dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype) else: body_mean_pose = body_pose_prior.get_mean().detach().cpu() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2].reshape(1, -1) # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) if use_joints_conf: joints_conf = joints_conf.to(device=device, dtype=dtype) scan_tensor = None if scan is not None: scan_tensor = torch.tensor(scan.get('points'), device=device, dtype=dtype).unsqueeze(0) # load pre-computed signed distance field sdf = None sdf_normals = None grid_min = None grid_max = None voxel_size = None if sdf_penetration: with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f: sdf_data = json.load(f) grid_min = torch.tensor(np.array(sdf_data['min']), dtype=dtype, device=device) grid_max = torch.tensor(np.array(sdf_data['max']), dtype=dtype, device=device) grid_dim = sdf_data['dim'] voxel_size = (grid_max - grid_min) / grid_dim sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape( grid_dim, grid_dim, grid_dim) sdf = torch.tensor(sdf, dtype=dtype, device=device) if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')): sdf_normals = np.load( osp.join(sdf_dir, scene_name + '_normals.npy')).reshape( grid_dim, grid_dim, grid_dim, 3) sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device) else: print("Normals not found...") with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f: cam2world = np.array(json.load(f)) R = torch.tensor(cam2world[:3, :3].reshape(3, 3), dtype=dtype, device=device) t = torch.tensor(cam2world[:3, 3].reshape(1, 3), dtype=dtype, device=device) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # load vertix ids of contact parts contact_verts_ids = ftov = None if contact: contact_verts_ids = [] for part in contact_body_parts: with open(os.path.join(body_segments_dir, part + '.json'), 'r') as f: data = json.load(f) contact_verts_ids.append(list(set(data["verts_ind"]))) contact_verts_ids = np.concatenate(contact_verts_ids) vertices = body_model(return_verts=True, body_pose=torch.zeros((batch_size, 63), dtype=dtype, device=device)).vertices vertices_np = vertices.detach().cpu().numpy().squeeze() body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape( -1, 3) m = Mesh(v=vertices_np, f=body_faces_np) ftov = m.faces_by_vertex(as_sparse_matrix=True) ftov = sparse.coo_matrix(ftov) indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device) values = torch.FloatTensor(ftov.data).to(device) shape = ftov.shape ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape)) # Read the scene scan if any scene_v = scene_vn = scene_f = None if scene_name is not None: if load_scene: scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply')) scene.vn = scene.estimate_vertex_normals() scene_v = torch.tensor(scene.v[np.newaxis, :], dtype=dtype, device=device).contiguous() scene_vn = torch.tensor(scene.vn[np.newaxis, :], dtype=dtype, device=device) scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :], dtype=torch.long, device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = { 'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights } if use_face: opt_weights_dict['face_weight'] = face_joints_weights opt_weights_dict['expr_prior_weight'] = expr_weights opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights if use_hands: opt_weights_dict['hand_weight'] = hand_joints_weights opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights if s2m: opt_weights_dict['s2m_weight'] = s2m_weights if m2s: opt_weights_dict['m2s_weight'] = m2s_weights if sdf_penetration: opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights if contact: opt_weights_dict['contact_loss_weight'] = contact_loss_weights keys = opt_weights_dict.keys() opt_weights = [ dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None)) ] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # load indices of the head of smpl-x model with open(osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp: head_indx = np.array(json.load(fp)) N = body_model.get_num_verts() body_indx = np.setdiff1d(np.arange(N), head_indx) head_mask = np.in1d(np.arange(N), head_indx) body_mask = np.in1d(np.arange(N), body_indx) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) edge_indices = kwargs.get('body_tri_idxs') # which initialization mode to choose: similar traingles, mean of the scan or the average of both if init_mode == 'scan': init_t = init_trans elif init_mode == 'both': init_t = (init_trans.to(device) + fitting.guess_init( body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype)) / 2.0 else: init_t = fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get( 'model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype) camera_loss = fitting.create_loss('camera_init', trans_estimation=init_t, init_joints_idxs=init_joints_idxs, depth_loss_weight=depth_loss_weight, camera_mode=camera_mode, dtype=dtype).to(device=device) camera_loss.trans_estimation[:] = init_t loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, s2m=s2m, m2s=m2s, rho_s2m=rho_s2m, rho_m2s=rho_m2s, head_mask=head_mask, body_mask=body_mask, sdf_penetration=sdf_penetration, voxel_size=voxel_size, grid_min=grid_min, grid_max=grid_max, sdf=sdf, sdf_normals=sdf_normals, R=R, t=t, contact=contact, contact_verts_ids=contact_verts_ids, rho_contact=rho_contact, contact_angle=contact_angle, dtype=dtype, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, viz_mode=viz_mode, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) H, W, _ = img.shape # Reset the parameters to estimate the initial translation of the # body model if camera_mode == 'moving': body_model.reset_params(body_pose=body_mean_pose) # Update the value of the translation of the camera as well as # the image center. with torch.no_grad(): camera.translation[:] = init_t.view_as(camera.translation) camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5 # Re-enable gradient calculation for the camera translation camera.translation.requires_grad = True camera_opt_params = [camera.translation, body_model.global_orient] elif camera_mode == 'fixed': body_model.reset_params(body_pose=body_mean_pose, transl=init_t) camera_opt_params = [body_model.transl, body_model.global_orient] # If the distance between the 2D shoulders is smaller than a # predefined threshold then try 2 fits, the initial one and a 180 # degree rotation shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx], gt_joints[:, right_shoulder_idx]) try_both_orient = shoulder_dist.item() < side_view_thsh camera_optimizer, camera_create_graph = optim_factory.create_optimizer( camera_opt_params, **kwargs) # The closure passed to the optimizer fit_camera = monitor.create_fitting_closure( camera_optimizer, body_model, camera, gt_joints, camera_loss, create_graph=camera_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, return_full_pose=False, return_verts=False) # Step 1: Optimize over the torso joints the camera translation # Initialize the computational graph by feeding the initial translation # of the camera and the initial pose of the body model. camera_init_start = time.time() cam_init_loss_val = monitor.run_fitting(camera_optimizer, fit_camera, camera_opt_params, body_model, use_vposer=use_vposer, pose_embedding=pose_embedding, vposer=vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() tqdm.write('Camera initialization done after {:.4f}'.format( time.time() - camera_init_start)) tqdm.write('Camera initialization final loss {:.4f}'.format( cam_init_loss_val)) # If the 2D detections/positions of the shoulder joints are too # close the rotate the body by 180 degrees and also fit to that # orientation if try_both_orient: body_orient = body_model.global_orient.detach().cpu().numpy() flipped_orient = cv2.Rodrigues(body_orient)[0].dot( cv2.Rodrigues(np.array([0., np.pi, 0]))[0]) flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel() flipped_orient = torch.tensor(flipped_orient, dtype=dtype, device=device).unsqueeze(dim=0) orientations = [body_orient, flipped_orient] else: orientations = [body_model.global_orient.detach().cpu().numpy()] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] body_transl = body_model.transl.clone().detach() # Step 2: Optimize the full model final_loss_val = 0 for or_idx, orient in enumerate(tqdm(orientations, desc='Orientation')): opt_start = time.time() new_params = defaultdict(transl=body_transl, global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) for opt_idx, curr_weights in enumerate( tqdm(opt_weights, desc='Stage')): if opt_idx not in trans_opt_stages: body_model.transl.requires_grad = False else: body_model.transl.requires_grad = True body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) if use_hands: joint_weights[:, 25:76] = curr_weights['hand_weight'] if use_face: joint_weights[:, 76:] = curr_weights['face_weight'] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, scene_v=scene_v, scene_vn=scene_vn, scene_f=scene_f, ftov=ftov, return_verts=True, return_full_pose=True) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write( 'Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write( 'Body fitting Orientation {} done after {:.4f} seconds'. format(or_idx, elapsed)) tqdm.write( 'Body final loss val = {:.5f}'.format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = { 'camera_' + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters() } result.update({ key: val.detach().cpu().numpy() for key, val in body_model.named_parameters() }) if use_vposer: result['pose_embedding'] = pose_embedding.detach().cpu().numpy( ) body_pose = vposer.decode(pose_embedding, output_type='aa').view( 1, -1) if use_vposer else None result['body_pose'] = body_pose.detach().cpu().numpy() results.append({'loss': final_loss_val, 'result': result}) with open(result_fn, 'wb') as result_file: if len(results) > 1: min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1) else: min_idx = 0 pickle.dump(results[min_idx]['result'], result_file, protocol=2) if save_meshes or visualize: body_pose = vposer.decode(pose_embedding, output_type='aa').view( 1, -1) if use_vposer else None model_type = kwargs.get('model_type', 'smpl') append_wrists = model_type == 'smpl' and use_vposer if append_wrists: wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device) body_pose = torch.cat([body_pose, wrist_pose], dim=1) model_output = body_model(return_verts=True, body_pose=body_pose) vertices = model_output.vertices.detach().cpu().numpy().squeeze() import trimesh out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False) out_mesh.export(mesh_fn) if render_results: import pyrender # common H, W = 1080, 1920 camera_center = np.array([951.30, 536.77]) camera_pose = np.eye(4) camera_pose = np.array([1.0, -1.0, -1.0, 1.0]).reshape(-1, 1) * camera_pose camera = pyrender.camera.IntrinsicsCamera(fx=1060.53, fy=1060.38, cx=camera_center[0], cy=camera_center[1]) light = pyrender.DirectionalLight(color=np.ones(3), intensity=2.0) material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0)) body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material) ## rendering body img = img.detach().cpu().numpy() H, W, _ = img.shape scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) scene.add(camera, pose=camera_pose) scene.add(light, pose=camera_pose) # for node in light_nodes: # scene.add_node(node) scene.add(body_mesh, 'mesh') r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H, point_size=1.0) color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA) color = color.astype(np.float32) / 255.0 valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] input_img = img output_img = (color[:, :, :-1] * valid_mask + (1 - valid_mask) * input_img) img = pil_img.fromarray((output_img * 255).astype(np.uint8)) img.save(out_img_fn) ##redering body+scene body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material) static_scene = trimesh.load(osp.join(scene_dir, scene_name + '.ply')) trans = np.linalg.inv(cam2world) static_scene.apply_transform(trans) static_scene_mesh = pyrender.Mesh.from_trimesh(static_scene) scene = pyrender.Scene() scene.add(camera, pose=camera_pose) scene.add(light, pose=camera_pose) scene.add(static_scene_mesh, 'mesh') scene.add(body_mesh, 'mesh') r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H) color, _ = r.render(scene) color = color.astype(np.float32) / 255.0 img = pil_img.fromarray((color * 255).astype(np.uint8)) img.save(body_scene_rendering_fn)
def __init__( self, debug=True, camintr=None, hand_pca_nb=6, head_center_idx=8949, opt_weights={ # "hand_prior_weight":[1e2, 5 * 1e1, 1e1, 0.5 * 1e1], "hand_prior_weight": [0, 0, 0, 0], "hand_weight": [0.0, 0.0, 0.0, 1.0], # "body_pose_weight": [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78], "body_pose_weight": [0, 0, 0, 0], # "shape_weight": [1e2, 5 * 1e1, 1e1, 0.5 * 1e1]}, "shape_weight": [0, 0, 0, 0], }, mano_corresp_path="assets/models/MANO_SMPLX_vertex_ids.pkl", smpl_root="assets/models", vposer_dir="assets/vposer", vposer_dim=32, data_weight=1000 / 256, parts_path="assets/models/smplx/smplx_parts_segm.pkl", ): self.debug = debug self.hand_pca_nb = hand_pca_nb self.head_center_idx = head_center_idx self.smplx_vertex_nb = 10475 self.vposer_dim = vposer_dim # Optim weights self.opt_weights = opt_weights self.data_weight = data_weight # Initialize SMPL-X model self.smpl_model = vposeutils.get_smplx(model_root=smpl_root, vposer_dir=vposer_dir) self.smpl_f = self.smpl_model.faces with open(mano_corresp_path, "rb") as p_f: self.mano_corresp = pickle.load(p_f) # Get vposer self.vposer = load_vposer(vposer_dir, vp_model="snapshot")[0] self.vposer.eval() self.tareader = TarReader() # Translate human so that head is at camera level self.set_head2cam_trans() self.armfaces = smplvis.filter_parts(self.smpl_f, parts_path) self.joint_weights = initialize.get_joint_weights() fx, fy = camintr[0, 0], camintr[1, 1] center = camintr[:2, 2] rot = torch.eye(3) rot[0, 0] = -1 rot[1, 1] = -1 self.camera = camera.create_camera( focal_length_x=fx, focal_length_y=fy, center=torch.Tensor(center).unsqueeze(0), rotation=rot.unsqueeze(0), ) # Initialize priors self.body_pose_prior = prior.create_prior(prior_type="l2") self.hand_pca_nb = hand_pca_nb self.left_hand_prior = prior.create_prior( prior_type="l2", use_left_hand=True, num_gaussians=hand_pca_nb, ) self.right_hand_prior = prior.create_prior( prior_type="l2", use_right_hand=True, num_gaussians=hand_pca_nb, ) self.angle_prior = prior.create_prior(prior_type="angle") self.shape_prior = prior.create_prior(prior_type="l2")
body_rotmat = R.from_rotvec( body_rot_angle).as_dcm() # to a [b, 3,3] rotation mat body_transf = np.tile(np.eye(4), (batch_size, 1, 1)) body_transf[:, :-1, :-1] = body_rotmat body_transf[:, :-1, -1] = body_transl body_transf_w = np.matmul(trans_to_target_origin, body_transf) body_params_dict['global_orient'] = R.from_dcm( body_transf_w[:, :-1, :-1]).as_rotvec() body_params_dict['transl'] = body_transf_w[:, :-1, -1] - delta_T return body_param_dict ## figure out body model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vposer_model, _ = load_vposer('/home/yzhang/body_models/VPoser/vposer_v1_0/', vp_model='snapshot') vposer_model = vposer_model.to(device) smplx_model = smplx.create('/home/yzhang/body_models/VPoser/', model_type='smplx', gender='neutral', ext='npz', num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True,
def fit_single_frame(img, keypoints, body_model, camera, joint_weights, body_pose_prior, jaw_prior, left_hand_prior, right_hand_prior, shape_prior, expr_prior, angle_prior, result_fn='out.pkl', mesh_fn='out.obj', out_img_fn='overlay.png', loss_type='smplify', use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=True, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=True, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn='', focal_length=5000., side_view_thsh=25., rho=100, vposer_latent_dim=32, vposer_ckpt='', use_joints_conf=False, interactive=True, visualize=False, save_meshes=True, degrees=None, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, **kwargs): assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' device = torch.device('cuda') if use_cuda else torch.device('cpu') if degrees is None: degrees = [0, 90, 180, 270] if data_weights is None: data_weights = [ 1, ] * 5 if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] msg = ('Number of Body pose prior weights {}'.format( len(body_pose_prior_weights)) + ' does not match the number of data term weights {}'.format( len(data_weights))) assert (len(data_weights) == len(body_pose_prior_weights)), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights does not match the' + ' number of hand pose prior weights') assert ( len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of hand joint distance weights') assert ( len(hand_joints_weights) == len(body_pose_prior_weights)), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Shape prior weights = {}') assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format( len(shape_weights), len(body_pose_prior_weights)) if use_face: if jaw_pose_prior_weights is None: jaw_pose_prior_weights = [[x] * 3 for x in shape_weights] else: jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')), jaw_pose_prior_weights) jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights] msg = ('Number of Body pose prior weights does not match the' + ' number of jaw pose prior weights') assert ( len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg if expr_weights is None: expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Expression prior weights = {}') assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format( len(body_pose_prior_weights), len(expr_weights)) if face_joints_weights is None: face_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of face joint distance weights') assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ('Number of Body pose prior weights does not match the' + ' number of collision loss weights') assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg use_vposer = kwargs.get('use_vposer', True) vposer, pose_embedding = [ None, ] * 2 if use_vposer: pose_embedding = torch.zeros([batch_size, 32], dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype) else: body_mean_pose = body_pose_prior.get_mean().detach().cpu() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2].reshape(1, -1) # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) if use_joints_conf: joints_conf = joints_conf.to(device=device, dtype=dtype) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = { 'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights } if use_face: opt_weights_dict['face_weight'] = face_joints_weights opt_weights_dict['expr_prior_weight'] = expr_weights opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights if use_hands: opt_weights_dict['hand_weight'] = hand_joints_weights opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights keys = opt_weights_dict.keys() opt_weights = [ dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None)) ] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) edge_indices = kwargs.get('body_tri_idxs') init_t = fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length, dtype=dtype) camera_loss = fitting.create_loss('camera_init', trans_estimation=init_t, init_joints_idxs=init_joints_idxs, depth_loss_weight=depth_loss_weight, dtype=dtype).to(device=device) camera_loss.trans_estimation[:] = init_t loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, dtype=dtype, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) H, W, _ = img.shape data_weight = 1000 / H # The closure passed to the optimizer camera_loss.reset_loss_weights({'data_weight': data_weight}) # Reset the parameters to estimate the initial translation of the # body model body_model.reset_params(body_pose=body_mean_pose) # If the distance between the 2D shoulders is smaller than a # predefined threshold then try 2 fits, the initial one and a 180 # degree rotation shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx], gt_joints[:, right_shoulder_idx]) try_both_orient = shoulder_dist.item() < side_view_thsh # Update the value of the translation of the camera as well as # the image center. with torch.no_grad(): camera.translation[:] = init_t.view_as(camera.translation) camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5 # Re-enable gradient calculation for the camera translation camera.translation.requires_grad = True camera_opt_params = [camera.translation, body_model.global_orient] camera_optimizer, camera_create_graph = optim_factory.create_optimizer( camera_opt_params, **kwargs) # The closure passed to the optimizer fit_camera = monitor.create_fitting_closure( camera_optimizer, body_model, camera, gt_joints, camera_loss, create_graph=camera_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_full_pose=False, return_verts=False) # Step 1: Optimize over the torso joints the camera translation # Initialize the computational graph by feeding the initial translation # of the camera and the initial pose of the body model. camera_init_start = time.time() cam_init_loss_val = monitor.run_fitting(camera_optimizer, fit_camera, camera_opt_params, body_model, use_vposer=use_vposer, pose_embedding=pose_embedding, vposer=vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() tqdm.write('Camera initialization done after {:.4f}'.format( time.time() - camera_init_start)) tqdm.write('Camera initialization final loss {:.4f}'.format( cam_init_loss_val)) # If the 2D detections/positions of the shoulder joints are too # close the rotate the body by 180 degrees and also fit to that # orientation if try_both_orient: body_orient = body_model.global_orient.detach().cpu().numpy() flipped_orient = cv2.Rodrigues(body_orient)[0].dot( cv2.Rodrigues(np.array([0., np.pi, 0]))[0]) flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel() flipped_orient = torch.tensor(flipped_orient, dtype=dtype, device=device).unsqueeze(dim=0) orientations = [body_orient, flipped_orient] else: orientations = [body_model.global_orient.detach().cpu().numpy()] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] # Step 2: Optimize the full model final_loss_val = 0 for or_idx, orient in enumerate(tqdm(orientations, desc='Orientation')): opt_start = time.time() new_params = defaultdict(global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) for opt_idx, curr_weights in enumerate( tqdm(opt_weights, desc='Stage')): body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['data_weight'] = data_weight curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) if use_hands: joint_weights[:, 25:76] = curr_weights['hand_weight'] if use_face: joint_weights[:, 76:] = curr_weights['face_weight'] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_verts=True, return_full_pose=True) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write( 'Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write( 'Body fitting Orientation {} done after {:.4f} seconds'. format(or_idx, elapsed)) tqdm.write( 'Body final loss val = {:.5f}'.format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = { 'camera_' + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters() } result.update({ key: val.detach().cpu().numpy() for key, val in body_model.named_parameters() }) if use_vposer: result['body_pose'] = pose_embedding.detach().cpu().numpy() results.append({'loss': final_loss_val, 'result': result}) with open(result_fn, 'wb') as result_file: if len(results) > 1: min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1) else: min_idx = 0 pickle.dump(results[min_idx]['result'], result_file, protocol=2) if save_meshes or visualize: body_pose = vposer.decode(pose_embedding, output_type='aa').view( 1, -1) if use_vposer else None model_type = kwargs.get('model_type', 'smpl') append_wrists = model_type == 'smpl' and use_vposer if append_wrists: wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device) body_pose = torch.cat([body_pose, wrist_pose], dim=1) model_output = body_model(return_verts=True, body_pose=body_pose) vertices = model_output.vertices.detach().cpu().numpy().squeeze() import trimesh out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False) rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) out_mesh.apply_transform(rot) out_mesh.export(mesh_fn) if visualize: import pyrender material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0)) mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material) scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) scene.add(mesh, 'mesh') camera_center = camera.center.detach().cpu().numpy().squeeze() camera_transl = camera.translation.detach().cpu().numpy().squeeze() # Equivalent to 180 degrees around the y-axis. Transforms the fit to # OpenGL compatible coordinate system. camera_transl[0] *= -1.0 camera_pose = np.eye(4) camera_pose[:3, 3] = camera_transl camera = pyrender.camera.IntrinsicsCamera(fx=focal_length, fy=focal_length, cx=camera_center[0], cy=camera_center[1]) scene.add(camera, pose=camera_pose) # Get the lights from the viewer light_nodes = monitor.mv.viewer._create_raymond_lights() for node in light_nodes: scene.add_node(node) r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H, point_size=1.0) color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA) color = color.astype(np.float32) / 255.0 valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] input_img = img.detach().cpu().numpy() output_img = (color[:, :, :-1] * valid_mask + (1 - valid_mask) * input_img) img = pil_img.fromarray((output_img * 255).astype(np.uint8)) img.save(out_img_fn)
def optimize(): scene_mesh, cur_scene_verts, s_grid_min_batch, s_grid_max_batch, s_sdf_batch = read_mesh_sdf(args.dataset_path, args.dataset, args.scene_name) save_path = '{}/{}/{}'.format(args.save_path, args.dataset, args.scene_name) if not os.path.exists(save_path): os.makedirs(save_path) smplx_model = smplx.create(args.smplx_model_path, model_type='smplx', gender='neutral', ext='npz', num_pca_comps=12, create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=True, batch_size=1 ).to(device) print('[INFO] smplx model loaded.') vposer_model, _ = load_vposer(args.vposer_model_path, vp_model='snapshot') vposer_model = vposer_model.to(device) print('[INFO] vposer model loaded') ####################### calculate scene bps representation ################## rot_angle, scene_min_x, scene_max_x, scene_min_y, scene_max_y = define_scene_boundary(args.dataset, args.scene_name) scene_verts_crop_local_list, scene_verts_local_list, = [], [] shift_list = [] rot_angle_list_1, rot_angle_list_2 = [], [] np.random.seed(0) random_seed_list = np.random.randint(10000, size=args.n_sample) for i in tqdm(range(args.n_sample)): scene_verts = rotate_scene_smplx_predefine(cur_scene_verts, rot_angle=rot_angle) scene_verts_local, scene_verts_crop_local, shift = crop_scene_cube_smplx_predifine( scene_verts, r=args.cube_size, with_wall_ceilling=True, random_seed=random_seed_list[i], scene_min_x=scene_min_x, scene_max_x=scene_max_x, scene_min_y=scene_min_y, scene_max_y=scene_max_y, rotate=True) scene_verts_crop_local_list.append(scene_verts_crop_local) # list, different verts num for each cropped scene scene_verts_local_list.append(scene_verts_local) shift_list.append(shift) rot_angle_list_1.append(rot_angle) print('[INFO] scene mesh cropped and shifted.') scene_bps_list, body_bps_list = [], [] scene_bps_verts_global_list, scene_bps_verts_local_list = [], [] scene_bps_verts_global_list = [] scene_basis_set = bps_gen_ball_inside(n_bps=args.num_bps, random_seed=100) np.random.seed(1) random_seed_list = np.random.randint(10000, size=args.n_sample) for i in tqdm(range(args.n_sample)): scene_verts_global, scene_verts_crop_global, rot_angle = \ augmentation_crop_scene_smplx(scene_verts_local_list[i] / args.cube_size, scene_verts_crop_local_list[i] / args.cube_size, random_seed_list[i]) scene_bps, selected_scene_verts_global, selected_ind = bps_encode_scene(scene_basis_set, scene_verts_crop_global) # [n_feat, n_bps] scene_bps_list.append(scene_bps) selected_scene_verts_local = scene_verts_crop_local_list[i][selected_ind] scene_bps_verts_local_list.append(selected_scene_verts_local) scene_bps_verts_global_list.append(selected_scene_verts_global) rot_angle_list_2.append(rot_angle) scene_bps_list = np.asarray(scene_bps_list) # [n_sample*4, n_feat, n_bps] scene_bps_verts_local_list = np.asarray(scene_bps_verts_local_list) scene_verts_local_list = np.asarray(scene_verts_local_list) np.save('{}/scene_bps_list.npy'.format(save_path), scene_bps_list) print('[INFO] scene bps/verts saved.') ######################## set dataloader and load model ######################## dataset = TestLoader() dataset.n_samples = args.n_sample dataset.scene_bps_list = scene_bps_list # [n_sample, n_feat, n_bps] dataset.scene_bps_verts_list = scene_bps_verts_local_list print('[INFO] dataloader updated, select n_samples={}'.format(dataset.__len__())) test_dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False) scene_bps_AE = BPSRecMLP(n_bps=args.num_bps, n_bps_feat=1, hsize1=1024, hsize2=512).to(device) weights = torch.load(args.scene_bps_AE_path, map_location=lambda storage, loc: storage) scene_bps_AE.load_state_dict(weights) c_VAE = BPS_CVAE(n_bps=args.num_bps, n_bps_feat=1, hsize1=1024, hsize2=512, eps_d=args.eps_d).to(device) weights = torch.load(args.cVAE_path, map_location=lambda storage, loc: storage) c_VAE.load_state_dict(weights) scene_AE = Verts_AE(n_bps=10000, hsize1=1024, hsize2=512).to(device) weights = torch.load(args.scene_verts_AE_path, map_location=lambda storage, loc: storage) scene_AE.load_state_dict(weights) body_dec = Body_Dec_shift(n_bps=10000, n_bps_feat=1, hsize1=1024, hsize2=512, n_body_verts=10475, body_param_dim=75, rec_goal='body_verts').to(device) weights = torch.load(args.bodyDec_path, map_location=lambda storage, loc: storage) body_dec.load_state_dict(weights) scene_bps_AE.eval() c_VAE.eval() scene_AE.eval() body_dec.eval() ######################## initialize for optimization ########################## body_verts_sample_list = [] body_bps_sample_list = [] np.random.seed(2) random_seed_list = np.random.randint(10000, size=args.n_sample) with torch.no_grad(): for step, data in tqdm(enumerate(test_dataloader)): [scene_bps, scene_bps_verts] = [item.to(device) for item in data] scene_bps_verts = scene_bps_verts / args.cube_size _, scene_bps_feat = scene_bps_AE(scene_bps) _, scene_bps_verts_feat = scene_AE(scene_bps_verts) # torch.manual_seed(random_seed_list[step]) body_bps_sample = c_VAE.sample(1, scene_bps_feat) # [1, 1, 10000] body_bps_sample_list.append(body_bps_sample[0].detach().cpu().numpy()) # [n, 1, 10000] body_verts_sample, body_shift = body_dec(body_bps_sample, scene_bps_verts_feat) # [1, 3, 10475], unit ball scale, local coordinate # shifted generated body body_shift = body_shift.repeat(1, 1, 10475).reshape([body_verts_sample.shape[0], 10475, 3]) # [bs, 10475, 3] body_verts_sample = body_verts_sample + body_shift.permute(0, 2, 1) # [bs, 3, 10475] body_verts_sample_list.append(body_verts_sample[0].detach().cpu().numpy()) # [n, 3, 10475] contact_part = ['L_Leg', 'R_Leg'] vid, _ = get_contact_id(body_segments_folder=os.path.join(args.prox_dataset_path, 'body_segments'), contact_body_parts=contact_part) ############################### save data ############################### shift_list = np.asarray(shift_list) rot_angle_list_1 = np.asarray(rot_angle_list_1) rot_angle_list_2 = np.asarray(rot_angle_list_2) body_bps_sample_list = np.asarray(body_bps_sample_list) body_verts_sample_list = np.asarray(body_verts_sample_list) np.save('{}/shift_list.npy'.format(save_path), shift_list) np.save('{}/rot_angle_list_1.npy'.format(save_path), rot_angle_list_1) np.save('{}/rot_angle_list_2.npy'.format(save_path), rot_angle_list_2) np.save('{}/body_bps_sample_list.npy'.format(save_path), body_bps_sample_list) # save generated body verts in original scale np.save('{}/body_verts_sample_list.npy'.format(save_path), body_verts_sample_list.transpose((0,2,1))*args.cube_size) ############################### optimization ################################## if args.optimize: #################### stage 1 (simple optimization, without contact/collision loss) ################### print('[INFO] start optimization stage 1...') body_params_opt_list_s1 = [] for cnt in range(args.n_sample): print('stage 1: current cnt:', cnt) body_params_rec = torch.randn(1, 72).to(device) # initiliza smplx params, bs=1, local coordinate system body_params_rec[0, 0] = 0.0 body_params_rec[0, 1] = 0.0 body_params_rec[0, 2] = 0.0 body_params_rec[0, 3] = 1.5 body_params_rec[0, 4] = 0.0 body_params_rec[0, 5] = 0.0 body_params_rec = convert_to_6D_rot(body_params_rec) body_params_rec.requires_grad = True optimizer = optim.Adam([body_params_rec], lr=0.1) body_bps = torch.from_numpy(body_bps_sample_list[cnt]).float().unsqueeze(0).to(device) # [bs=1, 1, 10000] body_verts = torch.from_numpy(body_verts_sample_list[cnt]).float().unsqueeze(0).to(device) # [bs=1, 3, 10475] body_verts = body_verts.permute(0, 2, 1) # [1, 10475, 3] body_verts = body_verts * args.cube_size # to local coordinate system scale for step in tqdm(range(args.itr_s1)): if step > 100: for param_group in optimizer.param_groups: param_group['lr'] = 0.01 if step > 300: for param_group in optimizer.param_groups: param_group['lr'] = 0.001 optimizer.zero_grad() body_params_rec_72 = convert_to_3D_rot(body_params_rec) # tensor, [bs=1, 72] body_pose_joint = vposer_model.decode(body_params_rec_72[:, 16:48], output_type='aa').view(1, -1) # tensor, [bs=1, 63] body_verts_rec = gen_body_mesh(body_params_rec_72, body_pose_joint, smplx_model)[0] # tensor, [n_body_vert, 3] # transform body verts to unit ball global coordinate system temp = body_verts_rec / args.cube_size # scale into unit ball body_verts_rec_global = torch.zeros(body_verts_rec.shape).to(device) body_verts_rec_global[:, 0] = temp[:, 0] * math.cos(rot_angle_list_2[cnt]) - \ temp[:, 1] * math.sin(rot_angle_list_2[cnt]) body_verts_rec_global[:, 1] = temp[:, 0] * math.sin(rot_angle_list_2[cnt]) + \ temp[:, 1] * math.cos(rot_angle_list_2[cnt]) body_verts_rec_global[:, 2] = temp[:, 2] # calculate optimized body bps feature body_bps_rec = torch.zeros(body_bps.shape) if args.weight_loss_rec_bps > 0: nbrs = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm="ball_tree").fit(body_verts_rec_global.detach().cpu().numpy()) neigh_dist, neigh_ind = nbrs.kneighbors(scene_bps_verts_global_list[cnt]) body_bps_rec = body_verts_rec_global[neigh_ind.squeeze()] - \ torch.from_numpy(scene_bps_verts_global_list[cnt]).float().to(device) # [n_bps, 3] body_bps_rec = torch.sqrt(body_bps_rec[:, 0] ** 2 + body_bps_rec[:, 1] ** 2 + body_bps_rec[:, 2] ** 2).unsqueeze(0).unsqueeze(0) # [bs=1, 1, n_bps] ### body bps feature reconstruct loss loss_rec_verts = F.l1_loss(body_verts_rec.unsqueeze(0), body_verts) loss_rec_bps = F.l1_loss(body_bps, body_bps_rec) ### vposer loss body_params_rec_72 = convert_to_3D_rot(body_params_rec) vposer_pose = body_params_rec_72[:, 16:48] loss_vposer = torch.mean(vposer_pose ** 2) ### shape prior loss shape_params = body_params_rec_72[:, 6:16] loss_shape = torch.mean(shape_params ** 2) ### hand pose prior loss hand_params = body_params_rec_72[:, 48:] loss_hand = torch.mean(hand_params ** 2) loss = args.weight_loss_rec_verts * loss_rec_verts + args.weight_loss_rec_bps * loss_rec_bps + \ args.weight_loss_vposer * loss_vposer + \ args.weight_loss_shape * loss_shape + \ args.weight_loss_hand * loss_hand loss.backward(retain_graph=True) optimizer.step() body_params_opt_list_s1.append(body_params_rec[0].detach().cpu().numpy()) body_params_opt_list_s1 = np.asarray(body_params_opt_list_s1) np.save('{}/body_params_opt_list_s1.npy'.format(save_path), body_params_opt_list_s1) ################ stage 2 (advanced optimization, with contact/collision loss) ################## print('[INFO] start optimization stage 2...') body_params_opt_list_s2 = [] for cnt in range(args.n_sample): print('current cnt:', cnt) body_params_rec = body_params_opt_list_s1[cnt] # [75] body_params_rec = torch.from_numpy(body_params_rec).float().to(device).unsqueeze(0) body_params_rec.requires_grad = True optimizer = optim.Adam([body_params_rec], lr=0.01) body_bps = torch.from_numpy(body_bps_sample_list[cnt]).float().unsqueeze(0).to(device) # [bs=1, 1, 10000] body_verts = torch.from_numpy(body_verts_sample_list[cnt]).float().unsqueeze(0).to(device) body_verts = body_verts.permute(0, 2, 1) # [1, 10475, 3] body_verts = body_verts * args.cube_size # to local coordinate system scale for step in tqdm(range(args.itr_s2)): optimizer.zero_grad() body_params_rec_72 = convert_to_3D_rot(body_params_rec) # tensor, [bs=1, 72] body_pose_joint = vposer_model.decode(body_params_rec_72[:, 16:48], output_type='aa').view(1,-1) # tensor, [bs=1, 63] body_verts_rec = gen_body_mesh(body_params_rec_72, body_pose_joint, smplx_model)[0] # tensor, [n_body_vert, 3] # transform body verts to unit ball global coordinate temp = body_verts_rec / args.cube_size # scale into unit ball body_verts_rec_global = torch.zeros(body_verts_rec.shape).to(device) body_verts_rec_global[:, 0] = temp[:, 0] * math.cos(rot_angle_list_2[cnt]) - \ temp[:, 1] * math.sin(rot_angle_list_2[cnt]) body_verts_rec_global[:, 1] = temp[:, 0] * math.sin(rot_angle_list_2[cnt]) + \ temp[:, 1] * math.cos(rot_angle_list_2[cnt]) body_verts_rec_global[:, 2] = temp[:, 2] # calculate body_bps_rec body_bps_rec = torch.zeros(body_bps.shape) if args.weight_loss_rec_bps > 0: nbrs = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm="ball_tree").fit(body_verts_rec_global.detach().cpu().numpy()) neigh_dist, neigh_ind = nbrs.kneighbors(scene_bps_verts_global_list[cnt]) body_bps_rec = body_verts_rec_global[neigh_ind.squeeze()] - \ torch.from_numpy(scene_bps_verts_global_list[cnt]).float().to(device) # [n_bps, 3] body_bps_rec = torch.sqrt(body_bps_rec[:, 0] ** 2 + body_bps_rec[:, 1] ** 2 + body_bps_rec[:, 2] ** 2).unsqueeze(0).unsqueeze(0) # [bs=1, 1, n_bps] ### body bps encoding reconstruct loss loss_rec_verts = F.l1_loss(body_verts_rec.unsqueeze(0), body_verts) loss_rec_bps = F.l1_loss(body_bps, body_bps_rec) ### vposer loss body_params_rec_72 = convert_to_3D_rot(body_params_rec) vposer_pose = body_params_rec_72[:, 16:48] loss_vposer = torch.mean(vposer_pose ** 2) ### shape prior loss shape_params = body_params_rec_72[:, 6:16] loss_shape = torch.mean(shape_params ** 2) ### hand pose prior loss hand_params = body_params_rec_72[:, 48:] loss_hand = torch.mean(hand_params ** 2) # transfrom local body_verts_rec to prox coordinate system body_verts_rec_prox = torch.zeros(body_verts_rec.shape).to(device) temp = body_verts_rec - torch.from_numpy(shift_list[cnt]).float().to(device) body_verts_rec_prox[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \ temp[:, 1] * math.sin(-rot_angle_list_1[cnt]) body_verts_rec_prox[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \ temp[:, 1] * math.cos(-rot_angle_list_1[cnt]) body_verts_rec_prox[:, 2] = temp[:, 2] body_verts_rec_prox = body_verts_rec_prox.unsqueeze(0) # tensor, [bs=1, 10475, 3] ### sdf collision loss norm_verts_batch = (body_verts_rec_prox - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) * 2 - 1 n_verts = norm_verts_batch.shape[1] body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1), norm_verts_batch[:, :, [2, 1, 0]].view(-1, n_verts, 1, 1, 3), padding_mode='border') # if there are no penetrating vertices then set sdf_penetration_loss = 0 if body_sdf_batch.lt(0).sum().item() < 1: loss_collision = torch.tensor(0.0, dtype=torch.float32).to(device) else: loss_collision = body_sdf_batch[body_sdf_batch < 0].abs().mean() ### contact loss body_verts_contact = body_verts_rec.unsqueeze(0)[:, vid, :] # [1,1121,3] dist_chamfer_contact = ext.chamferDist() # scene_verts: [bs=1, n_scene_verts, 3] scene_verts = torch.from_numpy(scene_verts_local_list[cnt]).float().to(device).unsqueeze(0) # [1,50000,3] contact_dist, _ = dist_chamfer_contact(body_verts_contact.contiguous(), scene_verts.contiguous()) loss_contact = torch.mean(torch.sqrt(contact_dist + 1e-4) / (torch.sqrt(contact_dist + 1e-4) + 1.0)) loss = args.weight_loss_rec_verts * loss_rec_verts + args.weight_loss_rec_bps * loss_rec_bps + \ args.weight_loss_vposer * loss_vposer + \ args.weight_loss_shape * loss_shape + \ args.weight_loss_hand * loss_hand + \ args.weight_collision * loss_collision + args.weight_loss_contact * loss_contact loss.backward(retain_graph=True) optimizer.step() body_params_opt_list_s2.append(body_params_rec[0].detach().cpu().numpy()) body_params_opt_list_s2 = np.asarray(body_params_opt_list_s2) np.save('{}/body_params_opt_list_s2.npy'.format(save_path), body_params_opt_list_s2)