示例#1
0
def get_interpenetration_module():
    from mesh_intersection.bvh_search_tree import BVH
    import mesh_intersection.loss as collisions_loss
    from mesh_intersection.filter_faces import FilterFaces

    max_collisions = 8  # 128
    df_cone_height = 0.5  # 0.0001
    point2plane = False
    penalize_outside = True
    part_segm_fn = 'smplx_parts_segm.pkl'
    ign_part_pairs = ["9,16", "9,17", "6,16", "6,17", "1,2", "12,22"]

    search_tree = BVH(max_collisions=max_collisions)
    pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
        sigma=df_cone_height,
        point2plane=point2plane,
        vectorized=True,
        penalize_outside=penalize_outside)
    if part_segm_fn:
        part_segm_fn = os.path.expandvars(part_segm_fn)
        with open(part_segm_fn, 'rb') as faces_parents_file:
            face_segm_data = pkl.load(faces_parents_file, encoding='latin1')
        faces_segm = face_segm_data['segm']
        faces_parents = face_segm_data['parents']
        tri_filtering_module = FilterFaces(
            faces_segm=faces_segm,
            faces_parents=faces_parents,
            ign_part_pairs=ign_part_pairs).cuda()
    return search_tree, pen_distance, tri_filtering_module
示例#2
0
def main():
    description = 'Example script for untangling Mesh self intersections'
    parser = argparse.ArgumentParser(description=description,
                                     prog='Batch Mesh Untangle')
    parser.add_argument('--data_folder', type=str,
                        default='data',
                        help='The path to data')
    parser.add_argument('--point2plane', default=False,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Use point to distance')
    parser.add_argument('--sigma', default=0.5, type=float,
                        help='The height of the cone used to calculate the' +
                        ' distance field loss')
    parser.add_argument('--lr', default=1, type=float,
                        help='The learning rate for SGD')
    parser.add_argument('--coll_loss_weight', default=1e-4, type=float,
                        help='The weight for the collision loss')
    parser.add_argument('--verts_reg_weight', default=1e-5, type=float,
                        help='The weight for the verts regularizer')
    parser.add_argument('--max_collisions', default=8, type=int,
                        help='The maximum number of bounding box collisions')
    parser.add_argument('--iterations', default=100, type=int,
                        help='Number of optimization iterations')

    args = parser.parse_args()
    data_folder = args.data_folder
    point2plane = args.point2plane
    lr = args.lr
    coll_loss_weight = args.coll_loss_weight
    max_collisions = args.max_collisions
    sigma = args.sigma
    iterations = args.iterations


    device = torch.device('cuda')
    
    obj_paths = sorted(glob(data_folder+'/*.obj'))
    verts_list,faces_list = [],[]
    for obj_path in obj_paths:
        cur_verts, _, cur_faces = readObj(obj_path)
        cur_faces -= 1
        verts_list.append(cur_verts)
        faces_list.append(cur_faces)
    all_verts = np.array(verts_list).astype(np.float32)
    all_faces = np.array(faces_list).astype(np.int64)

    verts_tensor = torch.tensor(all_verts.copy(), dtype=torch.float32,
                               device=device)
    face_tensor = torch.tensor(all_faces, dtype=torch.long,
                               device=device)
    param = torch.tensor(all_verts.copy(), dtype=torch.float32,
                               device=device).requires_grad_()
    bs, nv = verts_tensor.shape[:2]
    bs, nf = face_tensor.shape[:2]
    faces_idx = face_tensor + \
        (torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None]

    # Create the search tree
    search_tree = BVH(max_collisions=max_collisions)

    pen_distance = \
        collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma,
                                                     point2plane=point2plane,
                                                     vectorized=True)

    mse_loss = nn.MSELoss(reduction='sum').to(device=device)


    optimizer = torch.optim.SGD([param], lr=lr)

    step = 0
    for i in range(iterations)
        optimizer.zero_grad()

        triangles = param.view([-1, 3])[faces_idx]

        with torch.no_grad():
            collision_idxs = search_tree(triangles)

        pen_loss = coll_loss_weight * \
            pen_distance(triangles, collision_idxs)

        verts_reg_loss = torch.tensor(0, device=device,
                                     dtype=torch.float32)
        if verts_reg_weight > 0:
            verts_reg_loss = verts_reg_weight * \
                mse_loss(param, verts_tensor)

        loss = pen_loss + verts_reg_loss

        np_loss = loss.detach().cpu().squeeze().tolist()
        if type(np_loss) != list:
            np_loss = [np_loss]
        msg = '{:.5f} ' * len(np_loss)
        print('Loss per model:', msg.format(*np_loss))
        
        loss.backward(torch.ones_like(loss))
        optimizer.step()
        step += 1

    optimized_verts = param.detach().cpu().numpy()
    for i in range(optimized_verts.shape[0]):
示例#3
0
def main():
    description = 'Example script for untangling SMPL self intersections'
    parser = argparse.ArgumentParser(description=description,
                                     prog='Batch SMPL-Untangle')
    parser.add_argument('--param_fn', type=str,
                        nargs='*',
                        required=True,
                        help='The pickle file with the model parameters')
    parser.add_argument('--interactive', default=True,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Display the mesh during the optimization' +
                        ' process')
    parser.add_argument('--delay', type=int, default=50,
                        help='The delay for the animation callback in ms')
    parser.add_argument('--model_folder', type=str,
                        default='models',
                        help='The path to the LBS model')
    parser.add_argument('--model_type', type=str,
                        default='smpl', choices=['smpl', 'smplx', 'smplh'],
                        help='The type of model to create')
    parser.add_argument('--point2plane', default=False,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Use point to distance')
    parser.add_argument('--optimize_pose', default=True,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Enable optimization over the joint pose')
    parser.add_argument('--optimize_shape', default=False,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Enable optimization over the shape of the model')
    parser.add_argument('--sigma', default=0.5, type=float,
                        help='The height of the cone used to calculate the' +
                        ' distance field loss')
    parser.add_argument('--lr', default=1, type=float,
                        help='The learning rate for SGD')
    parser.add_argument('--coll_loss_weight', default=1e-4, type=float,
                        help='The weight for the collision loss')
    parser.add_argument('--pose_reg_weight', default=0, type=float,
                        help='The weight for the pose regularizer')
    parser.add_argument('--shape_reg_weight', default=0, type=float,
                        help='The weight for the shape regularizer')
    parser.add_argument('--max_collisions', default=8, type=int,
                        help='The maximum number of bounding box collisions')
    parser.add_argument('--part_segm_fn', default='', type=str,
                        help='The file with the part segmentation for the' +
                        ' faces of the model')
    parser.add_argument('--print_timings', default=False,
                        type=lambda arg: arg.lower() in ['true', '1'],
                        help='Print timings for all the operations')

    args = parser.parse_args()

    model_folder = args.model_folder
    model_type = args.model_type
    param_fn = args.param_fn
    interactive = args.interactive
    delay = args.delay
    point2plane = args.point2plane
    #  optimize_shape = args.optimize_shape
    #  optimize_pose = args.optimize_pose
    lr = args.lr
    coll_loss_weight = args.coll_loss_weight
    pose_reg_weight = args.pose_reg_weight
    shape_reg_weight = args.shape_reg_weight
    max_collisions = args.max_collisions
    sigma = args.sigma
    part_segm_fn = args.part_segm_fn
    print_timings = args.print_timings

    if interactive:
        import trimesh
        import pyrender

    device = torch.device('cuda')
    batch_size = len(param_fn)

    params_dict = defaultdict(lambda: [])
    for idx, fn in enumerate(param_fn):
        with open(fn, 'rb') as param_file:
            data = pickle.load(param_file, encoding='latin1')

        assert 'betas' in data, \
            'No key for shape parameter in provided pickle file'
        assert 'global_pose' in data, \
            'No key for the global pose in the given pickle file'
        assert 'pose' in data, \
            'No key for the pose of the joints in the given pickle file'

        for key, val in data.items():
            params_dict[key].append(val)

    params = {}
    for key in params_dict:
        params[key] = np.stack(params_dict[key], axis=0).astype(np.float32)
        if len(params[key].shape) < 2:
            params[key] = params[key][np.newaxis]
    if 'global_pose' in params:
        params['global_orient'] = params['global_pose']
    if 'pose' in params:
        params['body_pose'] = params['pose']

    if part_segm_fn:
        # Read the part segmentation
        with open(part_segm_fn, 'rb') as faces_parents_file:
            data = pickle.load(faces_parents_file, encoding='latin1')
        faces_segm = data['segm']
        faces_parents = data['parents']
        # Create the module used to filter invalid collision pairs
        filter_faces = FilterFaces(faces_segm, faces_parents).to(device=device)

    # Create the body model
    body = create(model_folder, batch_size=batch_size,
                  model_type=model_type).to(device=device)
    body.reset_params(**params)

    # Clone the given pose to use it as a target for regularization
    init_pose = body.body_pose.clone().detach()

    # Create the search tree
    search_tree = BVH(max_collisions=max_collisions)

    pen_distance = \
        collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma,
                                                     point2plane=point2plane,
                                                     vectorized=True)

    mse_loss = nn.MSELoss(reduction='sum').to(device=device)

    face_tensor = torch.tensor(body.faces.astype(np.int64), dtype=torch.long,
                               device=device).unsqueeze_(0).repeat([batch_size,
                                                                    1, 1])
    with torch.no_grad():
        output = body(get_skin=True)
        verts = output.vertices

    bs, nv = verts.shape[:2]
    bs, nf = face_tensor.shape[:2]
    faces_idx = face_tensor + \
        (torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None]

    optimizer = torch.optim.SGD([body.body_pose], lr=lr)

    if interactive:
        # Plot the initial mesh
        with torch.no_grad():
            output = body(get_skin=True)
            verts = output.vertices

            np_verts = verts.detach().cpu().numpy()

        def create_mesh(vertices, faces, color=(0.3, 0.3, 0.3, 1.0),
                        wireframe=False):

            tri_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
            rot = trimesh.transformations.rotation_matrix(np.radians(180),
                                                          [1, 0, 0])
            tri_mesh.apply_transform(rot)

            material = pyrender.MetallicRoughnessMaterial(
                metallicFactor=0.0,
                alphaMode='BLEND',
                baseColorFactor=color)
            return pyrender.Mesh.from_trimesh(
                tri_mesh,
                material=material)

        scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 1.0],
                               ambient_light=(1.0, 1.0, 1.0))
        for bidx in range(np_verts.shape[0]):
            curr_verts = np_verts[bidx].copy()
            body_mesh = create_mesh(curr_verts, body.faces,
                                    color=(0.3, 0.3, 0.3, 0.99),
                                    wireframe=True)

            pose = np.eye(4)
            pose[0, 3] = bidx * 2
            scene.add(body_mesh,
                      name='body_mesh_{:03d}'.format(bidx),
                      pose=pose)

        viewer = pyrender.Viewer(scene, use_raymond_lighting=True,
                                 viewport_size=(1200, 800),
                                 cull_faces=False,
                                 run_in_thread=True)

    query_names = ['recv_mesh', 'intr_mesh', 'body_mesh']

    step = 0
    while True:
        optimizer.zero_grad()

        if print_timings:
            start = time.time()

        if print_timings:
            torch.cuda.synchronize()
        output = body(get_skin=True)
        verts = output.vertices

        if print_timings:
            torch.cuda.synchronize()
            print('Body model forward: {:5f}'.format(time.time() - start))

        if print_timings:
            torch.cuda.synchronize()
            start = time.time()
        triangles = verts.view([-1, 3])[faces_idx]
        if print_timings:
            torch.cuda.synchronize()
            print('Triangle indexing: {:5f}'.format(time.time() - start))

        with torch.no_grad():
            if print_timings:
                start = time.time()
            collision_idxs = search_tree(triangles)
            if print_timings:
                torch.cuda.synchronize()
                print('Collision Detection: {:5f}'.format(time.time() -
                                                          start))
            if part_segm_fn:
                if print_timings:
                    start = time.time()
                collision_idxs = filter_faces(collision_idxs)
                if print_timings:
                    torch.cuda.synchronize()
                    print('Collision filtering: {:5f}'.format(time.time() -
                                                              start))

        if print_timings:
            start = time.time()
        pen_loss = coll_loss_weight * \
            pen_distance(triangles, collision_idxs)
        if print_timings:
            torch.cuda.synchronize()
            print('Penetration loss: {:5f}'.format(time.time() - start))

        shape_reg_loss = torch.tensor(0, device=device,
                                      dtype=torch.float32)
        if shape_reg_weight > 0:
            shape_reg_loss = shape_reg_weight * torch.sum(body.betas ** 2)
        pose_reg_loss = torch.tensor(0, device=device,
                                     dtype=torch.float32)
        if pose_reg_weight > 0:
            pose_reg_loss = pose_reg_weight * \
                mse_loss(body.pose, init_pose)

        loss = pen_loss + pose_reg_loss + shape_reg_loss

        np_loss = loss.detach().cpu().squeeze().tolist()
        if type(np_loss) != list:
            np_loss = [np_loss]
        msg = '{:.5f} ' * len(np_loss)
        print('Loss per model:', msg.format(*np_loss))

        if print_timings:
            start = time.time()
        loss.backward(torch.ones_like(loss))
        if print_timings:
            torch.cuda.synchronize()
            print('Backward pass: {:5f}'.format(time.time() - start))

        if interactive:
            with torch.no_grad():
                output = body(get_skin=True)
                verts = output.vertices

                np_verts = verts.detach().cpu().numpy()

            np_collision_idxs = collision_idxs.detach().cpu().numpy()
            np_receivers = np_collision_idxs[:, :, 0]
            np_intruders = np_collision_idxs[:, :, 1]

            viewer.render_lock.acquire()

            for node in scene.get_nodes():
                if node.name is None:
                    continue
                if any([query in node.name for query in query_names]):
                    scene.remove_node(node)

            for bidx in range(batch_size):
                recv_faces_idxs = np_receivers[bidx][np_receivers[bidx] >= 0]
                intr_faces_idxs = np_intruders[bidx][np_intruders[bidx] >= 0]
                recv_faces = body.faces[recv_faces_idxs]
                intr_faces = body.faces[intr_faces_idxs]

                curr_verts = np_verts[bidx].copy()
                body_mesh = create_mesh(curr_verts, body.faces,
                                        color=(0.3, 0.3, 0.3, 0.99),
                                        wireframe=True)

                pose = np.eye(4)
                pose[0, 3] = bidx * 2
                scene.add(body_mesh,
                          name='body_mesh_{:03d}'.format(bidx),
                          pose=pose)

                if len(intr_faces) > 0:
                    intr_mesh = create_mesh(curr_verts, intr_faces,
                                            color=(0.9, 0.0, 0.0, 1.0))
                    scene.add(intr_mesh,
                              name='intr_mesh_{:03d}'.format(bidx),
                              pose=pose)

                if len(recv_faces) > 0:
                    recv_mesh = create_mesh(curr_verts, recv_faces,
                                            color=(0.0, 0.9, 0.0, 1.0))
                    scene.add(recv_mesh, name='recv_mesh_{:03d}'.format(bidx),
                              pose=pose)
            viewer.render_lock.release()

            if not viewer.is_active:
                break

            time.sleep(delay / 1000)
        optimizer.step()

        step += 1
