示例#1
0
def get_vposer_mean_pose(vposer_dir="misc/vposer_dir"):
    pose_embedding = torch.zeros([1, 32])
    vposer, _ = load_vposer(vposer_dir, vp_model="snapshot")
    vposer.eval()

    mean_pose = (vposer.decode(pose_embedding,
                               output_type="aa").contiguous().view(1, -1))
    return mean_pose
示例#2
0
    def __init__(self, trainconfig, lossconfig):
        for key, val in trainconfig.items():
            setattr(self, key, val)


        for key, val in lossconfig.items():
            setattr(self, key, val)

        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)

        if len(self.ckp_dir) > 0:
            self.resume_training=True

        ### define model

        if self.use_cont_rot:
            n_dim_body=72+3
        else:
            n_dim_body=72

        self.model_h_latentD = 256
        self.model_h = HumanCVAES2(latentD_g=self.model_h_latentD,
                                     latentD_l=self.model_h_latentD,
                                     scene_model_ckpt=self.scene_model_ckpt,
                                     n_dim_body=n_dim_body,
                                     n_dim_scene=self.model_h_latentD)

        self.optimizer_h = optim.Adam(self.model_h.parameters(), 
                                      lr=self.init_lr_h)


        ### body mesh model
        self.vposer, _ = load_vposer(self.vposer_ckpt_path, vp_model='snapshot')
        self.body_mesh_model = smplx.create(self.human_model_path, 
                                            model_type='smplx',
                                            gender='neutral', ext='npz',
                                            num_pca_comps=12,
                                            create_global_orient=True,
                                            create_body_pose=True,
                                            create_betas=True,
                                            create_left_hand_pose=True,
                                            create_right_hand_pose=True,
                                            create_expression=True,
                                            create_jaw_pose=True,
                                            create_leye_pose=True,
                                            create_reye_pose=True,
                                            create_transl=True,
                                            batch_size=self.batch_size
                                            )

        self.smplx_face_idx = np.load(os.path.join(self.human_model_path, 
                                        'smplx/SMPLX_NEUTRAL.npz'),
                                    allow_pickle=True)['f'].reshape(-1,3)
        self.smplx_face_idx = torch.tensor(self.smplx_face_idx.astype(np.int64), 
                                            device=self.device)

        print('--[INFO] device: '+str(torch.cuda.get_device_name(self.device)) )
示例#3
0
    def __init__(self, fittingconfig, lossconfig):


        for key, val in fittingconfig.items():
            setattr(self, key, val)


        for key, val in lossconfig.items():
            setattr(self, key, val)


        self.vposer, _ = load_vposer(self.vposer_ckpt_path, 
                                     vp_model='snapshot')
        self.body_mesh_model = smplx.create(self.human_model_path, 
                                       model_type='smplx',
                                       gender='neutral', ext='npz',
                                       num_pca_comps=12,
                                       create_global_orient=True,
                                       create_body_pose=True,
                                       create_betas=True,
                                       create_left_hand_pose=True,
                                       create_right_hand_pose=True,
                                       create_expression=True,
                                       create_jaw_pose=True,
                                       create_leye_pose=True,
                                       create_reye_pose=True,
                                       create_transl=True,
                                       batch_size=self.batch_size
                                       )
        self.vposer.to(self.device)
        self.body_mesh_model.to(self.device)

        self.xhr_rec = Variable(torch.randn(1,75).to(self.device), requires_grad=True)
        self.optimizer = optim.Adam([self.xhr_rec], lr=self.init_lr_h)




        ## read scene sdf
        with open(self.scene_sdf_path+'.json') as f:
                sdf_data = json.load(f)
                grid_min = np.array(sdf_data['min'])
                grid_max = np.array(sdf_data['max'])
                grid_dim = sdf_data['dim']
        sdf = np.load(self.scene_sdf_path + '_sdf.npy').reshape(grid_dim, grid_dim, grid_dim)

        self.s_grid_min_batch = torch.tensor(grid_min, dtype=torch.float32, device=self.device).unsqueeze(0)
        self.s_grid_max_batch = torch.tensor(grid_max, dtype=torch.float32, device=self.device).unsqueeze(0)
        self.s_sdf_batch = torch.tensor(sdf, dtype=torch.float32, device=self.device).unsqueeze(0)

        ## read scene vertices
        scene_o3d = o3d.io.read_triangle_mesh(self.scene_verts_path)
        scene_verts = np.asarray(scene_o3d.vertices)
        self.s_verts_batch = torch.tensor(scene_verts, dtype=torch.float32, device=self.device).unsqueeze(0)
def sample_vposer(expr_dir, bm, num_samples=5, vp_model='snapshot'):
    from human_body_prior.tools.omni_tools import id_generator, makepath
    from human_body_prior.tools.model_loader import load_vposer
    from human_body_prior.tools.omni_tools import copy2cpu

    vposer_pt, ps = load_vposer(expr_dir, vp_model=vp_model)

    sampled_pose_body = copy2cpu(vposer_pt.sample_poses(num_poses=num_samples))

    out_dir = makepath(
        os.path.join(ps.work_dir, 'evaluations', 'pose_generation'))
    out_imgpath = os.path.join(out_dir, '%s.png' % id_generator(6))

    dump_vposer_samples(bm, sampled_pose_body, out_imgpath)
    print('Dumped samples at %s' % out_dir)
    return sampled_pose_body
示例#5
0
    def __init__(
        self,
        debug=True,
        batch_size=0,
        hand_pca_nb=6,
        head_center_idx=8949,
        smpl_root="assets/models",
        vposer_dir="assets/vposer",
        vposer_dim=32,
        parts_path="assets/models/smplx/smplx_parts_segm.pkl",
        mano_corresp_path="assets/models/MANO_SMPLX_vertex_ids.pkl",
    ):
        super().__init__()
        self.debug = debug
        self.hand_pca_nb = hand_pca_nb
        self.head_center_idx = head_center_idx
        self.smplx_vertex_nb = 10475
        self.vposer_dim = vposer_dim

        # Initialize SMPL-X model
        self.smpl_model = vposeutils.get_smplx(model_root=smpl_root,
                                               vposer_dir=vposer_dir)
        self.smpl_f = self.smpl_model.faces
        # Get vposer
        self.vposer = load_vposer(vposer_dir, vp_model="snapshot")[0]
        self.vposer.eval()

        # Translate human so that head is at camera level
        self.set_head2cam_trans()

        self.armfaces = smplvis.filter_parts(self.smpl_f, parts_path)
        with open(mano_corresp_path, "rb") as p_f:
            self.mano_corresp = pickle.load(p_f)

        # Initialize model parameters
        self.batch_size = batch_size
        left_hand_pose = self.smpl_model.left_hand_pose.repeat(batch_size, 1)
        right_hand_pose = self.smpl_model.right_hand_pose.repeat(batch_size, 1)
        pose_embedding = (
            self.get_neutral_pose_embedding().unsqueeze(0).repeat(
                batch_size, 1))
        self.pose_embedding = torch.nn.Parameter(pose_embedding,
                                                 requires_grad=True)
        self.left_hand_pose = torch.nn.Parameter(left_hand_pose,
                                                 requires_grad=True)
        self.right_hand_pose = torch.nn.Parameter(right_hand_pose,
                                                  requires_grad=True)
def main(**args):
    input_media = args.pop('input_media')
    config = easy_configuration.configure(**args)
    device = torch.device('cpu')
    dtype=torch.float32

    body_model = config['neutral_model']
    body_model.to(device= device)
    camera = config['camera']

    vposer_ckpt = args['vposer_ckpt']
    vposer_ckpt = osp.expandvars(vposer_ckpt)
    vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
    vposer = vposer.to(device=device)
    vposer.eval()

    viewer = MeshViewer()


    body_color=(1.0, 1.0, 0.9, 1.0)
    img_np = cv2.imread(input_media)
    img = get_img(input_media)
    # Keypoints for the first person
    keypoints = get_keypoints(img_np)[[0]]

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    gt_joints = gt_joints.to(device=device, dtype=dtype)


    def render_embedding(vposer, pose_embedding, body_model, viewer):
        body_pose = vposer.decode(pose_embedding, output_type='aa').view(1, -1)
        body_pose.to(device=device)
        body_model_output = body_model(body_pose=body_pose)

        vertices = body_model_output.vertices.detach().cpu().numpy().squeeze()

        viewer.update_mesh(vertices,  body_model.faces)

    while True:
        pose_embedding = torch.randn([1,32],  dtype=torch.float32, device=device) * 10929292929
        render_embedding(vposer, pose_embedding, body_model, viewer)
示例#7
0
    def __init__(self, testconfig):
        for key, val in testconfig.items():
            setattr(self, key, val)

        if not os.path.exists(self.ckpt_dir):
            print('--[ERROR] checkpoints do not exist')
            sys.exit()

        #define model
        if self.use_cont_rot:
            n_dim_body = 72 + 3
        else:
            n_dim_body = 72

        self.model_h_latentD = 256
        self.model_h = HumanCVAES2(latentD_g=self.model_h_latentD,
                                   latentD_l=self.model_h_latentD,
                                   n_dim_body=n_dim_body,
                                   n_dim_scene=self.model_h_latentD,
                                   test=True)

        ### body mesh model
        self.vposer, _ = load_vposer(self.vposer_ckpt_path,
                                     vp_model='snapshot')
        self.body_mesh_model = smplx.create(self.human_model_path,
                                            model_type='smplx',
                                            gender='neutral',
                                            ext='npz',
                                            num_pca_comps=12,
                                            create_global_orient=True,
                                            create_body_pose=True,
                                            create_betas=True,
                                            create_left_hand_pose=True,
                                            create_right_hand_pose=True,
                                            create_expression=True,
                                            create_jaw_pose=True,
                                            create_leye_pose=True,
                                            create_reye_pose=True,
                                            create_transl=True,
                                            batch_size=self.n_samples)
def load_avakhitov_fits(dp, load_betas=True, load_body_poses=True, load_expressions=False, load_fid_lst=True):
    result = dict()
    for flag, k, fn_no_ext in [
        [load_betas, 'betas', 'betas'],
        [load_body_poses, 'body_poses', 'poses'],
        [load_expressions, 'expressions', 'expressions'],
        [load_fid_lst, 'fid_lst', 'fid_lst']
    ]:
        if flag:
            load_fp = osp.join(dp, f'{fn_no_ext}.npy')
            try:
                loaded = np.load(load_fp)
            except:
                print(load_fp)
                raise Exception()

            if fn_no_ext == 'poses':
                #load the vposer model
                if loaded.shape[1] == 69:
                    pose_body = loaded[:, 0:32]
                else:
                    vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
                    vposer.eval()
                    pose_body_vp = torch.tensor(loaded[:, 0:32])      
                    #convert from vposer to rotation matrices
                    pose_body_mats = vposer.decode(pose_body_vp).reshape(len(loaded), -1, 3, 3).detach().cpu().numpy()
                    pose_body = np.zeros((pose_body_mats.shape[0], 63))
                    for i in range(0, pose_body_mats.shape[0]):
                        for j in range(0, pose_body_mats.shape[1]):
                            rot_vec, jac = cv2.Rodrigues(pose_body_mats[i,j])
                            pose_body[i, 3*j : 3*j+3] = rot_vec.reshape(-1)   
                result[k] = pose_body
                result['global_rvecs'] = loaded[:, -3:]
                result['global_tvecs'] = loaded[:, -6:-3]
                result['n'] = len(loaded)
            else:
                result[k] = loaded
    return result
示例#9
0
                               num_pca_comps=12,
                               create_global_orient=True,
                               create_body_pose=True,
                               create_betas=True,
                               create_left_hand_pose=True,
                               create_right_hand_pose=True,
                               create_expression=True,
                               create_jaw_pose=True,
                               create_leye_pose=True,
                               create_reye_pose=True,
                               create_transl=True,
                               batch_size=batch_size
                               ).to(device)
    print('[INFO] smplx model loaded.')

    vposer_model, _ = load_vposer(vposer_model_path, vp_model='snapshot')
    vposer_model = vposer_model.to(device)
    print('[INFO] vposer model loaded')


    ######## set body mesh gen dataloader ###########
    scene_name = 'MPH1Library'
    dataset = PreprocessLoader(scene_name=scene_name)
    dataset.load_body_params(proxd_path)
    dataset.load_cam_params(cam2world_path)
    dataset.load_scene(scene_mesh_file=scene_mesh_path)
    bps_dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=4,
                                                 drop_last=True)  # drop_last=T:  smlpx model needs to predefine bs

    ######### get body mesh for all samples ###########
    body_verts_list = []
