def _read_identity_morph_targets(self):
        """Reads and returns the identities in the face model.

        Returns:
            A tuple whose first element is a list of identity names and whose
            second element is a list of openmesh.PolyMesh identity meshes.
        """
        id_names = []
        id_meshes = []

        identityNum = 0
        while True:
            id_name = 'identity{:03d}'.format(identityNum)
            file_name = id_name + ".obj"
            file_path = os.path.join(self._model_path, file_name)
            mesh = None
            try:
                print("Reading identity morph target: " + id_name)
                mesh = om.read_polymesh(file_path)
                id_names.append(id_name)
                id_meshes.append(mesh)
                identityNum = identityNum + 1
            except Exception as e:
                print("Unable to read identity morph target. Continuing...")
                break
            else:
                continue
            finally:
                pass

        return id_names, id_meshes
    def _read_generic_neutral_mesh(self):
        """Reads and returns the face model generic neutral mesh.

        Returns:
            A openmesh.PolyMesh representation of the generic neutral mesh.
        """
        file_path = os.path.join(self._model_path, 'generic_neutral_mesh.obj')
        generic_neutral_mesh = om.read_polymesh(file_path, halfedge_tex_coord = True)
        return generic_neutral_mesh
    def _read_expression_morph_targets(self, expression_names):
        """Reads and returns the expressions in the face model.

        Returns:
            A tuple whose first element is a list of expression names and whose
            second element is a list of openmesh.PolyMesh expression meshes.
        """
        ex_names = []
        ex_meshes = []
        for ex_name in expression_names:
            print("Reading expression morph target: " + ex_name)
            file_name = ex_name + '.obj'
            file_path = os.path.join(self._model_path, file_name)
            mesh = om.read_polymesh(file_path)
            ex_names.append(ex_name)
            ex_meshes.append(mesh)
        return ex_names, ex_meshes
 def test_read_nonexistent_om(self):
     with self.assertRaises(RuntimeError):
         self.mesh = openmesh.read_trimesh("TestFiles/nonexistent.om")
     with self.assertRaises(RuntimeError):
         self.mesh = openmesh.read_polymesh("TestFiles/nonexistent.om")
def test(net=None, save_subdir="test"):
    opt.phase = "test"
    dataset = build_dataset(opt)

    if opt.dim == 3:
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    if net is None:
        # network
        net = networks.NetworkFull(
            opt,
            dim=opt.dim,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=cage_V_t,
            template_faces=init_cage_Fs[-1],
        ).cuda()
        net.eval()
        load_network(net, opt.ckpt)
    else:
        net.eval()

    print(net)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=tolerating_collate,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    test_output_dir = os.path.join(opt.log_dir, save_subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with open(os.path.join(test_output_dir, "eval.txt"), "w") as f:
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                data = dataset.uncollate(data)

                ############# blending ############
                # sample 4 different alpha
                if opt.blend_style:
                    num_alpha = 4
                    blend_alpha = torch.linspace(
                        0, 1, steps=num_alpha,
                        dtype=torch.float32).cuda().reshape(num_alpha, 1)
                    data["source_shape"] = data["source_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                    data["target_shape"] = data["target_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                else:
                    blend_alpha = 1.0

                data["alpha"] = blend_alpha

                ###################################
                source_shape_t = data["source_shape"].transpose(
                    1, 2).contiguous().detach()
                target_shape_t = data["target_shape"].transpose(
                    1, 2).contiguous().detach()

                outputs = net(source_shape_t, target_shape_t, blend_alpha)
                deformed = outputs["deformed"]

                ####################### evaluation ########################
                s_filename = os.path.splitext(data["source_file"][0])[0]
                t_filename = os.path.splitext(data["target_file"][0])[0]

                log_str = "{}/{} {}-{} ".format(i, len(dataloader), s_filename,
                                                t_filename)
                print(log_str)
                f.write(log_str + "\n")

                ###################### outputs ############################
                for b in range(deformed.shape[0]):
                    if "source_mesh" in data and data[
                            "source_mesh"] is not None:
                        if isinstance(data["source_mesh"][0], str):
                            source_mesh = om.read_polymesh(
                                data["source_mesh"][0]).points().copy()
                            source_mesh = dataset.normalize(
                                source_mesh, opt.isV2)
                            source_mesh = torch.from_numpy(
                                source_mesh.astype(
                                    np.float32)).unsqueeze(0).cuda()
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], source_mesh)
                        else:
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], data["source_mesh"])

                    deformed[b] = center_bounding_box(deformed[b])[0]
                    if data["source_face"] is not None and data[
                            "source_mesh"] is not None:
                        source_mesh = data["source_mesh"][0].detach().cpu()
                        source_mesh = center_bounding_box(source_mesh)[0]
                        source_face = data["source_face"][0].detach().cpu()
                        tosave = pymesh.form_mesh(vertices=source_mesh,
                                                  faces=source_face)
                        pymesh.save_mesh(os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-Sa.obj".format(s_filename, t_filename)),
                                         tosave,
                                         use_float=True)
                        tosave = pymesh.form_mesh(
                            vertices=deformed[0].detach().cpu(),
                            faces=source_face)
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)),
                            tosave,
                            use_float=True,
                        )
                    elif data["source_face"] is None and isinstance(
                            data["source_mesh"][0], str):
                        orig_file_path = data["source_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = source_mesh[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.obj".format(s_filename, t_filename)),
                            mesh)
                        points_arr[:] = deformed[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)), mesh)
                    else:
                        # save to "pts" for rendering
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.pts".format(s_filename, t_filename)),
                            data["source_shape"][b].detach().cpu())
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.pts".format(
                                    s_filename, t_filename, b)),
                            deformed[0].detach().cpu())

                    if data["target_face"] is not None and data[
                            "target_mesh"] is not None:
                        data["target_mesh"][0] = center_bounding_box(
                            data["target_mesh"][0])[0]
                        tosave = pymesh.form_mesh(
                            vertices=data["target_mesh"][0].detach().cpu(),
                            faces=data["target_face"][0].detach().cpu())
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            tosave,
                            use_float=True,
                        )
                    elif data["target_face"] is None and isinstance(
                            data["target_mesh"][0], str):
                        orig_file_path = data["target_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = dataset.normalize(
                            points_arr.copy(), opt.isV2)
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            mesh)
                    else:
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.pts".format(s_filename, t_filename)),
                            data["target_shape"][0].detach().cpu())

                    outputs["cage"][b] = center_bounding_box(
                        outputs["cage"][b])[0]
                    outputs["new_cage"][b] = center_bounding_box(
                        outputs["new_cage"][b])[0]
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage1-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage2-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["new_cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)

    dataset.render_result(test_output_dir)
import sys
import os
import torch
import numpy as np
import openmesh as om
from pytorch_points.network.operations import faiss_knn
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR + "/..")
from common import read_trimesh