示例#4
0
def non_linear_solver(setting,
                      data,
                      batch_size=1,
                      data_weights=None,
                      body_pose_prior_weights=None,
                      shape_weights=None,
                      coll_loss_weights=None,
                      use_joints_conf=False,
                      use_3d=False,
                      rho=100,
                      interpenetration=False,
                      loss_type='smplify',
                      visualize=False,
                      use_vposer=True,
                      interactive=True,
                      use_cuda=True,
                      is_seq=False,
                      **kwargs):
    assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'

    views = setting['views']
    device = setting['device']
    dtype = setting['dtype']
    vposer = setting['vposer']
    keypoints = data['keypoints']
    joint_weights = setting['joints_weight']
    model = setting['model']
    camera = setting['camera']
    pose_embedding = setting['pose_embedding']
    seq_start = setting['seq_start']

    assert (len(data_weights) == len(body_pose_prior_weights)
            and len(shape_weights) == len(body_pose_prior_weights)
            and len(coll_loss_weights)
            == len(body_pose_prior_weights)), "Number of weight must match"

    # process keypoints
    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :, :2]
    if use_joints_conf:
        joints_conf = []
        for v in keypoint_data:
            conf = v[:, :, 2].reshape(1, -1)
            conf = conf.to(device=device, dtype=dtype)
            joints_conf.append(conf)

    if use_3d:
        joints_data = torch.tensor(joints3d, dtype=dtype)
        gt_joints3d = joints_data[:, :3]
        if use_joints_conf:
            joints3d_conf = joints_data[:, 3].reshape(1, -1).to(device=device,
                                                                dtype=dtype)
            if not use_hip:
                joints3d_conf[0][11] = 0
                joints3d_conf[0][12] = 0

        gt_joints3d = gt_joints3d.to(device=device, dtype=dtype)
    else:
        gt_joints3d = None
        joints3d_conf = None
    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    # we do not use this term at this time
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        'data_weight': data_weights,
        'body_pose_weight': body_pose_prior_weights,
        'shape_weight': shape_weights
    }
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights

    # get weights for each stage
    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # create fitting loss
    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=setting['body_pose_prior'],
                               shape_prior=setting['shape_prior'],
                               angle_prior=setting['angle_prior'],
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               dtype=dtype,
                               use_3d=use_3d,
                               **kwargs)
    loss = loss.to(device=device)

    monitor = fitting.FittingMonitor(batch_size=batch_size,
                                     visualize=visualize,
                                     **kwargs)
    # with fitting.FittingMonitor(
    #         batch_size=batch_size, visualize=visualize, **kwargs) as monitor:

    H, W, _ = data['img'][0].shape

    data_weight = 500 / H

    # Reset the parameters to estimate the initial translation of the
    # body model
    # body_model.reset_params(body_pose=body_mean_pose, transl=init['init_t'], global_orient=init['init_r'], scale=init['init_s'], betas=init['init_betas'])

    # we do not change rotation in multi-view task
    orientations = [model.global_orient]

    # store here the final error for both orientations,
    # and pick the orientation resulting in the lowest error
    results = []

    # Step 1: Optimize the full model
    final_loss_val = 0
    opt_start = time.time()

    # # initial value for non-linear solve
    # new_params = defaultdict(global_orient=model.global_orient,
    #                             # body_pose=body_mean_pose,
    #                             transl=model.transl,
    #                             scale=model.scale,
    #                             betas=model.betas,
    #                             )
    # if vposer is not None:
    #     with torch.no_grad():
    #         pose_embedding.fill_(0)
    # model.reset_params(**new_params)

    for opt_idx, curr_weights in enumerate(tqdm(opt_weights, desc='Stage')):
        # pass stage1 and stage2 if it is a sequence
        if not seq_start and is_seq:
            if opt_idx < 2:
                continue
            elif opt_idx == 2:
                curr_weights['body_pose_weight'] *= 0.15

        body_params = list(model.parameters())

        final_params = list(filter(lambda x: x.requires_grad, body_params))

        if vposer is not None:
            final_params.append(pose_embedding)

        body_optimizer, body_create_graph = optim_factory.create_optimizer(
            final_params, **kwargs)
        body_optimizer.zero_grad()

        curr_weights['data_weight'] = data_weight
        curr_weights['bending_prior_weight'] = (
            3.17 * curr_weights['body_pose_weight'])
        loss.reset_loss_weights(curr_weights)

        closure = monitor.create_fitting_closure(
            body_optimizer,
            model,
            camera=camera,
            gt_joints=gt_joints,
            joints_conf=joints_conf,
            gt_joints3d=gt_joints3d,
            joints3d_conf=joints3d_conf,
            joint_weights=joint_weights,
            loss=loss,
            create_graph=body_create_graph,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            return_verts=True,
            return_full_pose=True,
            use_3d=use_3d)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            stage_start = time.time()
        final_loss_val = monitor.run_fitting(body_optimizer,
                                             closure,
                                             final_params,
                                             model,
                                             pose_embedding=pose_embedding,
                                             vposer=vposer,
                                             use_vposer=use_vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            elapsed = time.time() - stage_start
            if interactive:
                tqdm.write('Stage {:03d} done after {:.4f} seconds'.format(
                    opt_idx, elapsed))

    if interactive:
        if use_cuda and torch.cuda.is_available():
            torch.cuda.synchronize()
        elapsed = time.time() - opt_start
        tqdm.write('Body fitting done after {:.4f} seconds'.format(elapsed))
        tqdm.write('Body final loss val = {:.5f}'.format(final_loss_val))

        # Get the result of the fitting process
        result = {
            key: val.detach().cpu().numpy()
            for key, val in model.named_parameters()
        }
        result['loss'] = final_loss_val
        result['pose_embedding'] = pose_embedding
    return result
示例#5
0
def fit_single_frame(
        img,
        keypoints,
        init_trans,
        scan,
        scene_name,
        body_model,
        camera,
        joint_weights,
        body_pose_prior,
        jaw_prior,
        left_hand_prior,
        right_hand_prior,
        shape_prior,
        expr_prior,
        angle_prior,
        result_fn='out.pkl',
        mesh_fn='out.obj',
        body_scene_rendering_fn='body_scene.png',
        out_img_fn='overlay.png',
        loss_type='smplify',
        use_cuda=True,
        init_joints_idxs=(9, 12, 2, 5),
        use_face=True,
        use_hands=True,
        data_weights=None,
        body_pose_prior_weights=None,
        hand_pose_prior_weights=None,
        jaw_pose_prior_weights=None,
        shape_weights=None,
        expr_weights=None,
        hand_joints_weights=None,
        face_joints_weights=None,
        depth_loss_weight=1e2,
        interpenetration=True,
        coll_loss_weights=None,
        df_cone_height=0.5,
        penalize_outside=True,
        max_collisions=8,
        point2plane=False,
        part_segm_fn='',
        focal_length_x=5000.,
        focal_length_y=5000.,
        side_view_thsh=25.,
        rho=100,
        vposer_latent_dim=32,
        vposer_ckpt='',
        use_joints_conf=False,
        interactive=True,
        visualize=False,
        save_meshes=True,
        degrees=None,
        batch_size=1,
        dtype=torch.float32,
        ign_part_pairs=None,
        left_shoulder_idx=2,
        right_shoulder_idx=5,
        ####################
        ### PROX
        render_results=True,
        camera_mode='moving',
        ## Depth
        s2m=False,
        s2m_weights=None,
        m2s=False,
        m2s_weights=None,
        rho_s2m=1,
        rho_m2s=1,
        init_mode=None,
        trans_opt_stages=None,
        viz_mode='mv',
        #penetration
        sdf_penetration=False,
        sdf_penetration_weights=0.0,
        sdf_dir=None,
        cam2world_dir=None,
        #contact
        contact=False,
        rho_contact=1.0,
        contact_loss_weights=None,
        contact_angle=15,
        contact_body_parts=None,
        body_segments_dir=None,
        load_scene=False,
        scene_dir=None,
        **kwargs):
    assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'
    body_model.reset_params()
    body_model.transl.requires_grad = True

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    if visualize:
        pil_img.fromarray((img * 255).astype(np.uint8)).show()

    if degrees is None:
        degrees = [0, 90, 180, 270]

    if data_weights is None:
        data_weights = [
            1,
        ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = ('Number of Body pose prior weights {}'.format(
        len(body_pose_prior_weights)) +
           ' does not match the number of data term weights {}'.format(
               len(data_weights)))
    assert (len(data_weights) == len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (
            len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (
                len(hand_joints_weights) == len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format(
        len(shape_weights), len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (
            len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format(
            len(body_pose_prior_weights), len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [
        None,
    ] * 2
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype,
                                     device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim],
                                     dtype=dtype)
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2].reshape(1, -1)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    scan_tensor = None
    if scan is not None:
        scan_tensor = torch.tensor(scan.get('points'),
                                   device=device,
                                   dtype=dtype).unsqueeze(0)

    # load pre-computed signed distance field
    sdf = None
    sdf_normals = None
    grid_min = None
    grid_max = None
    voxel_size = None
    if sdf_penetration:
        with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f:
            sdf_data = json.load(f)
            grid_min = torch.tensor(np.array(sdf_data['min']),
                                    dtype=dtype,
                                    device=device)
            grid_max = torch.tensor(np.array(sdf_data['max']),
                                    dtype=dtype,
                                    device=device)
            grid_dim = sdf_data['dim']
        voxel_size = (grid_max - grid_min) / grid_dim
        sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape(
            grid_dim, grid_dim, grid_dim)
        sdf = torch.tensor(sdf, dtype=dtype, device=device)
        if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')):
            sdf_normals = np.load(
                osp.join(sdf_dir, scene_name + '_normals.npy')).reshape(
                    grid_dim, grid_dim, grid_dim, 3)
            sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device)
        else:
            print("Normals not found...")

    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        cam2world = np.array(json.load(f))
        R = torch.tensor(cam2world[:3, :3].reshape(3, 3),
                         dtype=dtype,
                         device=device)
        t = torch.tensor(cam2world[:3, 3].reshape(1, 3),
                         dtype=dtype,
                         device=device)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # load vertix ids of contact parts
    contact_verts_ids = ftov = None
    if contact:
        contact_verts_ids = []
        for part in contact_body_parts:
            with open(os.path.join(body_segments_dir, part + '.json'),
                      'r') as f:
                data = json.load(f)
                contact_verts_ids.append(list(set(data["verts_ind"])))
        contact_verts_ids = np.concatenate(contact_verts_ids)

        vertices = body_model(return_verts=True,
                              body_pose=torch.zeros((batch_size, 63),
                                                    dtype=dtype,
                                                    device=device)).vertices
        vertices_np = vertices.detach().cpu().numpy().squeeze()
        body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape(
            -1, 3)
        m = Mesh(v=vertices_np, f=body_faces_np)
        ftov = m.faces_by_vertex(as_sparse_matrix=True)

        ftov = sparse.coo_matrix(ftov)
        indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device)
        values = torch.FloatTensor(ftov.data).to(device)
        shape = ftov.shape
        ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

    # Read the scene scan if any
    scene_v = scene_vn = scene_f = None
    if scene_name is not None:
        if load_scene:
            scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply'))

            scene.vn = scene.estimate_vertex_normals()

            scene_v = torch.tensor(scene.v[np.newaxis, :],
                                   dtype=dtype,
                                   device=device).contiguous()
            scene_vn = torch.tensor(scene.vn[np.newaxis, :],
                                    dtype=dtype,
                                    device=device)
            scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :],
                                   dtype=torch.long,
                                   device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        'data_weight': data_weights,
        'body_pose_weight': body_pose_prior_weights,
        'shape_weight': shape_weights
    }
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights
    if s2m:
        opt_weights_dict['s2m_weight'] = s2m_weights
    if m2s:
        opt_weights_dict['m2s_weight'] = m2s_weights
    if sdf_penetration:
        opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights
    if contact:
        opt_weights_dict['contact_loss_weight'] = contact_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # load indices of the head of smpl-x model
    with open(osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp:
        head_indx = np.array(json.load(fp))
    N = body_model.get_num_verts()
    body_indx = np.setdiff1d(np.arange(N), head_indx)
    head_mask = np.in1d(np.arange(N), head_indx)
    body_mask = np.in1d(np.arange(N), body_indx)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    edge_indices = kwargs.get('body_tri_idxs')

    # which initialization mode to choose: similar traingles, mean of the scan or the average of both
    if init_mode == 'scan':
        init_t = init_trans
    elif init_mode == 'both':
        init_t = (init_trans.to(device) + fitting.guess_init(
            body_model,
            gt_joints,
            edge_indices,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            model_type=kwargs.get('model_type', 'smpl'),
            focal_length=focal_length_x,
            dtype=dtype)) / 2.0

    else:
        init_t = fitting.guess_init(body_model,
                                    gt_joints,
                                    edge_indices,
                                    use_vposer=use_vposer,
                                    vposer=vposer,
                                    pose_embedding=pose_embedding,
                                    model_type=kwargs.get(
                                        'model_type', 'smpl'),
                                    focal_length=focal_length_x,
                                    dtype=dtype)

    camera_loss = fitting.create_loss('camera_init',
                                      trans_estimation=init_t,
                                      init_joints_idxs=init_joints_idxs,
                                      depth_loss_weight=depth_loss_weight,
                                      camera_mode=camera_mode,
                                      dtype=dtype).to(device=device)
    camera_loss.trans_estimation[:] = init_t

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face,
                               use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               s2m=s2m,
                               m2s=m2s,
                               rho_s2m=rho_s2m,
                               rho_m2s=rho_m2s,
                               head_mask=head_mask,
                               body_mask=body_mask,
                               sdf_penetration=sdf_penetration,
                               voxel_size=voxel_size,
                               grid_min=grid_min,
                               grid_max=grid_max,
                               sdf=sdf,
                               sdf_normals=sdf_normals,
                               R=R,
                               t=t,
                               contact=contact,
                               contact_verts_ids=contact_verts_ids,
                               rho_contact=rho_contact,
                               contact_angle=contact_angle,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size,
                                visualize=visualize,
                                viz_mode=viz_mode,
                                **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        H, W, _ = img.shape

        # Reset the parameters to estimate the initial translation of the
        # body model
        if camera_mode == 'moving':
            body_model.reset_params(body_pose=body_mean_pose)
            # Update the value of the translation of the camera as well as
            # the image center.
            with torch.no_grad():
                camera.translation[:] = init_t.view_as(camera.translation)
                camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5

            # Re-enable gradient calculation for the camera translation
            camera.translation.requires_grad = True

            camera_opt_params = [camera.translation, body_model.global_orient]

        elif camera_mode == 'fixed':
            body_model.reset_params(body_pose=body_mean_pose, transl=init_t)
            camera_opt_params = [body_model.transl, body_model.global_orient]

        # If the distance between the 2D shoulders is smaller than a
        # predefined threshold then try 2 fits, the initial one and a 180
        # degree rotation
        shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx],
                                   gt_joints[:, right_shoulder_idx])
        try_both_orient = shoulder_dist.item() < side_view_thsh

        camera_optimizer, camera_create_graph = optim_factory.create_optimizer(
            camera_opt_params, **kwargs)

        # The closure passed to the optimizer
        fit_camera = monitor.create_fitting_closure(
            camera_optimizer,
            body_model,
            camera,
            gt_joints,
            camera_loss,
            create_graph=camera_create_graph,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            scan_tensor=scan_tensor,
            return_full_pose=False,
            return_verts=False)

        # Step 1: Optimize over the torso joints the camera translation
        # Initialize the computational graph by feeding the initial translation
        # of the camera and the initial pose of the body model.
        camera_init_start = time.time()
        cam_init_loss_val = monitor.run_fitting(camera_optimizer,
                                                fit_camera,
                                                camera_opt_params,
                                                body_model,
                                                use_vposer=use_vposer,
                                                pose_embedding=pose_embedding,
                                                vposer=vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            tqdm.write('Camera initialization done after {:.4f}'.format(
                time.time() - camera_init_start))
            tqdm.write('Camera initialization final loss {:.4f}'.format(
                cam_init_loss_val))

        # If the 2D detections/positions of the shoulder joints are too
        # close the rotate the body by 180 degrees and also fit to that
        # orientation
        if try_both_orient:
            body_orient = body_model.global_orient.detach().cpu().numpy()
            flipped_orient = cv2.Rodrigues(body_orient)[0].dot(
                cv2.Rodrigues(np.array([0., np.pi, 0]))[0])
            flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel()

            flipped_orient = torch.tensor(flipped_orient,
                                          dtype=dtype,
                                          device=device).unsqueeze(dim=0)
            orientations = [body_orient, flipped_orient]
        else:
            orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []
        body_transl = body_model.transl.clone().detach()
        # Step 2: Optimize the full model
        final_loss_val = 0
        for or_idx, orient in enumerate(tqdm(orientations,
                                             desc='Orientation')):
            opt_start = time.time()

            new_params = defaultdict(transl=body_transl,
                                     global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)

            for opt_idx, curr_weights in enumerate(
                    tqdm(opt_weights, desc='Stage')):
                if opt_idx not in trans_opt_stages:
                    body_model.transl.requires_grad = False
                else:
                    body_model.transl.requires_grad = True
                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                body_optimizer, body_create_graph = optim_factory.create_optimizer(
                    final_params, **kwargs)
                body_optimizer.zero_grad()

                curr_weights['bending_prior_weight'] = (
                    3.17 * curr_weights['body_pose_weight'])
                if use_hands:
                    joint_weights[:, 25:76] = curr_weights['hand_weight']
                if use_face:
                    joint_weights[:, 76:] = curr_weights['face_weight']
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer,
                    body_model,
                    camera=camera,
                    gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss,
                    create_graph=body_create_graph,
                    use_vposer=use_vposer,
                    vposer=vposer,
                    pose_embedding=pose_embedding,
                    scan_tensor=scan_tensor,
                    scene_v=scene_v,
                    scene_vn=scene_vn,
                    scene_f=scene_f,
                    ftov=ftov,
                    return_verts=True,
                    return_full_pose=True)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure,
                    final_params,
                    body_model,
                    pose_embedding=pose_embedding,
                    vposer=vposer,
                    use_vposer=use_vposer)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write(
                            'Stage {:03d} done after {:.4f} seconds'.format(
                                opt_idx, elapsed))

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write(
                    'Body fitting Orientation {} done after {:.4f} seconds'.
                    format(or_idx, elapsed))
                tqdm.write(
                    'Body final loss val = {:.5f}'.format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {
                'camera_' + str(key): val.detach().cpu().numpy()
                for key, val in camera.named_parameters()
            }
            result.update({
                key: val.detach().cpu().numpy()
                for key, val in body_model.named_parameters()
            })
            if use_vposer:
                result['pose_embedding'] = pose_embedding.detach().cpu().numpy(
                )
                body_pose = vposer.decode(pose_embedding,
                                          output_type='aa').view(
                                              1, -1) if use_vposer else None
                result['body_pose'] = body_pose.detach().cpu().numpy()

            results.append({'loss': final_loss_val, 'result': result})

        with open(result_fn, 'wb') as result_file:
            if len(results) > 1:
                min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1)
            else:
                min_idx = 0
            pickle.dump(results[min_idx]['result'], result_file, protocol=2)

    if save_meshes or visualize:
        body_pose = vposer.decode(pose_embedding, output_type='aa').view(
            1, -1) if use_vposer else None

        model_type = kwargs.get('model_type', 'smpl')
        append_wrists = model_type == 'smpl' and use_vposer
        if append_wrists:
            wrist_pose = torch.zeros([body_pose.shape[0], 6],
                                     dtype=body_pose.dtype,
                                     device=body_pose.device)
            body_pose = torch.cat([body_pose, wrist_pose], dim=1)

        model_output = body_model(return_verts=True, body_pose=body_pose)
        vertices = model_output.vertices.detach().cpu().numpy().squeeze()

        import trimesh

        out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False)
        out_mesh.export(mesh_fn)

    if render_results:
        import pyrender

        # common
        H, W = 1080, 1920
        camera_center = np.array([951.30, 536.77])
        camera_pose = np.eye(4)
        camera_pose = np.array([1.0, -1.0, -1.0, 1.0]).reshape(-1,
                                                               1) * camera_pose
        camera = pyrender.camera.IntrinsicsCamera(fx=1060.53,
                                                  fy=1060.38,
                                                  cx=camera_center[0],
                                                  cy=camera_center[1])
        light = pyrender.DirectionalLight(color=np.ones(3), intensity=2.0)

        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.0,
            alphaMode='OPAQUE',
            baseColorFactor=(1.0, 1.0, 0.9, 1.0))
        body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)

        ## rendering body
        img = img.detach().cpu().numpy()
        H, W, _ = img.shape

        scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
                               ambient_light=(0.3, 0.3, 0.3))
        scene.add(camera, pose=camera_pose)
        scene.add(light, pose=camera_pose)
        # for node in light_nodes:
        #     scene.add_node(node)

        scene.add(body_mesh, 'mesh')

        r = pyrender.OffscreenRenderer(viewport_width=W,
                                       viewport_height=H,
                                       point_size=1.0)
        color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
        color = color.astype(np.float32) / 255.0

        valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
        input_img = img
        output_img = (color[:, :, :-1] * valid_mask +
                      (1 - valid_mask) * input_img)

        img = pil_img.fromarray((output_img * 255).astype(np.uint8))
        img.save(out_img_fn)

        ##redering body+scene
        body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)
        static_scene = trimesh.load(osp.join(scene_dir, scene_name + '.ply'))
        trans = np.linalg.inv(cam2world)
        static_scene.apply_transform(trans)

        static_scene_mesh = pyrender.Mesh.from_trimesh(static_scene)

        scene = pyrender.Scene()
        scene.add(camera, pose=camera_pose)
        scene.add(light, pose=camera_pose)

        scene.add(static_scene_mesh, 'mesh')
        scene.add(body_mesh, 'mesh')

        r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H)
        color, _ = r.render(scene)
        color = color.astype(np.float32) / 255.0
        img = pil_img.fromarray((color * 255).astype(np.uint8))
        img.save(body_scene_rendering_fn)
