Example #1
0
def __get_smpl_model(demo_type, smpl_type):
    smplx_model_path = './extra_data/smpl/SMPLX_NEUTRAL.pkl'
    smpl_model_path = './extra_data/smpl//basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'

    if demo_type == 'hand':
        # use original smpl-x
        smpl = smplx.create(smplx_model_path,
                            model_type="smplx",
                            batch_size=1,
                            gender='neutral',
                            num_betas=10,
                            use_pca=False,
                            ext='pkl')
    else:
        if smpl_type == 'smplx':
            # use modified smpl-x from body module
            smpl = SMPLX(smplx_model_path,
                         batch_size=1,
                         num_betas=10,
                         use_pca=False,
                         create_transl=False)
        else:
            # use modified smpl from body module
            assert smpl_type == 'smpl'
            smpl = SMPL(smpl_model_path, batch_size=1, create_transl=False)
    return smpl
Example #2
0
def main(model_folder,
         corr_fname,
         ext='npz',
         head_color=(0.3, 0.3, 0.6),
         gender='neutral'):

    head_idxs = np.load(corr_fname)

    model = smplx.create(model_folder,
                         model_type='smplx',
                         gender=gender,
                         ext=ext)
    betas = torch.zeros([1, 10], dtype=torch.float32)
    expression = torch.zeros([1, 10], dtype=torch.float32)

    output = model(betas=betas, expression=expression, return_verts=True)
    vertices = output.vertices.detach().cpu().numpy().squeeze()
    joints = output.joints.detach().cpu().numpy().squeeze()

    print('Vertices shape =', vertices.shape)
    print('Joints shape =', joints.shape)

    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    mesh.triangles = o3d.utility.Vector3iVector(model.faces)
    mesh.compute_vertex_normals()

    colors = np.ones_like(vertices) * [0.3, 0.3, 0.3]
    colors[head_idxs] = head_color

    mesh.vertex_colors = o3d.utility.Vector3dVector(colors)

    o3d.visualization.draw_geometries([mesh])
Example #3
0
def get_all_smpl(pkl_data, json_data):
    gender = json_data['people'][0]['gender_gt']
    all_meshes = []

    trans = np.array([4, 0, 0])

    for i, result in enumerate(pkl_data['all_results']):
        t = trans + [0, i * 3, 0]
        betas = torch.Tensor(result['betas']).unsqueeze(0)
        pose = torch.Tensor(result['body_pose']).unsqueeze(0)
        transl = torch.Tensor(result['transl']).unsqueeze(0)
        global_orient = torch.Tensor(result['global_orient']).unsqueeze(0)

        model = smplx.create('models', model_type='smpl', gender=gender)
        output = model(betas=betas, body_pose=pose, transl=transl, global_orient=global_orient, return_verts=True)
        smpl_vertices = output.vertices.detach().cpu().numpy().squeeze()

        smpl_o3d = o3d.TriangleMesh()
        smpl_o3d.triangles = o3d.Vector3iVector(model.faces)
        smpl_o3d.vertices = o3d.Vector3dVector(smpl_vertices)
        smpl_o3d.compute_vertex_normals()
        smpl_o3d.translate(t)

        for idx, key in enumerate(result['loss_dict'].keys()):
            lbl = '{} {:.2f}'.format(key, float(result['loss_dict'][key]))
            all_meshes.append(text_3d(lbl, t + [1, idx * 0.2 - 1, 2], direction=(0.01, 0, -1), degree=-90, font_size=150, density=0.2))

        all_meshes.append(smpl_o3d)

    return all_meshes
Example #4
0
    def __init__(self, trainconfig, lossconfig):
        for key, val in trainconfig.items():
            setattr(self, key, val)


        for key, val in lossconfig.items():
            setattr(self, key, val)

        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)

        if len(self.ckp_dir) > 0:
            self.resume_training=True

        ### define model

        if self.use_cont_rot:
            n_dim_body=72+3
        else:
            n_dim_body=72

        self.model_h_latentD = 256
        self.model_h = HumanCVAES2(latentD_g=self.model_h_latentD,
                                     latentD_l=self.model_h_latentD,
                                     scene_model_ckpt=self.scene_model_ckpt,
                                     n_dim_body=n_dim_body,
                                     n_dim_scene=self.model_h_latentD)

        self.optimizer_h = optim.Adam(self.model_h.parameters(), 
                                      lr=self.init_lr_h)


        ### body mesh model
        self.vposer, _ = load_vposer(self.vposer_ckpt_path, vp_model='snapshot')
        self.body_mesh_model = smplx.create(self.human_model_path, 
                                            model_type='smplx',
                                            gender='neutral', ext='npz',
                                            num_pca_comps=12,
                                            create_global_orient=True,
                                            create_body_pose=True,
                                            create_betas=True,
                                            create_left_hand_pose=True,
                                            create_right_hand_pose=True,
                                            create_expression=True,
                                            create_jaw_pose=True,
                                            create_leye_pose=True,
                                            create_reye_pose=True,
                                            create_transl=True,
                                            batch_size=self.batch_size
                                            )

        self.smplx_face_idx = np.load(os.path.join(self.human_model_path, 
                                        'smplx/SMPLX_NEUTRAL.npz'),
                                    allow_pickle=True)['f'].reshape(-1,3)
        self.smplx_face_idx = torch.tensor(self.smplx_face_idx.astype(np.int64), 
                                            device=self.device)

        print('--[INFO] device: '+str(torch.cuda.get_device_name(self.device)) )
Example #5
0
def get_smpl(pkl_data, json_data):
    gender = json_data['people'][0]['gender_gt']
    print('Target height {}, weight {}'.format(json_data['people'][0]['height'], json_data['people'][0]['weight']))

    betas = torch.Tensor(pkl_data['betas']).unsqueeze(0)
    pose = torch.Tensor(pkl_data['body_pose']).unsqueeze(0)
    transl = torch.Tensor(pkl_data['transl']).unsqueeze(0)
    global_orient = torch.Tensor(pkl_data['global_orient']).unsqueeze(0)

    model = smplx.create('models', model_type='smpl', gender=gender)
    output = model(betas=betas, body_pose=pose, transl=transl, global_orient=global_orient, return_verts=True)
    smpl_vertices = output.vertices.detach().cpu().numpy().squeeze()
    smpl_joints = output.joints.detach().cpu().numpy().squeeze()

    output_unposed = model(betas=betas, body_pose=pose * 0, transl=transl, global_orient=global_orient, return_verts=True)
    smpl_vertices_unposed = output_unposed.vertices.detach().cpu().numpy().squeeze()

    for i, lbl in enumerate(['Wingspan', 'Height', 'Thickness']):
        print('Actual', lbl, smpl_vertices_unposed[:, i].max() - smpl_vertices_unposed[:, i].min(), end=' ')
    print()

    smpl_trimesh = trimesh.Trimesh(vertices=np.asarray(smpl_vertices_unposed), faces=model.faces)
    print('Est weight from volume', smpl_trimesh.volume * 1.03 * 1000)
    # print('Pose embedding', pkl_data['pose_embedding'])
    # print('Body pose', np.array2string(pkl_data['body_pose'], separator=', '))

    smpl_o3d = o3d.TriangleMesh()
    smpl_o3d.triangles = o3d.Vector3iVector(model.faces)
    smpl_o3d.vertices = o3d.Vector3dVector(smpl_vertices)
    smpl_o3d.compute_vertex_normals()
    # smpl_o3d.paint_uniform_color([0.3, 0.3, 0.3])

    smpl_o3d_2 = o3d.TriangleMesh()
    smpl_o3d_2.triangles = o3d.Vector3iVector(model.faces)
    smpl_o3d_2.vertices = o3d.Vector3dVector(smpl_vertices + np.array([1.5, 0, 0]))
    smpl_o3d_2.compute_vertex_normals()
    smpl_o3d_2.paint_uniform_color([0.7, 0.3, 0.3])

    # Visualize SMPL joints - Patrick

    camera = PerspectiveCamera(rotation=torch.tensor(pkl_data['camera_rotation']).unsqueeze(0),
                               translation=torch.tensor(pkl_data['camera_translation']).unsqueeze(0),
                               center=torch.tensor(pkl_data['camera_center']),
                               focal_length_x=torch.tensor(pkl_data['camera_focal_length_x']),
                               focal_length_y=torch.tensor(pkl_data['camera_focal_length_y']))

    gt_pos_3d = camera.inverse_camera_tform(torch.tensor(pkl_data['gt_joints']).unsqueeze(0), 1.8).detach().squeeze(0).cpu().numpy()

    all_markers = []
    for i in range(25):
        color = cm.jet(i / 25.0)[:3]
        # smpl_marker = get_o3d_sphere(color=color, pos=smpl_joints[i, :])
        # all_markers.append(smpl_marker)

        pred_marker = get_o3d_sphere(color=color, pos=gt_pos_3d[i, :], radius=0.03)
        all_markers.append(pred_marker)

    return smpl_vertices, model.faces, smpl_o3d, smpl_o3d_2, all_markers
Example #6
0
def create_smpl_model(**args):
    assert os.path.exists(
        args['smpl_model_folder']
    ), 'Path {} does not exist in argument smpl_model_folder!'.format(
        args['smpl_model_folder'])

    model_params['model_path'] = args['smpl_model_folder']
    body_model = smplx.create(model_type='smpl', gender='male', **model_params)
    body_model = body_model.to(device=DEVICE)
    return body_model
Example #7
0
    def __init__(self, fittingconfig, lossconfig):


        for key, val in fittingconfig.items():
            setattr(self, key, val)


        for key, val in lossconfig.items():
            setattr(self, key, val)


        self.vposer, _ = load_vposer(self.vposer_ckpt_path, 
                                     vp_model='snapshot')
        self.body_mesh_model = smplx.create(self.human_model_path, 
                                       model_type='smplx',
                                       gender='neutral', ext='npz',
                                       num_pca_comps=12,
                                       create_global_orient=True,
                                       create_body_pose=True,
                                       create_betas=True,
                                       create_left_hand_pose=True,
                                       create_right_hand_pose=True,
                                       create_expression=True,
                                       create_jaw_pose=True,
                                       create_leye_pose=True,
                                       create_reye_pose=True,
                                       create_transl=True,
                                       batch_size=self.batch_size
                                       )
        self.vposer.to(self.device)
        self.body_mesh_model.to(self.device)

        self.xhr_rec = Variable(torch.randn(1,75).to(self.device), requires_grad=True)
        self.optimizer = optim.Adam([self.xhr_rec], lr=self.init_lr_h)




        ## read scene sdf
        with open(self.scene_sdf_path+'.json') as f:
                sdf_data = json.load(f)
                grid_min = np.array(sdf_data['min'])
                grid_max = np.array(sdf_data['max'])
                grid_dim = sdf_data['dim']
        sdf = np.load(self.scene_sdf_path + '_sdf.npy').reshape(grid_dim, grid_dim, grid_dim)

        self.s_grid_min_batch = torch.tensor(grid_min, dtype=torch.float32, device=self.device).unsqueeze(0)
        self.s_grid_max_batch = torch.tensor(grid_max, dtype=torch.float32, device=self.device).unsqueeze(0)
        self.s_sdf_batch = torch.tensor(sdf, dtype=torch.float32, device=self.device).unsqueeze(0)

        ## read scene vertices
        scene_o3d = o3d.io.read_triangle_mesh(self.scene_verts_path)
        scene_verts = np.asarray(scene_o3d.vertices)
        self.s_verts_batch = torch.tensor(scene_verts, dtype=torch.float32, device=self.device).unsqueeze(0)