示例#10
0
                        **args)

    model = smplx.create(**model_params)
    model = model.to(device=device)

    batch_size = args.get('batch_size', 1)
    use_vposer = args.get('use_vposer', True)
    vposer, pose_embedding = [None, ] * 2
    vposer_ckpt = args.get('vposer_ckpt', '')
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype, device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    for pkl_path in pkl_paths:
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        if use_vposer:
            with torch.no_grad():
                pose_embedding[:] = torch.tensor(
                    data['body_pose'], device=device, dtype=dtype)

        est_params = {}
        for key, val in data.items():
            if key == 'body_pose' and use_vposer:
                body_pose = vposer.decode(
示例#11
0
def optimize_visulize():
    # read scene mesh, scene sdf
    scene, cur_scene_verts, s_grid_min_batch, s_grid_max_batch, s_sdf_batch = read_mesh_sdf(
        args.dataset_path, args.dataset, args.scene_name)
    smplx_model = smplx.create(args.smplx_model_path,
                               model_type='smplx',
                               gender='neutral',
                               ext='npz',
                               num_pca_comps=12,
                               create_global_orient=True,
                               create_body_pose=True,
                               create_betas=True,
                               create_left_hand_pose=True,
                               create_right_hand_pose=True,
                               create_expression=True,
                               create_jaw_pose=True,
                               create_leye_pose=True,
                               create_reye_pose=True,
                               create_transl=True,
                               batch_size=1).to(device)
    print('[INFO] smplx model loaded.')

    vposer_model, _ = load_vposer(args.vposer_model_path, vp_model='snapshot')
    vposer_model = vposer_model.to(device)
    print('[INFO] vposer model loaded')

    ##################### load optimization results ##################
    shift_list = np.load('{}/{}/shift_list.npy'.format(
        args.optimize_result_dir, args.scene_name))
    rot_angle_list_1 = np.load('{}/{}/rot_angle_list_1.npy'.format(
        args.optimize_result_dir, args.scene_name))

    if args.optimize:
        body_params_opt_list_s1 = np.load(
            '{}/{}/body_params_opt_list_s1.npy'.format(
                args.optimize_result_dir, args.scene_name))
        body_params_opt_list_s2 = np.load(
            '{}/{}/body_params_opt_list_s2.npy'.format(
                args.optimize_result_dir, args.scene_name))
    body_verts_sample_list = np.load('{}/{}/body_verts_sample_list.npy'.format(
        args.optimize_result_dir, args.scene_name))
    n_sample = len(body_verts_sample_list)

    ########################## evaluation (contact/collision score) #########################
    loss_non_collision_sample, loss_contact_sample = 0, 0
    loss_non_collision_opt_s1, loss_contact_opt_s1 = 0, 0
    loss_non_collision_opt_s2, loss_contact_opt_s2 = 0, 0
    body_params_prox_list_s1, body_params_prox_list_s2 = [], []
    body_verts_opt_prox_s2_list = []

    for cnt in tqdm(range(0, n_sample)):
        body_verts_sample = body_verts_sample_list[cnt]  # [10475, 3]

        # smplx params --> body mesh
        body_params_opt_s1 = torch.from_numpy(
            body_params_opt_list_s1[cnt]).float().unsqueeze(0).to(
                device)  # [1,75]
        body_params_opt_s1 = convert_to_3D_rot(
            body_params_opt_s1)  # tensor, [bs=1, 72]
        body_pose_joint = vposer_model.decode(body_params_opt_s1[:, 16:48],
                                              output_type='aa').view(
                                                  1, -1)  # [1, 63]
        body_verts_opt_s1 = gen_body_mesh(body_params_opt_s1, body_pose_joint,
                                          smplx_model)[0]  # [n_body_vert, 3]
        body_verts_opt_s1 = body_verts_opt_s1.detach().cpu().numpy()

        body_params_opt_s2 = torch.from_numpy(
            body_params_opt_list_s2[cnt]).float().unsqueeze(0).to(device)
        body_params_opt_s2 = convert_to_3D_rot(
            body_params_opt_s2)  # tensor, [bs=1, 72]
        body_pose_joint = vposer_model.decode(body_params_opt_s2[:, 16:48],
                                              output_type='aa').view(1, -1)
        body_verts_opt_s2 = gen_body_mesh(body_params_opt_s2, body_pose_joint,
                                          smplx_model)[0]
        body_verts_opt_s2 = body_verts_opt_s2.detach().cpu().numpy()

        ####################### transfrom local body verts to prox coodinate system ####################
        # generated body verts from cvae, before optimization
        body_verts_sample_prox = np.zeros(
            body_verts_sample.shape)  # [10475, 3]
        temp = body_verts_sample - shift_list[cnt]
        body_verts_sample_prox[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \
                                       temp[:, 1] * math.sin(-rot_angle_list_1[cnt])
        body_verts_sample_prox[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \
                                       temp[:, 1] * math.cos(-rot_angle_list_1[cnt])
        body_verts_sample_prox[:, 2] = temp[:, 2]

        ######### optimized body verts
        trans_matrix_1 = np.array([[
            math.cos(-rot_angle_list_1[cnt]),
            -math.sin(-rot_angle_list_1[cnt]), 0, 0
        ],
                                   [
                                       math.sin(-rot_angle_list_1[cnt]),
                                       math.cos(-rot_angle_list_1[cnt]), 0, 0
                                   ], [0, 0, 1, 0], [0, 0, 0, 1]])
        trans_matrix_2 = np.array([[1, 0, 0, -shift_list[cnt][0]],
                                   [0, 1, 0, -shift_list[cnt][1]],
                                   [0, 0, 1, -shift_list[cnt][2]],
                                   [0, 0, 0, 1]])
        ### stage 1: simple optimization results
        body_verts_opt_prox_s1 = np.zeros(
            body_verts_opt_s1.shape)  # [10475, 3]
        temp = body_verts_opt_s1 - shift_list[cnt]
        body_verts_opt_prox_s1[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \
                                       temp[:, 1] * math.sin(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s1[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \
                                       temp[:, 1] * math.cos(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s1[:, 2] = temp[:, 2]
        # transfrom local params to prox coordinate system
        body_params_prox_s1 = update_globalRT_for_smplx(
            body_params_opt_s1[0].cpu().numpy(), smplx_model,
            trans_matrix_2)  # [72]
        body_params_prox_s1 = update_globalRT_for_smplx(
            body_params_prox_s1, smplx_model, trans_matrix_1)  # [72]
        body_params_prox_list_s1.append(body_params_prox_s1)

        ### stage 2: advanced optimiation results
        body_verts_opt_prox_s2 = np.zeros(
            body_verts_opt_s2.shape)  # [10475, 3]
        temp = body_verts_opt_s2 - shift_list[cnt]
        body_verts_opt_prox_s2[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \
                                       temp[:, 1] * math.sin(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s2[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \
                                       temp[:, 1] * math.cos(-rot_angle_list_1[cnt])
        body_verts_opt_prox_s2[:, 2] = temp[:, 2]
        # transfrom local params to prox coordinate system
        body_params_prox_s2 = update_globalRT_for_smplx(
            body_params_opt_s2[0].cpu().numpy(), smplx_model,
            trans_matrix_2)  # [72]
        body_params_prox_s2 = update_globalRT_for_smplx(
            body_params_prox_s2, smplx_model, trans_matrix_1)  # [72]
        body_params_prox_list_s2.append(body_params_prox_s2)
        body_verts_opt_prox_s2_list.append(body_verts_opt_prox_s2)

        ########################### visualization ##########################
        if args.visualize:
            body_mesh_sample = o3d.geometry.TriangleMesh()
            body_mesh_sample.vertices = o3d.utility.Vector3dVector(
                body_verts_sample_prox)
            body_mesh_sample.triangles = o3d.utility.Vector3iVector(
                smplx_model.faces)
            body_mesh_sample.compute_vertex_normals()

            body_mesh_opt_s1 = o3d.geometry.TriangleMesh()
            body_mesh_opt_s1.vertices = o3d.utility.Vector3dVector(
                body_verts_opt_prox_s1)
            body_mesh_opt_s1.triangles = o3d.utility.Vector3iVector(
                smplx_model.faces)
            body_mesh_opt_s1.compute_vertex_normals()

            body_mesh_opt_s2 = o3d.geometry.TriangleMesh()
            body_mesh_opt_s2.vertices = o3d.utility.Vector3dVector(
                body_verts_opt_prox_s2)
            body_mesh_opt_s2.triangles = o3d.utility.Vector3iVector(
                smplx_model.faces)
            body_mesh_opt_s2.compute_vertex_normals()

            o3d.visualization.draw_geometries(
                [scene, body_mesh_sample])  # generated body mesh by cvae
            o3d.visualization.draw_geometries([scene, body_mesh_opt_s1
                                               ])  # simple-optimized body mesh
            o3d.visualization.draw_geometries([scene, body_mesh_opt_s2
                                               ])  # adv-optimizaed body mesh

        #####################  compute non-collision/contact score ##############
        # body verts before optimization
        body_verts_sample_prox_tensor = torch.from_numpy(
            body_verts_sample_prox).float().unsqueeze(0).to(
                device)  # [1, 10475, 3]
        norm_verts_batch = (body_verts_sample_prox_tensor - s_grid_min_batch
                            ) / (s_grid_max_batch - s_grid_min_batch) * 2 - 1
        body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1),
                                       norm_verts_batch[:, :, [2, 1, 0]].view(
                                           -1, 10475, 1, 1, 3),
                                       padding_mode='border')
        if body_sdf_batch.lt(0).sum().item(
        ) < 1:  # if no interpenetration: negative sdf entries is less than one
            loss_non_collision_sample += 1.0
            loss_contact_sample += 0.0
        else:
            loss_non_collision_sample += (body_sdf_batch >
                                          0).sum().float().item() / 10475.0
            loss_contact_sample += 1.0

        # stage 1: simple optimization results
        body_verts_opt_prox_tensor = torch.from_numpy(
            body_verts_opt_prox_s1).float().unsqueeze(0).to(
                device)  # [1, 10475, 3]
        norm_verts_batch = (body_verts_opt_prox_tensor - s_grid_min_batch) / (
            s_grid_max_batch - s_grid_min_batch) * 2 - 1
        body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1),
                                       norm_verts_batch[:, :, [2, 1, 0]].view(
                                           -1, 10475, 1, 1, 3),
                                       padding_mode='border')
        if body_sdf_batch.lt(0).sum().item(
        ) < 1:  # if no interpenetration: negative sdf entries is less than one
            loss_non_collision_opt_s1 += 1.0
            loss_contact_opt_s1 += 0.0
        else:
            loss_non_collision_opt_s1 += (body_sdf_batch >
                                          0).sum().float().item() / 10475.0
            loss_contact_opt_s1 += 1.0

        # stage 2: advanced optimization results
        body_verts_opt_prox_tensor = torch.from_numpy(
            body_verts_opt_prox_s2).float().unsqueeze(0).to(
                device)  # [1, 10475, 3]
        norm_verts_batch = (body_verts_opt_prox_tensor - s_grid_min_batch) / (
            s_grid_max_batch - s_grid_min_batch) * 2 - 1
        body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1),
                                       norm_verts_batch[:, :, [2, 1, 0]].view(
                                           -1, 10475, 1, 1, 3),
                                       padding_mode='border')
        if body_sdf_batch.lt(0).sum().item(
        ) < 1:  # if no interpenetration: negative sdf entries is less than one
            loss_non_collision_opt_s2 += 1.0
            loss_contact_opt_s2 += 0.0
        else:
            loss_non_collision_opt_s2 += (body_sdf_batch >
                                          0).sum().float().item() / 10475.0
            loss_contact_opt_s2 += 1.0

    print('scene:', args.scene_name)

    loss_non_collision_sample = loss_non_collision_sample / n_sample
    loss_contact_sample = loss_contact_sample / n_sample
    print('w/o optimization body: non_collision score:',
          loss_non_collision_sample)
    print('w/o optimization body: contact score:', loss_contact_sample)

    loss_non_collision_opt_s1 = loss_non_collision_opt_s1 / n_sample
    loss_contact_opt_s1 = loss_contact_opt_s1 / n_sample
    print('optimized body s1: non_collision score:', loss_non_collision_opt_s1)
    print('optimized body s1: contact score:', loss_contact_opt_s1)

    loss_non_collision_opt_s2 = loss_non_collision_opt_s2 / n_sample
    loss_contact_opt_s2 = loss_contact_opt_s2 / n_sample
    print('optimized body s2: non_collision score:', loss_non_collision_opt_s2)
    print('optimized body s2: contact score:', loss_contact_opt_s2)
示例#12
0
def fit_single_frame(img,
                     keypoints,
                     init_trans,
                     scan,
                     scene_name,
                     body_model,
                     camera,
                     joint_weights,
                     body_pose_prior,
                     jaw_prior,
                     left_hand_prior,
                     right_hand_prior,
                     shape_prior,
                     expr_prior,
                     angle_prior,
                     result_fn='out.pkl',
                     mesh_fn='out.obj',
                     body_scene_rendering_fn='body_scene.png',
                     out_img_fn='overlay.png',
                     loss_type='smplify',
                     use_cuda=True,
                     init_joints_idxs=(9, 12, 2, 5),
                     use_face=True,
                     use_hands=True,
                     data_weights=None,
                     body_pose_prior_weights=None,
                     hand_pose_prior_weights=None,
                     jaw_pose_prior_weights=None,
                     shape_weights=None,
                     expr_weights=None,
                     hand_joints_weights=None,
                     face_joints_weights=None,
                     depth_loss_weight=1e2,
                     interpenetration=True,
                     coll_loss_weights=None,
                     df_cone_height=0.5,
                     penalize_outside=True,
                     max_collisions=8,
                     point2plane=False,
                     part_segm_fn='',
                     focal_length_x=5000.,
                     focal_length_y=5000.,
                     side_view_thsh=25.,
                     rho=100,
                     vposer_latent_dim=32,
                     vposer_ckpt='',
                     use_joints_conf=False,
                     interactive=True,
                     visualize=False,
                     save_meshes=True,
                     degrees=None,
                     batch_size=1,
                     dtype=torch.float32,
                     ign_part_pairs=None,
                     left_shoulder_idx=2,
                     right_shoulder_idx=5,
                     ####################
                     ### PROX
                     render_results=True,
                     camera_mode='moving',
                     ## Depth
                     s2m=False,
                     s2m_weights=None,
                     m2s=False,
                     m2s_weights=None,
                     rho_s2m=1,
                     rho_m2s=1,
                     init_mode=None,
                     trans_opt_stages=None,
                     viz_mode='mv',
                     #penetration
                     sdf_penetration=False,
                     sdf_penetration_weights=0.0,
                     sdf_dir=None,
                     cam2world_dir=None,
                     #contact
                     contact=False,
                     rho_contact=1.0,
                     contact_loss_weights=None,
                     contact_angle=15,
                     contact_body_parts=None,
                     body_segments_dir=None,
                     load_scene=False,
                     scene_dir=None,
                     height=None,
                     weight=None,
                     gender='male',
                     weight_w=0,
                     height_w=0,
                     **kwargs):

    if kwargs['optim_type'] == 'lbfgsls':
        assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'

    batch_size = keypoints.shape[0]

    body_model.reset_params()
    body_model.transl.requires_grad = True

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    # if visualize:
    #     pil_img.fromarray((img * 255).astype(np.uint8)).show()

    if degrees is None:
        degrees = [0, 90, 180, 270]

    if data_weights is None:
        data_weights = [1, ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = (
        'Number of Body pose prior weights {}'.format(
            len(body_pose_prior_weights)) +
        ' does not match the number of data term weights {}'.format(
            len(data_weights)))
    assert (len(data_weights) ==
            len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (len(hand_pose_prior_weights) ==
                len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (len(hand_joints_weights) ==
                    len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) ==
            len(body_pose_prior_weights)), msg.format(
                len(shape_weights),
                len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (len(jaw_pose_prior_weights) ==
                len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) ==
                len(body_pose_prior_weights)), msg.format(
                    len(body_pose_prior_weights),
                    len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) ==
                len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) ==
            len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [None, ] * 2
    if use_vposer:
        # pose_embedding = torch.zeros([batch_size, 32],
        #                              dtype=dtype, device=device,
        #                              requires_grad=True)

        # Patrick: hack to set default body pose to something more sleep-y
        mean_body = np.array([[ 0.19463745,  1.6240447,   0.6890624,   0.19186097,  0.08003145, -0.04189298,
                       3.450903,   -0.29570094,  0.25072002, -1.1879578,   0.33350763,  0.23568614,
                       0.38122794, -2.1258948,   0.2910664,   2.2407222,  -0.5400814,  -0.95984083,
                      -1.2880017,   1.1122228,   0.7411389,  -0.2265636,  -4.8202057,  -1.950323,
                      -0.28771818, -1.9282387,   0.9928907,  -0.27183488, -0.55805033,  0.04047768,
                      -0.537362,    0.65770334]])

        pose_embedding = torch.tensor(mean_body, dtype=dtype, device=device, requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype)
    else:
        # body_mean_pose = body_pose_prior.get_mean().detach().cpu()
        # body_mean_pose = torch.zeros([batch_size, 69], dtype=dtype)

        # mean_body =  np.array([[-2.33263850e-01,  1.35460928e-01,  2.94471830e-01, -3.22930813e-01,
        #                         -4.73931670e-01, -2.67531037e-01,  7.12558180e-02,  7.89440796e-03,
        #                         8.67700949e-03,  1.05982251e-01,  2.79584467e-01, -7.04243258e-02,
        #                         3.61106455e-01, -5.87305248e-01,  1.10897996e-01, -1.68918714e-01,
        #                         -4.60174456e-02,  3.28684039e-02,  5.80525696e-01, -5.11317095e-03,
        #                         -1.57546505e-01,  5.85777402e-01, -8.94948393e-02,  2.24680841e-01,
        #                         1.55473784e-01,  5.38146123e-04,  4.30279821e-02, -4.68525589e-02,
        #                         7.75185153e-02,  7.82282930e-03,  6.74356073e-02,  4.09710407e-02,
        #                         -3.60425897e-02, -4.71813440e-01,  5.02379127e-02,  2.02309843e-02,
        #                         5.29680364e-02,  1.68510173e-02,  2.25090146e-01, -4.52307612e-02,
        #                         7.72185996e-02, -2.17333943e-01,  3.30020368e-01,  4.21866514e-02,
        #                         7.15153441e-02,  3.05950731e-01, -3.63454908e-01, -1.28235269e+00,
        #                         5.09610713e-01,  4.65482563e-01,  1.20263052e+00,  5.56594551e-01,
        #                         -2.24000740e+00,  3.83565158e-01,  5.31355202e-01,  2.21637583e+00,
        #                         -5.63146770e-01, -3.01193684e-01, -4.31942672e-01,  6.85038209e-01,
        #                         3.61178756e-01,  2.76136428e-01, -2.64388829e-01,  0.00000000e+00,
        #                         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        #                         0.00000000e+00]])
        mean_body = np.array(joint_limits.axang_limits_patrick / 180 * np.pi).mean(1)
        body_mean_pose = torch.tensor(mean_body, dtype=dtype).unsqueeze(0)


    betanet = None
    if height is not None:
        betanet = torch.load('models/betanet_old_pytorch.pt')
        betanet = betanet.to(device=device)
        betanet.eval()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2]

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    scan_tensor = None
    if scan is not None:
        scan_tensor = scan.to(device=device)

    # load pre-computed signed distance field
    sdf = None
    sdf_normals = None
    grid_min = None
    grid_max = None
    voxel_size = None
    # if sdf_penetration:
    #     with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f:
    #         sdf_data = json.load(f)
    #         grid_min = torch.tensor(np.array(sdf_data['min']), dtype=dtype, device=device)
    #         grid_max = torch.tensor(np.array(sdf_data['max']), dtype=dtype, device=device)
    #         grid_dim = sdf_data['dim']
    #     voxel_size = (grid_max - grid_min) / grid_dim
    #     sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape(grid_dim, grid_dim, grid_dim)
    #     sdf = torch.tensor(sdf, dtype=dtype, device=device)
    #     if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')):
    #         sdf_normals = np.load(osp.join(sdf_dir, scene_name + '_normals.npy')).reshape(grid_dim, grid_dim, grid_dim, 3)
    #         sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device)
    #     else:
    #         print("Normals not found...")


    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        cam2world = np.array(json.load(f))
        R = torch.tensor(cam2world[:3, :3].reshape(3, 3), dtype=dtype, device=device)
        t = torch.tensor(cam2world[:3, 3].reshape(1, 3), dtype=dtype, device=device)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm, faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # load vertix ids of contact parts
    contact_verts_ids  = ftov = None
    if contact:
        contact_verts_ids = []
        for part in contact_body_parts:
            with open(os.path.join(body_segments_dir, part + '.json'), 'r') as f:
                data = json.load(f)
                contact_verts_ids.append(list(set(data["verts_ind"])))
        contact_verts_ids = np.concatenate(contact_verts_ids)

        vertices = body_model(return_verts=True, body_pose= torch.zeros((batch_size, 63), dtype=dtype, device=device)).vertices
        vertices_np = vertices.detach().cpu().numpy().squeeze()
        body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape(-1, 3)
        m = Mesh(v=vertices_np, f=body_faces_np)
        ftov = m.faces_by_vertex(as_sparse_matrix=True)

        ftov = sparse.coo_matrix(ftov)
        indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device)
        values = torch.FloatTensor(ftov.data).to(device)
        shape = ftov.shape
        ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

    # Read the scene scan if any
    scene_v = scene_vn = scene_f = None
    if scene_name is not None:
        if load_scene:
            scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply'))

            scene.vn = scene.estimate_vertex_normals()

            scene_v = torch.tensor(scene.v[np.newaxis, :],
                                   dtype=dtype,
                                   device=device).contiguous()
            scene_vn = torch.tensor(scene.vn[np.newaxis, :],
                                    dtype=dtype,
                                    device=device)
            scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :],
                                   dtype=torch.long,
                                   device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {'data_weight': data_weights,
                        'body_pose_weight': body_pose_prior_weights,
                        'shape_weight': shape_weights}
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights
    if s2m:
        opt_weights_dict['s2m_weight'] = s2m_weights
    if m2s:
        opt_weights_dict['m2s_weight'] = m2s_weights
    if sdf_penetration:
        opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights
    if contact:
        opt_weights_dict['contact_loss_weight'] = contact_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [dict(zip(keys, vals)) for vals in
                   zip(*(opt_weights_dict[k] for k in keys
                         if opt_weights_dict[k] is not None))]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # load indices of the head of smpl-x model
    with open( osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp:
        head_indx = np.array(json.load(fp))
    N = body_model.get_num_verts()
    body_indx = np.setdiff1d(np.arange(N), head_indx)
    head_mask = np.in1d(np.arange(N), head_indx)
    body_mask = np.in1d(np.arange(N), body_indx)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    edge_indices = kwargs.get('body_tri_idxs')

    # which initialization mode to choose: similar traingles, mean of the scan or the average of both
    if init_mode == 'scan':
        init_t = init_trans
    elif init_mode == 'both':
        init_t = (init_trans.to(device) + fitting.guess_init(body_model, gt_joints, edge_indices,
                                    use_vposer=use_vposer, vposer=vposer,
                                    pose_embedding=pose_embedding,
                                    model_type=kwargs.get('model_type', 'smpl'),
                                    focal_length=focal_length_x, dtype=dtype) ) /2.0

    else:
        init_t = fitting.guess_init(body_model, gt_joints, edge_indices,
                                    use_vposer=use_vposer, vposer=vposer,
                                    pose_embedding=pose_embedding,
                                    model_type=kwargs.get('model_type', 'smpl'),
                                    focal_length=focal_length_x, dtype=dtype)

    camera_loss = fitting.create_loss('camera_init',
                                      trans_estimation=init_t,
                                      init_joints_idxs=init_joints_idxs,
                                      depth_loss_weight=depth_loss_weight,
                                      camera_mode=camera_mode,
                                      dtype=dtype).to(device=device)
    camera_loss.trans_estimation[:] = init_t

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face, use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               s2m=s2m,
                               m2s=m2s,
                               rho_s2m=rho_s2m,
                               rho_m2s=rho_m2s,
                               head_mask=head_mask,
                               body_mask=body_mask,
                               sdf_penetration=sdf_penetration,
                               voxel_size=voxel_size,
                               grid_min=grid_min,
                               grid_max=grid_max,
                               sdf=sdf,
                               sdf_normals=sdf_normals,
                               R=R,
                               t=t,
                               contact=contact,
                               contact_verts_ids=contact_verts_ids,
                               rho_contact=rho_contact,
                               contact_angle=contact_angle,
                               dtype=dtype,
                               betanet=betanet,
                               height=height,
                               weight=weight,
                               gender=gender,
                               weight_w=weight_w,
                               height_w=height_w,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, viz_mode=viz_mode, **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        _, H, W, _ = img.shape

        # Reset the parameters to estimate the initial translation of the
        # body model
        if camera_mode == 'moving':
            body_model.reset_params(body_pose=body_mean_pose)
            # Update the value of the translation of the camera as well as
            # the image center.
            with torch.no_grad():
                camera.translation[:] = init_t.view_as(camera.translation)
                camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5

            # Re-enable gradient calculation for the camera translation
            camera.translation.requires_grad = True

            camera_opt_params = [camera.translation, body_model.global_orient]

        elif camera_mode == 'fixed':
            # body_model.reset_params()
            # body_model.transl[:] = torch.tensor(init_t)
            # body_model.body_pose[:] = torch.tensor(body_mean_pose)
            body_model.reset_params(body_pose=body_mean_pose, transl=init_t)
            camera_opt_params = [body_model.transl, body_model.global_orient]

        # If the distance between the 2D shoulders is smaller than a
        # predefined threshold then try 2 fits, the initial one and a 180
        # degree rotation
        shoulder_dist = torch.norm(gt_joints[:, left_shoulder_idx, :] - gt_joints[:, right_shoulder_idx, :], dim=1)
        try_both_orient = shoulder_dist.min() < side_view_thsh

        kwargs['lr'] *= 10
        camera_optimizer, camera_create_graph = optim_factory.create_optimizer(camera_opt_params, **kwargs)
        kwargs['lr'] /= 10

        # The closure passed to the optimizer
        fit_camera = monitor.create_fitting_closure(
            camera_optimizer, body_model, camera, gt_joints,
            camera_loss, create_graph=camera_create_graph,
            use_vposer=use_vposer, vposer=vposer,
            pose_embedding=pose_embedding,
            scan_tensor=scan_tensor,
            return_full_pose=False, return_verts=False)

        # Step 1: Optimize over the torso joints the camera translation
        # Initialize the computational graph by feeding the initial translation
        # of the camera and the initial pose of the body model.
        camera_init_start = time.time()
        cam_init_loss_val = monitor.run_fitting(camera_optimizer,
                                                fit_camera,
                                                camera_opt_params, body_model,
                                                use_vposer=use_vposer,
                                                pose_embedding=pose_embedding,
                                                vposer=vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            tqdm.write('Camera initialization done after {:.4f}'.format(
                time.time() - camera_init_start))
            tqdm.write('Camera initialization final loss {:.4f}'.format(
                cam_init_loss_val))

        # If the 2D detections/positions of the shoulder joints are too
        # close the rotate the body by 180 degrees and also fit to that
        # orientation
        if try_both_orient:
            with torch.no_grad():
                flipped_orient = torch.zeros_like(body_model.global_orient)
                for i in range(batch_size):
                    body_orient = body_model.global_orient[i, :].detach().cpu().numpy()
                    local_flip = cv2.Rodrigues(body_orient)[0].dot(cv2.Rodrigues(np.array([0., np.pi, 0]))[0])
                    local_flip = cv2.Rodrigues(local_flip)[0].ravel()

                    flipped_orient[i, :] = torch.Tensor(local_flip).to(device)

            orientations = [body_model.global_orient, flipped_orient]
        else:
            orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []
        body_transl = body_model.transl.clone().detach()
        # Step 2: Optimize the full model
        final_loss_val = 0

        # for or_idx, orient in enumerate(orientations):
        or_idx = 0
        while or_idx < len(orientations):
            global_vars.cur_orientation = or_idx
            orient = orientations[or_idx]
            print('Trying orientation', or_idx, 'of', len(orientations))
            opt_start = time.time()
            or_idx += 1

            new_params = defaultdict(transl=body_transl,
                                     global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)
                    pose_embedding += torch.tensor(mean_body, dtype=dtype, device=device)

            for opt_idx, curr_weights in enumerate(opt_weights):
                global_vars.cur_opt_stage = opt_idx

                if opt_idx not in trans_opt_stages:
                    body_model.transl.requires_grad = False
                else:
                    body_model.transl.requires_grad = True
                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                body_optimizer, body_create_graph = optim_factory.create_optimizer(
                    final_params,
                    **kwargs)
                body_optimizer.zero_grad()

                curr_weights['bending_prior_weight'] = (
                    3.17 * curr_weights['body_pose_weight'])
                if use_hands:
                    joint_weights[:, 25:76] = curr_weights['hand_weight']
                if use_face:
                    joint_weights[:, 76:] = curr_weights['face_weight']
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer, body_model,
                    camera=camera, gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss, create_graph=body_create_graph,
                    use_vposer=use_vposer, vposer=vposer,
                    pose_embedding=pose_embedding,
                    scan_tensor=scan_tensor,
                    scene_v=scene_v, scene_vn=scene_vn, scene_f=scene_f,ftov=ftov,
                    return_verts=True, return_full_pose=True)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure, final_params,
                    body_model,
                    pose_embedding=pose_embedding, vposer=vposer,
                    use_vposer=use_vposer)

                # print('Final loss val', final_loss_val)
                # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']):
                #     break

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write('Stage {:03d} done after {:.4f} seconds'.format(
                            opt_idx, elapsed))

            # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']):
            #     print('Optimization FAILURE, retrying')
            #     orientations.append(orientations[or_idx-1] * 0.9)
            #     continue

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write('Body fitting Orientation {} done after {:.4f} seconds'.format(or_idx, elapsed))
                tqdm.write('Body final loss val = {:.5f}'.format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {'camera_' + str(key): val.detach().cpu().numpy()
                      for key, val in camera.named_parameters()}

            result['camera_focal_length_x'] = camera.focal_length_x.detach().cpu().numpy()
            result['camera_focal_length_y'] = camera.focal_length_y.detach().cpu().numpy()
            result['camera_center'] = camera.center.detach().cpu().numpy()

            result.update({key: val.detach().cpu().numpy()
                           for key, val in body_model.named_parameters()})
            if use_vposer:
                result['pose_embedding'] = pose_embedding.detach().cpu().numpy()
                body_pose = vposer.decode(pose_embedding, output_type='aa').view(1, -1) if use_vposer else None

                if "smplx.body_models.SMPL'" in str(type(body_model)):
                    wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device)
                    body_pose = torch.cat([body_pose, wrist_pose], dim=1)

                result['body_pose'] = body_pose.detach().cpu().numpy()
            result['final_loss_val'] = final_loss_val
            result['loss_dict'] = global_vars.cur_loss_dict
            result['betanet_weight'] = global_vars.cur_weight
            result['betanet_height'] = global_vars.cur_height
            result['gt_joints'] = gt_joints.detach().cpu().numpy()
            result['max_joint'] = global_vars.cur_max_joint

            results.append(result)

        for idx, res_folder in enumerate(result_fn):    # Iterate over batch
            pkl_data = {}
            min_loss = np.inf
            all_results = []
            for result in results:  # Iterate over orientations
                sel_res = misc_utils.get_data_from_batched_dict(result, idx, len(result_fn))
                all_results.append(sel_res)

                cost = sel_res['loss_dict']['total'] + sel_res['loss_dict']['pprior'] * 60
                if cost < min_loss:
                    min_loss = cost
                    pkl_data.update(sel_res)

            pkl_data['all_results'] = all_results

            with open(res_folder, 'wb') as result_file:
                pickle.dump(pkl_data, result_file, protocol=2)

            img_s = img[idx, :].detach().cpu().numpy()
            img_s = pil_img.fromarray((img_s * 255).astype(np.uint8))
            img_s.save(out_img_fn[idx])
示例#13
0
def fit_frames(img,
               keypoints,
               body_model,
               camera=None,
               joint_weights=None,
               body_pose_prior=None,
               jaw_prior=None,
               left_hand_prior=None,
               right_hand_prior=None,
               shape_prior=None,
               expr_prior=None,
               angle_prior=None,
               loss_type="smplify",
               use_cuda=True,
               init_joints_idxs=(9, 12, 2, 5),
               use_face=False,
               use_hands=True,
               data_weights=None,
               body_pose_prior_weights=None,
               hand_pose_prior_weights=None,
               jaw_pose_prior_weights=None,
               shape_weights=None,
               expr_weights=None,
               hand_joints_weights=None,
               face_joints_weights=None,
               depth_loss_weight=1e2,
               interpenetration=False,
               coll_loss_weights=None,
               df_cone_height=0.5,
               penalize_outside=True,
               max_collisions=8,
               point2plane=False,
               part_segm_fn="",
               focal_length=5000.0,
               side_view_thsh=25.0,
               rho=100,
               vposer_latent_dim=32,
               vposer_ckpt="",
               use_joints_conf=False,
               interactive=True,
               visualize=False,
               batch_size=1,
               dtype=torch.float32,
               ign_part_pairs=None,
               left_shoulder_idx=2,
               right_shoulder_idx=5,
               freeze_camera=True,
               **kwargs):
    # assert batch_size == 1, "PyTorch L-BFGS only supports batch_size == 1"

    device = torch.device("cuda") if use_cuda else torch.device("cpu")

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]
    if data_weights is None:
        data_weights = [1] * len(body_pose_prior_weights)

    msg = "Number of Body pose prior weights {}".format(
        len(body_pose_prior_weights)
    ) + " does not match the number of data term weights {}".format(
        len(data_weights))
    assert len(data_weights) == len(body_pose_prior_weights), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1]
        msg = ("Number of Body pose prior weights does not match the" +
               " number of hand pose prior weights")
        assert len(hand_pose_prior_weights) == len(
            body_pose_prior_weights), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ("Number of Body pose prior weights does not match the" +
                   " number of hand joint distance weights")
            assert len(hand_joints_weights) == len(
                body_pose_prior_weights), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1]
    msg = ("Number of Body pose prior weights = {} does not match the" +
           " number of Shape prior weights = {}")
    assert len(shape_weights) == len(body_pose_prior_weights), msg.format(
        len(shape_weights), len(body_pose_prior_weights))

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ("Number of Body pose prior weights does not match the" +
           " number of collision loss weights")
    assert len(coll_loss_weights) == len(body_pose_prior_weights), msg

    use_vposer = kwargs.get("use_vposer", True)
    vposer, pose_embedding = [None] * 2
    if use_vposer:
        pose_embedding = torch.zeros(
            [batch_size, vposer_latent_dim],
            dtype=dtype,
            device=device,
            requires_grad=True,
        )

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model="snapshot")
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = (
            vposeutils.get_vposer_mean_pose().detach().cpu().numpy())
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2].reshape(keypoint_data.shape[0],
                                                     -1)
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, "Interpenetration term can only be used with CUDA"
        assert torch.cuda.is_available(), (
            "No CUDA Device! Interpenetration term can only be used" +
            " with CUDA")

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
            sigma=df_cone_height,
            point2plane=point2plane,
            vectorized=True,
            penalize_outside=penalize_outside,
        )

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, "rb") as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding="latin1")
            faces_segm = face_segm_data["segm"]
            faces_parents = face_segm_data["parents"]
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs,
            ).to(device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        "data_weight": data_weights,
        "body_pose_weight": body_pose_prior_weights,
        "shape_weight": shape_weights,
    }
    if use_face:
        opt_weights_dict["face_weight"] = face_joints_weights
        opt_weights_dict["expr_prior_weight"] = expr_weights
        opt_weights_dict["jaw_prior_weight"] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict["hand_weight"] = hand_joints_weights
        opt_weights_dict["hand_prior_weight"] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict["coll_loss_weight"] = coll_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    # Hand joints start at 25 (before body)
    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face,
                               use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size,
                                visualize=visualize,
                                **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        H, W, _ = img.shape

        data_weight = 1000 / H
        orientations = [body_model.global_orient.detach().cpu().numpy()]

        # # Step 2: Optimize the full model
        final_loss_val = 0
        for or_idx, orient in enumerate(tqdm(orientations,
                                             desc="Orientation")):
            opt_start = time.time()

            new_params = defaultdict(global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)

            for opt_idx, curr_weights in enumerate(
                    tqdm(opt_weights, desc="Stage")):

                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                (
                    body_optimizer,
                    body_create_graph,
                ) = optim_factory.create_optimizer(final_params, **kwargs)
                body_optimizer.zero_grad()

                curr_weights["data_weight"] = data_weight
                curr_weights["bending_prior_weight"] = (
                    3.17 * curr_weights["body_pose_weight"])
                if use_hands:
                    # joint_weights[:, 25:67] = curr_weights['hand_weight']
                    pass
                if use_face:
                    joint_weights[:, 67:] = curr_weights["face_weight"]
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer,
                    body_model,
                    camera=camera,
                    gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss,
                    create_graph=body_create_graph,
                    use_vposer=use_vposer,
                    vposer=vposer,
                    pose_embedding=pose_embedding,
                    return_verts=True,
                    return_full_pose=True,
                )

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure,
                    final_params,
                    body_model,
                    pose_embedding=pose_embedding,
                    vposer=vposer,
                    use_vposer=use_vposer,
                )

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write(
                            "Stage {:03d} done after {:.4f} seconds".format(
                                opt_idx, elapsed))

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write(
                    "Body fitting Orientation {} done after {:.4f} seconds".
                    format(or_idx, elapsed))
                tqdm.write(
                    "Body final loss val = {:.5f}".format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {
                "camera_" + str(key): val.detach().cpu().numpy()
                for key, val in camera.named_parameters()
            }
            result.update({
                key: val.detach().cpu().numpy()
                for key, val in body_model.named_parameters()
            })
            if use_vposer:
                result["pose_embedding"] = (
                    pose_embedding.detach().cpu().numpy())
                body_pose = (vposer.decode(
                    pose_embedding, output_type="aa").reshape(
                        pose_embedding.shape[0], -1) if use_vposer else None)
                result["body_pose"] = body_pose.detach().cpu().numpy()

        model_output = body_model(return_verts=True, body_pose=body_pose)
    return model_output, result
示例#14
0
def main(args):

    scene_name = os.path.abspath(args.gen_folder).split("/")[-1]

    outimg_dir = args.outimg_dir
    if not os.path.exists(outimg_dir):
        os.makedirs(outimg_dir)

    ### setup visualization window
    vis = o3d.visualization.Visualizer()
    vis.create_window(width=960, height=540, visible=True)
    render_opt = vis.get_render_option().mesh_show_back_face = True

    ### put the scene into the environment
    scene = o3d.io.read_triangle_mesh(
        osp.join(args.prox_dir, scene_name + '.ply'))
    vis.add_geometry(scene)
    vis.update_geometry()

    # put the body into the environment
    vposer_ckpt = osp.join(args.model_folder, 'vposer_v1_0')
    vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')

    model = smplx.create(args.model_folder,
                         model_type='smplx',
                         gender=args.gender,
                         ext='npz',
                         num_pca_comps=args.num_pca_comps,
                         create_global_orient=True,
                         create_body_pose=True,
                         create_betas=True,
                         create_left_hand_pose=True,
                         create_right_hand_pose=True,
                         create_expression=True,
                         create_jaw_pose=True,
                         create_leye_pose=True,
                         create_reye_pose=True,
                         create_transl=True)

    ## create a corn at the camera location
    # mesh_corn = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
    # mesh_corn.transform(trans)
    # vis.add_geometry(mesh_corn)
    # vis.update_geometry()
    # print(trans)

    ## create a corn at the world origin
    # mesh_corn2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
    # vis.add_geometry(mesh_corn2)
    # vis.update_geometry()
    # print(trans)

    cv2.namedWindow("GUI")

    gen_file_list = glob.glob(os.path.join(args.gen_folder, '*'))

    body = o3d.geometry.TriangleMesh()
    vis.add_geometry(body)
    for idx, gen_file in enumerate(gen_file_list):

        with open(gen_file, 'rb') as f:
            param = pickle.load(f)

        cam_ext = param['cam_ext'][0]
        cam_int = param['cam_int'][0]

        body_pose = vposer.decode(torch.tensor(param['body_pose']),
                                  output_type='aa').view(1, -1)
        torch_param = {}

        for key in param.keys():
            if key in ['body_pose', 'camera_rotation', 'camera_translation']:
                continue
            else:
                torch_param[key] = torch.tensor(param[key])

        output = model(return_verts=True, body_pose=body_pose, **torch_param)
        vertices = output.vertices.detach().cpu().numpy().squeeze()

        body.vertices = o3d.utility.Vector3dVector(vertices)
        body.triangles = o3d.utility.Vector3iVector(model.faces)
        body.vertex_normals = o3d.utility.Vector3dVector([])
        body.triangle_normals = o3d.utility.Vector3dVector([])
        body.compute_vertex_normals()
        T_mat = np.eye(4)
        T_mat[1, :] = np.array([0, -1, 0, 0])
        T_mat[2, :] = np.array([0, 0, -1, 0])
        trans = np.dot(cam_ext, T_mat)
        body.transform(trans)
        vis.update_geometry()

        # while True:
        #     vis.poll_events()
        #     vis.update_renderer()
        #     cv2.imshow("GUI", np.random.random([10,10,3]))

        #     # ctr = vis.get_view_control()
        #     # cam_param = ctr.convert_to_pinhole_camera_parameters()
        #     # print(cam_param.extrinsic)

        #     key = cv2.waitKey(15)
        #     if key == 27:
        #         break

        ctr = vis.get_view_control()
        cam_param = ctr.convert_to_pinhole_camera_parameters()
        cam_param = update_cam(cam_param, trans)
        ctr.convert_from_pinhole_camera_parameters(cam_param)
        vis.poll_events()
        vis.update_renderer()
        capture_image(vis,
                      outfilename=os.path.join(
                          outimg_dir, 'img_{:06d}_cam1.png'.format(idx)))

        # vis.run()
        # capture_image(vis, outfilename=os.path.join(outimg_dir, 'img_{:06d}_cam1.png'.format(idx)))

        ### setup rendering cam, depth capture, segmentation capture
        ctr = vis.get_view_control()
        cam_param = ctr.convert_to_pinhole_camera_parameters()
        cam_param.extrinsic = trans2_dict[scene_name]
        ctr.convert_from_pinhole_camera_parameters(cam_param)
        vis.poll_events()
        vis.update_renderer()
        capture_image(vis,
                      outfilename=os.path.join(
                          outimg_dir, 'img_{:06d}_cam2.png'.format(idx)))
def main(args):
    fitting_dir = args.fitting_dir
    recording_name = os.path.abspath(fitting_dir).split("/")[-1]
    fitting_dir = osp.join(fitting_dir, 'results')
    data_dir = args.data_dir
    cam2world_dir = osp.join(data_dir, 'cam2world')
    scene_dir = osp.join(data_dir, 'scenes_semantics')
    recording_dir = osp.join(data_dir, 'recordings', recording_name)
    scene_name = os.path.abspath(recording_dir).split("/")[-1].split("_")[0]

    ## setup the output folder
    output_folder = os.path.join('/mnt/hdd/PROX',
                                 'snapshot_virtualcam_TNoise0.5')
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    ### setup visualization window
    vis = o3d.visualization.Visualizer()
    vis.create_window(width=480, height=270, visible=True)
    render_opt = vis.get_render_option().mesh_show_back_face = True

    ### put the scene into the visualizer
    scene = o3d.io.read_triangle_mesh(
        osp.join(scene_dir, scene_name + '_withlabels.ply'))
    vis.add_geometry(scene)

    ## get scene 3D scene bounding box
    scene_o = o3d.io.read_triangle_mesh(
        osp.join(scene_dir, scene_name + '.ply'))
    scene_min = scene_o.get_min_bound()  #[x_min, y_min, z_min]
    scene_max = scene_o.get_max_bound()  #[x_max, y_max, z_max]
    # reduce the scene region furthermore, to avoid cams behind the window
    shift = 0.7
    scene_min = scene_min + shift
    scene_max = scene_max - shift

    ### get the real camera config
    trans_calib = np.eye(4)
    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        trans_calib = np.array(json.load(f))

    ## put the body into the environment
    vposer_ckpt = osp.join(args.model_folder, 'vposer_v1_0')
    vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')

    model = smplx.create(args.model_folder,
                         model_type='smplx',
                         gender='neutral',
                         ext='npz',
                         num_pca_comps=12,
                         create_global_orient=True,
                         create_body_pose=True,
                         create_betas=True,
                         create_left_hand_pose=True,
                         create_right_hand_pose=True,
                         create_expression=False,
                         create_jaw_pose=False,
                         create_leye_pose=False,
                         create_reye_pose=False,
                         create_transl=True)

    rec_count = -1
    sample_rate = 15  # 0.5second
    for img_name in sorted(os.listdir(fitting_dir))[::sample_rate]:

        ## get humam body params
        filename = osp.join(fitting_dir, img_name, '000.pkl')
        print('frame: ' + filename)

        if not os.path.exists(filename):
            print('file does not exist. Continue')
            continue

        with open(osp.join(fitting_dir, img_name, '000.pkl'), 'rb') as f:
            body_dict = pickle.load(f)

        if np.sum(np.isnan(body_dict['body_pose'])) > 0:
            continue

        rec_count = rec_count + 1
        ## save depth, semantics and render cam
        outname1 = os.path.join(output_folder, recording_name)
        if not os.path.exists(outname1):
            os.mkdir(outname1)

        ######################### then we obtain the virutal cam ################################

        ## find world coordinate of the human body in the current frame
        body_params_W_list, dT = update_globalRT_for_smplx(
            body_dict, model, vposer, [trans_calib])
        body_T_world = body_params_W_list[0]['transl'][0] + dT

        ## get virtual cams, and transform global_R and global_T to virtual cams
        new_cammat_ext_list0 = []
        new_cammat_ext_list0 = get_new_cams(scene_name,
                                            s_min=scene_min,
                                            s_max=scene_max,
                                            body_T=body_T_world)
        random.shuffle(new_cammat_ext_list0)
        new_cammat_ext_list = new_cammat_ext_list0[:30]

        print('--obtain {:d} cams'.format(len(new_cammat_ext_list)))

        new_cammat_list = [invert_transform(x) for x in new_cammat_ext_list]
        body_params_new_list, _ = update_globalRT_for_smplx(
            body_params_W_list[0], model, vposer, new_cammat_list, delta_T=dT)

        #### capture depth and seg in new cams
        for idx_cam, cam_ext in enumerate(new_cammat_ext_list):

            ### save filename
            outname = os.path.join(
                outname1,
                'rec_frame{:06d}_cam{:06d}.mat'.format(rec_count, idx_cam))

            ## put the render cam to the real cam
            ctr = vis.get_view_control()
            cam_param = ctr.convert_to_pinhole_camera_parameters()
            cam_param = update_render_cam(cam_param, cam_ext)
            ctr.convert_from_pinhole_camera_parameters(cam_param)
            vis.poll_events()
            vis.update_renderer()

            ## get render cam parameters
            cam_dict = {}
            cam_dict['extrinsic'] = cam_param.extrinsic
            cam_dict['intrinsic'] = cam_param.intrinsic.intrinsic_matrix

            ## capture depth image
            depth = np.asarray(vis.capture_depth_float_buffer(do_render=True))
            _h = depth.shape[0]
            _w = depth.shape[1]
            depth0 = depth
            depth_canvas, scaling_factor = data_preprocessing(depth, 'depth')

            ### skip settings when the human body is severely occluded.
            body_is_occluded = is_body_occluded(body_params_new_list[idx_cam],
                                                cam_dict, depth)

            if body_is_occluded:
                print(
                    '-- body is occluded or not in the scene at current view.')
                continue

            ## capture semantics
            seg = np.asarray(vis.capture_screen_float_buffer(do_render=True))
            verid = np.mean(seg * 255 / 5.0, axis=-1)  #.astype(int)
            seg0 = verid
            # verid = cv2.resize(verid, (_w//factor, _h//factor))
            seg_canvas, _ = data_preprocessing(verid, 'seg')

            # pdb.set_trace()

            ## save file to disk
            ot_dict = {}
            ot_dict['scaling_factor'] = scaling_factor
            ot_dict['depth'] = depth_canvas
            ot_dict['depth0'] = depth0
            ot_dict['seg0'] = seg0
            ot_dict['seg'] = seg_canvas
            ot_dict['cam'] = cam_dict
            ot_dict['body'] = body_params_new_list[idx_cam]
            sio.savemat(outname, ot_dict)

    vis.destroy_window()
示例#16
0
if mode == "smplx":
    model_params = dict(
        model_path="assets/models",
        model_type="smplx",
        gender="female",
        # create_body_pose=True,
        dtype=torch.float32,
        use_face=False,
    )
    model = smplx.create(**model_params)
    model.cuda()
else:
    bm_path = "assets/models/smplx/SMPLX_FEMALE.npz"
    bm = BodyModel(bm_path=bm_path, batch_size=1).to("cuda")
    bm.cuda()
vp, ps = load_vposer("assets/vposer")
vp = vp.to("cuda")
vp.eval()

# Sample a 32 dimentional vector from a Normal distribution
poZ_body_sample = torch.zeros(1, 32).cuda()
pose_body = vp.decode(poZ_body_sample, output_type="aa").view(-1, 63)


pose_embedding = torch.zeros(
    [1, 32], requires_grad=True, device=poZ_body_sample.device
)
pose_embedding = torch.rand(
    [1, 32], requires_grad=True, device=poZ_body_sample.device
)
示例#17
0
def fit_single_frame(
        img,
        keypoints,
        init_trans,
        scan,
        scene_name,
        body_model,
        camera,
        joint_weights,
        body_pose_prior,
        jaw_prior,
        left_hand_prior,
        right_hand_prior,
        shape_prior,
        expr_prior,
        angle_prior,
        result_fn='out.pkl',
        mesh_fn='out.obj',
        body_scene_rendering_fn='body_scene.png',
        out_img_fn='overlay.png',
        loss_type='smplify',
        use_cuda=True,
        init_joints_idxs=(9, 12, 2, 5),
        use_face=True,
        use_hands=True,
        data_weights=None,
        body_pose_prior_weights=None,
        hand_pose_prior_weights=None,
        jaw_pose_prior_weights=None,
        shape_weights=None,
        expr_weights=None,
        hand_joints_weights=None,
        face_joints_weights=None,
        depth_loss_weight=1e2,
        interpenetration=True,
        coll_loss_weights=None,
        df_cone_height=0.5,
        penalize_outside=True,
        max_collisions=8,
        point2plane=False,
        part_segm_fn='',
        focal_length_x=5000.,
        focal_length_y=5000.,
        side_view_thsh=25.,
        rho=100,
        vposer_latent_dim=32,
        vposer_ckpt='',
        use_joints_conf=False,
        interactive=True,
        visualize=False,
        save_meshes=True,
        degrees=None,
        batch_size=1,
        dtype=torch.float32,
        ign_part_pairs=None,
        left_shoulder_idx=2,
        right_shoulder_idx=5,
        ####################
        ### PROX
        render_results=True,
        camera_mode='moving',
        ## Depth
        s2m=False,
        s2m_weights=None,
        m2s=False,
        m2s_weights=None,
        rho_s2m=1,
        rho_m2s=1,
        init_mode=None,
        trans_opt_stages=None,
        viz_mode='mv',
        #penetration
        sdf_penetration=False,
        sdf_penetration_weights=0.0,
        sdf_dir=None,
        cam2world_dir=None,
        #contact
        contact=False,
        rho_contact=1.0,
        contact_loss_weights=None,
        contact_angle=15,
        contact_body_parts=None,
        body_segments_dir=None,
        load_scene=False,
        scene_dir=None,
        **kwargs):
    assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'
    body_model.reset_params()
    body_model.transl.requires_grad = True

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    if visualize:
        pil_img.fromarray((img * 255).astype(np.uint8)).show()

    if degrees is None:
        degrees = [0, 90, 180, 270]

    if data_weights is None:
        data_weights = [
            1,
        ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = ('Number of Body pose prior weights {}'.format(
        len(body_pose_prior_weights)) +
           ' does not match the number of data term weights {}'.format(
               len(data_weights)))
    assert (len(data_weights) == len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (
            len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (
                len(hand_joints_weights) == len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format(
        len(shape_weights), len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (
            len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format(
            len(body_pose_prior_weights), len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [
        None,
    ] * 2
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype,
                                     device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim],
                                     dtype=dtype)
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2].reshape(1, -1)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    scan_tensor = None
    if scan is not None:
        scan_tensor = torch.tensor(scan.get('points'),
                                   device=device,
                                   dtype=dtype).unsqueeze(0)

    # load pre-computed signed distance field
    sdf = None
    sdf_normals = None
    grid_min = None
    grid_max = None
    voxel_size = None
    if sdf_penetration:
        with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f:
            sdf_data = json.load(f)
            grid_min = torch.tensor(np.array(sdf_data['min']),
                                    dtype=dtype,
                                    device=device)
            grid_max = torch.tensor(np.array(sdf_data['max']),
                                    dtype=dtype,
                                    device=device)
            grid_dim = sdf_data['dim']
        voxel_size = (grid_max - grid_min) / grid_dim
        sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape(
            grid_dim, grid_dim, grid_dim)
        sdf = torch.tensor(sdf, dtype=dtype, device=device)
        if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')):
            sdf_normals = np.load(
                osp.join(sdf_dir, scene_name + '_normals.npy')).reshape(
                    grid_dim, grid_dim, grid_dim, 3)
            sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device)
        else:
            print("Normals not found...")

    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        cam2world = np.array(json.load(f))
        R = torch.tensor(cam2world[:3, :3].reshape(3, 3),
                         dtype=dtype,
                         device=device)
        t = torch.tensor(cam2world[:3, 3].reshape(1, 3),
                         dtype=dtype,
                         device=device)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # load vertix ids of contact parts
    contact_verts_ids = ftov = None
    if contact:
        contact_verts_ids = []
        for part in contact_body_parts:
            with open(os.path.join(body_segments_dir, part + '.json'),
                      'r') as f:
                data = json.load(f)
                contact_verts_ids.append(list(set(data["verts_ind"])))
        contact_verts_ids = np.concatenate(contact_verts_ids)

        vertices = body_model(return_verts=True,
                              body_pose=torch.zeros((batch_size, 63),
                                                    dtype=dtype,
                                                    device=device)).vertices
        vertices_np = vertices.detach().cpu().numpy().squeeze()
        body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape(
            -1, 3)
        m = Mesh(v=vertices_np, f=body_faces_np)
        ftov = m.faces_by_vertex(as_sparse_matrix=True)

        ftov = sparse.coo_matrix(ftov)
        indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device)
        values = torch.FloatTensor(ftov.data).to(device)
        shape = ftov.shape
        ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

    # Read the scene scan if any
    scene_v = scene_vn = scene_f = None
    if scene_name is not None:
        if load_scene:
            scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply'))

            scene.vn = scene.estimate_vertex_normals()

            scene_v = torch.tensor(scene.v[np.newaxis, :],
                                   dtype=dtype,
                                   device=device).contiguous()
            scene_vn = torch.tensor(scene.vn[np.newaxis, :],
                                    dtype=dtype,
                                    device=device)
            scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :],
                                   dtype=torch.long,
                                   device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        'data_weight': data_weights,
        'body_pose_weight': body_pose_prior_weights,
        'shape_weight': shape_weights
    }
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights
    if s2m:
        opt_weights_dict['s2m_weight'] = s2m_weights
    if m2s:
        opt_weights_dict['m2s_weight'] = m2s_weights
    if sdf_penetration:
        opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights
    if contact:
        opt_weights_dict['contact_loss_weight'] = contact_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # load indices of the head of smpl-x model
    with open(osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp:
        head_indx = np.array(json.load(fp))
    N = body_model.get_num_verts()
    body_indx = np.setdiff1d(np.arange(N), head_indx)
    head_mask = np.in1d(np.arange(N), head_indx)
    body_mask = np.in1d(np.arange(N), body_indx)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    edge_indices = kwargs.get('body_tri_idxs')

    # which initialization mode to choose: similar traingles, mean of the scan or the average of both
    if init_mode == 'scan':
        init_t = init_trans
    elif init_mode == 'both':
        init_t = (init_trans.to(device) + fitting.guess_init(
            body_model,
            gt_joints,
            edge_indices,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            model_type=kwargs.get('model_type', 'smpl'),
            focal_length=focal_length_x,
            dtype=dtype)) / 2.0

    else:
        init_t = fitting.guess_init(body_model,
                                    gt_joints,
                                    edge_indices,
                                    use_vposer=use_vposer,
                                    vposer=vposer,
                                    pose_embedding=pose_embedding,
                                    model_type=kwargs.get(
                                        'model_type', 'smpl'),
                                    focal_length=focal_length_x,
                                    dtype=dtype)

    camera_loss = fitting.create_loss('camera_init',
                                      trans_estimation=init_t,
                                      init_joints_idxs=init_joints_idxs,
                                      depth_loss_weight=depth_loss_weight,
                                      camera_mode=camera_mode,
                                      dtype=dtype).to(device=device)
    camera_loss.trans_estimation[:] = init_t

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face,
                               use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               s2m=s2m,
                               m2s=m2s,
                               rho_s2m=rho_s2m,
                               rho_m2s=rho_m2s,
                               head_mask=head_mask,
                               body_mask=body_mask,
                               sdf_penetration=sdf_penetration,
                               voxel_size=voxel_size,
                               grid_min=grid_min,
                               grid_max=grid_max,
                               sdf=sdf,
                               sdf_normals=sdf_normals,
                               R=R,
                               t=t,
                               contact=contact,
                               contact_verts_ids=contact_verts_ids,
                               rho_contact=rho_contact,
                               contact_angle=contact_angle,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size,
                                visualize=visualize,
                                viz_mode=viz_mode,
                                **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        H, W, _ = img.shape

        # Reset the parameters to estimate the initial translation of the
        # body model
        if camera_mode == 'moving':
            body_model.reset_params(body_pose=body_mean_pose)
            # Update the value of the translation of the camera as well as
            # the image center.
            with torch.no_grad():
                camera.translation[:] = init_t.view_as(camera.translation)
                camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5

            # Re-enable gradient calculation for the camera translation
            camera.translation.requires_grad = True

            camera_opt_params = [camera.translation, body_model.global_orient]

        elif camera_mode == 'fixed':
            body_model.reset_params(body_pose=body_mean_pose, transl=init_t)
            camera_opt_params = [body_model.transl, body_model.global_orient]

        # If the distance between the 2D shoulders is smaller than a
        # predefined threshold then try 2 fits, the initial one and a 180
        # degree rotation
        shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx],
                                   gt_joints[:, right_shoulder_idx])
        try_both_orient = shoulder_dist.item() < side_view_thsh

        camera_optimizer, camera_create_graph = optim_factory.create_optimizer(
            camera_opt_params, **kwargs)

        # The closure passed to the optimizer
        fit_camera = monitor.create_fitting_closure(
            camera_optimizer,
            body_model,
            camera,
            gt_joints,
            camera_loss,
            create_graph=camera_create_graph,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            scan_tensor=scan_tensor,
            return_full_pose=False,
            return_verts=False)

        # Step 1: Optimize over the torso joints the camera translation
        # Initialize the computational graph by feeding the initial translation
        # of the camera and the initial pose of the body model.
        camera_init_start = time.time()
        cam_init_loss_val = monitor.run_fitting(camera_optimizer,
                                                fit_camera,
                                                camera_opt_params,
                                                body_model,
                                                use_vposer=use_vposer,
                                                pose_embedding=pose_embedding,
                                                vposer=vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            tqdm.write('Camera initialization done after {:.4f}'.format(
                time.time() - camera_init_start))
            tqdm.write('Camera initialization final loss {:.4f}'.format(
                cam_init_loss_val))

        # If the 2D detections/positions of the shoulder joints are too
        # close the rotate the body by 180 degrees and also fit to that
        # orientation
        if try_both_orient:
            body_orient = body_model.global_orient.detach().cpu().numpy()
            flipped_orient = cv2.Rodrigues(body_orient)[0].dot(
                cv2.Rodrigues(np.array([0., np.pi, 0]))[0])
            flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel()

            flipped_orient = torch.tensor(flipped_orient,
                                          dtype=dtype,
                                          device=device).unsqueeze(dim=0)
            orientations = [body_orient, flipped_orient]
        else:
            orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []
        body_transl = body_model.transl.clone().detach()
        # Step 2: Optimize the full model
        final_loss_val = 0
        for or_idx, orient in enumerate(tqdm(orientations,
                                             desc='Orientation')):
            opt_start = time.time()

            new_params = defaultdict(transl=body_transl,
                                     global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)

            for opt_idx, curr_weights in enumerate(
                    tqdm(opt_weights, desc='Stage')):
                if opt_idx not in trans_opt_stages:
                    body_model.transl.requires_grad = False
                else:
                    body_model.transl.requires_grad = True
                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                body_optimizer, body_create_graph = optim_factory.create_optimizer(
                    final_params, **kwargs)
                body_optimizer.zero_grad()

                curr_weights['bending_prior_weight'] = (
                    3.17 * curr_weights['body_pose_weight'])
                if use_hands:
                    joint_weights[:, 25:76] = curr_weights['hand_weight']
                if use_face:
                    joint_weights[:, 76:] = curr_weights['face_weight']
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer,
                    body_model,
                    camera=camera,
                    gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss,
                    create_graph=body_create_graph,
                    use_vposer=use_vposer,
                    vposer=vposer,
                    pose_embedding=pose_embedding,
                    scan_tensor=scan_tensor,
                    scene_v=scene_v,
                    scene_vn=scene_vn,
                    scene_f=scene_f,
                    ftov=ftov,
                    return_verts=True,
                    return_full_pose=True)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure,
                    final_params,
                    body_model,
                    pose_embedding=pose_embedding,
                    vposer=vposer,
                    use_vposer=use_vposer)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write(
                            'Stage {:03d} done after {:.4f} seconds'.format(
                                opt_idx, elapsed))

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write(
                    'Body fitting Orientation {} done after {:.4f} seconds'.
                    format(or_idx, elapsed))
                tqdm.write(
                    'Body final loss val = {:.5f}'.format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {
                'camera_' + str(key): val.detach().cpu().numpy()
                for key, val in camera.named_parameters()
            }
            result.update({
                key: val.detach().cpu().numpy()
                for key, val in body_model.named_parameters()
            })
            if use_vposer:
                result['pose_embedding'] = pose_embedding.detach().cpu().numpy(
                )
                body_pose = vposer.decode(pose_embedding,
                                          output_type='aa').view(
                                              1, -1) if use_vposer else None
                result['body_pose'] = body_pose.detach().cpu().numpy()

            results.append({'loss': final_loss_val, 'result': result})

        with open(result_fn, 'wb') as result_file:
            if len(results) > 1:
                min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1)
            else:
                min_idx = 0
            pickle.dump(results[min_idx]['result'], result_file, protocol=2)

    if save_meshes or visualize:
        body_pose = vposer.decode(pose_embedding, output_type='aa').view(
            1, -1) if use_vposer else None

        model_type = kwargs.get('model_type', 'smpl')
        append_wrists = model_type == 'smpl' and use_vposer
        if append_wrists:
            wrist_pose = torch.zeros([body_pose.shape[0], 6],
                                     dtype=body_pose.dtype,
                                     device=body_pose.device)
            body_pose = torch.cat([body_pose, wrist_pose], dim=1)

        model_output = body_model(return_verts=True, body_pose=body_pose)
        vertices = model_output.vertices.detach().cpu().numpy().squeeze()

        import trimesh

        out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False)
        out_mesh.export(mesh_fn)

    if render_results:
        import pyrender

        # common
        H, W = 1080, 1920
        camera_center = np.array([951.30, 536.77])
        camera_pose = np.eye(4)
        camera_pose = np.array([1.0, -1.0, -1.0, 1.0]).reshape(-1,
                                                               1) * camera_pose
        camera = pyrender.camera.IntrinsicsCamera(fx=1060.53,
                                                  fy=1060.38,
                                                  cx=camera_center[0],
                                                  cy=camera_center[1])
        light = pyrender.DirectionalLight(color=np.ones(3), intensity=2.0)

        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.0,
            alphaMode='OPAQUE',
            baseColorFactor=(1.0, 1.0, 0.9, 1.0))
        body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)

        ## rendering body
        img = img.detach().cpu().numpy()
        H, W, _ = img.shape

        scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
                               ambient_light=(0.3, 0.3, 0.3))
        scene.add(camera, pose=camera_pose)
        scene.add(light, pose=camera_pose)
        # for node in light_nodes:
        #     scene.add_node(node)

        scene.add(body_mesh, 'mesh')

        r = pyrender.OffscreenRenderer(viewport_width=W,
                                       viewport_height=H,
                                       point_size=1.0)
        color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
        color = color.astype(np.float32) / 255.0

        valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
        input_img = img
        output_img = (color[:, :, :-1] * valid_mask +
                      (1 - valid_mask) * input_img)

        img = pil_img.fromarray((output_img * 255).astype(np.uint8))
        img.save(out_img_fn)

        ##redering body+scene
        body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)
        static_scene = trimesh.load(osp.join(scene_dir, scene_name + '.ply'))
        trans = np.linalg.inv(cam2world)
        static_scene.apply_transform(trans)

        static_scene_mesh = pyrender.Mesh.from_trimesh(static_scene)

        scene = pyrender.Scene()
        scene.add(camera, pose=camera_pose)
        scene.add(light, pose=camera_pose)

        scene.add(static_scene_mesh, 'mesh')
        scene.add(body_mesh, 'mesh')

        r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H)
        color, _ = r.render(scene)
        color = color.astype(np.float32) / 255.0
        img = pil_img.fromarray((color * 255).astype(np.uint8))
        img.save(body_scene_rendering_fn)
