def render_video(mesh_dir, video_fn, overwrite=False): from psbody.mesh import Mesh, MeshViewer from os.path import join, exists, splitext from glob import glob import tempfile from subprocess import call from pickle import load import numpy as np from tqdm import tqdm if exists(video_fn): if overwrite: print("File {0} exists, removing it and remaking it".format( video_fn)) call(['rm', '-rf', video_fn]) else: print("File {0} exists, not re-rendering".format(video_fn)) return files_seq = sorted(glob(join(mesh_dir, '*.obj'))) if len(files_seq) == 0: print('No files to render in {}'.format(mesh_dir)) return # Load the meshes print("Loading meshes from {}..".format(mesh_dir)) meshes = [] for fn in files_seq: meshes.append(Mesh(filename=fn)) from shutil import rmtree from tempfile import mkdtemp tmp_folder = str(mkdtemp()) if exists(tmp_folder): rmtree(tmp_folder) from os import mkdir mkdir(tmp_folder) mv = MeshViewer(window_width=1000, window_height=800) print('Rendering extracted meshes (tmp file, auto-removed later)..') for k, mesh in enumerate(tqdm(meshes)): mv.set_dynamic_meshes([mesh]) mv.save_snapshot(join(tmp_folder, '{:0>6d}.png'.format(k)), blocking=True) cmd = [ 'ffmpeg', '-i', '{0}/%06d.png'.format(tmp_folder), '-vcodec', 'h264', '-pix_fmt', 'yuv420p', '-r', '15', '-an', '-b:v', '5000k', video_fn ] call(cmd) rmtree(tmp_folder)
def optimize_pose_only(th_scan_meshes, smplx, iterations, steps_per_iter, scan_part_labels, smplx_part_labels, search_tree=None, pen_distance=None, tri_filtering_module=None, display=None): """ Initially we want to only optimize the global rotation of SMPLX. Next we optimize full pose. We optimize pose based on the 3D keypoints in th_pose_3d. :param th_pose_3d: array containing the 3D keypoints. """ batch_sz = 1 # smplx.pose.shape[0] split_smplx = th_batch_SMPLX_split_params( batch_sz, top_betas=smplx.betas.data[:, :2], other_betas=smplx.betas.data[:, 2:], global_pose=smplx.global_pose.data, body_pose=smplx.body_pose.data, left_hand_pose=smplx.left_hand_pose.data, right_hand_pose=smplx.right_hand_pose.data, expression=smplx.expression.data, jaw_pose=smplx.jaw_pose.data, leye_pose=smplx.leye_pose.data, reye_pose=smplx.reye_pose.data, faces=smplx.faces, gender=smplx.gender).to(DEVICE) # split_smplx.expression.requires_grad = False # split_smplx.jaw_pose.requires_grad = False optimizer = torch.optim.Adam( [split_smplx.trans, split_smplx.top_betas, split_smplx.global_pose], 0.02, betas=(0.9, 0.999)) # Get loss_weights weight_dict = get_loss_weights() if display is not None: assert int(display) < len(th_scan_meshes) # mvs = MeshViewers((1,1)) mv = MeshViewer(keepalive=True) iter_for_global = 1 for it in range(iter_for_global + iterations): loop = tqdm(range(steps_per_iter)) if it < iter_for_global: # Optimize global orientation print('Optimizing SMPLX global orientation') loop.set_description('Optimizing SMPLX global orientation') elif it == iter_for_global: # Now optimize full SMPLX pose print('Optimizing SMPLX pose only') loop.set_description('Optimizing SMPLX pose only') optimizer = torch.optim.Adam([ split_smplx.trans, split_smplx.top_betas, split_smplx.global_pose, split_smplx.body_pose, split_smplx.left_hand_pose, split_smplx.right_hand_pose ], 0.02, betas=(0.9, 0.999)) else: loop.set_description('Optimizing SMPLX pose only') for i in loop: optimizer.zero_grad() # Get losses for a forward pass loss_dict = forward_step(th_scan_meshes, split_smplx, scan_part_labels, smplx_part_labels, search_tree, pen_distance, tri_filtering_module) # Get total loss for backward pass tot_loss = backward_step(loss_dict, weight_dict, it) tot_loss.backward() optimizer.step() l_str = 'Iter: {}'.format(i) for k in loss_dict: l_str += ', {}: {:0.4f}'.format( k, weight_dict[k](loss_dict[k], it).mean().item()) loop.set_description(l_str) if display is not None: # verts, _, _, _ = split_smplx() verts = split_smplx() smplx_mesh = Mesh(v=verts[display].cpu().detach().numpy(), f=smplx.faces.cpu().numpy()) scan_mesh = Mesh( v=th_scan_meshes[display].vertices.cpu().detach().numpy(), f=th_scan_meshes[display].faces.cpu().numpy(), vc=np.array([0, 1, 0])) scan_mesh.set_vertex_colors_from_weights( scan_part_labels[display].cpu().detach().numpy()) mv.set_dynamic_meshes([smplx_mesh, scan_mesh]) # Put back pose, shape and trans into original smplx smplx.global_pose.data = split_smplx.global_pose.data smplx.body_pose.data = split_smplx.body_pose.data smplx.left_hand_pose.data = split_smplx.left_hand_pose.data smplx.right_hand_pose.data = split_smplx.right_hand_pose.data # smplx.jaw_pose.data = split_smplx.jaw_pose.data smplx.leye_pose.data = split_smplx.leye_pose.data smplx.reye_pose.data = split_smplx.reye_pose.data smplx.betas.data = split_smplx.betas.data smplx.trans.data = split_smplx.trans.data print('** Optimised smplx pose **')
def optimize_pose_only(th_scan_meshes, smpl, iterations, steps_per_iter, th_pose_3d, prior_weight, display=None): """ Initially we want to only optimize the global rotation of SMPL. Next we optimize full pose. We optimize pose based on the 3D keypoints in th_pose_3d. :param th_pose_3d: array containing the 3D keypoints. :param prior_weight: weights corresponding to joints depending on visibility of the joint in the 3D scan. eg: hand could be inside pocket. """ batch_sz = smpl.pose.shape[0] split_smpl = th_batch_SMPL_split_params(batch_sz, top_betas=smpl.betas.data[:, :2], other_betas=smpl.betas.data[:, 2:], global_pose=smpl.pose.data[:, :3], other_pose=smpl.pose.data[:, 3:], faces=smpl.faces, gender=smpl.gender).cuda() optimizer = torch.optim.Adam( [split_smpl.trans, split_smpl.top_betas, split_smpl.global_pose], 0.02, betas=(0.9, 0.999)) # Get loss_weights weight_dict = get_loss_weights() if display is not None: assert int(display) < len(th_scan_meshes) # mvs = MeshViewers((1,1)) mv = MeshViewer(keepalive=True) iter_for_global = 1 for it in range(iter_for_global + iterations): loop = tqdm(range(steps_per_iter)) if it < iter_for_global: # Optimize global orientation print('Optimizing SMPL global orientation') loop.set_description('Optimizing SMPL global orientation') elif it == iter_for_global: # Now optimize full SMPL pose print('Optimizing SMPL pose only') loop.set_description('Optimizing SMPL pose only') optimizer = torch.optim.Adam([ split_smpl.trans, split_smpl.top_betas, split_smpl.global_pose, split_smpl.other_pose ], 0.02, betas=(0.9, 0.999)) else: loop.set_description('Optimizing SMPL pose only') for i in loop: optimizer.zero_grad() # Get losses for a forward pass loss_dict = forward_step_pose_only(split_smpl, th_pose_3d, prior_weight) # Get total loss for backward pass tot_loss = backward_step(loss_dict, weight_dict, it) tot_loss.backward() optimizer.step() l_str = 'Iter: {}'.format(i) for k in loss_dict: l_str += ', {}: {:0.4f}'.format( k, weight_dict[k](loss_dict[k], it).mean().item()) loop.set_description(l_str) if display is not None: verts, _, _, _ = split_smpl() smpl_mesh = Mesh(v=verts[display].cpu().detach().numpy(), f=smpl.faces.cpu().numpy()) scan_mesh = Mesh( v=th_scan_meshes[display].vertices.cpu().detach().numpy(), f=th_scan_meshes[display].faces.cpu().numpy(), vc=np.array([0, 1, 0])) mv.set_dynamic_meshes([smpl_mesh, scan_mesh]) # from matplotlib import cm # col = cm.tab20c(np.arange(len(th_pose_3d[display]['pose_keypoints_3d'])) % 20)[:, :3] # # jts, _, _ = split_smpl.get_landmarks() # Js = plot_points(jts[display].detach().cpu().numpy(), cols=col) # Js_observed = plot_points(th_pose_3d[display]['pose_keypoints_3d'][:, :3].numpy(), cols=col) # mvs[0][0].set_static_meshes([smpl_mesh, scan_mesh]) # mvs[0][1].set_static_meshes(Js) # mvs[0][2].set_static_meshes(Js_observed) # Put back pose, shape and trans into original smpl smpl.pose.data = split_smpl.pose.data smpl.betas.data = split_smpl.betas.data smpl.trans.data = split_smpl.trans.data print('** Optimised smpl pose **')