Example #8
0
def get_smpl_vertices(betas,
                      expression,
                      smpl_file_name: str = None,
                      texture_file_name: str = None,
                      uv_map_file_name: str = None,
                      body_pose: torch.Tensor = None,
                      return_betas_exps=False) -> np.array:
    """
    Load SMPL model, texture file and uv-map.
    Set arm angles and convert to mesh.

    Parameters
    ----------
    betas: np.array(1, 10)
        Betas for smpl
    expression: np.array(1, 10)
        Expression for smpl
    smpl_file_name : str
        file name of smpl model (.pkl).
    texture_file_name : str
        file name of texture for smpl (.jpg).
    uv_map_file_name : str
        file name of uv-map for smpl (.npy).
    right_arm_angle : float, optional
        desired right arm angle in radians. The default is 0..
    left_arm_angle : float, optional
        desired left arm angle in radians. The default is 0.
    body_pose : torch.Tensor[1, 69]
        Body poses for SMPL

    Returns
    -------
    vertices : np.array
        SMPL mesh with texture and desired body pose.

    """
    if smpl_file_name is None:
        smpl_file_name = "SMPLs/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl"
    if texture_file_name is None:
        texture_file_name = "textures/female1.jpg"
    if uv_map_file_name is None:
        uv_map_file_name = "textures/smpl_uv_map.npy"
    model = smplx.create(smpl_file_name, model_type='smpl')
    # set betas and expression to fixed values
    betas = torch.tensor(betas).float()
    expression = torch.tensor(expression).float()
    output = model(betas=betas,
                   expression=expression,
                   return_verts=True,
                   body_pose=body_pose)
    vertices = output.vertices.detach().cpu().numpy().squeeze()
    return vertices
Example #9
0
 def __init__(self, model_type, body_models_dp, device):
     super().__init__()
     self.device = device
     self.models = dict()
     for gender in ['female', 'male']:
         model = smplx.create(
             body_models_dp,
             model_type=model_type,
             gender=gender,
             batch_size=1,
             create_transl=False
         ).to(device=self.device)
         model.eval()
         self.models[gender] = model
def main(args, fitting_dir):
    recording_name = os.path.abspath(fitting_dir).split("/")[-1]
    fitting_dir = osp.join(fitting_dir, 'results')
    scene_name = recording_name.split("_")[0]
    print("scene_name")
    base_dir = args.base_dir
    cam2world_dir = osp.join(base_dir, 'cam2world')
    scene_dir = osp.join(base_dir, 'scenes')

    female_subjects_ids = [162, 3452, 159, 3403]
    subject_id = int(recording_name.split('_')[1])
    if subject_id in female_subjects_ids:
        gender = 'female'
    else:
        gender = 'male'

    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        trans = np.array(json.load(f))

    model = smplx.create(args.model_folder,
                         model_type='smplx',
                         gender=gender,
                         ext='npz',
                         num_pca_comps=args.num_pca_comps,
                         create_global_orient=True,
                         create_body_pose=True,
                         create_betas=True,
                         create_left_hand_pose=True,
                         create_right_hand_pose=True,
                         create_expression=True,
                         create_jaw_pose=True,
                         create_leye_pose=True,
                         create_reye_pose=True,
                         create_transl=True)
    lis = sorted(os.listdir(fitting_dir))

    arr = torch.zeros((len(lis), 103))
    for i, img_name in enumerate(lis):
        try:
            vec = get_vec(osp.join(fitting_dir, img_name, '000.pkl'))
        except:
            arr[i, :] = arr[i - 1, :]
            continue
        trans_vec = transform(vec, trans, model)
        # print(trans_vec.shape)
        arr[i, :] = trans_vec[0, :]
    arr = vec_to_cont(arr)
    np.save("../dataset/PROXD_cleaned/" + recording_name, arr.numpy())
    print("recording_name compiled all dataset")
Example #11
0
def vis_sequence(cfg,sequence, mv):

        seq_data = parse_npz(sequence)
        n_comps = seq_data['n_comps']
        gender = seq_data['gender']

        T = seq_data.n_frames

        sbj_mesh = os.path.join(grab_path, '..', seq_data.body.vtemp)
        sbj_vtemp = np.array(Mesh(filename=sbj_mesh).vertices)

        sbj_m = smplx.create(model_path=cfg.model_path,
                             model_type='smplx',
                             gender=gender,
                             num_pca_comps=n_comps,
                             v_template=sbj_vtemp,
                             batch_size=T)

        sbj_parms = params2torch(seq_data.body.params)
        verts_sbj = to_cpu(sbj_m(**sbj_parms).vertices)


        obj_mesh = os.path.join(grab_path, '..', seq_data.object.object_mesh)
        obj_mesh = Mesh(filename=obj_mesh)
        obj_vtemp = np.array(obj_mesh.vertices)
        obj_m = ObjectModel(v_template=obj_vtemp,
                            batch_size=T)
        obj_parms = params2torch(seq_data.object.params)
        verts_obj = to_cpu(obj_m(**obj_parms).vertices)

        table_mesh = os.path.join(grab_path, '..', seq_data.table.table_mesh)
        table_mesh = Mesh(filename=table_mesh)
        table_vtemp = np.array(table_mesh.vertices)
        table_m = ObjectModel(v_template=table_vtemp,
                            batch_size=T)
        table_parms = params2torch(seq_data.table.params)
        verts_table = to_cpu(table_m(**table_parms).vertices)

        skip_frame = 4
        for frame in range(0,T, skip_frame):
            o_mesh = Mesh(vertices=verts_obj[frame], faces=obj_mesh.faces, vc=colors['yellow'])
            o_mesh.set_vertex_colors(vc=colors['red'], vertex_ids=seq_data['contact']['object'][frame] > 0)

            s_mesh = Mesh(vertices=verts_sbj[frame], faces=sbj_m.faces, vc=colors['pink'], smooth=True)
            s_mesh.set_vertex_colors(vc=colors['red'], vertex_ids=seq_data['contact']['body'][frame] > 0)

            t_mesh = Mesh(vertices=verts_table[frame], faces=table_mesh.faces, vc=colors['white'])

            mv.set_static_meshes([o_mesh, s_mesh, t_mesh])
Example #12
0
def get_smpl(joints, axes, amounts, translation=(0, 0, 0)):
    model = smplx.create('models', model_type='smpl', gender='male')

    body_pose = torch.zeros([1, 69])
    for i in range(len(joints)):
        pose_index = int(joints[i] * 3 + axes[i])
        body_pose[0, pose_index] = axang_mean[
            pose_index] + axang_var[pose_index] * (amounts[i] * 2 - 1)

    output = model(body_pose=torch.Tensor(body_pose), return_verts=True)
    smpl_vertices = output.vertices.detach().cpu().numpy().squeeze()

    smpl_o3d = o3d.TriangleMesh()
    smpl_o3d.triangles = o3d.Vector3iVector(model.faces)
    smpl_o3d.vertices = o3d.Vector3dVector(smpl_vertices +
                                           np.array(translation))
    smpl_o3d.compute_vertex_normals()
    smpl_o3d.paint_uniform_color([amounts[0] / 2 + 0.5, 0.3, 0.3])

    return smpl_o3d
Example #13
0
 def __init__(self,
      body_models_dp,
      use_pca=True, num_pca_comps=45, flat_hand_mean=False,
      device=torch.device('cpu')
 ):
     super().__init__()
     self.models = dict()
     self.device = device
     for gender in ['female', 'male']:
         model = smplx.create(
             body_models_dp,
             model_type='smplx',
             gender=gender,
             batch_size=1,
             use_pca=use_pca,
             num_pca_comps=num_pca_comps,
             flat_hand_mean=flat_hand_mean,
             create_transl=False
         ).to(device=self.device)
         model.eval()
         self.models[gender] = model
Example #14
0
    def __init__(self, testconfig):
        for key, val in testconfig.items():
            setattr(self, key, val)

        if not os.path.exists(self.ckpt_dir):
            print('--[ERROR] checkpoints do not exist')
            sys.exit()

        #define model
        if self.use_cont_rot:
            n_dim_body = 72 + 3
        else:
            n_dim_body = 72

        self.model_h_latentD = 256
        self.model_h = HumanCVAES2(latentD_g=self.model_h_latentD,
                                   latentD_l=self.model_h_latentD,
                                   n_dim_body=n_dim_body,
                                   n_dim_scene=self.model_h_latentD,
                                   test=True)

        ### body mesh model
        self.vposer, _ = load_vposer(self.vposer_ckpt_path,
                                     vp_model='snapshot')
        self.body_mesh_model = smplx.create(self.human_model_path,
                                            model_type='smplx',
                                            gender='neutral',
                                            ext='npz',
                                            num_pca_comps=12,
                                            create_global_orient=True,
                                            create_body_pose=True,
                                            create_betas=True,
                                            create_left_hand_pose=True,
                                            create_right_hand_pose=True,
                                            create_expression=True,
                                            create_jaw_pose=True,
                                            create_leye_pose=True,
                                            create_reye_pose=True,
                                            create_transl=True,
                                            batch_size=self.n_samples)
Example #15
0
def main(model_folder,
         corr_fname,
         ext='npz',
         hand_color=(0.3, 0.3, 0.6),
         gender='neutral',
         hand='right'):

    with open(corr_fname, 'rb') as f:
        idxs_data = pickle.load(f)
        if hand == 'both':
            hand_idxs = np.concatenate(
                [idxs_data['left_hand'], idxs_data['right_hand']])
        else:
            hand_idxs = idxs_data[f'{hand}_hand']

    model = smplx.create(model_folder,
                         model_type='smplx',
                         gender=gender,
                         ext=ext)
    betas = torch.zeros([1, 10], dtype=torch.float32)
    expression = torch.zeros([1, 10], dtype=torch.float32)

    output = model(betas=betas, expression=expression, return_verts=True)
    vertices = output.vertices.detach().cpu().numpy().squeeze()
    joints = output.joints.detach().cpu().numpy().squeeze()

    print('Vertices shape =', vertices.shape)
    print('Joints shape =', joints.shape)

    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    mesh.triangles = o3d.utility.Vector3iVector(model.faces)
    mesh.compute_vertex_normals()

    colors = np.ones_like(vertices) * [0.3, 0.3, 0.3]
    colors[hand_idxs] = hand_color

    mesh.vertex_colors = o3d.utility.Vector3dVector(colors)

    o3d.visualization.draw_geometries([mesh])
Example #16
0
def get_smplx(model_root="misc/models"):
    model_params = dict(
        model_path=model_root,
        model_type="smplx",
        gender="neutral",
        create_body_pose=True,
        dtype=torch.float32,
        use_face=False,
    )

    model = smplx.create(**model_params)
    print(model)

    # Initialize body parameters.
    betas = torch.zeros([1, 10], dtype=torch.float32)
    expression = torch.randn([1, 10], dtype=torch.float32)
    mean_pose = get_vposer_mean_pose()
    # Set them.
    model.expression.data.copy_(expression)
    model.betas.data.copy_(betas)
    model.body_pose.data.copy_(mean_pose)
    return model
Example #17
0
def main(model_folder, model_type='smplx', ext='npz',
         gender='neutral', plot_joints=False,
         use_face_contour=False):

    model = smplx.create(model_folder, model_type=model_type,
                         gender=gender, use_face_contour=use_face_contour,
                         ext=ext)
    print(model)

    betas = torch.randn([1, 10], dtype=torch.float32)
    expression = torch.randn([1, 10], dtype=torch.float32)

    output = model(betas=betas, expression=expression,
                   return_verts=True)
    vertices = output.vertices.detach().cpu().numpy().squeeze()
    joints = output.joints.detach().cpu().numpy().squeeze()

    print('Vertices shape =', vertices.shape)
    print('Joints shape =', joints.shape)

    vertex_colors = np.ones([vertices.shape[0], 4]) * [0.3, 0.3, 0.3, 0.8]
    tri_mesh = trimesh.Trimesh(vertices, model.faces,
                               vertex_colors=vertex_colors)

    mesh = pyrender.Mesh.from_trimesh(tri_mesh)

    scene = pyrender.Scene()
    scene.add(mesh)

    if plot_joints:
        sm = trimesh.creation.uv_sphere(radius=0.005)
        sm.visual.vertex_colors = [0.9, 0.1, 0.1, 1.0]
        tfs = np.tile(np.eye(4), (len(joints), 1, 1))
        tfs[:, :3, 3] = joints
        joints_pcl = pyrender.Mesh.from_trimesh(sm, poses=tfs)
        scene.add(joints_pcl)

    pyrender.Viewer(scene, use_raymond_lighting=True)