示例#18
0
    def __init__(
        self,
        debug=True,
        camintr=None,
        hand_pca_nb=6,
        head_center_idx=8949,
        opt_weights={
            # "hand_prior_weight":[1e2, 5 * 1e1, 1e1, 0.5 * 1e1],
            "hand_prior_weight": [0, 0, 0, 0],
            "hand_weight": [0.0, 0.0, 0.0, 1.0],
            # "body_pose_weight": [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78],
            "body_pose_weight": [0, 0, 0, 0],
            # "shape_weight": [1e2, 5 * 1e1, 1e1, 0.5 * 1e1]},
            "shape_weight": [0, 0, 0, 0],
        },
        mano_corresp_path="assets/models/MANO_SMPLX_vertex_ids.pkl",
        smpl_root="assets/models",
        vposer_dir="assets/vposer",
        vposer_dim=32,
        data_weight=1000 / 256,
        parts_path="assets/models/smplx/smplx_parts_segm.pkl",
    ):
        self.debug = debug
        self.hand_pca_nb = hand_pca_nb
        self.head_center_idx = head_center_idx
        self.smplx_vertex_nb = 10475
        self.vposer_dim = vposer_dim

        # Optim weights
        self.opt_weights = opt_weights
        self.data_weight = data_weight

        # Initialize SMPL-X model
        self.smpl_model = vposeutils.get_smplx(model_root=smpl_root,
                                               vposer_dir=vposer_dir)
        self.smpl_f = self.smpl_model.faces
        with open(mano_corresp_path, "rb") as p_f:
            self.mano_corresp = pickle.load(p_f)
        # Get vposer
        self.vposer = load_vposer(vposer_dir, vp_model="snapshot")[0]
        self.vposer.eval()
        self.tareader = TarReader()

        # Translate human so that head is at camera level
        self.set_head2cam_trans()

        self.armfaces = smplvis.filter_parts(self.smpl_f, parts_path)
        self.joint_weights = initialize.get_joint_weights()
        fx, fy = camintr[0, 0], camintr[1, 1]
        center = camintr[:2, 2]
        rot = torch.eye(3)
        rot[0, 0] = -1
        rot[1, 1] = -1
        self.camera = camera.create_camera(
            focal_length_x=fx,
            focal_length_y=fy,
            center=torch.Tensor(center).unsqueeze(0),
            rotation=rot.unsqueeze(0),
        )

        # Initialize priors
        self.body_pose_prior = prior.create_prior(prior_type="l2")
        self.hand_pca_nb = hand_pca_nb
        self.left_hand_prior = prior.create_prior(
            prior_type="l2",
            use_left_hand=True,
            num_gaussians=hand_pca_nb,
        )
        self.right_hand_prior = prior.create_prior(
            prior_type="l2",
            use_right_hand=True,
            num_gaussians=hand_pca_nb,
        )
        self.angle_prior = prior.create_prior(prior_type="angle")
        self.shape_prior = prior.create_prior(prior_type="l2")