示例#6
0
def fit_single_frame(img,
                     keypoints,
                     body_model,
                     camera,
                     joint_weights,
                     body_pose_prior,
                     jaw_prior,
                     left_hand_prior,
                     right_hand_prior,
                     shape_prior,
                     expr_prior,
                     angle_prior,
                     result_fn='out.pkl',
                     mesh_fn='out.obj',
                     out_img_fn='overlay.png',
                     loss_type='smplify',
                     use_cuda=True,
                     init_joints_idxs=(9, 12, 2, 5),
                     use_face=True,
                     use_hands=True,
                     data_weights=None,
                     body_pose_prior_weights=None,
                     hand_pose_prior_weights=None,
                     jaw_pose_prior_weights=None,
                     shape_weights=None,
                     expr_weights=None,
                     hand_joints_weights=None,
                     face_joints_weights=None,
                     depth_loss_weight=1e2,
                     interpenetration=True,
                     coll_loss_weights=None,
                     df_cone_height=0.5,
                     penalize_outside=True,
                     max_collisions=8,
                     point2plane=False,
                     part_segm_fn='',
                     focal_length=5000.,
                     side_view_thsh=25.,
                     rho=100,
                     vposer_latent_dim=32,
                     vposer_ckpt='',
                     use_joints_conf=False,
                     interactive=True,
                     visualize=False,
                     save_meshes=True,
                     degrees=None,
                     batch_size=1,
                     dtype=torch.float32,
                     ign_part_pairs=None,
                     left_shoulder_idx=2,
                     right_shoulder_idx=5,
                     **kwargs):
    assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    if degrees is None:
        degrees = [0, 90, 180, 270]

    if data_weights is None:
        data_weights = [
            1,
        ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = ('Number of Body pose prior weights {}'.format(
        len(body_pose_prior_weights)) +
           ' does not match the number of data term weights {}'.format(
               len(data_weights)))
    assert (len(data_weights) == len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (
            len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (
                len(hand_joints_weights) == len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format(
        len(shape_weights), len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (
            len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format(
            len(body_pose_prior_weights), len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [
        None,
    ] * 2
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype,
                                     device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim],
                                     dtype=dtype)
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2].reshape(1, -1)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        'data_weight': data_weights,
        'body_pose_weight': body_pose_prior_weights,
        'shape_weight': shape_weights
    }
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    edge_indices = kwargs.get('body_tri_idxs')
    init_t = fitting.guess_init(body_model,
                                gt_joints,
                                edge_indices,
                                use_vposer=use_vposer,
                                vposer=vposer,
                                pose_embedding=pose_embedding,
                                model_type=kwargs.get('model_type', 'smpl'),
                                focal_length=focal_length,
                                dtype=dtype)

    camera_loss = fitting.create_loss('camera_init',
                                      trans_estimation=init_t,
                                      init_joints_idxs=init_joints_idxs,
                                      depth_loss_weight=depth_loss_weight,
                                      dtype=dtype).to(device=device)
    camera_loss.trans_estimation[:] = init_t

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face,
                               use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size,
                                visualize=visualize,
                                **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        H, W, _ = img.shape

        data_weight = 1000 / H
        # The closure passed to the optimizer
        camera_loss.reset_loss_weights({'data_weight': data_weight})

        # Reset the parameters to estimate the initial translation of the
        # body model
        body_model.reset_params(body_pose=body_mean_pose)

        # If the distance between the 2D shoulders is smaller than a
        # predefined threshold then try 2 fits, the initial one and a 180
        # degree rotation
        shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx],
                                   gt_joints[:, right_shoulder_idx])
        try_both_orient = shoulder_dist.item() < side_view_thsh

        # Update the value of the translation of the camera as well as
        # the image center.
        with torch.no_grad():
            camera.translation[:] = init_t.view_as(camera.translation)
            camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5

        # Re-enable gradient calculation for the camera translation
        camera.translation.requires_grad = True

        camera_opt_params = [camera.translation, body_model.global_orient]

        camera_optimizer, camera_create_graph = optim_factory.create_optimizer(
            camera_opt_params, **kwargs)

        # The closure passed to the optimizer
        fit_camera = monitor.create_fitting_closure(
            camera_optimizer,
            body_model,
            camera,
            gt_joints,
            camera_loss,
            create_graph=camera_create_graph,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            return_full_pose=False,
            return_verts=False)

        # Step 1: Optimize over the torso joints the camera translation
        # Initialize the computational graph by feeding the initial translation
        # of the camera and the initial pose of the body model.
        camera_init_start = time.time()
        cam_init_loss_val = monitor.run_fitting(camera_optimizer,
                                                fit_camera,
                                                camera_opt_params,
                                                body_model,
                                                use_vposer=use_vposer,
                                                pose_embedding=pose_embedding,
                                                vposer=vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            tqdm.write('Camera initialization done after {:.4f}'.format(
                time.time() - camera_init_start))
            tqdm.write('Camera initialization final loss {:.4f}'.format(
                cam_init_loss_val))

        # If the 2D detections/positions of the shoulder joints are too
        # close the rotate the body by 180 degrees and also fit to that
        # orientation
        if try_both_orient:
            body_orient = body_model.global_orient.detach().cpu().numpy()
            flipped_orient = cv2.Rodrigues(body_orient)[0].dot(
                cv2.Rodrigues(np.array([0., np.pi, 0]))[0])
            flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel()

            flipped_orient = torch.tensor(flipped_orient,
                                          dtype=dtype,
                                          device=device).unsqueeze(dim=0)
            orientations = [body_orient, flipped_orient]
        else:
            orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []

        # Step 2: Optimize the full model
        final_loss_val = 0
        for or_idx, orient in enumerate(tqdm(orientations,
                                             desc='Orientation')):
            opt_start = time.time()

            new_params = defaultdict(global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)

            for opt_idx, curr_weights in enumerate(
                    tqdm(opt_weights, desc='Stage')):

                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                body_optimizer, body_create_graph = optim_factory.create_optimizer(
                    final_params, **kwargs)
                body_optimizer.zero_grad()

                curr_weights['data_weight'] = data_weight
                curr_weights['bending_prior_weight'] = (
                    3.17 * curr_weights['body_pose_weight'])
                if use_hands:
                    joint_weights[:, 25:76] = curr_weights['hand_weight']
                if use_face:
                    joint_weights[:, 76:] = curr_weights['face_weight']
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer,
                    body_model,
                    camera=camera,
                    gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss,
                    create_graph=body_create_graph,
                    use_vposer=use_vposer,
                    vposer=vposer,
                    pose_embedding=pose_embedding,
                    return_verts=True,
                    return_full_pose=True)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure,
                    final_params,
                    body_model,
                    pose_embedding=pose_embedding,
                    vposer=vposer,
                    use_vposer=use_vposer)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write(
                            'Stage {:03d} done after {:.4f} seconds'.format(
                                opt_idx, elapsed))

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write(
                    'Body fitting Orientation {} done after {:.4f} seconds'.
                    format(or_idx, elapsed))
                tqdm.write(
                    'Body final loss val = {:.5f}'.format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {
                'camera_' + str(key): val.detach().cpu().numpy()
                for key, val in camera.named_parameters()
            }
            result.update({
                key: val.detach().cpu().numpy()
                for key, val in body_model.named_parameters()
            })
            if use_vposer:
                result['body_pose'] = pose_embedding.detach().cpu().numpy()

            results.append({'loss': final_loss_val, 'result': result})

        with open(result_fn, 'wb') as result_file:
            if len(results) > 1:
                min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1)
            else:
                min_idx = 0
            pickle.dump(results[min_idx]['result'], result_file, protocol=2)

    if save_meshes or visualize:
        body_pose = vposer.decode(pose_embedding, output_type='aa').view(
            1, -1) if use_vposer else None

        model_type = kwargs.get('model_type', 'smpl')
        append_wrists = model_type == 'smpl' and use_vposer
        if append_wrists:
            wrist_pose = torch.zeros([body_pose.shape[0], 6],
                                     dtype=body_pose.dtype,
                                     device=body_pose.device)
            body_pose = torch.cat([body_pose, wrist_pose], dim=1)

        model_output = body_model(return_verts=True, body_pose=body_pose)
        vertices = model_output.vertices.detach().cpu().numpy().squeeze()

        import trimesh

        out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False)
        rot = trimesh.transformations.rotation_matrix(np.radians(180),
                                                      [1, 0, 0])
        out_mesh.apply_transform(rot)
        out_mesh.export(mesh_fn)

    if visualize:
        import pyrender

        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.0,
            alphaMode='OPAQUE',
            baseColorFactor=(1.0, 1.0, 0.9, 1.0))
        mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)

        scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
                               ambient_light=(0.3, 0.3, 0.3))
        scene.add(mesh, 'mesh')

        camera_center = camera.center.detach().cpu().numpy().squeeze()
        camera_transl = camera.translation.detach().cpu().numpy().squeeze()
        # Equivalent to 180 degrees around the y-axis. Transforms the fit to
        # OpenGL compatible coordinate system.
        camera_transl[0] *= -1.0

        camera_pose = np.eye(4)
        camera_pose[:3, 3] = camera_transl

        camera = pyrender.camera.IntrinsicsCamera(fx=focal_length,
                                                  fy=focal_length,
                                                  cx=camera_center[0],
                                                  cy=camera_center[1])
        scene.add(camera, pose=camera_pose)

        # Get the lights from the viewer
        light_nodes = monitor.mv.viewer._create_raymond_lights()
        for node in light_nodes:
            scene.add_node(node)

        r = pyrender.OffscreenRenderer(viewport_width=W,
                                       viewport_height=H,
                                       point_size=1.0)
        color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
        color = color.astype(np.float32) / 255.0

        valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
        input_img = img.detach().cpu().numpy()
        output_img = (color[:, :, :-1] * valid_mask +
                      (1 - valid_mask) * input_img)

        img = pil_img.fromarray((output_img * 255).astype(np.uint8))
        img.save(out_img_fn)
