Beispiel #1
0
def optimize_offsets(th_scan_meshes, smpl, init_smpl_meshes, iterations,
                     steps_per_iter):
    # Optimizer
    optimizer = torch.optim.Adam(
        [smpl.offsets, smpl.pose, smpl.trans, smpl.betas],
        0.005,
        betas=(0.9, 0.999))  # ,

    # Get loss_weights
    weight_dict = get_loss_weights()
    # search_tree, pen_distance, tri_filtering_module = get_interpenetration_module()
    for it in range(iterations):
        loop = tqdm(range(steps_per_iter))
        loop.set_description('Optimizing SMPL+D')
        for i in loop:
            optimizer.zero_grad()
            # Get losses for a forward pass
            loss_dict = forward_step(th_scan_meshes, smpl, init_smpl_meshes)
            # Get total loss for backward pass
            tot_loss = backward_step(loss_dict, weight_dict, it)
            tot_loss.backward()
            optimizer.step()

            l_str = 'Lx100. Iter: {}'.format(i)
            for k in loss_dict:
                l_str += ', {}: {:0.4f}'.format(
                    k, loss_dict[k].mean().item() * 100)
            loop.set_description(l_str)
Beispiel #2
0
def optimize_pose_shape(th_scan_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                        display=None):
    """
    Optimize SMPL.
    :param display: if not None, pass index of the scan in th_scan_meshes to visualize.
    """
    # Optimizer
    optimizer = torch.optim.Adam([smpl.trans, smpl.betas, smpl.pose], 0.02, betas=(0.9, 0.999))

    # Get loss_weights
    weight_dict = get_loss_weights()

    # Display
    if display is not None:
        assert int(display) < len(th_scan_meshes)
        mv = MeshViewer()

    for it in range(iterations):
        loop = tqdm(range(steps_per_iter))
        loop.set_description('Optimizing SMPL')
        for i in loop:
            optimizer.zero_grad()
            # Get losses for a forward pass
            loss_dict = forward_step(th_scan_meshes, smpl, scan_part_labels, smpl_part_labels)
            # Get total loss for backward pass
            tot_loss = backward_step(loss_dict, weight_dict, it)
            tot_loss.backward()
            optimizer.step()

            l_str = 'Iter: {}'.format(i)
            for k in loss_dict:
                l_str += ', {}: {:0.4f}'.format(k, weight_dict[k](loss_dict[k], it).mean().item())
                loop.set_description(l_str)

            if display is not None:
                verts, _, _, _ = smpl()
                smpl_mesh = Mesh(v=verts[display].cpu().detach().numpy(), f=smpl.faces.cpu().numpy())
                scan_mesh = Mesh(v=th_scan_meshes[display].vertices.cpu().detach().numpy(),
                                 f=th_scan_meshes[display].faces.cpu().numpy(), vc=np.array([0, 1, 0]))
                scan_mesh.set_vertex_colors_from_weights(scan_part_labels[display].cpu().detach().numpy())
                mv.set_static_meshes([scan_mesh, smpl_mesh])

    print('** Optimised smpl pose and shape **')
Beispiel #3
0
def optimize_pose_only(th_scan_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                       display=None):
    """
    Initially we want to only optimize the global rotation of SMPL. Next we optimize full pose.
    We optimize pose based on the 3D keypoints in th_pose_3d.
    :param  th_pose_3d: array containing the 3D keypoints.
    """

    batch_sz = smpl.pose.shape[0]
    split_smpl = th_batch_SMPL_split_params(batch_sz, top_betas=smpl.betas.data[:, :2],
                                            other_betas=smpl.betas.data[:, 2:],
                                            global_pose=smpl.pose.data[:, :3], other_pose=smpl.pose.data[:, 3:],
                                            faces=smpl.faces, gender=smpl.gender).to(DEVICE)
    optimizer = torch.optim.Adam([split_smpl.trans, split_smpl.top_betas, split_smpl.global_pose], 0.02,
                                 betas=(0.9, 0.999))

    # Get loss_weights
    weight_dict = get_loss_weights()

    if display is not None:
        assert int(display) < len(th_scan_meshes)
        # mvs = MeshViewers((1,1))
        mv = MeshViewer(keepalive=True)

    iter_for_global = 1
    for it in range(iter_for_global + iterations):
        loop = tqdm(range(steps_per_iter))
        if it < iter_for_global:
            # Optimize global orientation
            print('Optimizing SMPL global orientation')
            loop.set_description('Optimizing SMPL global orientation')
        elif it == iter_for_global:
            # Now optimize full SMPL pose
            print('Optimizing SMPL pose only')
            loop.set_description('Optimizing SMPL pose only')
            optimizer = torch.optim.Adam([split_smpl.trans, split_smpl.top_betas, split_smpl.global_pose,
                                          split_smpl.other_pose], 0.02, betas=(0.9, 0.999))
        else:
            loop.set_description('Optimizing SMPL pose only')

        for i in loop:
            optimizer.zero_grad()
            # Get losses for a forward pass
            loss_dict = forward_step(th_scan_meshes, split_smpl, scan_part_labels, smpl_part_labels)
            # Get total loss for backward pass
            tot_loss = backward_step(loss_dict, weight_dict, it)
            tot_loss.backward()
            optimizer.step()

            l_str = 'Iter: {}'.format(i)
            for k in loss_dict:
                l_str += ', {}: {:0.4f}'.format(k, weight_dict[k](loss_dict[k], it).mean().item())
                loop.set_description(l_str)

            if display is not None:
                verts, _, _, _ = split_smpl()
                smpl_mesh = Mesh(v=verts[display].cpu().detach().numpy(), f=smpl.faces.cpu().numpy())
                scan_mesh = Mesh(v=th_scan_meshes[display].vertices.cpu().detach().numpy(),
                                 f=th_scan_meshes[display].faces.cpu().numpy(), vc=np.array([0, 1, 0]))
                scan_mesh.set_vertex_colors_from_weights(scan_part_labels[display].cpu().detach().numpy())

                mv.set_dynamic_meshes([smpl_mesh, scan_mesh])

    # Put back pose, shape and trans into original smpl
    smpl.pose.data = split_smpl.pose.data
    smpl.betas.data = split_smpl.betas.data
    smpl.trans.data = split_smpl.trans.data

    print('** Optimised smpl pose **')