示例#19
0
    body_rotmat = R.from_rotvec(
        body_rot_angle).as_dcm()  # to a [b, 3,3] rotation mat
    body_transf = np.tile(np.eye(4), (batch_size, 1, 1))
    body_transf[:, :-1, :-1] = body_rotmat
    body_transf[:, :-1, -1] = body_transl
    body_transf_w = np.matmul(trans_to_target_origin, body_transf)
    body_params_dict['global_orient'] = R.from_dcm(
        body_transf_w[:, :-1, :-1]).as_rotvec()
    body_params_dict['transl'] = body_transf_w[:, :-1, -1] - delta_T

    return body_param_dict


## figure out body model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vposer_model, _ = load_vposer('/home/yzhang/body_models/VPoser/vposer_v1_0/',
                              vp_model='snapshot')
vposer_model = vposer_model.to(device)

smplx_model = smplx.create('/home/yzhang/body_models/VPoser/',
                           model_type='smplx',
                           gender='neutral',
                           ext='npz',
                           num_pca_comps=12,
                           create_global_orient=True,
                           create_body_pose=True,
                           create_betas=True,
                           create_left_hand_pose=True,
                           create_right_hand_pose=True,
                           create_expression=True,
                           create_jaw_pose=True,
                           create_leye_pose=True,
示例#20
0
def fit_single_frame(img,
                     keypoints,
                     body_model,
                     camera,
                     joint_weights,
                     body_pose_prior,
                     jaw_prior,
                     left_hand_prior,
                     right_hand_prior,
                     shape_prior,
                     expr_prior,
                     angle_prior,
                     result_fn='out.pkl',
                     mesh_fn='out.obj',
                     out_img_fn='overlay.png',
                     loss_type='smplify',
                     use_cuda=True,
                     init_joints_idxs=(9, 12, 2, 5),
                     use_face=True,
                     use_hands=True,
                     data_weights=None,
                     body_pose_prior_weights=None,
                     hand_pose_prior_weights=None,
                     jaw_pose_prior_weights=None,
                     shape_weights=None,
                     expr_weights=None,
                     hand_joints_weights=None,
                     face_joints_weights=None,
                     depth_loss_weight=1e2,
                     interpenetration=True,
                     coll_loss_weights=None,
                     df_cone_height=0.5,
                     penalize_outside=True,
                     max_collisions=8,
                     point2plane=False,
                     part_segm_fn='',
                     focal_length=5000.,
                     side_view_thsh=25.,
                     rho=100,
                     vposer_latent_dim=32,
                     vposer_ckpt='',
                     use_joints_conf=False,
                     interactive=True,
                     visualize=False,
                     save_meshes=True,
                     degrees=None,
                     batch_size=1,
                     dtype=torch.float32,
                     ign_part_pairs=None,
                     left_shoulder_idx=2,
                     right_shoulder_idx=5,
                     **kwargs):
    assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    if degrees is None:
        degrees = [0, 90, 180, 270]

    if data_weights is None:
        data_weights = [
            1,
        ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = ('Number of Body pose prior weights {}'.format(
        len(body_pose_prior_weights)) +
           ' does not match the number of data term weights {}'.format(
               len(data_weights)))
    assert (len(data_weights) == len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (
            len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (
                len(hand_joints_weights) == len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format(
        len(shape_weights), len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (
            len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format(
            len(body_pose_prior_weights), len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [
        None,
    ] * 2
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype,
                                     device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim],
                                     dtype=dtype)
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2].reshape(1, -1)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        'data_weight': data_weights,
        'body_pose_weight': body_pose_prior_weights,
        'shape_weight': shape_weights
    }
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    edge_indices = kwargs.get('body_tri_idxs')
    init_t = fitting.guess_init(body_model,
                                gt_joints,
                                edge_indices,
                                use_vposer=use_vposer,
                                vposer=vposer,
                                pose_embedding=pose_embedding,
                                model_type=kwargs.get('model_type', 'smpl'),
                                focal_length=focal_length,
                                dtype=dtype)

    camera_loss = fitting.create_loss('camera_init',
                                      trans_estimation=init_t,
                                      init_joints_idxs=init_joints_idxs,
                                      depth_loss_weight=depth_loss_weight,
                                      dtype=dtype).to(device=device)
    camera_loss.trans_estimation[:] = init_t

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face,
                               use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size,
                                visualize=visualize,
                                **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        H, W, _ = img.shape

        data_weight = 1000 / H
        # The closure passed to the optimizer
        camera_loss.reset_loss_weights({'data_weight': data_weight})

        # Reset the parameters to estimate the initial translation of the
        # body model
        body_model.reset_params(body_pose=body_mean_pose)

        # If the distance between the 2D shoulders is smaller than a
        # predefined threshold then try 2 fits, the initial one and a 180
        # degree rotation
        shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx],
                                   gt_joints[:, right_shoulder_idx])
        try_both_orient = shoulder_dist.item() < side_view_thsh

        # Update the value of the translation of the camera as well as
        # the image center.
        with torch.no_grad():
            camera.translation[:] = init_t.view_as(camera.translation)
            camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5

        # Re-enable gradient calculation for the camera translation
        camera.translation.requires_grad = True

        camera_opt_params = [camera.translation, body_model.global_orient]

        camera_optimizer, camera_create_graph = optim_factory.create_optimizer(
            camera_opt_params, **kwargs)

        # The closure passed to the optimizer
        fit_camera = monitor.create_fitting_closure(
            camera_optimizer,
            body_model,
            camera,
            gt_joints,
            camera_loss,
            create_graph=camera_create_graph,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            return_full_pose=False,
            return_verts=False)

        # Step 1: Optimize over the torso joints the camera translation
        # Initialize the computational graph by feeding the initial translation
        # of the camera and the initial pose of the body model.
        camera_init_start = time.time()
        cam_init_loss_val = monitor.run_fitting(camera_optimizer,
                                                fit_camera,
                                                camera_opt_params,
                                                body_model,
                                                use_vposer=use_vposer,
                                                pose_embedding=pose_embedding,
                                                vposer=vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            tqdm.write('Camera initialization done after {:.4f}'.format(
                time.time() - camera_init_start))
            tqdm.write('Camera initialization final loss {:.4f}'.format(
                cam_init_loss_val))

        # If the 2D detections/positions of the shoulder joints are too
        # close the rotate the body by 180 degrees and also fit to that
        # orientation
        if try_both_orient:
            body_orient = body_model.global_orient.detach().cpu().numpy()
            flipped_orient = cv2.Rodrigues(body_orient)[0].dot(
                cv2.Rodrigues(np.array([0., np.pi, 0]))[0])
            flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel()

            flipped_orient = torch.tensor(flipped_orient,
                                          dtype=dtype,
                                          device=device).unsqueeze(dim=0)
            orientations = [body_orient, flipped_orient]
        else:
            orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []

        # Step 2: Optimize the full model
        final_loss_val = 0
        for or_idx, orient in enumerate(tqdm(orientations,
                                             desc='Orientation')):
            opt_start = time.time()

            new_params = defaultdict(global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)

            for opt_idx, curr_weights in enumerate(
                    tqdm(opt_weights, desc='Stage')):

                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                body_optimizer, body_create_graph = optim_factory.create_optimizer(
                    final_params, **kwargs)
                body_optimizer.zero_grad()

                curr_weights['data_weight'] = data_weight
                curr_weights['bending_prior_weight'] = (
                    3.17 * curr_weights['body_pose_weight'])
                if use_hands:
                    joint_weights[:, 25:76] = curr_weights['hand_weight']
                if use_face:
                    joint_weights[:, 76:] = curr_weights['face_weight']
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer,
                    body_model,
                    camera=camera,
                    gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss,
                    create_graph=body_create_graph,
                    use_vposer=use_vposer,
                    vposer=vposer,
                    pose_embedding=pose_embedding,
                    return_verts=True,
                    return_full_pose=True)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure,
                    final_params,
                    body_model,
                    pose_embedding=pose_embedding,
                    vposer=vposer,
                    use_vposer=use_vposer)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write(
                            'Stage {:03d} done after {:.4f} seconds'.format(
                                opt_idx, elapsed))

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write(
                    'Body fitting Orientation {} done after {:.4f} seconds'.
                    format(or_idx, elapsed))
                tqdm.write(
                    'Body final loss val = {:.5f}'.format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {
                'camera_' + str(key): val.detach().cpu().numpy()
                for key, val in camera.named_parameters()
            }
            result.update({
                key: val.detach().cpu().numpy()
                for key, val in body_model.named_parameters()
            })
            if use_vposer:
                result['body_pose'] = pose_embedding.detach().cpu().numpy()

            results.append({'loss': final_loss_val, 'result': result})

        with open(result_fn, 'wb') as result_file:
            if len(results) > 1:
                min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1)
            else:
                min_idx = 0
            pickle.dump(results[min_idx]['result'], result_file, protocol=2)

    if save_meshes or visualize:
        body_pose = vposer.decode(pose_embedding, output_type='aa').view(
            1, -1) if use_vposer else None

        model_type = kwargs.get('model_type', 'smpl')
        append_wrists = model_type == 'smpl' and use_vposer
        if append_wrists:
            wrist_pose = torch.zeros([body_pose.shape[0], 6],
                                     dtype=body_pose.dtype,
                                     device=body_pose.device)
            body_pose = torch.cat([body_pose, wrist_pose], dim=1)

        model_output = body_model(return_verts=True, body_pose=body_pose)
        vertices = model_output.vertices.detach().cpu().numpy().squeeze()

        import trimesh

        out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False)
        rot = trimesh.transformations.rotation_matrix(np.radians(180),
                                                      [1, 0, 0])
        out_mesh.apply_transform(rot)
        out_mesh.export(mesh_fn)

    if visualize:
        import pyrender

        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.0,
            alphaMode='OPAQUE',
            baseColorFactor=(1.0, 1.0, 0.9, 1.0))
        mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)

        scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
                               ambient_light=(0.3, 0.3, 0.3))
        scene.add(mesh, 'mesh')

        camera_center = camera.center.detach().cpu().numpy().squeeze()
        camera_transl = camera.translation.detach().cpu().numpy().squeeze()
        # Equivalent to 180 degrees around the y-axis. Transforms the fit to
        # OpenGL compatible coordinate system.
        camera_transl[0] *= -1.0

        camera_pose = np.eye(4)
        camera_pose[:3, 3] = camera_transl

        camera = pyrender.camera.IntrinsicsCamera(fx=focal_length,
                                                  fy=focal_length,
                                                  cx=camera_center[0],
                                                  cy=camera_center[1])
        scene.add(camera, pose=camera_pose)

        # Get the lights from the viewer
        light_nodes = monitor.mv.viewer._create_raymond_lights()
        for node in light_nodes:
            scene.add_node(node)

        r = pyrender.OffscreenRenderer(viewport_width=W,
                                       viewport_height=H,
                                       point_size=1.0)
        color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
        color = color.astype(np.float32) / 255.0

        valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
        input_img = img.detach().cpu().numpy()
        output_img = (color[:, :, :-1] * valid_mask +
                      (1 - valid_mask) * input_img)

        img = pil_img.fromarray((output_img * 255).astype(np.uint8))
        img.save(out_img_fn)