Example #18
0
def main(model_folder,
         model_type='smplx',
         ext='npz',
         gender='neutral',
         plot_joints=False,
         plotting_module='pyrender',
         use_face_contour=False):

    model = smplx.create(model_folder,
                         model_type=model_type,
                         gender=gender,
                         use_face_contour=use_face_contour,
                         ext=ext)
    print(model)

    betas = torch.randn([1, 10], dtype=torch.float32)
    expression = torch.randn([1, 10], dtype=torch.float32)

    output = model(betas=betas, expression=expression, return_verts=True)
    vertices = output.vertices.detach().cpu().numpy().squeeze()
    joints = output.joints.detach().cpu().numpy().squeeze()

    print('Vertices shape =', vertices.shape)
    print('Joints shape =', joints.shape)

    if plotting_module == 'pyrender':
        import pyrender
        import trimesh
        vertex_colors = np.ones([vertices.shape[0], 4]) * [0.3, 0.3, 0.3, 0.8]
        tri_mesh = trimesh.Trimesh(vertices,
                                   model.faces,
                                   vertex_colors=vertex_colors)

        mesh = pyrender.Mesh.from_trimesh(tri_mesh)

        scene = pyrender.Scene()
        scene.add(mesh)

        if plot_joints:
            sm = trimesh.creation.uv_sphere(radius=0.005)
            sm.visual.vertex_colors = [0.9, 0.1, 0.1, 1.0]
            tfs = np.tile(np.eye(4), (len(joints), 1, 1))
            tfs[:, :3, 3] = joints
            joints_pcl = pyrender.Mesh.from_trimesh(sm, poses=tfs)
            scene.add(joints_pcl)

        pyrender.Viewer(scene, use_raymond_lighting=True)
    elif plotting_module == 'matplotlib':
        from matplotlib import pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        from mpl_toolkits.mplot3d.art3d import Poly3DCollection

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        mesh = Poly3DCollection(vertices[model.faces], alpha=0.1)
        face_color = (1.0, 1.0, 0.9)
        edge_color = (0, 0, 0)
        mesh.set_edgecolor(edge_color)
        mesh.set_facecolor(face_color)
        ax.add_collection3d(mesh)
        ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')

        if plot_joints:
            ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], alpha=0.1)
        plt.show()
    elif plotting_module == 'open3d':
        import open3d as o3d

        mesh = o3d.TriangleMesh()
        mesh.vertices = o3d.Vector3dVector(vertices)
        mesh.triangles = o3d.Vector3iVector(model.faces)
        mesh.compute_vertex_normals()
        mesh.paint_uniform_color([0.3, 0.3, 0.3])

        o3d.visualization.draw_geometries([mesh])
    else:
        raise ValueError('Unknown plotting_module: {}'.format(plotting_module))
Example #19
0
def main(args):

    scene_name = os.path.abspath(args.gen_folder).split("/")[-1]

    outimg_dir = args.outimg_dir
    if not os.path.exists(outimg_dir):
        os.makedirs(outimg_dir)

    ### setup visualization window
    vis = o3d.visualization.Visualizer()
    vis.create_window(width=960, height=540, visible=True)
    render_opt = vis.get_render_option().mesh_show_back_face = True

    ### put the scene into the environment
    scene = o3d.io.read_triangle_mesh(
        osp.join(args.prox_dir, scene_name + '.ply'))
    vis.add_geometry(scene)
    vis.update_geometry()

    # put the body into the environment
    vposer_ckpt = osp.join(args.model_folder, 'vposer_v1_0')
    vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')

    model = smplx.create(args.model_folder,
                         model_type='smplx',
                         gender=args.gender,
                         ext='npz',
                         num_pca_comps=args.num_pca_comps,
                         create_global_orient=True,
                         create_body_pose=True,
                         create_betas=True,
                         create_left_hand_pose=True,
                         create_right_hand_pose=True,
                         create_expression=True,
                         create_jaw_pose=True,
                         create_leye_pose=True,
                         create_reye_pose=True,
                         create_transl=True)

    ## create a corn at the camera location
    # mesh_corn = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
    # mesh_corn.transform(trans)
    # vis.add_geometry(mesh_corn)
    # vis.update_geometry()
    # print(trans)

    ## create a corn at the world origin
    # mesh_corn2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
    # vis.add_geometry(mesh_corn2)
    # vis.update_geometry()
    # print(trans)

    cv2.namedWindow("GUI")

    gen_file_list = glob.glob(os.path.join(args.gen_folder, '*'))

    body = o3d.geometry.TriangleMesh()
    vis.add_geometry(body)
    for idx, gen_file in enumerate(gen_file_list):

        with open(gen_file, 'rb') as f:
            param = pickle.load(f)

        cam_ext = param['cam_ext'][0]
        cam_int = param['cam_int'][0]

        body_pose = vposer.decode(torch.tensor(param['body_pose']),
                                  output_type='aa').view(1, -1)
        torch_param = {}

        for key in param.keys():
            if key in ['body_pose', 'camera_rotation', 'camera_translation']:
                continue
            else:
                torch_param[key] = torch.tensor(param[key])

        output = model(return_verts=True, body_pose=body_pose, **torch_param)
        vertices = output.vertices.detach().cpu().numpy().squeeze()

        body.vertices = o3d.utility.Vector3dVector(vertices)
        body.triangles = o3d.utility.Vector3iVector(model.faces)
        body.vertex_normals = o3d.utility.Vector3dVector([])
        body.triangle_normals = o3d.utility.Vector3dVector([])
        body.compute_vertex_normals()
        T_mat = np.eye(4)
        T_mat[1, :] = np.array([0, -1, 0, 0])
        T_mat[2, :] = np.array([0, 0, -1, 0])
        trans = np.dot(cam_ext, T_mat)
        body.transform(trans)
        vis.update_geometry()

        # while True:
        #     vis.poll_events()
        #     vis.update_renderer()
        #     cv2.imshow("GUI", np.random.random([10,10,3]))

        #     # ctr = vis.get_view_control()
        #     # cam_param = ctr.convert_to_pinhole_camera_parameters()
        #     # print(cam_param.extrinsic)

        #     key = cv2.waitKey(15)
        #     if key == 27:
        #         break

        ctr = vis.get_view_control()
        cam_param = ctr.convert_to_pinhole_camera_parameters()
        cam_param = update_cam(cam_param, trans)
        ctr.convert_from_pinhole_camera_parameters(cam_param)
        vis.poll_events()
        vis.update_renderer()
        capture_image(vis,
                      outfilename=os.path.join(
                          outimg_dir, 'img_{:06d}_cam1.png'.format(idx)))

        # vis.run()
        # capture_image(vis, outfilename=os.path.join(outimg_dir, 'img_{:06d}_cam1.png'.format(idx)))

        ### setup rendering cam, depth capture, segmentation capture
        ctr = vis.get_view_control()
        cam_param = ctr.convert_to_pinhole_camera_parameters()
        cam_param.extrinsic = trans2_dict[scene_name]
        ctr.convert_from_pinhole_camera_parameters(cam_param)
        vis.poll_events()
        vis.update_renderer()
        capture_image(vis,
                      outfilename=os.path.join(
                          outimg_dir, 'img_{:06d}_cam2.png'.format(idx)))
def main(args):
    fitting_dir = args.fitting_dir
    recording_name = os.path.abspath(fitting_dir).split("/")[-1]
    fitting_dir = osp.join(fitting_dir, 'results')
    data_dir = args.data_dir
    cam2world_dir = osp.join(data_dir, 'cam2world')
    scene_dir = osp.join(data_dir, 'scenes_semantics')
    recording_dir = osp.join(data_dir, 'recordings', recording_name)
    scene_name = os.path.abspath(recording_dir).split("/")[-1].split("_")[0]

    ## setup the output folder
    output_folder = os.path.join('/mnt/hdd/PROX',
                                 'snapshot_virtualcam_TNoise0.5')
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    ### setup visualization window
    vis = o3d.visualization.Visualizer()
    vis.create_window(width=480, height=270, visible=True)
    render_opt = vis.get_render_option().mesh_show_back_face = True

    ### put the scene into the visualizer
    scene = o3d.io.read_triangle_mesh(
        osp.join(scene_dir, scene_name + '_withlabels.ply'))
    vis.add_geometry(scene)

    ## get scene 3D scene bounding box
    scene_o = o3d.io.read_triangle_mesh(
        osp.join(scene_dir, scene_name + '.ply'))
    scene_min = scene_o.get_min_bound()  #[x_min, y_min, z_min]
    scene_max = scene_o.get_max_bound()  #[x_max, y_max, z_max]
    # reduce the scene region furthermore, to avoid cams behind the window
    shift = 0.7
    scene_min = scene_min + shift
    scene_max = scene_max - shift

    ### get the real camera config
    trans_calib = np.eye(4)
    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        trans_calib = np.array(json.load(f))

    ## put the body into the environment
    vposer_ckpt = osp.join(args.model_folder, 'vposer_v1_0')
    vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')

    model = smplx.create(args.model_folder,
                         model_type='smplx',
                         gender='neutral',
                         ext='npz',
                         num_pca_comps=12,
                         create_global_orient=True,
                         create_body_pose=True,
                         create_betas=True,
                         create_left_hand_pose=True,
                         create_right_hand_pose=True,
                         create_expression=False,
                         create_jaw_pose=False,
                         create_leye_pose=False,
                         create_reye_pose=False,
                         create_transl=True)

    rec_count = -1
    sample_rate = 15  # 0.5second
    for img_name in sorted(os.listdir(fitting_dir))[::sample_rate]:

        ## get humam body params
        filename = osp.join(fitting_dir, img_name, '000.pkl')
        print('frame: ' + filename)

        if not os.path.exists(filename):
            print('file does not exist. Continue')
            continue

        with open(osp.join(fitting_dir, img_name, '000.pkl'), 'rb') as f:
            body_dict = pickle.load(f)

        if np.sum(np.isnan(body_dict['body_pose'])) > 0:
            continue

        rec_count = rec_count + 1
        ## save depth, semantics and render cam
        outname1 = os.path.join(output_folder, recording_name)
        if not os.path.exists(outname1):
            os.mkdir(outname1)

        ######################### then we obtain the virutal cam ################################

        ## find world coordinate of the human body in the current frame
        body_params_W_list, dT = update_globalRT_for_smplx(
            body_dict, model, vposer, [trans_calib])
        body_T_world = body_params_W_list[0]['transl'][0] + dT

        ## get virtual cams, and transform global_R and global_T to virtual cams
        new_cammat_ext_list0 = []
        new_cammat_ext_list0 = get_new_cams(scene_name,
                                            s_min=scene_min,
                                            s_max=scene_max,
                                            body_T=body_T_world)
        random.shuffle(new_cammat_ext_list0)
        new_cammat_ext_list = new_cammat_ext_list0[:30]

        print('--obtain {:d} cams'.format(len(new_cammat_ext_list)))

        new_cammat_list = [invert_transform(x) for x in new_cammat_ext_list]
        body_params_new_list, _ = update_globalRT_for_smplx(
            body_params_W_list[0], model, vposer, new_cammat_list, delta_T=dT)

        #### capture depth and seg in new cams
        for idx_cam, cam_ext in enumerate(new_cammat_ext_list):

            ### save filename
            outname = os.path.join(
                outname1,
                'rec_frame{:06d}_cam{:06d}.mat'.format(rec_count, idx_cam))

            ## put the render cam to the real cam
            ctr = vis.get_view_control()
            cam_param = ctr.convert_to_pinhole_camera_parameters()
            cam_param = update_render_cam(cam_param, cam_ext)
            ctr.convert_from_pinhole_camera_parameters(cam_param)
            vis.poll_events()
            vis.update_renderer()

            ## get render cam parameters
            cam_dict = {}
            cam_dict['extrinsic'] = cam_param.extrinsic
            cam_dict['intrinsic'] = cam_param.intrinsic.intrinsic_matrix

            ## capture depth image
            depth = np.asarray(vis.capture_depth_float_buffer(do_render=True))
            _h = depth.shape[0]
            _w = depth.shape[1]
            depth0 = depth
            depth_canvas, scaling_factor = data_preprocessing(depth, 'depth')

            ### skip settings when the human body is severely occluded.
            body_is_occluded = is_body_occluded(body_params_new_list[idx_cam],
                                                cam_dict, depth)

            if body_is_occluded:
                print(
                    '-- body is occluded or not in the scene at current view.')
                continue

            ## capture semantics
            seg = np.asarray(vis.capture_screen_float_buffer(do_render=True))
            verid = np.mean(seg * 255 / 5.0, axis=-1)  #.astype(int)
            seg0 = verid
            # verid = cv2.resize(verid, (_w//factor, _h//factor))
            seg_canvas, _ = data_preprocessing(verid, 'seg')

            # pdb.set_trace()

            ## save file to disk
            ot_dict = {}
            ot_dict['scaling_factor'] = scaling_factor
            ot_dict['depth'] = depth_canvas
            ot_dict['depth0'] = depth0
            ot_dict['seg0'] = seg0
            ot_dict['seg'] = seg_canvas
            ot_dict['cam'] = cam_dict
            ot_dict['body'] = body_params_new_list[idx_cam]
            sio.savemat(outname, ot_dict)

    vis.destroy_window()