source_model = sys.argv[1]
target_model = sys.argv[2]
orig_label_path = os.path.splitext(source_model)[0] + ".picked"
new_lable = os.path.splitext(target_model)[0] + ".picked"

target_mesh = om.read_polymesh(target_model)
# target_mesh = om.read_trimesh(target_model)
target_shape_arr = target_mesh.points()
target_shape = target_shape_arr.copy()
target_shape = torch.from_numpy(target_shape[:, :3].astype(np.float32))
target_shape.unsqueeze_(0)
orig_label = pd.read_csv(orig_label_path,
                         delimiter=" ",
                         skiprows=1,
                         header=None)
orig_label_name = orig_label.iloc[:, 5]
source_points = torch.from_numpy(orig_label.iloc[:, 6:9].to_numpy().astype(
    np.float32))
source_points = source_points.unsqueeze(0)
# find the closest point on the original meshes
source_mesh = om.read_polymesh(source_model)
Exemple #7
0
def optimize(opt):
    """
    weights are the same with the original source mesh
    target=net(old_source)
    """
    # load new target
    if opt.is_poly:
        target_mesh = om.read_polymesh(opt.model)
    else:
        target_mesh = om.read_trimesh(opt.model)
    target_shape_arr = target_mesh.points()
    target_shape = target_shape_arr.copy()
    target_shape = torch.from_numpy(
        target_shape[:, :3].astype(np.float32)).cuda()
    target_shape.unsqueeze_(0)

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    cage_v = states["template_vertices"].transpose(1, 2).cuda()
    cage_f = states["template_faces"].cuda()
    shape_v = states["source_vertices"].transpose(1, 2).cuda()
    shape_f = states["source_faces"].cuda()

    if os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        new_label_path = opt.model.replace(os.path.splitext(opt.model)[1], ".picked")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {} and {}".format(orig_label_path, new_label_path))
        import pandas as pd
        new_label = pd.read_csv(new_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        new_label_name = new_label.iloc[:,5].tolist()
        new_to_orig_idx = []
        for i, name in enumerate(new_label_name):
            matched_idx = orig_label_name[orig_label_name==name].index
            if matched_idx.size == 1:
                new_to_orig_idx.append((i, matched_idx[0]))
        new_to_orig_idx = np.array(new_to_orig_idx)
        if new_label.shape[1] == 10:
            new_vidx = new_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,0]]
            target_points = target_shape[:, new_vidx, :]
        else:
            new_label_points = torch.from_numpy(new_label.iloc[:,6:9].to_numpy().astype(np.float32))
            target_points = new_label_points.unsqueeze(0).cuda()
            target_points, new_vidx, _ = faiss_knn(1, target_points, target_shape, NCHW=False)
            target_points = target_points.squeeze(2) # B,N,3
            new_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            new_label.to_csv(new_label_path, sep=" ", header=[str(new_label.shape[0])]+[""]*(new_label.shape[1]-1), index=False)
            target_points = target_points[:, new_to_orig_idx[:,0], :]

        target_points = target_points.cuda()
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).float()
        if orig_label.shape[1] == 10:
            orig_vidx = orig_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,1]]
            source_points = source_shape[:, orig_vidx, :]
        else:
            orig_label_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = orig_label_points.unsqueeze(0)
            # find the closest point on the original meshes
            source_points, new_vidx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            orig_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            orig_label.to_csv(orig_label_path, sep=" ", header=[str(orig_label.shape[0])]+[""]*(orig_label.shape[1]-1), index=False)
            source_points = source_points[:,new_to_orig_idx[:,1],:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
        source_points = source_points.cuda()
        # # shift target so that the belly match
        # try:
        #     orig_bellyUp_idx = orig_label_name[orig_label_name=="bellUp"].index[0]
        #     orig_bellyUp = orig_label_points[orig_bellyUp_idx, :]
        #     new_bellyUp_idx = [i for i, i2 in new_to_orig_idx if i2==orig_bellyUp_idx][0]
        #     new_bellyUp = new_label_points[new_bellyUp_idx,:]
        #     target_points += (orig_bellyUp - new_bellyUp)
        # except Exception as e:
        #     logger.warn("Couldn\'t match belly to belly")
        #     traceback.print_exc(file=sys.stdout)

        # source_points[0] = center_bounding_box(source_points[0])[0]
    elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        logger.info("Assuming Faust model")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {}".format(orig_label_path))
        import pandas as pd
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).cuda().float()
        if orig_label.shape[1] == 10:
            idx = torch.from_numpy(orig_label.iloc[:,9].to_numpy()).long()
            source_points = source_shape[:,idx,:]
            target_points = target_shape[:,idx,:]
        else:
            source_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = source_points.unsqueeze(0).cuda()
            # find the closest point on the original meshes
            source_points, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            idx = idx.squeeze(-1)
            target_points = target_shape[:,idx,:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
    elif opt.corres_idx is None and target_shape.shape[1] == shape_v.shape[1]:
        logger.info("No correspondence provided, assuming registered Faust models")
        # corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda()
        corresp_v = torch.unique(torch.randint(0, shape_v.shape[1], (4800,))).cuda()
        target_points = torch.index_select(target_shape, 1, corresp_v)
        source_points = torch.index_select(shape_v, 1, corresp_v)

    target_shape[0], target_center, target_scale = center_bounding_box(target_shape[0])
    _, _, source_scale = center_bounding_box(shape_v[0])
    target_scale_factor = (source_scale/target_scale)[1]
    target_shape *= target_scale_factor
    target_points -= target_center
    target_points = (target_points*target_scale_factor).detach()
    # make sure test use the normalized
    target_shape_arr[:] = target_shape[0].cpu().numpy()
    om.write_mesh(os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj"), target_mesh)
    opt.model = os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj")
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.obj"),
                         shape_v[0].cpu().numpy(), shape_f[0].cpu().numpy())
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "cage-initial.obj"),
                         cage_v[0].cpu().numpy(), cage_f[0].cpu().numpy())
    save_ply(target_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "target_points.ply"))
    save_ply(source_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "source_points.ply"))
    logger.info("Optimizing for {} corresponding vertices".format(
        target_points.shape[1]))

    cage_init = cage_v.clone().detach()
    lap_loss = MeshLaplacianLoss(torch.nn.MSELoss(reduction="none"), use_cot=True,
                                 use_norm=True, consistent_topology=True, precompute_L=True)
    mvc_reg_loss = MVCRegularizer(threshold=50, beta=1.0, alpha=0.0)
    cage_v.requires_grad_(True)
    optimizer = torch.optim.Adam([cage_v], lr=opt.lr, betas=(0.5, 0.9))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(opt.nepochs*0.4), gamma=0.5, last_epoch=-1)

    if opt.dim == 3:
        weights_ref = mean_value_coordinates_3D(
            source_points, cage_init, cage_f, verbose=False)
    else:
        raise NotImplementedError

    for t in range(opt.nepochs):
        optimizer.zero_grad()
        weights = mean_value_coordinates_3D(
            target_points, cage_v, cage_f, verbose=False)
        loss_mvc = torch.mean((weights-weights_ref)**2)
        # reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4
        reg = 0
        if opt.clap_weight > 0:
            reg = lap_loss(cage_init, cage_v, face=cage_f)*opt.clap_weight
            reg = reg.mean()
        if opt.mvc_weight > 0:
            reg += mvc_reg_loss(weights)*opt.mvc_weight

        # weight regularizer with the shape difference
        # dist = torch.sum((source_points - target_points)**2, dim=-1)
        # weights = torch.exp(-dist)
        # reg = reg*weights*0.1

        loss = loss_mvc + reg
        if (t+1) % 50 == 0:
            print("t {}/{} mvc_loss: {} reg: {}".format(t,
                                                        opt.nepochs, loss_mvc.item(), reg.item()))

        if loss_mvc.item() < 5e-6:
            break
        loss.backward()
        optimizer.step()
        scheduler.step()

    return cage_v, cage_f