示例#7
0
def fit_SMPLXD(scans,
               smplx_pkl,
               gender='male',
               save_path=None,
               scale_file=None,
               interpenetration=True):

    search_tree = None
    pen_distance = None
    tri_filtering_module = None
    max_collisions = 128
    df_cone_height = 0.0001
    point2plane = False
    penalize_outside = True
    part_segm_fn = '/home/chen/IPNet_SMPLX/assets/smplx_parts_segm.pkl'
    ign_part_pairs = ["9,16", "9,17", "6,16", "6,17", "1,2", "12,22"]
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces
        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
            sigma=df_cone_height,
            point2plane=point2plane,
            vectorized=True,
            penalize_outside=penalize_outside)
        if part_segm_fn:
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pkl.load(faces_parents_file,
                                          encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            tri_filtering_module = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).cuda()

    # Get SMPLX faces
    # spx = SmplPaths(gender=gender)
    spx = SMPLX(model_path="/home/chen/SMPLX/models/smplx",
                batch_size=1,
                gender=gender)
    smplx_faces = spx.faces
    th_faces = torch.tensor(smplx_faces.astype('float32'),
                            dtype=torch.long).cuda()

    # Batch size
    batch_sz = len(scans)

    # Init SMPLX
    global_pose, body_pose, left_hand_pose, right_hand_pose = [], [], [], []
    expression, jaw_pose, leye_pose, reye_pose = [], [], [], []
    betas, trans = [], []
    for spkl in smplx_pkl:
        smplx_dict = pkl.load(open(spkl, 'rb'))
        g, bp, lh, rh, e, j, le, re, b, t = (
            smplx_dict['global_pose'], smplx_dict['body_pose'],
            smplx_dict['left_hand_pose'], smplx_dict['right_hand_pose'],
            smplx_dict['expression'], smplx_dict['jaw_pose'],
            smplx_dict['leye_pose'], smplx_dict['reye_pose'],
            smplx_dict['betas'], smplx_dict['trans'])
        global_pose.append(g)
        body_pose.append(bp)
        left_hand_pose.append(lh)
        right_hand_pose.append(rh)
        expression.append(e)
        jaw_pose.append(j)
        leye_pose.append(le)
        reye_pose.append(re)
        if len(b) == 10:
            # temp = np.zeros((300,))
            temp = np.zeros((10, ))
            temp[:10] = b
            b = temp.astype('float32')
        betas.append(b)
        trans.append(t)
    global_pose, body_pose, left_hand_pose, right_hand_pose = np.array(global_pose), np.array(body_pose), \
                                                              np.array(left_hand_pose), np.array(right_hand_pose)
    expression, jaw_pose, leye_pose, reye_pose = np.array(expression), np.array(jaw_pose), \
                                                 np.array(leye_pose), np.array(reye_pose)
    betas, trans = np.array(betas), np.array(trans)

    global_pose, body_pose, left_hand_pose, right_hand_pose = torch.tensor(global_pose), torch.tensor(body_pose), \
                                                              torch.tensor(left_hand_pose), torch.tensor(right_hand_pose)
    expression, jaw_pose, leye_pose, reye_pose = torch.tensor(expression), torch.tensor(jaw_pose), \
                                                 torch.tensor(leye_pose), torch.tensor(reye_pose)
    betas, trans = torch.tensor(betas), torch.tensor(trans)
    # smplx = th_batch_SMPLX(batch_sz, betas, pose, trans, faces=th_faces, gender=gender).cuda()
    smplx = th_batch_SMPLX(batch_sz,
                           betas,
                           global_pose,
                           body_pose,
                           left_hand_pose,
                           right_hand_pose,
                           trans,
                           expression,
                           jaw_pose,
                           leye_pose,
                           reye_pose,
                           faces=th_faces,
                           gender=gender).to(DEVICE)
    # verts, _, _, _ = smplx()
    verts = smplx()
    init_smplx_meshes = [
        tm.from_tensors(vertices=v.clone().detach(), faces=smplx.faces)
        for v in verts
    ]

    # Load scans
    th_scan_meshes = []
    for scan in scans:
        print('scan path ...', scan)
        temp = Mesh(filename=scan)
        th_scan = tm.from_tensors(
            torch.tensor(temp.v.astype('float32'),
                         requires_grad=False,
                         device=DEVICE),
            torch.tensor(temp.f.astype('int32'),
                         requires_grad=False,
                         device=DEVICE).long())
        th_scan_meshes.append(th_scan)

    if scale_file is not None:
        for n, sc in enumerate(scale_file):
            dat = np.load(sc, allow_pickle=True)
            th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE)
            th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE)

    # Optimize
    optimize_offsets(th_scan_meshes, smplx, init_smplx_meshes, 5, 10,
                     search_tree, pen_distance, tri_filtering_module)
    # optimize_offsets_only(th_scan_meshes, smplx, init_smplx_meshes, 5, 8, search_tree, pen_distance, tri_filtering_module)
    print('Done')

    # verts, _, _, _ = smplx()
    verts = smplx.get_vertices_clean_hand()
    th_smplx_meshes = [
        tm.from_tensors(vertices=v, faces=smplx.faces) for v in verts
    ]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = ['full.ply']  # [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(
            th_smplx_meshes,
            [join(save_path, n.replace('.ply', '_smplxd.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])
        # Save params
        for g, bp, lh, rh, e, j, le, re, b, t, d, n in zip(
                smplx.global_pose.cpu().detach().numpy(),
                smplx.body_pose.cpu().detach().numpy(),
                smplx.left_hand_pose.cpu().detach().numpy(),
                smplx.right_hand_pose.cpu().detach().numpy(),
                smplx.expression.cpu().detach().numpy(),
                smplx.jaw_pose.cpu().detach().numpy(),
                smplx.leye_pose.cpu().detach().numpy(),
                smplx.reye_pose.cpu().detach().numpy(),
                smplx.betas.cpu().detach().numpy(),
                smplx.trans.cpu().detach().numpy(),
                smplx.offsets_clean_hand.cpu().detach().numpy(), names):
            smplx_dict = {
                'global_pose': g,
                'body_pose': bp,
                'left_hand_pose': lh,
                'right_hand_pose': rh,
                'expression': e,
                'jaw_pose': j,
                'leye_pose': le,
                'reye_pose': re,
                'betas': b,
                'trans': t,
                'offsets': d
            }
            pkl.dump(
                smplx_dict,
                open(join(save_path, n.replace('.ply', '_smplxd.pkl')), 'wb'))

    return (smplx.global_pose.cpu().detach().numpy(),
            smplx.body_pose.cpu().detach().numpy(),
            smplx.left_hand_pose.cpu().detach().numpy(),
            smplx.right_hand_pose.cpu().detach().numpy(),
            smplx.expression.cpu().detach().numpy(),
            smplx.jaw_pose.cpu().detach().numpy(),
            smplx.leye_pose.cpu().detach().numpy(),
            smplx.reye_pose.cpu().detach().numpy(),
            smplx.betas.cpu().detach().numpy(),
            smplx.trans.cpu().detach().numpy(),
            smplx.offsets_clean_hand.cpu().detach().numpy())
示例#8
0
def fit_SMPLX(scans,
              scan_labels,
              gender='male',
              save_path=None,
              scale_file=None,
              display=None,
              interpenetration=True):
    """
    :param save_path:
    :param scans: list of scan paths
    :param pose_files:
    :return:
    """
    search_tree = None
    pen_distance = None
    tri_filtering_module = None
    max_collisions = 128
    df_cone_height = 0.0001
    point2plane = False
    penalize_outside = True
    part_segm_fn = '/home/chen/IPNet_SMPLX/assets/smplx_parts_segm.pkl'
    ign_part_pairs = ["9,16", "9,17", "6,16", "6,17", "1,2", "12,22"]
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces
        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
            sigma=df_cone_height,
            point2plane=point2plane,
            vectorized=True,
            penalize_outside=penalize_outside)
        if part_segm_fn:
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pkl.load(faces_parents_file,
                                          encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            tri_filtering_module = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).cuda()

    # Get SMPLX faces
    # spx = SmplPaths(gender=gender)
    spx = SMPLX(model_path="/home/chen/SMPLX/models/smplx",
                batch_size=1,
                gender=gender)
    smplx_faces = spx.faces
    th_faces = torch.tensor(smplx_faces.astype('float32'),
                            dtype=torch.long).to(DEVICE)

    # Load SMPLX parts
    part_labels = pkl.load(
        open('/home/chen/IPNet_SMPLX/assets/smplx_parts_dense.pkl', 'rb'))
    labels = np.zeros((10475, ), dtype='int32')
    for n, k in enumerate(part_labels):
        labels[part_labels[k]] = n
    labels = torch.tensor(labels).unsqueeze(0).to(DEVICE)

    # Load scan parts
    scan_part_labels = []
    for sc_l in scan_labels:
        temp = torch.tensor(np.load(sc_l).astype('int32')).to(DEVICE)
        scan_part_labels.append(temp)

    # Batch size
    batch_sz = len(scans)

    # Set optimization hyper parameters
    iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30

    # prior = get_prior(gender=gender, precomputed=True)
    if gender == 'male':
        temp_model = pkl.load(open(
            '/home/chen/SMPLX/models/smplx/SMPLX_MALE.pkl', 'rb'),
                              encoding='latin1')
    elif gender == 'female':
        temp_model = pkl.load(open(
            '/home/chen/SMPLX/models/smplx/SMPLX_FEMALE.pkl', 'rb'),
                              encoding='latin1')
    else:
        print('Wrong gender input!')
        exit()
    left_hand_mean = torch.tensor(temp_model['hands_meanl']).unsqueeze(0)
    right_hand_mean = torch.tensor(temp_model['hands_meanr']).unsqueeze(0)
    # pose_init = torch.zeros((batch_sz, 69))
    # TODO consider to add the prior for smplx
    # pose_init[:, 3:] = prior.mean
    # betas, pose, trans = torch.zeros((batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3))
    betas, global_pose, body_pose, trans = torch.zeros(
        (batch_sz, 10)), torch.zeros((batch_sz, 3)), torch.zeros(
            (batch_sz, 63)), torch.zeros((batch_sz, 3))
    left_hand_pose, right_hand_pose, expression, jaw_pose = left_hand_mean, right_hand_mean, torch.zeros(
        (batch_sz, 10)), torch.zeros((batch_sz, 3))
    leye_pose, reye_pose = torch.zeros((batch_sz, 3)), torch.zeros(
        (batch_sz, 3))
    # Init SMPLX, pose with mean smplx pose, as in ch.registration
    smplx = th_batch_SMPLX(batch_sz,
                           betas,
                           global_pose,
                           body_pose,
                           left_hand_pose,
                           right_hand_pose,
                           trans,
                           expression,
                           jaw_pose,
                           leye_pose,
                           reye_pose,
                           faces=th_faces,
                           gender=gender).to(DEVICE)
    smplx_part_labels = torch.cat([labels] * batch_sz, axis=0)

    th_scan_meshes, centers = [], []
    for scan in scans:
        print('scan path ...', scan)
        temp = Mesh(filename=scan)
        th_scan = tm.from_tensors(
            torch.tensor(temp.v.astype('float32'),
                         requires_grad=False,
                         device=DEVICE),
            torch.tensor(temp.f.astype('int32'),
                         requires_grad=False,
                         device=DEVICE).long())
        th_scan_meshes.append(th_scan)

    if scale_file is not None:
        for n, sc in enumerate(scale_file):
            dat = np.load(sc, allow_pickle=True)
            th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE)
            th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE)

    # Optimize pose first
    optimize_pose_only(th_scan_meshes,
                       smplx,
                       pose_iterations,
                       pose_steps_per_iter,
                       scan_part_labels,
                       smplx_part_labels,
                       search_tree,
                       pen_distance,
                       tri_filtering_module,
                       display=None if display is None else 0)

    # Optimize pose and shape
    optimize_pose_shape(th_scan_meshes,
                        smplx,
                        iterations,
                        steps_per_iter,
                        scan_part_labels,
                        smplx_part_labels,
                        search_tree,
                        pen_distance,
                        tri_filtering_module,
                        display=None if display is None else 0)

    # verts, _, _, _ = smplx()
    verts = smplx()
    th_smplx_meshes = [
        tm.from_tensors(vertices=v, faces=smplx.faces) for v in verts
    ]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(
            th_smplx_meshes,
            [join(save_path, n.replace('.ply', '_smplx.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])

        # Save params
        for g, bp, lh, rh, e, j, le, re, b, t, n in zip(
                smplx.global_pose.cpu().detach().numpy(),
                smplx.body_pose.cpu().detach().numpy(),
                smplx.left_hand_pose.cpu().detach().numpy(),
                smplx.right_hand_pose.cpu().detach().numpy(),
                smplx.expression.cpu().detach().numpy(),
                smplx.jaw_pose.cpu().detach().numpy(),
                smplx.leye_pose.cpu().detach().numpy(),
                smplx.reye_pose.cpu().detach().numpy(),
                smplx.betas.cpu().detach().numpy(),
                smplx.trans.cpu().detach().numpy(), names):
            smplx_dict = {
                'global_pose': g,
                'body_pose': bp,
                'left_hand_pose': lh,
                'right_hand_pose': rh,
                'expression': e,
                'jaw_pose': j,
                'leye_pose': le,
                'reye_pose': re,
                'betas': b,
                'trans': t
            }
            pkl.dump(
                smplx_dict,
                open(join(save_path, n.replace('.ply', '_smplx.pkl')), 'wb'))

        return (smplx.global_pose.cpu().detach().numpy(),
                smplx.body_pose.cpu().detach().numpy(),
                smplx.left_hand_pose.cpu().detach().numpy(),
                smplx.right_hand_pose.cpu().detach().numpy(),
                smplx.expression.cpu().detach().numpy(),
                smplx.jaw_pose.cpu().detach().numpy(),
                smplx.leye_pose.cpu().detach().numpy(),
                smplx.reye_pose.cpu().detach().numpy(),
                smplx.betas.cpu().detach().numpy(),
                smplx.trans.cpu().detach().numpy())
示例#9
0
def fit_frames(img,
               keypoints,
               body_model,
               camera=None,
               joint_weights=None,
               body_pose_prior=None,
               jaw_prior=None,
               left_hand_prior=None,
               right_hand_prior=None,
               shape_prior=None,
               expr_prior=None,
               angle_prior=None,
               loss_type="smplify",
               use_cuda=True,
               init_joints_idxs=(9, 12, 2, 5),
               use_face=False,
               use_hands=True,
               data_weights=None,
               body_pose_prior_weights=None,
               hand_pose_prior_weights=None,
               jaw_pose_prior_weights=None,
               shape_weights=None,
               expr_weights=None,
               hand_joints_weights=None,
               face_joints_weights=None,
               depth_loss_weight=1e2,
               interpenetration=False,
               coll_loss_weights=None,
               df_cone_height=0.5,
               penalize_outside=True,
               max_collisions=8,
               point2plane=False,
               part_segm_fn="",
               focal_length=5000.0,
               side_view_thsh=25.0,
               rho=100,
               vposer_latent_dim=32,
               vposer_ckpt="",
               use_joints_conf=False,
               interactive=True,
               visualize=False,
               batch_size=1,
               dtype=torch.float32,
               ign_part_pairs=None,
               left_shoulder_idx=2,
               right_shoulder_idx=5,
               freeze_camera=True,
               **kwargs):
    # assert batch_size == 1, "PyTorch L-BFGS only supports batch_size == 1"

    device = torch.device("cuda") if use_cuda else torch.device("cpu")

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]
    if data_weights is None:
        data_weights = [1] * len(body_pose_prior_weights)

    msg = "Number of Body pose prior weights {}".format(
        len(body_pose_prior_weights)
    ) + " does not match the number of data term weights {}".format(
        len(data_weights))
    assert len(data_weights) == len(body_pose_prior_weights), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1]
        msg = ("Number of Body pose prior weights does not match the" +
               " number of hand pose prior weights")
        assert len(hand_pose_prior_weights) == len(
            body_pose_prior_weights), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ("Number of Body pose prior weights does not match the" +
                   " number of hand joint distance weights")
            assert len(hand_joints_weights) == len(
                body_pose_prior_weights), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1]
    msg = ("Number of Body pose prior weights = {} does not match the" +
           " number of Shape prior weights = {}")
    assert len(shape_weights) == len(body_pose_prior_weights), msg.format(
        len(shape_weights), len(body_pose_prior_weights))

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ("Number of Body pose prior weights does not match the" +
           " number of collision loss weights")
    assert len(coll_loss_weights) == len(body_pose_prior_weights), msg

    use_vposer = kwargs.get("use_vposer", True)
    vposer, pose_embedding = [None] * 2
    if use_vposer:
        pose_embedding = torch.zeros(
            [batch_size, vposer_latent_dim],
            dtype=dtype,
            device=device,
            requires_grad=True,
        )

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model="snapshot")
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = (
            vposeutils.get_vposer_mean_pose().detach().cpu().numpy())
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2].reshape(keypoint_data.shape[0],
                                                     -1)
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, "Interpenetration term can only be used with CUDA"
        assert torch.cuda.is_available(), (
            "No CUDA Device! Interpenetration term can only be used" +
            " with CUDA")

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
            sigma=df_cone_height,
            point2plane=point2plane,
            vectorized=True,
            penalize_outside=penalize_outside,
        )

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, "rb") as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding="latin1")
            faces_segm = face_segm_data["segm"]
            faces_parents = face_segm_data["parents"]
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs,
            ).to(device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        "data_weight": data_weights,
        "body_pose_weight": body_pose_prior_weights,
        "shape_weight": shape_weights,
    }
    if use_face:
        opt_weights_dict["face_weight"] = face_joints_weights
        opt_weights_dict["expr_prior_weight"] = expr_weights
        opt_weights_dict["jaw_prior_weight"] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict["hand_weight"] = hand_joints_weights
        opt_weights_dict["hand_prior_weight"] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict["coll_loss_weight"] = coll_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    # Hand joints start at 25 (before body)
    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face,
                               use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size,
                                visualize=visualize,
                                **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        H, W, _ = img.shape

        data_weight = 1000 / H
        orientations = [body_model.global_orient.detach().cpu().numpy()]

        # # Step 2: Optimize the full model
        final_loss_val = 0
        for or_idx, orient in enumerate(tqdm(orientations,
                                             desc="Orientation")):
            opt_start = time.time()

            new_params = defaultdict(global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)

            for opt_idx, curr_weights in enumerate(
                    tqdm(opt_weights, desc="Stage")):

                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                (
                    body_optimizer,
                    body_create_graph,
                ) = optim_factory.create_optimizer(final_params, **kwargs)
                body_optimizer.zero_grad()

                curr_weights["data_weight"] = data_weight
                curr_weights["bending_prior_weight"] = (
                    3.17 * curr_weights["body_pose_weight"])
                if use_hands:
                    # joint_weights[:, 25:67] = curr_weights['hand_weight']
                    pass
                if use_face:
                    joint_weights[:, 67:] = curr_weights["face_weight"]
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer,
                    body_model,
                    camera=camera,
                    gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss,
                    create_graph=body_create_graph,
                    use_vposer=use_vposer,
                    vposer=vposer,
                    pose_embedding=pose_embedding,
                    return_verts=True,
                    return_full_pose=True,
                )

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure,
                    final_params,
                    body_model,
                    pose_embedding=pose_embedding,
                    vposer=vposer,
                    use_vposer=use_vposer,
                )

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write(
                            "Stage {:03d} done after {:.4f} seconds".format(
                                opt_idx, elapsed))

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write(
                    "Body fitting Orientation {} done after {:.4f} seconds".
                    format(or_idx, elapsed))
                tqdm.write(
                    "Body final loss val = {:.5f}".format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {
                "camera_" + str(key): val.detach().cpu().numpy()
                for key, val in camera.named_parameters()
            }
            result.update({
                key: val.detach().cpu().numpy()
                for key, val in body_model.named_parameters()
            })
            if use_vposer:
                result["pose_embedding"] = (
                    pose_embedding.detach().cpu().numpy())
                body_pose = (vposer.decode(
                    pose_embedding, output_type="aa").reshape(
                        pose_embedding.shape[0], -1) if use_vposer else None)
                result["body_pose"] = body_pose.detach().cpu().numpy()

        model_output = body_model(return_verts=True, body_pose=body_pose)
    return model_output, result
示例#10
0
def fit_single_frame(keypoints,
                     body_model,
                     joint_weights,
                     body_pose_prior,
                     jaw_prior,
                     left_hand_prior,
                     right_hand_prior,
                     shape_prior,
                     expr_prior,
                     angle_prior,
                     result_fn='out.pkl',
                     mesh_fn='out.obj',
                     out_img_fn='overlay.png',
                     loss_type='smplify',
                     use_cuda=True,
                     init_joints_idxs=(9, 12, 2, 5),
                     use_face=True,
                     use_hands=True,
                     data_weights=None,
                     body_pose_prior_weights=None,
                     hand_pose_prior_weights=None,
                     jaw_pose_prior_weights=None,
                     shape_weights=None,
                     expr_weights=None,
                     hand_joints_weights=None,
                     face_joints_weights=None,
                     depth_loss_weight=1e2,
                     interpenetration=True,
                     coll_loss_weights=None,
                     df_cone_height=0.5,
                     penalize_outside=True,
                     max_collisions=8,
                     point2plane=False,
                     part_segm_fn='',
                     focal_length=5000.,
                     side_view_thsh=25.,
                     rho=100,
                     vposer_latent_dim=32,
                     vposer_ckpt='',
                     use_joints_conf=False,
                     interactive=True,
                     visualize=False,
                     save_meshes=True,
                     degrees=None,
                     batch_size=1,
                     dtype=torch.float32,
                     ign_part_pairs=None,
                     left_shoulder_idx=2,
                     right_shoulder_idx=5,
                     **kwargs):
    assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    if data_weights is None:
        data_weights = [1, ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = (
        'Number of Body pose prior weights {}'.format(
            len(body_pose_prior_weights)) +
        ' does not match the number of data term weights {}'.format(
            len(data_weights)))
    assert (len(data_weights) ==
            len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (len(hand_pose_prior_weights) ==
                len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (len(hand_joints_weights) ==
                    len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) ==
            len(body_pose_prior_weights)), msg.format(
                len(shape_weights),
                len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (len(jaw_pose_prior_weights) ==
                len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) ==
                len(body_pose_prior_weights)), msg.format(
                    len(body_pose_prior_weights),
                    len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) ==
                len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) ==
            len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [None, ] * 2
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype, device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim],
                                     dtype=dtype)
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :-1]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, -1].reshape(1, -1)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm, faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {'data_weight': data_weights,
                        'body_pose_weight': body_pose_prior_weights,
                        'shape_weight': shape_weights}
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [dict(zip(keys, vals)) for vals in
                   zip(*(opt_weights_dict[k] for k in keys
                         if opt_weights_dict[k] is not None))]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face, use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(
            batch_size=batch_size, visualize=visualize, **kwargs) as monitor:

        data_weight = 0.7

        # Reset the parameters to estimate the initial translation of the
        # body model
        body_model.reset_params(body_pose=body_mean_pose)


        orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []

        # Optimize the full model
        final_loss_val = 0
        opt_start = time.time()

        new_params = defaultdict(body_pose=body_mean_pose)
        body_model.reset_params(**new_params)
        if use_vposer:
            with torch.no_grad():
                pose_embedding.fill_(0)

        for opt_idx, curr_weights in enumerate(tqdm(opt_weights, desc='Stage')):

            body_params = list(body_model.parameters())

            final_params = list(
                filter(lambda x: x.requires_grad, body_params))

            if use_vposer:
                final_params.append(pose_embedding)

            body_optimizer, body_create_graph = optim_factory.create_optimizer(
                final_params,
                **kwargs)
            body_optimizer.zero_grad()

            curr_weights['data_weight'] = data_weight
            curr_weights['bending_prior_weight'] = (
                3.17 * curr_weights['body_pose_weight'])
            if use_hands:
                joint_weights[:, 25:67] = curr_weights['hand_weight']
            if use_face:
                joint_weights[:, 67:] = curr_weights['face_weight']
            loss.reset_loss_weights(curr_weights)

            closure = monitor.create_fitting_closure(
                body_optimizer, body_model,
                gt_joints=gt_joints,
                joints_conf=joints_conf,
                joint_weights=joint_weights,
                loss=loss, create_graph=body_create_graph,
                use_vposer=use_vposer, vposer=vposer,
                pose_embedding=pose_embedding,
                return_verts=True, return_full_pose=True)

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                stage_start = time.time()
            final_loss_val = monitor.run_fitting(
                body_optimizer,
                closure, final_params,
                body_model,
                pose_embedding=pose_embedding, vposer=vposer,
                use_vposer=use_vposer)

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - stage_start
                if interactive:
                    tqdm.write('Stage {:03d} done after {:.4f} seconds'.format(
                        opt_idx, elapsed))

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            elapsed = time.time() - opt_start
            tqdm.write(
                'Body fitting Orientation done after {:.4f} seconds'.format(elapsed))
            tqdm.write('Body final loss val = {:.5f}'.format(
                final_loss_val))

        # Get the result of the fitting process
        # Store in it the errors list in order to compare multiple
        # orientations, if they exist
        result = {}
        result.update({key: val.detach().cpu().numpy()
                       for key, val in body_model.named_parameters()})
        if use_vposer:
            result['body_pose'] = pose_embedding.detach().cpu().numpy()

        results = {'loss': final_loss_val, 'result': result}

    with open(result_fn, 'wb') as result_file:
        pickle.dump(results['result'], result_file, protocol=2)

    if save_meshes or visualize:
        body_pose = vposer.decode(
            pose_embedding,
            output_type='aa').view(1, -1) if use_vposer else None

        model_type = kwargs.get('model_type', 'smpl')
        append_wrists = model_type == 'smpl' and use_vposer
        if append_wrists:
                wrist_pose = torch.zeros([body_pose.shape[0], 6],
                                         dtype=body_pose.dtype,
                                         device=body_pose.device)
                body_pose = torch.cat([body_pose, wrist_pose], dim=1)

        model_output = body_model(return_verts=True, body_pose=body_pose)
        vertices = model_output.vertices.detach().cpu().numpy().squeeze()

        import trimesh

        out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False)
        rot = trimesh.transformations.rotation_matrix(
            np.radians(180), [1, 0, 0])
        out_mesh.apply_transform(rot)
        out_mesh.export(mesh_fn)
    def __call__(self,
                 init_pose,
                 init_betas,
                 init_cam_t,
                 j3d,
                 fit_beta=False,
                 conf_3d=1.0):
        """Perform body fitting.
        Input:
            init_pose: SMPL pose estimate
            init_betas: SMPL betas estimate
            init_cam_t: Camera translation estimate
            j3d: joints 3d aka keypoints
            conf_3d: confidence for 3d joints
        Returns:
            vertices: Vertices of optimized shape
            joints: 3D joints of optimized shape
            pose: SMPL pose parameters of optimized shape
            betas: SMPL beta parameters of optimized shape
            camera_translation: Camera translation
        """

        # add the mesh inter-section to avoid
        search_tree = None
        pen_distance = None
        filter_faces = None
        if self.use_collision:

            from mesh_intersection.bvh_search_tree import BVH
            import mesh_intersection.loss as collisions_loss
            from mesh_intersection.filter_faces import FilterFaces

            search_tree = BVH(max_collisions=8)

            pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
                sigma=0.5,
                point2plane=False,
                vectorized=True,
                penalize_outside=True)

            if self.part_segm_fn:
                # Read the part segmentation
                part_segm_fn = os.path.expandvars(self.part_segm_fn)
                with open(part_segm_fn, 'rb') as faces_parents_file:
                    face_segm_data = pickle.load(faces_parents_file,
                                                 encoding='latin1')
                faces_segm = face_segm_data['segm']
                faces_parents = face_segm_data['parents']
                # Create the module used to filter invalid collision pairs
                filter_faces = FilterFaces(
                    faces_segm=faces_segm,
                    faces_parents=faces_parents,
                    ign_part_pairs=None).to(device=self.device)

        # Split SMPL pose to body pose and global orientation
        body_pose = init_pose[:, 3:].detach().clone()
        global_orient = init_pose[:, :3].detach().clone()
        betas = init_betas.detach().clone()

        # use guess 3d to get the initial
        smpl_output = self.smpl(global_orient=global_orient,
                                body_pose=body_pose,
                                betas=betas)
        model_joints = smpl_output.joints

        init_cam_t = self.guess_init_3d(model_joints, j3d).detach()
        camera_translation = init_cam_t.clone()

        preserve_pose = init_pose[:, 3:].detach().clone()

        # -------------Step 1: Optimize camera translation and body orientation--------
        # Optimize only camera translation and body orientation
        body_pose.requires_grad = False
        betas.requires_grad = False
        global_orient.requires_grad = True
        camera_translation.requires_grad = True

        camera_opt_params = [global_orient, camera_translation]

        if self.use_lbfgs:
            camera_optimizer = torch.optim.LBFGS(camera_opt_params,
                                                 max_iter=self.num_iters,
                                                 lr=self.step_size,
                                                 line_search_fn='strong_wolfe')
            for i in range(10):

                def closure():
                    camera_optimizer.zero_grad()
                    smpl_output = self.smpl(global_orient=global_orient,
                                            body_pose=body_pose,
                                            betas=betas)
                    model_joints = smpl_output.joints

                    loss = self.camera_fitting_loss_3d(model_joints,
                                                       camera_translation,
                                                       init_cam_t, j3d)
                    loss.backward()
                    return loss

                camera_optimizer.step(closure)
        else:
            camera_optimizer = torch.optim.Adam(camera_opt_params,
                                                lr=self.step_size,
                                                betas=(0.9, 0.999))

            for i in range(20):
                smpl_output = self.smpl(global_orient=global_orient,
                                        body_pose=body_pose,
                                        betas=betas)
                model_joints = smpl_output.joints

                loss = self.camera_fitting_loss_3d(
                    model_joints[:, self.smpl_index], camera_translation,
                    init_cam_t, j3d[:, self.corr_index])
                camera_optimizer.zero_grad()
                loss.backward()
                camera_optimizer.step()

        # Fix camera translation after optimizing camera
        # --------Step 2: Optimize body joints --------------------------
        # Optimize only the body pose and global orientation of the body
        body_pose.requires_grad = True
        global_orient.requires_grad = True
        camera_translation.requires_grad = True

        # --- if we use the sequence, fix the shape
        if fit_beta:
            betas.requires_grad = True
            body_opt_params = [
                body_pose, betas, global_orient, camera_translation
            ]
        else:
            betas.requires_grad = False
            body_opt_params = [body_pose, global_orient, camera_translation]

        if self.use_lbfgs:
            body_optimizer = torch.optim.LBFGS(body_opt_params,
                                               max_iter=self.num_iters,
                                               lr=self.step_size,
                                               line_search_fn='strong_wolfe')
            for i in range(self.num_iters):

                def closure():
                    body_optimizer.zero_grad()
                    smpl_output = self.smpl(global_orient=global_orient,
                                            body_pose=body_pose,
                                            betas=betas)
                    model_joints = smpl_output.joints
                    model_vertices = smpl_output.vertices

                    loss = self.body_fitting_loss_3d(
                        body_pose,
                        preserve_pose,
                        betas,
                        model_joints[:, self.smpl_index],
                        camera_translation,
                        j3d[:, self.corr_index],
                        self.pose_prior,
                        joints3d_conf=conf_3d,
                        joint_loss_weight=600.0,
                        pose_preserve_weight=5.0,
                        use_collision=self.use_collision,
                        model_vertices=model_vertices,
                        model_faces=self.model_faces,
                        search_tree=search_tree,
                        pen_distance=pen_distance,
                        filter_faces=filter_faces)
                    loss.backward()
                    return loss

                body_optimizer.step(closure)
        else:
            body_optimizer = torch.optim.Adam(body_opt_params,
                                              lr=self.step_size,
                                              betas=(0.9, 0.999))

            for i in range(self.num_iters):
                smpl_output = self.smpl(global_orient=global_orient,
                                        body_pose=body_pose,
                                        betas=betas)
                model_joints = smpl_output.joints
                model_vertices = smpl_output.vertices

                loss = self.body_fitting_loss_3d(
                    body_pose,
                    preserve_pose,
                    betas,
                    model_joints[:, self.smpl_index],
                    camera_translation,
                    j3d[:, self.corr_index],
                    self.pose_prior,
                    joints3d_conf=conf_3d,
                    joint_loss_weight=600.0,
                    use_collision=self.use_collision,
                    model_vertices=model_vertices,
                    model_faces=self.model_faces,
                    search_tree=search_tree,
                    pen_distance=pen_distance,
                    filter_faces=filter_faces)
                body_optimizer.zero_grad()
                loss.backward()
                body_optimizer.step()

        # Get final loss value
        with torch.no_grad():
            smpl_output = self.smpl(global_orient=global_orient,
                                    body_pose=body_pose,
                                    betas=betas,
                                    return_full_pose=True)
            model_joints = smpl_output.joints
            model_vertices = smpl_output.vertices

            final_loss = self.body_fitting_loss_3d(
                body_pose,
                preserve_pose,
                betas,
                model_joints[:, self.smpl_index],
                camera_translation,
                j3d[:, self.corr_index],
                self.pose_prior,
                joints3d_conf=conf_3d,
                joint_loss_weight=600.0,
                use_collision=self.use_collision,
                model_vertices=model_vertices,
                model_faces=self.model_faces,
                search_tree=search_tree,
                pen_distance=pen_distance,
                filter_faces=filter_faces)

        vertices = smpl_output.vertices.detach()
        joints = smpl_output.joints.detach()
        pose = torch.cat([global_orient, body_pose], dim=-1).detach()
        betas = betas.detach()
        camera_translation = camera_translation.detach()

        return vertices, joints, pose, betas, camera_translation, final_loss