Example #21
0
def main():
    description = 'Example script for untangling SMPL self intersections'
    parser = argparse.ArgumentParser(description=description,
                                     prog='Batch SMPL-Untangle')
    parser.add_argument('--param_fn', type=str,
                        nargs='*',
                        required=True,
                        help='The pickle file with the model parameters')
    parser.add_argument('--interactive', default=True,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Display the mesh during the optimization' +
                        ' process')
    parser.add_argument('--delay', type=int, default=50,
                        help='The delay for the animation callback in ms')
    parser.add_argument('--model_folder', type=str,
                        default='models',
                        help='The path to the LBS model')
    parser.add_argument('--model_type', type=str,
                        default='smpl', choices=['smpl', 'smplx', 'smplh'],
                        help='The type of model to create')
    parser.add_argument('--point2plane', default=False,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Use point to distance')
    parser.add_argument('--optimize_pose', default=True,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Enable optimization over the joint pose')
    parser.add_argument('--optimize_shape', default=False,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Enable optimization over the shape of the model')
    parser.add_argument('--sigma', default=0.5, type=float,
                        help='The height of the cone used to calculate the' +
                        ' distance field loss')
    parser.add_argument('--lr', default=1, type=float,
                        help='The learning rate for SGD')
    parser.add_argument('--coll_loss_weight', default=1e-4, type=float,
                        help='The weight for the collision loss')
    parser.add_argument('--pose_reg_weight', default=0, type=float,
                        help='The weight for the pose regularizer')
    parser.add_argument('--shape_reg_weight', default=0, type=float,
                        help='The weight for the shape regularizer')
    parser.add_argument('--max_collisions', default=8, type=int,
                        help='The maximum number of bounding box collisions')
    parser.add_argument('--part_segm_fn', default='', type=str,
                        help='The file with the part segmentation for the' +
                        ' faces of the model')
    parser.add_argument('--print_timings', default=False,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Print timings for all the operations')

    args = parser.parse_args()

    model_folder = args.model_folder
    model_type = args.model_type
    param_fn = args.param_fn
    interactive = args.interactive
    delay = args.delay
    point2plane = args.point2plane
    #  optimize_shape = args.optimize_shape
    #  optimize_pose = args.optimize_pose
    lr = args.lr
    coll_loss_weight = args.coll_loss_weight
    pose_reg_weight = args.pose_reg_weight
    shape_reg_weight = args.shape_reg_weight
    max_collisions = args.max_collisions
    sigma = args.sigma
    part_segm_fn = args.part_segm_fn
    print_timings = args.print_timings

    if interactive:
        import trimesh
        import pyrender

    device = torch.device('cuda')
    batch_size = len(param_fn)

    params_dict = defaultdict(lambda: [])
    for idx, fn in enumerate(param_fn):
        with open(fn, 'rb') as param_file:
            data = pickle.load(param_file, encoding='latin1')

        assert 'betas' in data, \
            'No key for shape parameter in provided pickle file'
        assert 'global_pose' in data, \
            'No key for the global pose in the given pickle file'
        assert 'pose' in data, \
            'No key for the pose of the joints in the given pickle file'

        for key, val in data.items():
            params_dict[key].append(val)

    params = {}
    for key in params_dict:
        params[key] = np.stack(params_dict[key], axis=0).astype(np.float32)
        if len(params[key].shape) < 2:
            params[key] = params[key][np.newaxis]
    if 'global_pose' in params:
        params['global_orient'] = params['global_pose']
    if 'pose' in params:
        params['body_pose'] = params['pose']

    if part_segm_fn:
        # Read the part segmentation
        with open(part_segm_fn, 'rb') as faces_parents_file:
            data = pickle.load(faces_parents_file, encoding='latin1')
        faces_segm = data['segm']
        faces_parents = data['parents']
        # Create the module used to filter invalid collision pairs
        filter_faces = FilterFaces(faces_segm, faces_parents).to(device=device)

    # Create the body model
    body = create(model_folder, batch_size=batch_size,
                  model_type=model_type).to(device=device)
    body.reset_params(**params)

    # Clone the given pose to use it as a target for regularization
    init_pose = body.body_pose.clone().detach()

    # Create the search tree
    search_tree = BVH(max_collisions=max_collisions)

    pen_distance = \
        collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma,
                                                     point2plane=point2plane,
                                                     vectorized=True)

    mse_loss = nn.MSELoss(reduction='sum').to(device=device)

    face_tensor = torch.tensor(body.faces.astype(np.int64), dtype=torch.long,
                               device=device).unsqueeze_(0).repeat([batch_size,
                                                                    1, 1])
    with torch.no_grad():
        output = body(get_skin=True)
        verts = output.vertices

    bs, nv = verts.shape[:2]
    bs, nf = face_tensor.shape[:2]
    faces_idx = face_tensor + \
        (torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None]

    optimizer = torch.optim.SGD([body.body_pose], lr=lr)

    if interactive:
        # Plot the initial mesh
        with torch.no_grad():
            output = body(get_skin=True)
            verts = output.vertices

            np_verts = verts.detach().cpu().numpy()

        def create_mesh(vertices, faces, color=(0.3, 0.3, 0.3, 1.0),
                        wireframe=False):

            tri_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
            rot = trimesh.transformations.rotation_matrix(np.radians(180),
                                                          [1, 0, 0])
            tri_mesh.apply_transform(rot)

            material = pyrender.MetallicRoughnessMaterial(
                metallicFactor=0.0,
                alphaMode='BLEND',
                baseColorFactor=color)
            return pyrender.Mesh.from_trimesh(
                tri_mesh,
                material=material)

        scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 1.0],
                               ambient_light=(1.0, 1.0, 1.0))
        for bidx in range(np_verts.shape[0]):
            curr_verts = np_verts[bidx].copy()
            body_mesh = create_mesh(curr_verts, body.faces,
                                    color=(0.3, 0.3, 0.3, 0.99),
                                    wireframe=True)

            pose = np.eye(4)
            pose[0, 3] = bidx * 2
            scene.add(body_mesh,
                      name='body_mesh_{:03d}'.format(bidx),
                      pose=pose)

        viewer = pyrender.Viewer(scene, use_raymond_lighting=True,
                                 viewport_size=(1200, 800),
                                 cull_faces=False,
                                 run_in_thread=True)

    query_names = ['recv_mesh', 'intr_mesh', 'body_mesh']

    step = 0
    while True:
        optimizer.zero_grad()

        if print_timings:
            start = time.time()

        if print_timings:
            torch.cuda.synchronize()
        output = body(get_skin=True)
        verts = output.vertices

        if print_timings:
            torch.cuda.synchronize()
            print('Body model forward: {:5f}'.format(time.time() - start))

        if print_timings:
            torch.cuda.synchronize()
            start = time.time()
        triangles = verts.view([-1, 3])[faces_idx]
        if print_timings:
            torch.cuda.synchronize()
            print('Triangle indexing: {:5f}'.format(time.time() - start))

        with torch.no_grad():
            if print_timings:
                start = time.time()
            collision_idxs = search_tree(triangles)
            if print_timings:
                torch.cuda.synchronize()
                print('Collision Detection: {:5f}'.format(time.time() -
                                                          start))
            if part_segm_fn:
                if print_timings:
                    start = time.time()
                collision_idxs = filter_faces(collision_idxs)
                if print_timings:
                    torch.cuda.synchronize()
                    print('Collision filtering: {:5f}'.format(time.time() -
                                                              start))

        if print_timings:
            start = time.time()
        pen_loss = coll_loss_weight * \
            pen_distance(triangles, collision_idxs)
        if print_timings:
            torch.cuda.synchronize()
            print('Penetration loss: {:5f}'.format(time.time() - start))

        shape_reg_loss = torch.tensor(0, device=device,
                                      dtype=torch.float32)
        if shape_reg_weight > 0:
            shape_reg_loss = shape_reg_weight * torch.sum(body.betas ** 2)
        pose_reg_loss = torch.tensor(0, device=device,
                                     dtype=torch.float32)
        if pose_reg_weight > 0:
            pose_reg_loss = pose_reg_weight * \
                mse_loss(body.pose, init_pose)

        loss = pen_loss + pose_reg_loss + shape_reg_loss

        np_loss = loss.detach().cpu().squeeze().tolist()
        if type(np_loss) != list:
            np_loss = [np_loss]
        msg = '{:.5f} ' * len(np_loss)
        print('Loss per model:', msg.format(*np_loss))

        if print_timings:
            start = time.time()
        loss.backward(torch.ones_like(loss))
        if print_timings:
            torch.cuda.synchronize()
            print('Backward pass: {:5f}'.format(time.time() - start))

        if interactive:
            with torch.no_grad():
                output = body(get_skin=True)
                verts = output.vertices

                np_verts = verts.detach().cpu().numpy()

            np_collision_idxs = collision_idxs.detach().cpu().numpy()
            np_receivers = np_collision_idxs[:, :, 0]
            np_intruders = np_collision_idxs[:, :, 1]

            viewer.render_lock.acquire()

            for node in scene.get_nodes():
                if node.name is None:
                    continue
                if any([query in node.name for query in query_names]):
                    scene.remove_node(node)

            for bidx in range(batch_size):
                recv_faces_idxs = np_receivers[bidx][np_receivers[bidx] >= 0]
                intr_faces_idxs = np_intruders[bidx][np_intruders[bidx] >= 0]
                recv_faces = body.faces[recv_faces_idxs]
                intr_faces = body.faces[intr_faces_idxs]

                curr_verts = np_verts[bidx].copy()
                body_mesh = create_mesh(curr_verts, body.faces,
                                        color=(0.3, 0.3, 0.3, 0.99),
                                        wireframe=True)

                pose = np.eye(4)
                pose[0, 3] = bidx * 2
                scene.add(body_mesh,
                          name='body_mesh_{:03d}'.format(bidx),
                          pose=pose)

                if len(intr_faces) > 0:
                    intr_mesh = create_mesh(curr_verts, intr_faces,
                                            color=(0.9, 0.0, 0.0, 1.0))
                    scene.add(intr_mesh,
                              name='intr_mesh_{:03d}'.format(bidx),
                              pose=pose)

                if len(recv_faces) > 0:
                    recv_mesh = create_mesh(curr_verts, recv_faces,
                                            color=(0.0, 0.9, 0.0, 1.0))
                    scene.add(recv_mesh, name='recv_mesh_{:03d}'.format(bidx),
                              pose=pose)
            viewer.render_lock.release()

            if not viewer.is_active:
                break

            time.sleep(delay / 1000)
        optimizer.step()

        step += 1
Example #22
0
File: main.py Project: wkailiu/prox
def main(**args):
    data_folder = args.get('recording_dir')
    recording_name = osp.basename(args.get('recording_dir'))
    scene_name = recording_name.split("_")[0]
    base_dir = os.path.abspath(
        osp.join(args.get('recording_dir'), os.pardir, os.pardir))
    keyp_dir = osp.join(base_dir, 'keypoints')
    keyp_folder = osp.join(keyp_dir, recording_name)
    cam2world_dir = osp.join(base_dir, 'cam2world')
    scene_dir = osp.join(base_dir, 'scenes')
    calib_dir = osp.join(base_dir, 'calibration')
    sdf_dir = osp.join(base_dir, 'sdf')
    body_segments_dir = osp.join(base_dir, 'body_segments')

    output_folder = args.get('output_folder')
    output_folder = osp.expandvars(output_folder)
    output_folder = osp.join(output_folder, recording_name)
    if not osp.exists(output_folder):
        os.makedirs(output_folder)

    # Store the arguments for the current experiment
    conf_fn = osp.join(output_folder, 'conf.yaml')
    with open(conf_fn, 'w') as conf_file:
        yaml.dump(args, conf_file)
    #remove 'output_folder' from args list
    args.pop('output_folder')

    result_folder = args.pop('result_folder', 'results')
    result_folder = osp.join(output_folder, result_folder)
    if not osp.exists(result_folder):
        os.makedirs(result_folder)

    mesh_folder = args.pop('mesh_folder', 'meshes')
    mesh_folder = osp.join(output_folder, mesh_folder)
    if not osp.exists(mesh_folder):
        os.makedirs(mesh_folder)

    out_img_folder = osp.join(output_folder, 'images')
    if not osp.exists(out_img_folder):
        os.makedirs(out_img_folder)

    body_scene_rendering_dir = os.path.join(output_folder, 'renderings')
    if not osp.exists(body_scene_rendering_dir):
        os.mkdir(body_scene_rendering_dir)

    float_dtype = args['float_dtype']
    if float_dtype == 'float64':
        dtype = torch.float64
    elif float_dtype == 'float32':
        dtype = torch.float64
    else:
        print('Unknown float type {}, exiting!'.format(float_dtype))
        sys.exit(-1)

    use_cuda = args.get('use_cuda', True)
    if use_cuda and not torch.cuda.is_available():
        print('CUDA is not available, exiting!')
        sys.exit(-1)

    img_folder = args.pop('img_folder', 'Color')
    dataset_obj = create_dataset(img_folder=img_folder,
                                 data_folder=data_folder,
                                 keyp_folder=keyp_folder,
                                 calib_dir=calib_dir,
                                 **args)

    start = time.time()

    input_gender = args.pop('gender', 'neutral')
    gender_lbl_type = args.pop('gender_lbl_type', 'none')
    max_persons = args.pop('max_persons', -1)

    float_dtype = args.get('float_dtype', 'float32')
    if float_dtype == 'float64':
        dtype = torch.float64
    elif float_dtype == 'float32':
        dtype = torch.float32
    else:
        raise ValueError('Unknown float type {}, exiting!'.format(float_dtype))

    joint_mapper = JointMapper(dataset_obj.get_model2data())

    model_params = dict(model_path=args.get('model_folder'),
                        joint_mapper=joint_mapper,
                        create_global_orient=True,
                        create_body_pose=not args.get('use_vposer'),
                        create_betas=True,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        create_expression=True,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=True,
                        dtype=dtype,
                        **args)

    male_model = smplx.create(gender='male', **model_params)
    # SMPL-H has no gender-neutral model
    if args.get('model_type') != 'smplh':
        neutral_model = smplx.create(gender='neutral', **model_params)
    female_model = smplx.create(gender='female', **model_params)

    # Create the camera object
    camera_center = None \
        if args.get('camera_center_x') is None or args.get('camera_center_y') is None \
        else torch.tensor([args.get('camera_center_x'), args.get('camera_center_y')], dtype=dtype).view(-1, 2)
    camera = create_camera(focal_length_x=args.get('focal_length_x'),
                           focal_length_y=args.get('focal_length_y'),
                           center=camera_center,
                           batch_size=args.get('batch_size'),
                           dtype=dtype)

    if hasattr(camera, 'rotation'):
        camera.rotation.requires_grad = False

    use_hands = args.get('use_hands', True)
    use_face = args.get('use_face', True)

    body_pose_prior = create_prior(prior_type=args.get('body_prior_type'),
                                   dtype=dtype,
                                   **args)

    jaw_prior, expr_prior = None, None
    if use_face:
        jaw_prior = create_prior(prior_type=args.get('jaw_prior_type'),
                                 dtype=dtype,
                                 **args)
        expr_prior = create_prior(prior_type=args.get('expr_prior_type', 'l2'),
                                  dtype=dtype,
                                  **args)

    left_hand_prior, right_hand_prior = None, None
    if use_hands:
        lhand_args = args.copy()
        lhand_args['num_gaussians'] = args.get('num_pca_comps')
        left_hand_prior = create_prior(
            prior_type=args.get('left_hand_prior_type'),
            dtype=dtype,
            use_left_hand=True,
            **lhand_args)

        rhand_args = args.copy()
        rhand_args['num_gaussians'] = args.get('num_pca_comps')
        right_hand_prior = create_prior(
            prior_type=args.get('right_hand_prior_type'),
            dtype=dtype,
            use_right_hand=True,
            **rhand_args)

    shape_prior = create_prior(prior_type=args.get('shape_prior_type', 'l2'),
                               dtype=dtype,
                               **args)

    angle_prior = create_prior(prior_type='angle', dtype=dtype)

    if use_cuda and torch.cuda.is_available():
        device = torch.device('cuda')

        camera = camera.to(device=device)
        female_model = female_model.to(device=device)
        male_model = male_model.to(device=device)
        if args.get('model_type') != 'smplh':
            neutral_model = neutral_model.to(device=device)
        body_pose_prior = body_pose_prior.to(device=device)
        angle_prior = angle_prior.to(device=device)
        shape_prior = shape_prior.to(device=device)
        if use_face:
            expr_prior = expr_prior.to(device=device)
            jaw_prior = jaw_prior.to(device=device)
        if use_hands:
            left_hand_prior = left_hand_prior.to(device=device)
            right_hand_prior = right_hand_prior.to(device=device)
    else:
        device = torch.device('cpu')

    # A weight for every joint of the model
    joint_weights = dataset_obj.get_joint_weights().to(device=device,
                                                       dtype=dtype)
    # Add a fake batch dimension for broadcasting
    joint_weights.unsqueeze_(dim=0)

    for idx, data in enumerate(dataset_obj):

        img = data['img']
        fn = data['fn']
        keypoints = data['keypoints']
        depth_im = data['depth_im']
        mask = data['mask']
        init_trans = None if data['init_trans'] is None else torch.tensor(
            data['init_trans'], dtype=dtype).view(-1, 3)
        scan = data['scan_dict']
        print('Processing: {}'.format(data['img_path']))

        curr_result_folder = osp.join(result_folder, fn)
        if not osp.exists(curr_result_folder):
            os.makedirs(curr_result_folder)
        curr_mesh_folder = osp.join(mesh_folder, fn)
        if not osp.exists(curr_mesh_folder):
            os.makedirs(curr_mesh_folder)
        #TODO: SMPLifyD and PROX won't work for multiple persons
        for person_id in range(keypoints.shape[0]):
            if person_id >= max_persons and max_persons > 0:
                continue

            curr_result_fn = osp.join(curr_result_folder,
                                      '{:03d}.pkl'.format(person_id))
            curr_mesh_fn = osp.join(curr_mesh_folder,
                                    '{:03d}.ply'.format(person_id))
            curr_body_scene_rendering_fn = osp.join(body_scene_rendering_dir,
                                                    fn + '.png')

            curr_img_folder = osp.join(output_folder, 'images', fn,
                                       '{:03d}'.format(person_id))
            if not osp.exists(curr_img_folder):
                os.makedirs(curr_img_folder)

            if gender_lbl_type != 'none':
                if gender_lbl_type == 'pd' and 'gender_pd' in data:
                    gender = data['gender_pd'][person_id]
                if gender_lbl_type == 'gt' and 'gender_gt' in data:
                    gender = data['gender_gt'][person_id]
            else:
                gender = input_gender

            if gender == 'neutral':
                body_model = neutral_model
            elif gender == 'female':
                body_model = female_model
            elif gender == 'male':
                body_model = male_model

            out_img_fn = osp.join(curr_img_folder, 'output.png')

            fit_single_frame(
                img,
                keypoints[[person_id]],
                init_trans,
                scan,
                cam2world_dir=cam2world_dir,
                scene_dir=scene_dir,
                sdf_dir=sdf_dir,
                body_segments_dir=body_segments_dir,
                scene_name=scene_name,
                body_model=body_model,
                camera=camera,
                joint_weights=joint_weights,
                dtype=dtype,
                output_folder=output_folder,
                result_folder=curr_result_folder,
                out_img_fn=out_img_fn,
                result_fn=curr_result_fn,
                mesh_fn=curr_mesh_fn,
                body_scene_rendering_fn=curr_body_scene_rendering_fn,
                shape_prior=shape_prior,
                expr_prior=expr_prior,
                body_pose_prior=body_pose_prior,
                left_hand_prior=left_hand_prior,
                right_hand_prior=right_hand_prior,
                jaw_prior=jaw_prior,
                angle_prior=angle_prior,
                **args)

    elapsed = time.time() - start
    time_msg = time.strftime('%H hours, %M minutes, %S seconds',
                             time.gmtime(elapsed))
    print('Processing the data took: {}'.format(time_msg))
Example #23
0
def main(args):
    fitting_dir = args.fitting_dir
    recording_name = os.path.abspath(fitting_dir).split("/")[-1]
    fitting_dir = osp.join(fitting_dir, 'results')
    scene_name = recording_name.split("_")[0]
    base_dir = args.base_dir
    cam2world_dir = osp.join(base_dir, 'cam2world')
    scene_dir = osp.join(base_dir, 'scenes')
    recording_dir = osp.join(base_dir, 'recordings', recording_name)
    color_dir = os.path.join(recording_dir, 'Color')

    female_subjects_ids = [162, 3452, 159, 3403]
    subject_id = int(recording_name.split('_')[1])
    if subject_id in female_subjects_ids:
        gender = 'female'
    else:
        gender = 'male'

    cv2.namedWindow('frame', cv2.WINDOW_NORMAL)

    vis = o3d.Visualizer()
    vis.create_window()

    scene = o3d.io.read_triangle_mesh(osp.join(scene_dir, scene_name + '.ply'))
    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        trans = np.array(json.load(f))
    vis.add_geometry(scene)


    model = smplx.create(args.model_folder, model_type='smplx',
                         gender=gender, ext='npz',
                         num_pca_comps=args.num_pca_comps,
                         create_global_orient=True,
                         create_body_pose=True,
                         create_betas=True,
                         create_left_hand_pose=True,
                         create_right_hand_pose=True,
                         create_expression=True,
                         create_jaw_pose=True,
                         create_leye_pose=True,
                         create_reye_pose=True,
                         create_transl=True
                         )

    count = 0
    for img_name in sorted(os.listdir(fitting_dir))[args.start::args.step]:
        print('viz frame {}'.format(img_name))

        with open(osp.join(fitting_dir, img_name, '000.pkl'), 'rb') as f:
            param = pickle.load(f)
        torch_param = {}
        for key in param.keys():
            if key in ['pose_embedding', 'camera_rotation', 'camera_translation']:
                continue
            else:
                torch_param[key] = torch.tensor(param[key])

        output = model(return_verts=True, **torch_param)
        vertices = output.vertices.detach().cpu().numpy().squeeze()

        if count == 0:
            body = o3d.TriangleMesh()
            vis.add_geometry(body)
        body.vertices = o3d.Vector3dVector(vertices)
        body.triangles = o3d.Vector3iVector(model.faces)
        body.vertex_normals = o3d.Vector3dVector([])
        body.triangle_normals = o3d.Vector3dVector([])
        body.compute_vertex_normals()
        body.transform(trans)


        color_img = cv2.imread(os.path.join(color_dir, img_name + '.jpg'))
        color_img = cv2.flip(color_img, 1)

        vis.update_geometry()
        while True:
            cv2.imshow('frame', color_img)
            vis.poll_events()
            vis.update_renderer()
            key = cv2.waitKey(30)
            if key == 27:
                break

        count += 1
Example #24
0
    # coordinate masking for error calculation
    ih26m_joint_from_mesh = ih26m_joint_from_mesh[np.tile(
        ih26m_joint_valid == 1, (1, 3))].reshape(-1, 3)
    ih26m_joint_cam = ih26m_joint_cam[np.tile(ih26m_joint_valid == 1,
                                              (1, 3))].reshape(-1, 3)

    error = np.sqrt(np.sum((ih26m_joint_from_mesh - ih26m_joint_cam)**2,
                           1)).mean()
    return error


# mano layer
smplx_path = 'SMPLX_PATH'
mano_layer = {
    'right': smplx.create(smplx_path, 'mano', use_pca=False, is_rhand=True),
    'left': smplx.create(smplx_path, 'mano', use_pca=False, is_rhand=False)
}
ih26m_joint_regressor = np.load('J_regressor_mano_ih26m.npy')

# fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48)
if torch.sum(
        torch.abs(mano_layer['left'].shapedirs[:, 0, :] -
                  mano_layer['right'].shapedirs[:, 0, :])) < 1:
    print('Fix shapedirs bug of MANO')
    mano_layer['left'].shapedirs[:, 0, :] *= -1

root_path = '../../data/InterHand2.6M/data/'
img_root_path = osp.join(root_path, 'images')
annot_root_path = osp.join(root_path, 'annotations')
subset = 'all'
Example #25
0
def train():
    parser = config_parser()
    args = parser.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.default_device = device
    if args.model_type not in [
            "nerf", "smpl_nerf", "append_to_nerf", "smpl", "warp",
            'vertex_sphere', "smpl_estimator", "original_nerf",
            'dummy_dynamic', 'image_wise_dynamic',
            "append_vertex_locations_to_nerf", 'append_smpl_params'
    ]:
        raise Exception("The model type ", args.model_type, " does not exist.")

    transform = transforms.Compose([
        NormalizeRGB(),
        CoarseSampling(args.near, args.far, args.number_coarse_samples),
        ToTensor()
    ])

    train_dir = os.path.join(args.dataset_dir, 'train')
    val_dir = os.path.join(args.dataset_dir, 'val')
    if args.model_type == "nerf":
        train_data = RaysFromImagesDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = RaysFromImagesDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), transform)
    elif args.model_type == "smpl" or args.model_type == "warp":
        train_data = SmplDataset(train_dir,
                                 os.path.join(train_dir, 'transforms.json'),
                                 args,
                                 transform=NormalizeRGB())
        val_data = SmplDataset(val_dir,
                               os.path.join(val_dir, 'transforms.json'),
                               args,
                               transform=NormalizeRGB())
    elif args.model_type == "smpl_nerf" or args.model_type == "append_to_nerf" or args.model_type == "append_smpl_params":
        train_data = SmplNerfDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = SmplNerfDataset(val_dir,
                                   os.path.join(val_dir, 'transforms.json'),
                                   transform)
    elif args.model_type == "vertex_sphere":
        train_data = VertexSphereDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), args)
        val_data = VertexSphereDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), args)
    elif args.model_type == "smpl_estimator":
        transform = NormalizeRGBImage()
        train_data = SmplEstimatorDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'),
            args.vertex_sphere_radius, transform)
        val_data = SmplEstimatorDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'),
            args.vertex_sphere_radius, transform)
    elif args.model_type == "original_nerf":
        train_data = OriginalNerfDataset(
            args.dataset_dir,
            os.path.join(args.dataset_dir, 'transforms_train.json'), transform)
        val_data = OriginalNerfDataset(
            args.dataset_dir,
            os.path.join(args.dataset_dir, 'transforms_val.json'), transform)
    elif args.model_type == "dummy_dynamic":
        train_data = DummyDynamicDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = DummyDynamicDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), transform)
    elif args.model_type == "append_vertex_locations_to_nerf":
        train_data = DummyDynamicDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = DummyDynamicDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), transform)
    elif args.model_type == 'image_wise_dynamic':
        canonical_pose1 = torch.zeros(38).view(1, -1)
        canonical_pose2 = torch.zeros(2).view(1, -1)
        canonical_pose3 = torch.zeros(27).view(1, -1)
        arm_angle_l = torch.tensor([np.deg2rad(10)]).float().view(1, -1)
        arm_angle_r = torch.tensor([np.deg2rad(10)]).float().view(1, -1)
        smpl_estimator = DummyImageWiseEstimator(canonical_pose1,
                                                 canonical_pose2,
                                                 canonical_pose3, arm_angle_l,
                                                 arm_angle_r,
                                                 torch.zeros(10).view(1, -1),
                                                 torch.zeros(69).view(1, -1))
        train_data = ImageWiseDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'),
            smpl_estimator, transform, args)
        val_data = ImageWiseDataset(val_dir,
                                    os.path.join(val_dir, 'transforms.json'),
                                    smpl_estimator, transform, args)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batchsize_val,
                                             shuffle=False,
                                             num_workers=0)
    position_encoder = PositionalEncoder(args.number_frequencies_postitional,
                                         args.use_identity_positional)
    direction_encoder = PositionalEncoder(args.number_frequencies_directional,
                                          args.use_identity_directional)
    model_coarse = RenderRayNet(args.netdepth,
                                args.netwidth,
                                position_encoder.output_dim * 3,
                                direction_encoder.output_dim * 3,
                                skips=args.skips)
    model_fine = RenderRayNet(args.netdepth_fine,
                              args.netwidth_fine,
                              position_encoder.output_dim * 3,
                              direction_encoder.output_dim * 3,
                              skips=args.skips_fine)

    if args.model_type == "smpl_nerf":
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        positions_dim = position_encoder.output_dim if args.human_pose_encoding else 1
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1
        model_warp_field = WarpFieldNet(args.netdepth_warp, args.netwidth_warp,
                                        positions_dim * 3, human_pose_dim * 2)

        solver = SmplNerfSolver(model_coarse, model_fine, model_warp_field,
                                position_encoder, direction_encoder,
                                human_pose_encoder, train_data.canonical_smpl,
                                args, torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w)

        save_run(solver.writer.log_dir,
                 [model_coarse, model_fine, model_warp_field],
                 ['model_coarse.pt', 'model_fine.pt', 'model_warp_field.pt'],
                 parser)

    elif args.model_type == 'smpl':
        solver = SmplSolver(model_coarse, model_fine, position_encoder,
                            direction_encoder, args, torch.optim.Adam,
                            torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)
        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'nerf' or args.model_type == "original_nerf":
        solver = NerfSolver(model_coarse, model_fine, position_encoder,
                            direction_encoder, args, torch.optim.Adam,
                            torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)
        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'warp':
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        positions_dim = position_encoder.output_dim if args.human_pose_encoding else 1
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1
        model_warp_field = WarpFieldNet(args.netdepth_warp, args.netwidth_warp,
                                        positions_dim * 3, human_pose_dim * 2)
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        solver = WarpSolver(model_warp_field, position_encoder,
                            direction_encoder, human_pose_encoder, args)
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir, [model_warp_field],
                 ['model_warp_field.pt'], parser)
    elif args.model_type == 'append_smpl_params':
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1

        model_coarse = RenderRayNet(
            args.netdepth,
            args.netwidth,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 69,
            skips=args.skips,
            use_directional_input=args.use_directional_input)
        model_fine = RenderRayNet(
            args.netdepth_fine,
            args.netwidth_fine,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 69,
            skips=args.skips_fine,
            use_directional_input=args.use_directional_input)

        if args.load_run is not None:
            model_coarse.load_state_dict(
                torch.load(os.path.join(args.load_run, 'model_coarse.pt'),
                           map_location=torch.device(device)))
            model_fine.load_state_dict(
                torch.load(os.path.join(args.load_run, 'model_fine.pt'),
                           map_location=torch.device(device)))
            print("Models loaded from ", args.load_run)
        if args.siren:
            model_coarse = SirenRenderRayNet(
                args.netdepth,
                args.netwidth,
                position_encoder.output_dim * 3,
                direction_encoder.output_dim * 3,
                human_pose_dim * 69,
                skips=args.skips,
                use_directional_input=args.use_directional_input)
            model_fine = SirenRenderRayNet(
                args.netdepth_fine,
                args.netwidth_fine,
                position_encoder.output_dim * 3,
                direction_encoder.output_dim * 3,
                human_pose_dim * 69,
                skips=args.skips_fine,
                use_directional_input=args.use_directional_input)
        solver = AppendSmplParamsSolver(model_coarse, model_fine,
                                        position_encoder, direction_encoder,
                                        human_pose_encoder, args,
                                        torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)

        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

        model_dependent = [human_pose_encoder, human_pose_dim]
        inference_gif(solver.writer.log_dir, args.model_type, args, train_data,
                      val_data, position_encoder, direction_encoder,
                      model_coarse, model_fine, model_dependent)
    elif args.model_type == 'append_to_nerf':
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1
        model_coarse = RenderRayNet(
            args.netdepth,
            args.netwidth,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 2,
            skips=args.skips,
            use_directional_input=args.use_directional_input)
        model_fine = RenderRayNet(
            args.netdepth_fine,
            args.netwidth_fine,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 2,
            skips=args.skips_fine,
            use_directional_input=args.use_directional_input)
        solver = AppendToNerfSolver(model_coarse, model_fine, position_encoder,
                                    direction_encoder, human_pose_encoder,
                                    args, torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)

        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

        model_dependent = [human_pose_encoder, human_pose_dim]
        inference_gif(solver.writer.log_dir, args.model_type, args, train_data,
                      val_data, position_encoder, direction_encoder,
                      model_coarse, model_fine, model_dependent)
    elif args.model_type == 'append_vertex_locations_to_nerf':
        model_coarse = AppendVerticesNet(args.netdepth,
                                         args.netwidth,
                                         position_encoder.output_dim * 3,
                                         direction_encoder.output_dim * 3,
                                         6890,
                                         additional_input_layers=1,
                                         skips=args.skips)
        model_fine = AppendVerticesNet(args.netdepth_fine,
                                       args.netwidth_fine,
                                       position_encoder.output_dim * 3,
                                       direction_encoder.output_dim * 3,
                                       6890,
                                       additional_input_layers=1,
                                       skips=args.skips_fine)
        smpl_estimator = DummySmplEstimatorModel(train_data.goal_poses,
                                                 train_data.betas)
        smpl_file_name = "SMPLs/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl"
        smpl_model = smplx.create(smpl_file_name, model_type='smpl')
        smpl_model.batchsize = args.batchsize
        solver = AppendVerticesSolver(model_coarse, model_fine, smpl_estimator,
                                      smpl_model, position_encoder,
                                      direction_encoder, args,
                                      torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w)

        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'vertex_sphere':
        solver = VertexSphereSolver(model_coarse, model_fine, position_encoder,
                                    direction_encoder, args, torch.optim.Adam,
                                    torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'smpl_estimator':

        model = SmplEstimator(human_size=len(args.human_joints))

        solver = SmplEstimatorSolver(model, args, torch.optim.Adam,
                                     torch.nn.MSELoss())
        solver.train(train_loader, val_loader)
        save_run(solver.writer.log_dir, [model], ['model_smpl_estimator.pt'],
                 parser)
    elif args.model_type == "dummy_dynamic":
        smpl_file_name = "SMPLs/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl"
        smpl_model = smplx.create(smpl_file_name, model_type='smpl')
        smpl_model.batchsize = args.batchsize
        smpl_estimator = DummySmplEstimatorModel(train_data.goal_poses,
                                                 train_data.betas)
        solver = DynamicSolver(model_fine, model_coarse, smpl_estimator,
                               smpl_model, position_encoder, direction_encoder,
                               args)
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir,
                 [model_coarse, model_fine, smpl_estimator],
                 ['model_coarse.pt', 'model_fine.pt', 'smpl_estimator.pt'],
                 parser)
    elif args.model_type == "image_wise_dynamic":
        if args.load_coarse_model != None:
            print("Load model..")
            model_coarse.load_state_dict(
                torch.load(args.load_coarse_model,
                           map_location=torch.device(device)))
            for params in model_coarse.parameters():
                params.requires_grad = False
            model_coarse.eval()
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=1,
                                                   shuffle=True,
                                                   num_workers=0)
        val_loader = torch.utils.data.DataLoader(val_data,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=0)
        smpl_file_name = "SMPLs/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl"
        smpl_model = smplx.create(smpl_file_name, model_type='smpl')
        smpl_model.batchsize = args.batchsize
        solver = ImageWiseSolver(model_coarse, model_fine, smpl_estimator,
                                 smpl_model, position_encoder,
                                 direction_encoder, args)
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir,
                 [model_coarse, model_fine, smpl_estimator],
                 ['model_coarse.pt', 'model_fine.pt', 'smpl_estimator.pt'],
                 parser)
