Esempio n. 1
0
def optimize_pose_shape(th_scan_meshes,
                        smpl,
                        iterations,
                        steps_per_iter,
                        th_pose_3d=None,
                        display=None):
    """
    Optimize SMPL.
    :param display: if not None, pass index of the scan in th_scan_meshes to visualize.
    """
    # Optimizer
    optimizer = torch.optim.Adam([smpl.trans, smpl.betas, smpl.pose],
                                 0.02,
                                 betas=(0.9, 0.999))

    # Get loss_weights
    weight_dict = get_loss_weights()

    # Display
    if display is not None:
        assert int(display) < len(th_scan_meshes)
        mv = MeshViewer()

    for it in range(iterations):
        loop = tqdm(range(steps_per_iter))
        loop.set_description('Optimizing SMPL')
        for i in loop:
            optimizer.zero_grad()
            # Get losses for a forward pass
            loss_dict = forward_step(th_scan_meshes, smpl, th_pose_3d)
            # Get total loss for backward pass
            tot_loss = backward_step(loss_dict, weight_dict, it)
            tot_loss.backward()
            optimizer.step()

            l_str = 'Iter: {}'.format(i)
            for k in loss_dict:
                l_str += ', {}: {:0.4f}'.format(
                    k, weight_dict[k](loss_dict[k], it).mean().item())
                loop.set_description(l_str)

            if display is not None:
                verts, _, _, _ = smpl()
                smpl_mesh = Mesh(v=verts[display].cpu().detach().numpy(),
                                 f=smpl.faces.cpu().numpy())
                scan_mesh = Mesh(
                    v=th_scan_meshes[display].vertices.cpu().detach().numpy(),
                    f=th_scan_meshes[display].faces.cpu().numpy(),
                    vc=np.array([0, 1, 0]))
                mv.set_static_meshes([scan_mesh, smpl_mesh])

    print('** Optimised smpl pose and shape **')
Esempio n. 2
0
def optimize_pose_shape(th_scan_meshes,
                        smplx,
                        iterations,
                        steps_per_iter,
                        scan_part_labels,
                        smplx_part_labels,
                        search_tree=None,
                        pen_distance=None,
                        tri_filtering_module=None,
                        display=None):
    """
    Optimize SMPLX.
    :param display: if not None, pass index of the scan in th_scan_meshes to visualize.
    """
    # smplx.expression.requires_grad = False
    # smplx.jaw_pose.requires_grad = False
    # Optimizer
    optimizer = torch.optim.Adam([
        smplx.trans, smplx.betas, smplx.global_pose, smplx.body_pose,
        smplx.left_hand_pose, smplx.right_hand_pose
    ],
                                 0.02,
                                 betas=(0.9, 0.999))

    # Get loss_weights
    weight_dict = get_loss_weights()

    # Display
    if display is not None:
        assert int(display) < len(th_scan_meshes)
        mv = MeshViewer()

    for it in range(iterations):
        loop = tqdm(range(steps_per_iter))
        loop.set_description('Optimizing SMPLX')
        for i in loop:
            optimizer.zero_grad()
            # Get losses for a forward pass
            loss_dict = forward_step(th_scan_meshes, smplx, scan_part_labels,
                                     smplx_part_labels, search_tree,
                                     pen_distance, tri_filtering_module)
            # Get total loss for backward pass
            tot_loss = backward_step(loss_dict, weight_dict, it)
            tot_loss.backward()
            optimizer.step()

            l_str = 'Iter: {}'.format(i)
            for k in loss_dict:
                l_str += ', {}: {:0.4f}'.format(
                    k, weight_dict[k](loss_dict[k], it).mean().item())
                loop.set_description(l_str)

            if display is not None:
                # verts, _, _, _ = smplx()
                verts = smplx()
                smplx_mesh = Mesh(v=verts[display].cpu().detach().numpy(),
                                  f=smplx.faces.cpu().numpy())
                scan_mesh = Mesh(
                    v=th_scan_meshes[display].vertices.cpu().detach().numpy(),
                    f=th_scan_meshes[display].faces.cpu().numpy(),
                    vc=np.array([0, 1, 0]))
                scan_mesh.set_vertex_colors_from_weights(
                    scan_part_labels[display].cpu().detach().numpy())
                mv.set_static_meshes([scan_mesh, smplx_mesh])

    print('** Optimised smplx pose and shape **')