示例#12
0
def fit_single_frame(img,
                     keypoints,
                     init_trans,
                     scan,
                     scene_name,
                     body_model,
                     camera,
                     joint_weights,
                     body_pose_prior,
                     jaw_prior,
                     left_hand_prior,
                     right_hand_prior,
                     shape_prior,
                     expr_prior,
                     angle_prior,
                     result_fn='out.pkl',
                     mesh_fn='out.obj',
                     body_scene_rendering_fn='body_scene.png',
                     out_img_fn='overlay.png',
                     loss_type='smplify',
                     use_cuda=True,
                     init_joints_idxs=(9, 12, 2, 5),
                     use_face=True,
                     use_hands=True,
                     data_weights=None,
                     body_pose_prior_weights=None,
                     hand_pose_prior_weights=None,
                     jaw_pose_prior_weights=None,
                     shape_weights=None,
                     expr_weights=None,
                     hand_joints_weights=None,
                     face_joints_weights=None,
                     depth_loss_weight=1e2,
                     interpenetration=True,
                     coll_loss_weights=None,
                     df_cone_height=0.5,
                     penalize_outside=True,
                     max_collisions=8,
                     point2plane=False,
                     part_segm_fn='',
                     focal_length_x=5000.,
                     focal_length_y=5000.,
                     side_view_thsh=25.,
                     rho=100,
                     vposer_latent_dim=32,
                     vposer_ckpt='',
                     use_joints_conf=False,
                     interactive=True,
                     visualize=False,
                     save_meshes=True,
                     degrees=None,
                     batch_size=1,
                     dtype=torch.float32,
                     ign_part_pairs=None,
                     left_shoulder_idx=2,
                     right_shoulder_idx=5,
                     ####################
                     ### PROX
                     render_results=True,
                     camera_mode='moving',
                     ## Depth
                     s2m=False,
                     s2m_weights=None,
                     m2s=False,
                     m2s_weights=None,
                     rho_s2m=1,
                     rho_m2s=1,
                     init_mode=None,
                     trans_opt_stages=None,
                     viz_mode='mv',
                     #penetration
                     sdf_penetration=False,
                     sdf_penetration_weights=0.0,
                     sdf_dir=None,
                     cam2world_dir=None,
                     #contact
                     contact=False,
                     rho_contact=1.0,
                     contact_loss_weights=None,
                     contact_angle=15,
                     contact_body_parts=None,
                     body_segments_dir=None,
                     load_scene=False,
                     scene_dir=None,
                     height=None,
                     weight=None,
                     gender='male',
                     weight_w=0,
                     height_w=0,
                     **kwargs):

    if kwargs['optim_type'] == 'lbfgsls':
        assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'

    batch_size = keypoints.shape[0]

    body_model.reset_params()
    body_model.transl.requires_grad = True

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    # if visualize:
    #     pil_img.fromarray((img * 255).astype(np.uint8)).show()

    if degrees is None:
        degrees = [0, 90, 180, 270]

    if data_weights is None:
        data_weights = [1, ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = (
        'Number of Body pose prior weights {}'.format(
            len(body_pose_prior_weights)) +
        ' does not match the number of data term weights {}'.format(
            len(data_weights)))
    assert (len(data_weights) ==
            len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (len(hand_pose_prior_weights) ==
                len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (len(hand_joints_weights) ==
                    len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) ==
            len(body_pose_prior_weights)), msg.format(
                len(shape_weights),
                len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (len(jaw_pose_prior_weights) ==
                len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) ==
                len(body_pose_prior_weights)), msg.format(
                    len(body_pose_prior_weights),
                    len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) ==
                len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) ==
            len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [None, ] * 2
    if use_vposer:
        # pose_embedding = torch.zeros([batch_size, 32],
        #                              dtype=dtype, device=device,
        #                              requires_grad=True)

        # Patrick: hack to set default body pose to something more sleep-y
        mean_body = np.array([[ 0.19463745,  1.6240447,   0.6890624,   0.19186097,  0.08003145, -0.04189298,
                       3.450903,   -0.29570094,  0.25072002, -1.1879578,   0.33350763,  0.23568614,
                       0.38122794, -2.1258948,   0.2910664,   2.2407222,  -0.5400814,  -0.95984083,
                      -1.2880017,   1.1122228,   0.7411389,  -0.2265636,  -4.8202057,  -1.950323,
                      -0.28771818, -1.9282387,   0.9928907,  -0.27183488, -0.55805033,  0.04047768,
                      -0.537362,    0.65770334]])

        pose_embedding = torch.tensor(mean_body, dtype=dtype, device=device, requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype)
    else:
        # body_mean_pose = body_pose_prior.get_mean().detach().cpu()
        # body_mean_pose = torch.zeros([batch_size, 69], dtype=dtype)

        # mean_body =  np.array([[-2.33263850e-01,  1.35460928e-01,  2.94471830e-01, -3.22930813e-01,
        #                         -4.73931670e-01, -2.67531037e-01,  7.12558180e-02,  7.89440796e-03,
        #                         8.67700949e-03,  1.05982251e-01,  2.79584467e-01, -7.04243258e-02,
        #                         3.61106455e-01, -5.87305248e-01,  1.10897996e-01, -1.68918714e-01,
        #                         -4.60174456e-02,  3.28684039e-02,  5.80525696e-01, -5.11317095e-03,
        #                         -1.57546505e-01,  5.85777402e-01, -8.94948393e-02,  2.24680841e-01,
        #                         1.55473784e-01,  5.38146123e-04,  4.30279821e-02, -4.68525589e-02,
        #                         7.75185153e-02,  7.82282930e-03,  6.74356073e-02,  4.09710407e-02,
        #                         -3.60425897e-02, -4.71813440e-01,  5.02379127e-02,  2.02309843e-02,
        #                         5.29680364e-02,  1.68510173e-02,  2.25090146e-01, -4.52307612e-02,
        #                         7.72185996e-02, -2.17333943e-01,  3.30020368e-01,  4.21866514e-02,
        #                         7.15153441e-02,  3.05950731e-01, -3.63454908e-01, -1.28235269e+00,
        #                         5.09610713e-01,  4.65482563e-01,  1.20263052e+00,  5.56594551e-01,
        #                         -2.24000740e+00,  3.83565158e-01,  5.31355202e-01,  2.21637583e+00,
        #                         -5.63146770e-01, -3.01193684e-01, -4.31942672e-01,  6.85038209e-01,
        #                         3.61178756e-01,  2.76136428e-01, -2.64388829e-01,  0.00000000e+00,
        #                         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        #                         0.00000000e+00]])
        mean_body = np.array(joint_limits.axang_limits_patrick / 180 * np.pi).mean(1)
        body_mean_pose = torch.tensor(mean_body, dtype=dtype).unsqueeze(0)


    betanet = None
    if height is not None:
        betanet = torch.load('models/betanet_old_pytorch.pt')
        betanet = betanet.to(device=device)
        betanet.eval()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2]

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    scan_tensor = None
    if scan is not None:
        scan_tensor = scan.to(device=device)

    # load pre-computed signed distance field
    sdf = None
    sdf_normals = None
    grid_min = None
    grid_max = None
    voxel_size = None
    # if sdf_penetration:
    #     with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f:
    #         sdf_data = json.load(f)
    #         grid_min = torch.tensor(np.array(sdf_data['min']), dtype=dtype, device=device)
    #         grid_max = torch.tensor(np.array(sdf_data['max']), dtype=dtype, device=device)
    #         grid_dim = sdf_data['dim']
    #     voxel_size = (grid_max - grid_min) / grid_dim
    #     sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape(grid_dim, grid_dim, grid_dim)
    #     sdf = torch.tensor(sdf, dtype=dtype, device=device)
    #     if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')):
    #         sdf_normals = np.load(osp.join(sdf_dir, scene_name + '_normals.npy')).reshape(grid_dim, grid_dim, grid_dim, 3)
    #         sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device)
    #     else:
    #         print("Normals not found...")


    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        cam2world = np.array(json.load(f))
        R = torch.tensor(cam2world[:3, :3].reshape(3, 3), dtype=dtype, device=device)
        t = torch.tensor(cam2world[:3, 3].reshape(1, 3), dtype=dtype, device=device)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm, faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # load vertix ids of contact parts
    contact_verts_ids  = ftov = None
    if contact:
        contact_verts_ids = []
        for part in contact_body_parts:
            with open(os.path.join(body_segments_dir, part + '.json'), 'r') as f:
                data = json.load(f)
                contact_verts_ids.append(list(set(data["verts_ind"])))
        contact_verts_ids = np.concatenate(contact_verts_ids)

        vertices = body_model(return_verts=True, body_pose= torch.zeros((batch_size, 63), dtype=dtype, device=device)).vertices
        vertices_np = vertices.detach().cpu().numpy().squeeze()
        body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape(-1, 3)
        m = Mesh(v=vertices_np, f=body_faces_np)
        ftov = m.faces_by_vertex(as_sparse_matrix=True)

        ftov = sparse.coo_matrix(ftov)
        indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device)
        values = torch.FloatTensor(ftov.data).to(device)
        shape = ftov.shape
        ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

    # Read the scene scan if any
    scene_v = scene_vn = scene_f = None
    if scene_name is not None:
        if load_scene:
            scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply'))

            scene.vn = scene.estimate_vertex_normals()

            scene_v = torch.tensor(scene.v[np.newaxis, :],
                                   dtype=dtype,
                                   device=device).contiguous()
            scene_vn = torch.tensor(scene.vn[np.newaxis, :],
                                    dtype=dtype,
                                    device=device)
            scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :],
                                   dtype=torch.long,
                                   device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {'data_weight': data_weights,
                        'body_pose_weight': body_pose_prior_weights,
                        'shape_weight': shape_weights}
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights
    if s2m:
        opt_weights_dict['s2m_weight'] = s2m_weights
    if m2s:
        opt_weights_dict['m2s_weight'] = m2s_weights
    if sdf_penetration:
        opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights
    if contact:
        opt_weights_dict['contact_loss_weight'] = contact_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [dict(zip(keys, vals)) for vals in
                   zip(*(opt_weights_dict[k] for k in keys
                         if opt_weights_dict[k] is not None))]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # load indices of the head of smpl-x model
    with open( osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp:
        head_indx = np.array(json.load(fp))
    N = body_model.get_num_verts()
    body_indx = np.setdiff1d(np.arange(N), head_indx)
    head_mask = np.in1d(np.arange(N), head_indx)
    body_mask = np.in1d(np.arange(N), body_indx)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    edge_indices = kwargs.get('body_tri_idxs')

    # which initialization mode to choose: similar traingles, mean of the scan or the average of both
    if init_mode == 'scan':
        init_t = init_trans
    elif init_mode == 'both':
        init_t = (init_trans.to(device) + fitting.guess_init(body_model, gt_joints, edge_indices,
                                    use_vposer=use_vposer, vposer=vposer,
                                    pose_embedding=pose_embedding,
                                    model_type=kwargs.get('model_type', 'smpl'),
                                    focal_length=focal_length_x, dtype=dtype) ) /2.0

    else:
        init_t = fitting.guess_init(body_model, gt_joints, edge_indices,
                                    use_vposer=use_vposer, vposer=vposer,
                                    pose_embedding=pose_embedding,
                                    model_type=kwargs.get('model_type', 'smpl'),
                                    focal_length=focal_length_x, dtype=dtype)

    camera_loss = fitting.create_loss('camera_init',
                                      trans_estimation=init_t,
                                      init_joints_idxs=init_joints_idxs,
                                      depth_loss_weight=depth_loss_weight,
                                      camera_mode=camera_mode,
                                      dtype=dtype).to(device=device)
    camera_loss.trans_estimation[:] = init_t

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face, use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               s2m=s2m,
                               m2s=m2s,
                               rho_s2m=rho_s2m,
                               rho_m2s=rho_m2s,
                               head_mask=head_mask,
                               body_mask=body_mask,
                               sdf_penetration=sdf_penetration,
                               voxel_size=voxel_size,
                               grid_min=grid_min,
                               grid_max=grid_max,
                               sdf=sdf,
                               sdf_normals=sdf_normals,
                               R=R,
                               t=t,
                               contact=contact,
                               contact_verts_ids=contact_verts_ids,
                               rho_contact=rho_contact,
                               contact_angle=contact_angle,
                               dtype=dtype,
                               betanet=betanet,
                               height=height,
                               weight=weight,
                               gender=gender,
                               weight_w=weight_w,
                               height_w=height_w,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, viz_mode=viz_mode, **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        _, H, W, _ = img.shape

        # Reset the parameters to estimate the initial translation of the
        # body model
        if camera_mode == 'moving':
            body_model.reset_params(body_pose=body_mean_pose)
            # Update the value of the translation of the camera as well as
            # the image center.
            with torch.no_grad():
                camera.translation[:] = init_t.view_as(camera.translation)
                camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5

            # Re-enable gradient calculation for the camera translation
            camera.translation.requires_grad = True

            camera_opt_params = [camera.translation, body_model.global_orient]

        elif camera_mode == 'fixed':
            # body_model.reset_params()
            # body_model.transl[:] = torch.tensor(init_t)
            # body_model.body_pose[:] = torch.tensor(body_mean_pose)
            body_model.reset_params(body_pose=body_mean_pose, transl=init_t)
            camera_opt_params = [body_model.transl, body_model.global_orient]

        # If the distance between the 2D shoulders is smaller than a
        # predefined threshold then try 2 fits, the initial one and a 180
        # degree rotation
        shoulder_dist = torch.norm(gt_joints[:, left_shoulder_idx, :] - gt_joints[:, right_shoulder_idx, :], dim=1)
        try_both_orient = shoulder_dist.min() < side_view_thsh

        kwargs['lr'] *= 10
        camera_optimizer, camera_create_graph = optim_factory.create_optimizer(camera_opt_params, **kwargs)
        kwargs['lr'] /= 10

        # The closure passed to the optimizer
        fit_camera = monitor.create_fitting_closure(
            camera_optimizer, body_model, camera, gt_joints,
            camera_loss, create_graph=camera_create_graph,
            use_vposer=use_vposer, vposer=vposer,
            pose_embedding=pose_embedding,
            scan_tensor=scan_tensor,
            return_full_pose=False, return_verts=False)

        # Step 1: Optimize over the torso joints the camera translation
        # Initialize the computational graph by feeding the initial translation
        # of the camera and the initial pose of the body model.
        camera_init_start = time.time()
        cam_init_loss_val = monitor.run_fitting(camera_optimizer,
                                                fit_camera,
                                                camera_opt_params, body_model,
                                                use_vposer=use_vposer,
                                                pose_embedding=pose_embedding,
                                                vposer=vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            tqdm.write('Camera initialization done after {:.4f}'.format(
                time.time() - camera_init_start))
            tqdm.write('Camera initialization final loss {:.4f}'.format(
                cam_init_loss_val))

        # If the 2D detections/positions of the shoulder joints are too
        # close the rotate the body by 180 degrees and also fit to that
        # orientation
        if try_both_orient:
            with torch.no_grad():
                flipped_orient = torch.zeros_like(body_model.global_orient)
                for i in range(batch_size):
                    body_orient = body_model.global_orient[i, :].detach().cpu().numpy()
                    local_flip = cv2.Rodrigues(body_orient)[0].dot(cv2.Rodrigues(np.array([0., np.pi, 0]))[0])
                    local_flip = cv2.Rodrigues(local_flip)[0].ravel()

                    flipped_orient[i, :] = torch.Tensor(local_flip).to(device)

            orientations = [body_model.global_orient, flipped_orient]
        else:
            orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []
        body_transl = body_model.transl.clone().detach()
        # Step 2: Optimize the full model
        final_loss_val = 0

        # for or_idx, orient in enumerate(orientations):
        or_idx = 0
        while or_idx < len(orientations):
            global_vars.cur_orientation = or_idx
            orient = orientations[or_idx]
            print('Trying orientation', or_idx, 'of', len(orientations))
            opt_start = time.time()
            or_idx += 1

            new_params = defaultdict(transl=body_transl,
                                     global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)
                    pose_embedding += torch.tensor(mean_body, dtype=dtype, device=device)

            for opt_idx, curr_weights in enumerate(opt_weights):
                global_vars.cur_opt_stage = opt_idx

                if opt_idx not in trans_opt_stages:
                    body_model.transl.requires_grad = False
                else:
                    body_model.transl.requires_grad = True
                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                body_optimizer, body_create_graph = optim_factory.create_optimizer(
                    final_params,
                    **kwargs)
                body_optimizer.zero_grad()

                curr_weights['bending_prior_weight'] = (
                    3.17 * curr_weights['body_pose_weight'])
                if use_hands:
                    joint_weights[:, 25:76] = curr_weights['hand_weight']
                if use_face:
                    joint_weights[:, 76:] = curr_weights['face_weight']
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer, body_model,
                    camera=camera, gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss, create_graph=body_create_graph,
                    use_vposer=use_vposer, vposer=vposer,
                    pose_embedding=pose_embedding,
                    scan_tensor=scan_tensor,
                    scene_v=scene_v, scene_vn=scene_vn, scene_f=scene_f,ftov=ftov,
                    return_verts=True, return_full_pose=True)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure, final_params,
                    body_model,
                    pose_embedding=pose_embedding, vposer=vposer,
                    use_vposer=use_vposer)

                # print('Final loss val', final_loss_val)
                # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']):
                #     break

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write('Stage {:03d} done after {:.4f} seconds'.format(
                            opt_idx, elapsed))

            # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']):
            #     print('Optimization FAILURE, retrying')
            #     orientations.append(orientations[or_idx-1] * 0.9)
            #     continue

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write('Body fitting Orientation {} done after {:.4f} seconds'.format(or_idx, elapsed))
                tqdm.write('Body final loss val = {:.5f}'.format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {'camera_' + str(key): val.detach().cpu().numpy()
                      for key, val in camera.named_parameters()}

            result['camera_focal_length_x'] = camera.focal_length_x.detach().cpu().numpy()
            result['camera_focal_length_y'] = camera.focal_length_y.detach().cpu().numpy()
            result['camera_center'] = camera.center.detach().cpu().numpy()

            result.update({key: val.detach().cpu().numpy()
                           for key, val in body_model.named_parameters()})
            if use_vposer:
                result['pose_embedding'] = pose_embedding.detach().cpu().numpy()
                body_pose = vposer.decode(pose_embedding, output_type='aa').view(1, -1) if use_vposer else None

                if "smplx.body_models.SMPL'" in str(type(body_model)):
                    wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device)
                    body_pose = torch.cat([body_pose, wrist_pose], dim=1)

                result['body_pose'] = body_pose.detach().cpu().numpy()
            result['final_loss_val'] = final_loss_val
            result['loss_dict'] = global_vars.cur_loss_dict
            result['betanet_weight'] = global_vars.cur_weight
            result['betanet_height'] = global_vars.cur_height
            result['gt_joints'] = gt_joints.detach().cpu().numpy()
            result['max_joint'] = global_vars.cur_max_joint

            results.append(result)

        for idx, res_folder in enumerate(result_fn):    # Iterate over batch
            pkl_data = {}
            min_loss = np.inf
            all_results = []
            for result in results:  # Iterate over orientations
                sel_res = misc_utils.get_data_from_batched_dict(result, idx, len(result_fn))
                all_results.append(sel_res)

                cost = sel_res['loss_dict']['total'] + sel_res['loss_dict']['pprior'] * 60
                if cost < min_loss:
                    min_loss = cost
                    pkl_data.update(sel_res)

            pkl_data['all_results'] = all_results

            with open(res_folder, 'wb') as result_file:
                pickle.dump(pkl_data, result_file, protocol=2)

            img_s = img[idx, :].detach().cpu().numpy()
            img_s = pil_img.fromarray((img_s * 255).astype(np.uint8))
            img_s.save(out_img_fn[idx])