Example #26
0
    model_params = dict(model_path=args.get('model_folder'),
                        #  joint_mapper=joint_mapper,
                        create_global_orient=True,
                        create_body_pose=not args.get('use_vposer'),
                        create_betas=True,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        create_expression=True,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        dtype=dtype,
                        **args)

    model = smplx.create(**model_params)
    model = model.to(device=device)

    batch_size = args.get('batch_size', 1)
    use_vposer = args.get('use_vposer', True)
    vposer, pose_embedding = [None, ] * 2
    vposer_ckpt = args.get('vposer_ckpt', '')
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype, device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()
Example #27
0
def optimize_visulize():
    # read scene mesh, scene sdf
    scene, cur_scene_verts, s_grid_min_batch, s_grid_max_batch, s_sdf_batch = read_mesh_sdf(
        args.dataset_path, args.dataset, args.scene_name)
    smplx_model = smplx.create(args.smplx_model_path,
                               model_type='smplx',
                               gender='neutral',
                               ext='npz',
                               num_pca_comps=12,
                               create_global_orient=True,
                               create_body_pose=True,
                               create_betas=True,
                               create_left_hand_pose=True,
                               create_right_hand_pose=True,
                               create_expression=True,
                               create_jaw_pose=True,
                               create_leye_pose=True,
                               create_reye_pose=True,
                               create_transl=True,
                               batch_size=1).to(device)
    print('[INFO] smplx model loaded.')

    vposer_model, _ = load_vposer(args.vposer_model_path, vp_model='snapshot')
    vposer_model = vposer_model.to(device)
    print('[INFO] vposer model loaded')

    ##################### load optimization results ##################
    shift_list = np.load('{}/{}/shift_list.npy'.format(
        args.optimize_result_dir, args.scene_name))
    rot_angle_list_1 = np.load('{}/{}/rot_angle_list_1.npy'.format(
        args.optimize_result_dir, args.scene_name))

    if args.optimize:
        body_params_opt_list_s1 = np.load(
            '{}/{}/body_params_opt_list_s1.npy'.format(
                args.optimize_result_dir, args.scene_name))
        body_params_opt_list_s2 = np.load(
            '{}/{}/body_params_opt_list_s2.npy'.format(
                args.optimize_result_dir, args.scene_name))
    body_verts_sample_list = np.load('{}/{}/body_verts_sample_list.npy'.format(
        args.optimize_result_dir, args.scene_name))
    n_sample = len(body_verts_sample_list)

    ########################## evaluation (contact/collision score) #########################
    loss_non_collision_sample, loss_contact_sample = 0, 0
    loss_non_collision_opt_s1, loss_contact_opt_s1 = 0, 0
    loss_non_collision_opt_s2, loss_contact_opt_s2 = 0, 0
    body_params_prox_list_s1, body_params_prox_list_s2 = [], []
    body_verts_opt_prox_s2_list = []

    for cnt in tqdm(range(0, n_sample)):
        body_verts_sample = body_verts_sample_list[cnt]  # [10475, 3]

        # smplx params --> body mesh
        body_params_opt_s1 = torch.from_numpy(
            body_params_opt_list_s1[cnt]).float().unsqueeze(0).to(
                device)  # [1,75]
        body_params_opt_s1 = convert_to_3D_rot(
            body_params_opt_s1)  # tensor, [bs=1, 72]
        body_pose_joint = vposer_model.decode(body_params_opt_s1[:, 16:48],
                                              output_type='aa').view(
                                                  1, -1)  # [1, 63]
        body_verts_opt_s1 = gen_body_mesh(body_params_opt_s1, body_pose_joint,
                                          smplx_model)[0]  # [n_body_vert, 3]
        body_verts_opt_s1 = body_verts_opt_s1.detach().cpu().numpy()

        body_params_opt_s2 = torch.from_numpy(
            body_params_opt_list_s2[cnt]).float().unsqueeze(0).to(device)
        body_params_opt_s2 = convert_to_3D_rot(
            body_params_opt_s2)  # tensor, [bs=1, 72]
        body_pose_joint = vposer_model.decode(body_params_opt_s2[:, 16:48],
                                              output_type='aa').view(1, -1)
        body_verts_opt_s2 = gen_body_mesh(body_params_opt_s2, body_pose_joint,
                                          smplx_model)[0]
        body_verts_opt_s2 = body_verts_opt_s2.detach().cpu().numpy()

        ####################### transfrom local body verts to prox coodinate system ####################
        # generated body verts from cvae, before optimization
        body_verts_sample_prox = np.zeros(
            body_verts_sample.shape)  # [10475, 3]
        temp = body_verts_sample - shift_list[cnt]
        body_verts_sample_prox[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \
                                       temp[:, 1] * math.sin(-rot_angle_list_1[cnt])
        body_verts_sample_prox[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \
                                       temp[:, 1] * math.cos(-rot_angle_list_1[cnt])
        body_verts_sample_prox[:, 2] = temp[:, 2]

        ######### optimized body verts
        trans_matrix_1 = np.array([[
            math.cos(-rot_angle_list_1[cnt]),
            -math.sin(-rot_angle_list_1[cnt]), 0, 0
        ],
                                   [
                                       math.sin(-rot_angle_list_1[cnt]),
                                       math.cos(-rot_angle_list_1[cnt]), 0, 0
                                   ], [0, 0, 1, 0], [0, 0, 0, 1]])
        trans_matrix_2 = np.array([[1, 0, 0, -shift_list[cnt][0]],
                                   [0, 1, 0, -shift_list[cnt][1]],
                                   [0, 0, 1, -shift_list[cnt][2]],
                                   [0, 0, 0, 1]])
        ### stage 1: simple optimization results
        body_verts_opt_prox_s1 = np.zeros(
            body_verts_opt_s1.shape)  # [10475, 3]
        temp = body_verts_opt_s1 - shift_list[cnt]
        body_verts_opt_prox_s1[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \
                                       temp[:, 1] * math.sin(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s1[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \
                                       temp[:, 1] * math.cos(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s1[:, 2] = temp[:, 2]
        # transfrom local params to prox coordinate system
        body_params_prox_s1 = update_globalRT_for_smplx(
            body_params_opt_s1[0].cpu().numpy(), smplx_model,
            trans_matrix_2)  # [72]
        body_params_prox_s1 = update_globalRT_for_smplx(
            body_params_prox_s1, smplx_model, trans_matrix_1)  # [72]
        body_params_prox_list_s1.append(body_params_prox_s1)

        ### stage 2: advanced optimiation results
        body_verts_opt_prox_s2 = np.zeros(
            body_verts_opt_s2.shape)  # [10475, 3]
        temp = body_verts_opt_s2 - shift_list[cnt]
        body_verts_opt_prox_s2[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \
                                       temp[:, 1] * math.sin(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s2[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \
                                       temp[:, 1] * math.cos(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s2[:, 2] = temp[:, 2]
        # transfrom local params to prox coordinate system
        body_params_prox_s2 = update_globalRT_for_smplx(
            body_params_opt_s2[0].cpu().numpy(), smplx_model,
            trans_matrix_2)  # [72]
        body_params_prox_s2 = update_globalRT_for_smplx(
            body_params_prox_s2, smplx_model, trans_matrix_1)  # [72]
        body_params_prox_list_s2.append(body_params_prox_s2)
        body_verts_opt_prox_s2_list.append(body_verts_opt_prox_s2)

        ########################### visualization ##########################
        if args.visualize:
            body_mesh_sample = o3d.geometry.TriangleMesh()
            body_mesh_sample.vertices = o3d.utility.Vector3dVector(
                body_verts_sample_prox)
            body_mesh_sample.triangles = o3d.utility.Vector3iVector(
                smplx_model.faces)
            body_mesh_sample.compute_vertex_normals()

            body_mesh_opt_s1 = o3d.geometry.TriangleMesh()
            body_mesh_opt_s1.vertices = o3d.utility.Vector3dVector(
                body_verts_opt_prox_s1)
            body_mesh_opt_s1.triangles = o3d.utility.Vector3iVector(
                smplx_model.faces)
            body_mesh_opt_s1.compute_vertex_normals()

            body_mesh_opt_s2 = o3d.geometry.TriangleMesh()
            body_mesh_opt_s2.vertices = o3d.utility.Vector3dVector(
                body_verts_opt_prox_s2)
            body_mesh_opt_s2.triangles = o3d.utility.Vector3iVector(
                smplx_model.faces)
            body_mesh_opt_s2.compute_vertex_normals()

            o3d.visualization.draw_geometries(
                [scene, body_mesh_sample])  # generated body mesh by cvae
            o3d.visualization.draw_geometries([scene, body_mesh_opt_s1
                                               ])  # simple-optimized body mesh
            o3d.visualization.draw_geometries([scene, body_mesh_opt_s2
                                               ])  # adv-optimizaed body mesh

        #####################  compute non-collision/contact score ##############
        # body verts before optimization
        body_verts_sample_prox_tensor = torch.from_numpy(
            body_verts_sample_prox).float().unsqueeze(0).to(
                device)  # [1, 10475, 3]
        norm_verts_batch = (body_verts_sample_prox_tensor - s_grid_min_batch
                            ) / (s_grid_max_batch - s_grid_min_batch) * 2 - 1
        body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1),
                                       norm_verts_batch[:, :, [2, 1, 0]].view(
                                           -1, 10475, 1, 1, 3),
                                       padding_mode='border')
        if body_sdf_batch.lt(0).sum().item(
        ) < 1:  # if no interpenetration: negative sdf entries is less than one
            loss_non_collision_sample += 1.0
            loss_contact_sample += 0.0
        else:
            loss_non_collision_sample += (body_sdf_batch >
                                          0).sum().float().item() / 10475.0
            loss_contact_sample += 1.0

        # stage 1: simple optimization results
        body_verts_opt_prox_tensor = torch.from_numpy(
            body_verts_opt_prox_s1).float().unsqueeze(0).to(
                device)  # [1, 10475, 3]
        norm_verts_batch = (body_verts_opt_prox_tensor - s_grid_min_batch) / (
            s_grid_max_batch - s_grid_min_batch) * 2 - 1
        body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1),
                                       norm_verts_batch[:, :, [2, 1, 0]].view(
                                           -1, 10475, 1, 1, 3),
                                       padding_mode='border')
        if body_sdf_batch.lt(0).sum().item(
        ) < 1:  # if no interpenetration: negative sdf entries is less than one
            loss_non_collision_opt_s1 += 1.0
            loss_contact_opt_s1 += 0.0
        else:
            loss_non_collision_opt_s1 += (body_sdf_batch >
                                          0).sum().float().item() / 10475.0
            loss_contact_opt_s1 += 1.0

        # stage 2: advanced optimization results
        body_verts_opt_prox_tensor = torch.from_numpy(
            body_verts_opt_prox_s2).float().unsqueeze(0).to(
                device)  # [1, 10475, 3]
        norm_verts_batch = (body_verts_opt_prox_tensor - s_grid_min_batch) / (
            s_grid_max_batch - s_grid_min_batch) * 2 - 1
        body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1),
                                       norm_verts_batch[:, :, [2, 1, 0]].view(
                                           -1, 10475, 1, 1, 3),
                                       padding_mode='border')
        if body_sdf_batch.lt(0).sum().item(
        ) < 1:  # if no interpenetration: negative sdf entries is less than one
            loss_non_collision_opt_s2 += 1.0
            loss_contact_opt_s2 += 0.0
        else:
            loss_non_collision_opt_s2 += (body_sdf_batch >
                                          0).sum().float().item() / 10475.0
            loss_contact_opt_s2 += 1.0

    print('scene:', args.scene_name)

    loss_non_collision_sample = loss_non_collision_sample / n_sample
    loss_contact_sample = loss_contact_sample / n_sample
    print('w/o optimization body: non_collision score:',
          loss_non_collision_sample)
    print('w/o optimization body: contact score:', loss_contact_sample)

    loss_non_collision_opt_s1 = loss_non_collision_opt_s1 / n_sample
    loss_contact_opt_s1 = loss_contact_opt_s1 / n_sample
    print('optimized body s1: non_collision score:', loss_non_collision_opt_s1)
    print('optimized body s1: contact score:', loss_contact_opt_s1)

    loss_non_collision_opt_s2 = loss_non_collision_opt_s2 / n_sample
    loss_contact_opt_s2 = loss_contact_opt_s2 / n_sample
    print('optimized body s2: non_collision score:', loss_non_collision_opt_s2)
    print('optimized body s2: contact score:', loss_contact_opt_s2)
    def __init__(self, opt):

        BaseModel.initialize(self, opt)

        # set params
        self.inputSize = opt.inputSize

        self.single_branch = opt.single_branch
        self.two_branch = opt.two_branch
        self.aux_as_main = opt.aux_as_main
        assert (not self.single_branch and self.two_branch) or (
            self.single_branch and not self.two_branch)
        if self.aux_as_main:
            assert self.single_branch

        if opt.isTrain and opt.process_rank <= 0:
            if self.two_branch:
                print("!!!!!!!!!!!! Attention, use two branch framework")
                time.sleep(10)
            else:
                print("!!!!!!!!!!!! Attention, use one branch framework")

        self.total_params_dim = opt.total_params_dim
        self.cam_params_dim = opt.cam_params_dim
        self.pose_params_dim = opt.pose_params_dim
        self.shape_params_dim = opt.shape_params_dim
        self.top_finger_joints_type = opt.top_finger_joints_type

        assert(self.total_params_dim ==
               self.cam_params_dim+self.pose_params_dim+self.shape_params_dim)

        if opt.dist:
            self.batch_size = opt.batchSize // torch.distributed.get_world_size()
        else:
            self.batch_size = opt.batchSize
        nb = self.batch_size

        # set input image and 2d keypoints
        self.input_img = self.Tensor(
            nb, opt.input_nc, self.inputSize, self.inputSize)
      
        # joints 2d
        self.keypoints = self.Tensor(nb, opt.num_joints, 2)
        self.keypoints_weights = self.Tensor(nb, opt.num_joints)

        # mano pose params
        self.gt_pose_params = self.Tensor(nb, opt.pose_params_dim)
        self.mano_params_weight = self.Tensor(nb, 1)

        # joints 3d
        self.joints_3d = self.Tensor(nb, opt.num_joints, 3)
        self.joints_3d_weight = self.Tensor(nb, opt.num_joints, 1)

        # load mean params, the mean params are from HMR
        self.mean_param_file = osp.join(
            opt.model_root, opt.mean_param_file)
        self.load_params()

        # set differential SMPL (implemented with pytorch) and smpl_renderer
        # smplx_model_path = osp.join(opt.model_root, opt.smplx_model_file)
        smplx_model_path = opt.smplx_model_file
        self.smplx = smplx.create(
            smplx_model_path, 
            model_type = "smplx", 
            batch_size = self.batch_size,
            gender = 'neutral',
            num_betas = 10,
            use_pca = False,
            ext='pkl').cuda()

        # set encoder and optimizer
        self.encoder = H3DWEncoder(opt, self.mean_params).cuda()
        if opt.dist:
            self.encoder = DistributedDataParallel(
                self.encoder, device_ids=[torch.cuda.current_device()])
        if self.isTrain:
            self.optimizer_E = torch.optim.Adam(
                self.encoder.parameters(), lr=opt.lr_e)
        
        # load pretrained / trained weights for encoder
        if self.isTrain:
            assert False, "Not implemented Yet"
            pass
        else:
            # load trained model for testing
            which_epoch = opt.which_epoch
            if which_epoch == 'latest' or int(which_epoch)>0:
                self.success_load = self.load_network(self.encoder, 'encoder', which_epoch)
            else:
                checkpoint_path = opt.checkpoint_path
                if not osp.exists(checkpoint_path): 
                    print(f"Error: {checkpoint_path} does not exists")
                    self.success_load = False
                else:
                    if self.opt.dist:
                        self.encoder.module.load_state_dict(torch.load(
                            checkpoint_path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device())))
                    else:
                        saved_weights = torch.load(checkpoint_path)
                        self.encoder.load_state_dict(saved_weights)
                    self.success_load = True
Example #29
0
    proxd_path = os.path.join(dataset_path, 'PROXD')
    cam2world_path = os.path.join(dataset_path, 'cam2world')
    vposer_model_path = os.path.join(dataset_path, 'body_models/vposer_v1_0')
    smplx_model_path = os.path.join(dataset_path, 'body_models/smplx_model')

    scene_mesh_path = os.path.join(dataset_path, 'scenes_downsampled')
    scene_sdf_path = os.path.join(dataset_path, 'scenes_sdf')

    smplx_model = smplx.create(smplx_model_path, model_type='smplx',
                               gender='neutral', ext='npz',
                               num_pca_comps=12,
                               create_global_orient=True,
                               create_body_pose=True,
                               create_betas=True,
                               create_left_hand_pose=True,
                               create_right_hand_pose=True,
                               create_expression=True,
                               create_jaw_pose=True,
                               create_leye_pose=True,
                               create_reye_pose=True,
                               create_transl=True,
                               batch_size=batch_size
                               ).to(device)
    print('[INFO] smplx model loaded.')

    vposer_model, _ = load_vposer(vposer_model_path, vp_model='snapshot')
    vposer_model = vposer_model.to(device)
    print('[INFO] vposer model loaded')


    ######## set body mesh gen dataloader ###########
Example #30
0
def save_grab_vertices(cfg, logger=None, **params):

    grab_path = cfg.grab_path
    out_path = cfg.out_path
    makepath(out_path)

    if logger is None:
        logger = makelogger(log_dir=os.path.join(out_path,
                                                 'grab_preprocessing.log'),
                            mode='a').info
    else:
        logger = logger
    logger('Starting to get vertices for GRAB!')

    all_seqs = glob.glob(grab_path + '/*/*.npz')

    logger('Total sequences: %d' % len(all_seqs))

    # stime = datetime.now().replace(microsecond=0)
    # shutil.copy2(sys.argv[0],
    #              os.path.join(out_path,
    #                           os.path.basename(sys.argv[0]).replace('.py','_%s.py' % datetime.strftime(stime,'%Y%m%d_%H%M'))))

    if out_path is None:
        out_path = grab_path

    for sequence in tqdm(all_seqs):

        outfname = makepath(sequence.replace(grab_path, out_path).replace(
            '.npz', '_verts_body.npz'),
                            isfile=True)

        action_name = os.path.basename(sequence)
        if os.path.exists(outfname):
            logger('Results for %s split already exist.' % (action_name))
            continue
        else:
            logger('Processing data for %s split.' % (action_name))

        seq_data = parse_npz(sequence)
        n_comps = seq_data['n_comps']
        gender = seq_data['gender']

        T = seq_data.n_frames

        if cfg.save_body_verts:

            sbj_mesh = os.path.join(grab_path, '..', seq_data.body.vtemp)
            sbj_vtemp = np.array(Mesh(filename=sbj_mesh).vertices)

            sbj_m = smplx.create(model_path=cfg.model_path,
                                 model_type='smplx',
                                 gender=gender,
                                 num_pca_comps=n_comps,
                                 v_template=sbj_vtemp,
                                 batch_size=T)

            sbj_parms = params2torch(seq_data.body.params)
            verts_sbj = to_cpu(sbj_m(**sbj_parms).vertices)
            np.savez_compressed(outfname, verts_body=verts_sbj)

        if cfg.save_lhand_verts:
            lh_mesh = os.path.join(grab_path, '..', seq_data.lhand.vtemp)
            lh_vtemp = np.array(Mesh(filename=lh_mesh).vertices)

            lh_m = smplx.create(model_path=cfg.model_path,
                                model_type='mano',
                                is_rhand=False,
                                v_template=lh_vtemp,
                                num_pca_comps=n_comps,
                                flat_hand_mean=True,
                                batch_size=T)

            lh_parms = params2torch(seq_data.lhand.params)
            verts_lh = to_cpu(lh_m(**lh_parms).vertices)
            np.savez_compressed(outfname.replace('_verts_body.npz',
                                                 '_verts_lhand.npz'),
                                verts_body=verts_lh)

        if cfg.save_rhand_verts:
            rh_mesh = os.path.join(grab_path, '..', seq_data.rhand.vtemp)
            rh_vtemp = np.array(Mesh(filename=rh_mesh).vertices)

            rh_m = smplx.create(model_path=cfg.model_path,
                                model_type='mano',
                                is_rhand=True,
                                v_template=rh_vtemp,
                                num_pca_comps=n_comps,
                                flat_hand_mean=True,
                                batch_size=T)

            rh_parms = params2torch(seq_data.body.params)
            verts_rh = to_cpu(rh_m(**rh_parms).vertices)
            np.savez_compressed(outfname.replace('_verts_body.npz',
                                                 '_verts_rhand.npz'),
                                verts_body=verts_rh)

        if cfg.save_object_verts:

            obj_mesh = os.path.join(grab_path, '..',
                                    seq_data.object.object_mesh)
            obj_vtemp = np.array(Mesh(filename=obj_mesh).vertices)
            sample_id = np.random.choice(obj_vtemp.shape[0],
                                         cfg.n_verts_sample,
                                         replace=False)
            obj_m = ObjectModel(v_template=obj_vtemp[sample_id], batch_size=T)
            obj_parms = params2torch(seq_data.object.params)
            verts_obj = to_cpu(obj_m(**obj_parms).vertices)
            np.savez_compressed(outfname.replace('_verts_body.npz',
                                                 '_verts_object.npz'),
                                verts_object=verts_obj)

        logger('Processing finished')