示例#21
0
def optimize():
    scene_mesh, cur_scene_verts, s_grid_min_batch, s_grid_max_batch, s_sdf_batch = read_mesh_sdf(args.dataset_path,
                                                                                                 args.dataset,
                                                                                                 args.scene_name)
    save_path = '{}/{}/{}'.format(args.save_path, args.dataset, args.scene_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    smplx_model = smplx.create(args.smplx_model_path, model_type='smplx',
                               gender='neutral', ext='npz',
                               num_pca_comps=12,
                               create_global_orient=True,
                               create_body_pose=True,
                               create_betas=True,
                               create_left_hand_pose=True,
                               create_right_hand_pose=True,
                               create_expression=True,
                               create_jaw_pose=True,
                               create_leye_pose=True,
                               create_reye_pose=True,
                               create_transl=True,
                               batch_size=1
                               ).to(device)
    print('[INFO] smplx model loaded.')

    vposer_model, _ = load_vposer(args.vposer_model_path, vp_model='snapshot')
    vposer_model = vposer_model.to(device)
    print('[INFO] vposer model loaded')


    ####################### calculate scene bps representation ##################
    rot_angle, scene_min_x, scene_max_x, scene_min_y, scene_max_y = define_scene_boundary(args.dataset, args.scene_name)

    scene_verts_crop_local_list, scene_verts_local_list, = [], []
    shift_list = []
    rot_angle_list_1, rot_angle_list_2 = [], []

    np.random.seed(0)
    random_seed_list = np.random.randint(10000, size=args.n_sample)
    for i in tqdm(range(args.n_sample)):
        scene_verts = rotate_scene_smplx_predefine(cur_scene_verts, rot_angle=rot_angle)
        scene_verts_local, scene_verts_crop_local, shift = crop_scene_cube_smplx_predifine(
            scene_verts, r=args.cube_size, with_wall_ceilling=True, random_seed=random_seed_list[i],
            scene_min_x=scene_min_x, scene_max_x=scene_max_x, scene_min_y=scene_min_y, scene_max_y=scene_max_y,
            rotate=True)
        scene_verts_crop_local_list.append(scene_verts_crop_local)  # list, different verts num for each cropped scene
        scene_verts_local_list.append(scene_verts_local)
        shift_list.append(shift)
        rot_angle_list_1.append(rot_angle)
    print('[INFO] scene mesh cropped and shifted.')


    scene_bps_list, body_bps_list = [], []
    scene_bps_verts_global_list, scene_bps_verts_local_list = [], []
    scene_bps_verts_global_list = []
    scene_basis_set = bps_gen_ball_inside(n_bps=args.num_bps, random_seed=100)
    np.random.seed(1)
    random_seed_list = np.random.randint(10000, size=args.n_sample)
    for i in tqdm(range(args.n_sample)):
        scene_verts_global, scene_verts_crop_global, rot_angle = \
            augmentation_crop_scene_smplx(scene_verts_local_list[i] / args.cube_size,
                                          scene_verts_crop_local_list[i] / args.cube_size,
                                          random_seed_list[i])
        scene_bps, selected_scene_verts_global, selected_ind = bps_encode_scene(scene_basis_set,
                                                                                scene_verts_crop_global)  # [n_feat, n_bps]
        scene_bps_list.append(scene_bps)
        selected_scene_verts_local = scene_verts_crop_local_list[i][selected_ind]
        scene_bps_verts_local_list.append(selected_scene_verts_local)
        scene_bps_verts_global_list.append(selected_scene_verts_global)
        rot_angle_list_2.append(rot_angle)

    scene_bps_list = np.asarray(scene_bps_list)  # [n_sample*4, n_feat, n_bps]
    scene_bps_verts_local_list = np.asarray(scene_bps_verts_local_list)
    scene_verts_local_list = np.asarray(scene_verts_local_list)
    np.save('{}/scene_bps_list.npy'.format(save_path), scene_bps_list)
    print('[INFO] scene bps/verts saved.')


    ######################## set dataloader and load model ########################
    dataset = TestLoader()
    dataset.n_samples = args.n_sample
    dataset.scene_bps_list = scene_bps_list  # [n_sample, n_feat, n_bps]
    dataset.scene_bps_verts_list = scene_bps_verts_local_list
    print('[INFO] dataloader updated, select n_samples={}'.format(dataset.__len__()))
    test_dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False,
                                                  num_workers=0, drop_last=False)

    scene_bps_AE = BPSRecMLP(n_bps=args.num_bps, n_bps_feat=1, hsize1=1024, hsize2=512).to(device)
    weights = torch.load(args.scene_bps_AE_path, map_location=lambda storage, loc: storage)
    scene_bps_AE.load_state_dict(weights)

    c_VAE = BPS_CVAE(n_bps=args.num_bps, n_bps_feat=1, hsize1=1024, hsize2=512, eps_d=args.eps_d).to(device)
    weights = torch.load(args.cVAE_path, map_location=lambda storage, loc: storage)
    c_VAE.load_state_dict(weights)

    scene_AE = Verts_AE(n_bps=10000, hsize1=1024, hsize2=512).to(device)
    weights = torch.load(args.scene_verts_AE_path, map_location=lambda storage, loc: storage)
    scene_AE.load_state_dict(weights)

    body_dec = Body_Dec_shift(n_bps=10000, n_bps_feat=1, hsize1=1024, hsize2=512, n_body_verts=10475,
                              body_param_dim=75, rec_goal='body_verts').to(device)
    weights = torch.load(args.bodyDec_path, map_location=lambda storage, loc: storage)
    body_dec.load_state_dict(weights)

    scene_bps_AE.eval()
    c_VAE.eval()
    scene_AE.eval()
    body_dec.eval()


    ######################## initialize for optimization ##########################
    body_verts_sample_list = []
    body_bps_sample_list = []
    np.random.seed(2)
    random_seed_list = np.random.randint(10000, size=args.n_sample)
    with torch.no_grad():
        for step, data in tqdm(enumerate(test_dataloader)):
            [scene_bps, scene_bps_verts] = [item.to(device) for item in data]
            scene_bps_verts = scene_bps_verts / args.cube_size

            _, scene_bps_feat = scene_bps_AE(scene_bps)
            _, scene_bps_verts_feat = scene_AE(scene_bps_verts)
            # torch.manual_seed(random_seed_list[step])
            body_bps_sample = c_VAE.sample(1, scene_bps_feat)  # [1, 1, 10000]
            body_bps_sample_list.append(body_bps_sample[0].detach().cpu().numpy())  # [n, 1, 10000]
            body_verts_sample, body_shift = body_dec(body_bps_sample, scene_bps_verts_feat)  # [1, 3, 10475], unit ball scale, local coordinate

            # shifted generated body
            body_shift = body_shift.repeat(1, 1, 10475).reshape([body_verts_sample.shape[0], 10475, 3])  # [bs, 10475, 3]
            body_verts_sample = body_verts_sample + body_shift.permute(0, 2, 1)  # [bs, 3, 10475]

            body_verts_sample_list.append(body_verts_sample[0].detach().cpu().numpy())  # [n, 3, 10475]

    contact_part = ['L_Leg', 'R_Leg']
    vid, _ = get_contact_id(body_segments_folder=os.path.join(args.prox_dataset_path, 'body_segments'),
                            contact_body_parts=contact_part)


    ############################### save data ###############################
    shift_list = np.asarray(shift_list)
    rot_angle_list_1 = np.asarray(rot_angle_list_1)
    rot_angle_list_2 = np.asarray(rot_angle_list_2)
    body_bps_sample_list = np.asarray(body_bps_sample_list)
    body_verts_sample_list = np.asarray(body_verts_sample_list)

    np.save('{}/shift_list.npy'.format(save_path), shift_list)
    np.save('{}/rot_angle_list_1.npy'.format(save_path), rot_angle_list_1)
    np.save('{}/rot_angle_list_2.npy'.format(save_path), rot_angle_list_2)
    np.save('{}/body_bps_sample_list.npy'.format(save_path), body_bps_sample_list)
    # save generated body verts in original scale
    np.save('{}/body_verts_sample_list.npy'.format(save_path), body_verts_sample_list.transpose((0,2,1))*args.cube_size)


    ############################### optimization ##################################
    if args.optimize:
        #################### stage 1 (simple optimization, without contact/collision loss) ###################
        print('[INFO] start optimization stage 1...')
        body_params_opt_list_s1 = []
        for cnt in range(args.n_sample):
            print('stage 1: current cnt:', cnt)
            body_params_rec = torch.randn(1, 72).to(device)  # initiliza smplx params, bs=1, local coordinate system
            body_params_rec[0, 0] = 0.0
            body_params_rec[0, 1] = 0.0
            body_params_rec[0, 2] = 0.0
            body_params_rec[0, 3] = 1.5
            body_params_rec[0, 4] = 0.0
            body_params_rec[0, 5] = 0.0
            body_params_rec = convert_to_6D_rot(body_params_rec)
            body_params_rec.requires_grad = True

            optimizer = optim.Adam([body_params_rec], lr=0.1)

            body_bps = torch.from_numpy(body_bps_sample_list[cnt]).float().unsqueeze(0).to(device)  # [bs=1, 1, 10000]
            body_verts = torch.from_numpy(body_verts_sample_list[cnt]).float().unsqueeze(0).to(device)  # [bs=1, 3, 10475]
            body_verts = body_verts.permute(0, 2, 1)  # [1, 10475, 3]
            body_verts = body_verts * args.cube_size  # to local coordinate system scale

            for step in tqdm(range(args.itr_s1)):
                if step > 100:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = 0.01
                if step > 300:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = 0.001
                optimizer.zero_grad()

                body_params_rec_72 = convert_to_3D_rot(body_params_rec)  # tensor, [bs=1, 72]
                body_pose_joint = vposer_model.decode(body_params_rec_72[:, 16:48], output_type='aa').view(1, -1)  # tensor, [bs=1, 63]
                body_verts_rec = gen_body_mesh(body_params_rec_72, body_pose_joint, smplx_model)[0]  # tensor, [n_body_vert, 3]

                # transform body verts to unit ball global coordinate system
                temp = body_verts_rec / args.cube_size  # scale into unit ball
                body_verts_rec_global = torch.zeros(body_verts_rec.shape).to(device)
                body_verts_rec_global[:, 0] = temp[:, 0] * math.cos(rot_angle_list_2[cnt]) - \
                                              temp[:, 1] * math.sin(rot_angle_list_2[cnt])
                body_verts_rec_global[:, 1] = temp[:, 0] * math.sin(rot_angle_list_2[cnt]) + \
                                              temp[:, 1] * math.cos(rot_angle_list_2[cnt])
                body_verts_rec_global[:, 2] = temp[:, 2]

                # calculate optimized body bps feature
                body_bps_rec = torch.zeros(body_bps.shape)
                if args.weight_loss_rec_bps > 0:
                    nbrs = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm="ball_tree").fit(body_verts_rec_global.detach().cpu().numpy())
                    neigh_dist, neigh_ind = nbrs.kneighbors(scene_bps_verts_global_list[cnt])
                    body_bps_rec = body_verts_rec_global[neigh_ind.squeeze()] - \
                                   torch.from_numpy(scene_bps_verts_global_list[cnt]).float().to(device)  # [n_bps, 3]
                    body_bps_rec = torch.sqrt(body_bps_rec[:, 0] ** 2 + body_bps_rec[:, 1] ** 2 + body_bps_rec[:, 2] ** 2).unsqueeze(0).unsqueeze(0)  # [bs=1, 1, n_bps]

                ### body bps feature reconstruct loss
                loss_rec_verts = F.l1_loss(body_verts_rec.unsqueeze(0), body_verts)
                loss_rec_bps = F.l1_loss(body_bps, body_bps_rec)

                ### vposer loss
                body_params_rec_72 = convert_to_3D_rot(body_params_rec)
                vposer_pose = body_params_rec_72[:, 16:48]
                loss_vposer = torch.mean(vposer_pose ** 2)
                ### shape prior loss
                shape_params = body_params_rec_72[:, 6:16]
                loss_shape = torch.mean(shape_params ** 2)
                ### hand pose prior loss
                hand_params = body_params_rec_72[:, 48:]
                loss_hand = torch.mean(hand_params ** 2)

                loss = args.weight_loss_rec_verts * loss_rec_verts + args.weight_loss_rec_bps * loss_rec_bps + \
                       args.weight_loss_vposer * loss_vposer + \
                       args.weight_loss_shape * loss_shape + \
                       args.weight_loss_hand * loss_hand
                loss.backward(retain_graph=True)
                optimizer.step()

            body_params_opt_list_s1.append(body_params_rec[0].detach().cpu().numpy())


        body_params_opt_list_s1 = np.asarray(body_params_opt_list_s1)
        np.save('{}/body_params_opt_list_s1.npy'.format(save_path), body_params_opt_list_s1)


        ################ stage 2 (advanced optimization, with contact/collision loss) ##################
        print('[INFO] start optimization stage 2...')
        body_params_opt_list_s2 = []
        for cnt in range(args.n_sample):
            print('current cnt:', cnt)
            body_params_rec = body_params_opt_list_s1[cnt]  # [75]
            body_params_rec = torch.from_numpy(body_params_rec).float().to(device).unsqueeze(0)
            body_params_rec.requires_grad = True

            optimizer = optim.Adam([body_params_rec], lr=0.01)

            body_bps = torch.from_numpy(body_bps_sample_list[cnt]).float().unsqueeze(0).to(device)  # [bs=1, 1, 10000]
            body_verts = torch.from_numpy(body_verts_sample_list[cnt]).float().unsqueeze(0).to(device)
            body_verts = body_verts.permute(0, 2, 1)  # [1, 10475, 3]
            body_verts = body_verts * args.cube_size  # to local coordinate system scale

            for step in tqdm(range(args.itr_s2)):
                optimizer.zero_grad()

                body_params_rec_72 = convert_to_3D_rot(body_params_rec)  # tensor, [bs=1, 72]
                body_pose_joint = vposer_model.decode(body_params_rec_72[:, 16:48], output_type='aa').view(1,-1)  # tensor, [bs=1, 63]
                body_verts_rec = gen_body_mesh(body_params_rec_72, body_pose_joint, smplx_model)[0]  # tensor, [n_body_vert, 3]

                # transform body verts to unit ball global coordinate
                temp = body_verts_rec / args.cube_size  # scale into unit ball
                body_verts_rec_global = torch.zeros(body_verts_rec.shape).to(device)
                body_verts_rec_global[:, 0] = temp[:, 0] * math.cos(rot_angle_list_2[cnt]) - \
                                              temp[:, 1] * math.sin(rot_angle_list_2[cnt])
                body_verts_rec_global[:, 1] = temp[:, 0] * math.sin(rot_angle_list_2[cnt]) + \
                                              temp[:, 1] * math.cos(rot_angle_list_2[cnt])
                body_verts_rec_global[:, 2] = temp[:, 2]

                # calculate body_bps_rec
                body_bps_rec = torch.zeros(body_bps.shape)
                if args.weight_loss_rec_bps > 0:
                    nbrs = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm="ball_tree").fit(body_verts_rec_global.detach().cpu().numpy())
                    neigh_dist, neigh_ind = nbrs.kneighbors(scene_bps_verts_global_list[cnt])
                    body_bps_rec = body_verts_rec_global[neigh_ind.squeeze()] - \
                                   torch.from_numpy(scene_bps_verts_global_list[cnt]).float().to(device)  # [n_bps, 3]
                    body_bps_rec = torch.sqrt(body_bps_rec[:, 0] ** 2 + body_bps_rec[:, 1] ** 2 + body_bps_rec[:, 2] ** 2).unsqueeze(0).unsqueeze(0)  # [bs=1, 1, n_bps]

                ### body bps encoding reconstruct loss
                loss_rec_verts = F.l1_loss(body_verts_rec.unsqueeze(0), body_verts)
                loss_rec_bps = F.l1_loss(body_bps, body_bps_rec)

                ### vposer loss
                body_params_rec_72 = convert_to_3D_rot(body_params_rec)
                vposer_pose = body_params_rec_72[:, 16:48]
                loss_vposer = torch.mean(vposer_pose ** 2)
                ### shape prior loss
                shape_params = body_params_rec_72[:, 6:16]
                loss_shape = torch.mean(shape_params ** 2)
                ### hand pose prior loss
                hand_params = body_params_rec_72[:, 48:]
                loss_hand = torch.mean(hand_params ** 2)

                # transfrom local body_verts_rec to prox coordinate system
                body_verts_rec_prox = torch.zeros(body_verts_rec.shape).to(device)
                temp = body_verts_rec - torch.from_numpy(shift_list[cnt]).float().to(device)
                body_verts_rec_prox[:, 0] = temp[:, 0] * math.cos(-rot_angle_list_1[cnt]) - \
                                            temp[:, 1] * math.sin(-rot_angle_list_1[cnt])
                body_verts_rec_prox[:, 1] = temp[:, 0] * math.sin(-rot_angle_list_1[cnt]) + \
                                            temp[:, 1] * math.cos(-rot_angle_list_1[cnt])
                body_verts_rec_prox[:, 2] = temp[:, 2]
                body_verts_rec_prox = body_verts_rec_prox.unsqueeze(0)  # tensor, [bs=1, 10475, 3]

                ### sdf collision loss
                norm_verts_batch = (body_verts_rec_prox - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) * 2 - 1
                n_verts = norm_verts_batch.shape[1]
                body_sdf_batch = F.grid_sample(s_sdf_batch.unsqueeze(1),
                                               norm_verts_batch[:, :, [2, 1, 0]].view(-1, n_verts, 1, 1, 3),
                                               padding_mode='border')
                # if there are no penetrating vertices then set sdf_penetration_loss = 0
                if body_sdf_batch.lt(0).sum().item() < 1:
                    loss_collision = torch.tensor(0.0, dtype=torch.float32).to(device)
                else:
                    loss_collision = body_sdf_batch[body_sdf_batch < 0].abs().mean()

                ### contact loss
                body_verts_contact = body_verts_rec.unsqueeze(0)[:, vid, :]  # [1,1121,3]
                dist_chamfer_contact = ext.chamferDist()
                # scene_verts: [bs=1, n_scene_verts, 3]
                scene_verts = torch.from_numpy(scene_verts_local_list[cnt]).float().to(device).unsqueeze(0)  # [1,50000,3]
                contact_dist, _ = dist_chamfer_contact(body_verts_contact.contiguous(),
                                                       scene_verts.contiguous())
                loss_contact = torch.mean(torch.sqrt(contact_dist + 1e-4) / (torch.sqrt(contact_dist + 1e-4) + 1.0))

                loss = args.weight_loss_rec_verts * loss_rec_verts + args.weight_loss_rec_bps * loss_rec_bps + \
                       args.weight_loss_vposer * loss_vposer + \
                       args.weight_loss_shape * loss_shape + \
                       args.weight_loss_hand * loss_hand + \
                       args.weight_collision * loss_collision + args.weight_loss_contact * loss_contact
                loss.backward(retain_graph=True)
                optimizer.step()

            body_params_opt_list_s2.append(body_params_rec[0].detach().cpu().numpy())

        body_params_opt_list_s2 = np.asarray(body_params_opt_list_s2)
        np.save('{}/body_params_opt_list_s2.npy'.format(save_path), body_params_opt_list_s2)