Exemple #8
0
def test_all(opt, new_cage_shape):
    opt.phase = "test"
    opt.target_model = None
    print(opt.model)

    if opt.is_poly:
        source_mesh = om.read_polymesh(opt.model)
    else:
        source_mesh = om.read_trimesh(opt.model)

    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False,
                                             collate_fn=tolerating_collate,
                                             num_workers=0, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id))

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    # states["template_vertices"] = new_cage_shape.transpose(1, 2)
    # states["source_vertices"] = new_source.transpose(1,2)
    # states["source_faces"] = new_source_face
    # new_source_face = states["source_faces"]

    om.write_mesh(os.path.join(opt.log_dir, opt.subdir,
                               "template-Sa.ply"), source_mesh)

    net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size,
                                       template_vertices=states["template_vertices"], template_faces=states["template_faces"].cuda(),
                                       source_vertices=states["source_vertices"], source_faces=states["source_faces"]).cuda()
    load_network(net, states)

    source_points = torch.from_numpy(
        source_mesh.points().copy()).float().cuda().unsqueeze(0)
    with torch.no_grad():
        # source_face = net.source_faces.detach()
        for i, data in enumerate(dataloader):
            data = dataset.uncollate(data)

            target_shape, target_filename = data["target_shape"], data["target_file"]
            logger.info("", data["target_file"][0])

            sample_idx = None
            if "sample_idx" in data:
                sample_idx = data["sample_idx"]

            outputs = net(target_shape.transpose(1, 2), cage_only=True)
            if opt.d_residual:
                cage_offset = outputs["new_cage"]-outputs["cage"]
                outputs["cage"] = new_cage_shape
                outputs["new_cage"] = new_cage_shape+cage_offset

            deformed = deform_with_MVC(outputs["cage"], outputs["new_cage"], outputs["cage_face"].expand(
                outputs["cage"].shape[0], -1, -1), source_points)

            for b in range(deformed.shape[0]):
                t_filename = os.path.splitext(target_filename[b])[0]
                source_mesh_arr = source_mesh.points()
                source_mesh_arr[:] = deformed[0].cpu().detach().numpy()
                om.write_mesh(os.path.join(
                    opt.log_dir, opt.subdir, "template-{}-Sab.obj".format(t_filename)), source_mesh)
                # if data["target_face"] is not None and data["target_mesh"] is not None:
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sa.ply".format(t_filename)),
                #             source_mesh[0].detach().cpu(), source_face[b].detach().cpu())
                pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sb.ply".format(t_filename)),
                                     data["target_mesh"][b].detach().cpu(), data["target_face"][b].detach().cpu())
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab.ply".format(t_filename)),
                #             deformed[b].detach().cpu(), source_face[b].detach().cpu())

                # else:
                #     save_ply(source_mesh[0].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sa.ply".format(t_filename)))
                #     save_ply(target_shape[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sb.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())
                #     save_ply(deformed[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sab.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())

                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage1.ply".format(t_filename)),
                    outputs["cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )
                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage2.ply".format(t_filename)),
                    outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )

            # if opt.opt_lap and deformed.shape[1] == source_mesh.shape[1]:
            #     deformed = optimize_lap(opt, source_mesh, deformed, source_face)
            #     for b in range(deformed.shape[0]):
            #         pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab-optlap.ply".format(t_filename)),
            #                                 deformed[b].detach().cpu(), source_face[b].detach().cpu())

            if i % 20 == 0:
                logger.success("[{}/{}] Done".format(i, len(dataloader)))

    dataset.render_result(os.path.join(opt.log_dir, opt.subdir))