示例#1
0
    def prepare_sphere_template(self, template_vertices, template_faces):
        assert (template_vertices.ndim == 3
                and template_vertices.shape[1] == 3)  # (1,3,V)
        angle, self.r = xyz_to_polar(template_vertices)

        self.template = nn.Parameter(angle,
                                     requires_grad=self.opt.optimize_template)
        assert (template_faces.ndim == 3
                and template_faces.shape[2] == 3)  # (1,F,3)
        self.register_buffer("template_faces", template_faces)
        self.register_buffer("template_vertices", template_vertices)
        if self.template.requires_grad:
            logger.info("Enabled vertex optimization")
示例#2
0
    def set_up_template(self, template_vertices, template_faces):
        # save template as buffer
        assert (template_vertices.ndim == 3
                and template_vertices.shape[1] == self.dim)  # (1,3,V)
        if self.dim == 3:
            assert (template_faces.ndim == 3
                    and template_faces.shape[2] == 3)  # (1,F,3)

        self.register_buffer("template_faces", template_faces)
        self.register_buffer("template_vertices", template_vertices)
        self.template_vertices = nn.Parameter(
            self.template_vertices, requires_grad=(self.opt.optimize_template))
        if self.template_vertices.requires_grad:
            logger.info("Enabled vertex optimization")
示例#3
0
 def __init__(self, opt):
     super().__init__()
     self.opt = opt
     self.loss = defaultdict(float)
     self.labeled_chamfer_loss = LabeledChamferDistance(beta=opt.beta,
                                                        gamma=opt.gamma,
                                                        delta=opt.delta)
     self.cage_shortLength_loss = SimpleMeshRepulsionLoss(
         0.02, reduction="mean", consistent_topology=True)
     self.cage_faceAngle_loss = MeshDihedralAngleLoss(threshold=np.pi / 30)
     self.mvc_reg_loss = MVCRegularizer(threshold=50, beta=1.0, alpha=0.0)
     self.cage_laplacian = MeshLaplacianLoss(
         torch.nn.L1Loss(reduction="mean"),
         use_cot=False,
         use_norm=True,
         consistent_topology=True,
         precompute_L=True)
     self.cage_smooth_loss = MeshSmoothLoss(
         torch.nn.MSELoss(reduction="mean"), use_cot=False, use_norm=True)
     self.grounding_loss = GroundingLoss(
         up_dim=(1 if "SHAPENET" in opt.dataset else 2))
     if opt.sym_plane is not None:
         self.symmetry_loss = SymmetryLoss(sym_plane=opt.sym_plane,
                                           NCHW=False).cuda()
     # mesh_chamfer_loss = losses.InterpolatedCDTriMesh(interpolate_n=5, beta=1.0, gamma=0.0, delta=1/30)
     # cage_inside_loss = InsideLoss3DTriMesh(reduction="max")
     # cage_inside_loss = ExtPointToNearestFaceDistance(reduction="mean", min_dist=opt.cinside_eps)
     if self.opt.dataset in ("SURREAL", "FAUST"):
         logger.info("Using GTNormal loss")
         self.shape_normal_loss = GTNormalLoss()
     else:
         logger.info("Using KNN for normal loss")
         self.shape_normal_loss = NormalLoss(reduction="none", nn_size=16)
     self.shape_fnormal_loss = FaceNormalLoss(n_faces=300)
     self.stretch_loss = PointStretchLoss((4 if opt.dim == 3 else 2),
                                          reduction="mean")
     self.edge_loss = PointEdgeLengthLoss(
         (4 if opt.dim == 3 else 2), torch.nn.MSELoss(reduction="mean"))
     if self.opt.regular_sampling or (not opt.mesh_data):
         logger.info("Using point laplacian loss")
         self.shape_laplacian = PointLaplacianLoss(
             16, torch.nn.MSELoss(reduction="none"), use_norm=opt.slap_norm)
     else:
         logger.info("Using mesh laplacian loss")
         self.shape_laplacian = MeshLaplacianLoss(
             torch.nn.MSELoss(reduction="none"),
             use_cot=True,
             use_norm=True,
             consistent_topology=True,
             precompute_L=True)
     self.p2f_loss = LocalFeatureLoss(16,
                                      torch.nn.MSELoss(reduction="none"))
示例#4
0
def sample_pts(input_dir, output_dir, overwrite_pts=False):
    pool1 = ThreadPool(processes=N_CORE)
    results = []
    source_files = glob(os.path.join(input_dir, "*Sa*.ply"))+glob(os.path.join(input_dir, "*Sb.ply"))
    if len(source_files) == 0:
        source_files = glob(os.path.join(input_dir, "*Sa*.obj"))+glob(os.path.join(input_dir, "*Sb.obj"))
    logger.info("Sampling {} meshes into {}".format(len(source_files),output_dir))
    for source in source_files:
        source_filename = os.path.splitext(os.path.basename(source))[0]
        # ./MeshSample -v source output
        output_file = os.path.join(output_dir, "{}.pts".format(source_filename))
        if overwrite_pts or not os.path.isfile(output_file):
            # results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} -s1 {} {}".format(N_SAMPLE, source, output_file),)))
            results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} {} {}".format(N_SAMPLE, source, output_file),)))

    # Close the pool
    pool1.close()
    pool1.join()
    for result in results:
        out, err = result.get()
        if len(err) > 0:
            print("err: {}".format(repr(err)))
    results.clear()
    return True
示例#5
0
def svr_sample_pts(input_dir, output_dir, overwrite_pts=True):
    pool1 = ThreadPool(processes=N_CORE)
    results = []
    source_files = glob(os.path.join(input_dir, "*Sb.ply"))
    if len(source_files) == 0:
        source_files = glob(os.path.join(input_dir, "*Sb.obj"))
    target_names = np.unique(np.array([os.path.basename(p).split("-")[1] for p in source_files])).tolist()
    logger.info("Sampling {} target meshes into {}".format(len(target_names),output_dir))
    for target in target_names:
        source_filename = glob(os.path.join(input_dir, "*{}-Sb.*".format(target)))[0]
        # ./MeshSample -v source output
        output_file = os.path.join(output_dir, "{}.pts".format(target))
        if overwrite_pts or not os.path.isfile(output_file):
            # results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} -s1 {} {}".format(N_SAMPLE, source_filename, output_file),)))
            results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} {} {}".format(N_SAMPLE, source_filename, output_file),)))
    source_files = glob(os.path.join(input_dir, "*Sab*.ply"))
    if len(source_files) == 0:
        source_files = glob(os.path.join(input_dir, "*Sab*.obj"))
    logger.info("Sampling {} output meshes into {}".format(len(source_files),output_dir))
    for source in source_files:
        source_filename = os.path.splitext(os.path.basename(source))[0]
        # ./MeshSample -v source output
        output_file = os.path.join(output_dir, "{}.pts".format(source_filename))
        if overwrite_pts or not os.path.isfile(output_file):
            # results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} -s1 {} {}".format(N_SAMPLE, source, output_file),)))
            results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} {} {}".format(N_SAMPLE, source, output_file),)))
    if len(target_names) > 0:
        return True

    source_files = glob(os.path.join(input_dir, "*.obj"))+glob(os.path.join(input_dir, "*.ply"))
    logger.info("Sampling {} output meshes into {}".format(len(source_files),output_dir))
    for source in source_files:
        source_filename = os.path.splitext(os.path.basename(source))[0]
        # ./MeshSample -v source output
        output_file = os.path.join(output_dir, "{}.pts".format(source_filename))
        if overwrite_pts or not os.path.isfile(output_file):
            # results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} -s1 {} {}".format(N_SAMPLE, source, output_file),)))
            results.append(pool1.apply_async(call_proc, (SAMPLE_BIN + " -n{} {} {}".format(N_SAMPLE, source, output_file),)))

    # Close the pool
    pool1.close()
    pool1.join()
    for result in results:
        out, err = result.get()
        if len(err) > 0:
            print("err: {}".format(repr(err)))
    results.clear()
    return True
示例#6
0
    def __init__(self,
                 opt,
                 dim,
                 num_points,
                 bottleneck_size,
                 template_vertices=None,
                 template_faces=None,
                 source_vertices=None,
                 source_faces=None,
                 **kwargs):
        super().__init__()
        self.opt = opt
        self.initialized = False
        self.dim = dim
        ###### shared encoder ########
        if opt.pointnet2:
            self.encoder = PointNet2feat(dim=dim,
                                         num_points=opt.num_point,
                                         bottleneck_size=bottleneck_size,
                                         normalization=opt.normalization)
            bottleneck_size = self.encoder.bottleneck_size
        else:
            self.encoder = nn.Sequential(
                PointNetfeat(dim=dim,
                             num_points=opt.num_point,
                             bottleneck_size=bottleneck_size,
                             normalization=opt.normalization),
                Linear(bottleneck_size,
                       bottleneck_size,
                       activation="tanh",
                       normalization=opt.normalization))

        ###### save template and source to buffer ########
        self.initialize_buffers(template_vertices, template_faces,
                                source_vertices, source_faces)
        self.prob = None
        # print("!!!code_scale", self.code_scale)

        ###### cage refinement and cage deformation ########
        if opt.optimize_template:
            self.template_vertices = nn.Parameter(self.template_vertices)
            logger.info("optimize template cage as parameters")
        if opt.deform_template:
            logger.info("optimize template cage with point fold")
            self.nc_decoder = DeformationSharedMLP(
                dim, normalization=opt.normalization, residual=opt.c_residual)

        if opt.atlas:
            self.nd_decoder = MultiFoldPointGen(
                (bottleneck_size +
                 dim if opt.use_correspondence else bottleneck_size),
                dim,
                n_fold=opt.n_fold,
                normalization=opt.normalization,
                concat_prim=opt.concat_prim,
                return_aux=False,
                residual=opt.d_residual)
        else:
            self.nd_decoder = MLPDeformer(
                dim=dim,
                bottleneck_size=bottleneck_size,
                npoint=self.template_vertices.shape[-1],
                residual=opt.d_residual,
                normalization=opt.normalization)
示例#7
0
        return ("", "")


if __name__ == "__main__":
    N_CORE = 8
    N_POINT = 5000
    print("Using %d of %d cores" % (N_CORE, multiprocessing.cpu_count()))

    source_dir = sys.argv[1]  # input directoy
    output_dir = sys.argv[2]  # output directory

    ###################################
    # 1. gather source and target
    ###################################
    source_files = find_files(source_dir, 'obj')
    logger.info("Found {} source files".format(len(source_files)))

    os.makedirs(output_dir, exist_ok=True)

    ###################################
    # Sample
    ###################################
    pool = ThreadPool(processes=N_CORE)
    results = []
    for input_file in source_files:
        source_name = os.path.splitext(os.path.basename(input_file))[0]
        my_out_dir = os.path.join(
            output_dir, os.path.relpath(os.path.dirname(input_file),
                                        source_dir))
        os.makedirs(my_out_dir, exist_ok=True)
        output_file = os.path.join(my_out_dir, source_name + ".obj")
示例#8
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))
    source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float)
    source_face = dataset.mesh_face.unsqueeze(0)
    cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float)
    cage_face = dataset.cage_face.unsqueeze(0)
    mesh = Mesh(vertices=cage_shape[0], faces=cage_face[0])
    build_gemm(mesh, cage_face[0])
    cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
    cage_edges = edge_vertex_indices(cage_face[0])

    # network
    net = networks.FixedSourceDeformer(
        opt,
        3,
        opt.num_point,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_shape.transpose(1, 2),
        template_faces=cage_face,
        source_vertices=source_shape.transpose(1, 2),
        source_faces=source_face).cuda()
    print(net)
    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)
    net.train()

    all_losses = losses.AllLosses(opt)

    # optimizer
    optimizer = torch.optim.Adam([{
        'params': net.nd_decoder.parameters()
    }, {
        "params": net.encoder.parameters()
    }],
                                 lr=opt.lr)

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "network2.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_Sa.ply".format(0)),
        net.source_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.source_faces[0].detach().cpu())
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_template.ply".format(0)),
        net.template_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.template_faces[0].detach().cpu())

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                max(int(opt.nepochs * 0.75),
                                                    1),
                                                gamma=0.5,
                                                last_epoch=-1)

    # train
    net.train()
    t = 0
    start_epoch = 0
    warmed_up = False
    mvc_weight = opt.mvc_weight
    opt.mvc_weight = 0

    os.makedirs(opt.log_dir, exist_ok=True)
    running_avg_loss = -1
    log_file = open(os.path.join(opt.log_dir, "loss_log.txt"), "a")
    log_interval = min(max(len(dataloader) // 5, 50), 200)
    save_interval = max(opt.nepochs // 10, 1)

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for epoch_t, data in enumerate(dataloader):
                progress = epoch_t / len(dataloader) + epoch
                warming_up = progress < opt.warmup_epochs
                if (opt.deform_template or opt.optimize_template) and (
                        progress >= opt.warmup_epochs) and (not warmed_up):
                    if opt.deform_template:
                        optimizer.add_param_group({
                            'params':
                            net.nc_decoder.parameters(),
                            'lr':
                            0.1 * opt.lr
                        })
                    if opt.optimize_template:
                        optimizer.add_param_group({
                            'params': net.template_vertices,
                            'lr': 0.1 * opt.lr
                        })
                    warmed_up = True
                    # start to compute mvc weight
                    opt.mvc_weight = mvc_weight
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="warmed_up")

                ############# get data ###########
                data = dataset.uncollate(data)
                data["cage_edge_points"] = cage_edge_points
                data["cage_edges"] = cage_edges
                data["source_shape"] = net.source_vertices.detach()
                data["source_face"] = net.source_faces.detach()

                ############# run network ###########
                optimizer.zero_grad()
                target_shape_t = data["target_shape"].transpose(1, 2)
                sample_idx = None

                if "sample_idx" in data:
                    sample_idx = data["sample_idx"]
                    if data["source_normals"] is not None:
                        data["source_normals"] = torch.gather(
                            data["source_normals"], 1,
                            sample_idx.unsqueeze(-1).expand(-1, -1, 3))

                outputs = net(target_shape_t, sample_idx)
                if opt.sfnormal_weight > 0 and ("source_mesh" in data
                                                and "source_mesh" is not None):
                    if outputs["deformed"].shape[1] == data[
                            "source_mesh"].shape[1]:
                        outputs["deformed_hr"] = outputs["deformed"]
                    else:
                        outputs["deformed_hr"] = deform_with_MVC(
                            outputs["cage"].expand(
                                data["source_mesh"].shape[0], -1, -1).detach(),
                            outputs["new_cage"], outputs["cage_face"].expand(
                                data["source_mesh"].shape[0], -1,
                                -1), data["source_mesh"])
                data["source_shape"] = outputs["source_shape"]

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 10 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        not warmed_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)
                    # save_ply(data["target_shape"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sb.ply".format(t)))
                    # save_ply(outputs["deformed"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sab.ply".format(t)))
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage1.ply".format(t)),
                    #               outputs["cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage2.ply".format(t)),
                    #               outputs["new_cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 10*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()

                if opt.alternate_cd:
                    optimize_C = (progress > opt.warmup_epochs) and (
                        t % (opt.c_step + opt.d_step)) > opt.d_step
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                        net.encoder.zero_grad()
                    else:
                        try:
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                # clamp_gradient_norm(net, 1)
                optimizer.step()
                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test_all(net=net)
示例#9
0
def create_comparison_montage(render_dirs,
                              labels,
                              output_dir,
                              output_cage=False):
    files = find_files(render_dirs[0], "png")
    source_files = [p for p in files if "Sb.png" in p]
    logger.info("Found {} files".format(len(source_files)))
    pool = ThreadPool(processes=4)
    results = []
    os.makedirs(output_dir, exist_ok=True)
    for source in source_files:
        sname, tname = os.path.basename(source).split("-")[:2]
        output_file = os.path.join(output_dir,
                                   "{}-{}.png".format(sname, tname))

        images = [
            glob(os.path.join(cur_dir, "{}-{}-Sab*.png".format(sname, tname)))
            for cur_dir in render_dirs
        ]
        if not all([len(im_found) > 0 for im_found in images]):
            indices = [i for i, x in enumerate(images) if len(x) == 0]
            logger.warn(
                "", "Cannot find {} in {}".format(
                    "{}-{}-Sab*.png".format(sname, tname),
                    ", ".join([render_dirs[i] for i in indices])))
            continue

        images = [im_found[0] for im_found in images]
        cages = []
        mylabels = labels[:]
        if output_cage:
            # find cages
            for i, dir_img in enumerate(zip(render_dirs, images)):
                cur_dir, image = dir_img
                cage1 = glob(image.replace("Sab", "cage1"))
                cage2 = glob(image.replace("Sab", "cage2"))
                if len(cage1) > 0 and len(cage2) > 0:
                    cages.append((i, cage1[0], cage2[0]))
            # insert cages to the correct position in images
            cnt = 0
            for offset, cage1, cage2 in cages:
                images.insert(offset + cnt + 1, cage1)
                images.insert(offset + cnt + 2, cage2)
                mylabels.insert(offset + cnt + 1,
                                mylabels[offset + cnt] + "_cage1")
                mylabels.insert(offset + cnt + 2,
                                mylabels[offset + cnt] + "_cage2")
                cnt += 2
            assert (len(images) == len(mylabels))
            image_strs = " ".join([
                "-label {} {}".format(l, i) for l, i in zip(mylabels, images)
            ])
        else:
            image_strs = " ".join([
                "-label {} {}".format(l, i) for l, i in zip(mylabels, images)
            ])

        num_cols = len(images) + 2
        target = source.replace("Sa", "Sb")
        results.append(
            pool.apply_async(call_proc, (
                "montage -geometry +0+0 -gravity Center -crop 420x450+0+0 +repage -tile {}x1 -label input {} -label target {} {} {}"
                .format(num_cols, target, source, image_strs, output_file), )))
    # Close the pool and wait for each running task to complete
    pool.close()
    pool.join()
    for result in results:
        out, err = result.get()
        if len(err) > 0:
            print("err: {}".format(err))
示例#10
0
def create_two_row_comparison_montage(render_dirs,
                                      labels,
                                      output_dir,
                                      output_cage=True):
    files = find_files(render_dirs[0], "png")
    source_files = [p for p in files if "Sa.png" in p]
    logger.info("Found {} files".format(len(source_files)))
    pool = ThreadPool(processes=4)
    results = []
    os.makedirs(output_dir, exist_ok=True)
    # first concatenate cage1-cage2
    for cur_dir in render_dirs:
        cage1s = glob(os.path.join(cur_dir, "*cage1*.png"))
        cage2s = [f.replace("cage1", "cage2") for f in cage1s]
        for cage1, cage2 in zip(cage1s, cage2s):
            if not (os.path.isfile(cage1) and os.path.isfile(cage2)):
                continue
            output_file = os.path.join(cage1.replace("cage1", "cages"))
            results.append(
                pool.apply_async(call_proc, (
                    "montage -geometry +0+0 -gravity Center -crop 400x400+0+0 +repage -tile 2x1 {} {} {}"
                    .format(cage1, cage2, output_file), )))
    pool.close()
    pool.join()
    for result in results:
        out, err = result.get()
        if len(err) > 0:
            print("err: {}".format(err))
    results.clear()
    pool = ThreadPool(processes=4)
    for source in source_files:
        sname, tname = os.path.basename(source).split("-")[:2]
        output_file = os.path.join(output_dir,
                                   "{}-{}.png".format(sname, tname))

        images = [
            glob(os.path.join(cur_dir, "{}-{}-Sab*.png".format(sname, tname)))
            for cur_dir in render_dirs
        ]
        if not all([len(im_found) > 0 for im_found in images]):
            indices = [i for i, x in enumerate(images) if len(x) == 0]
            logger.warn(
                "", "Cannot find {} in {}".format(
                    "{}-{}-Sab*.png".format(sname, tname),
                    ", ".join([render_dirs[i] for i in indices])))

        images = [
            "null:" if len(im_found) == 0 else im_found[0]
            for im_found in images
        ]

        cages = [
            glob(os.path.join(cur_dir, "{}-{}-cages*.png".format(sname,
                                                                 tname)))
            for cur_dir in render_dirs
        ]
        if not all([len(im_found) > 0 for im_found in cages]):
            indices = [i for i, x in enumerate(cages) if len(x) == 0]
            logger.warn(
                "", "Cannot find {} in {}".format(
                    "{}-{}-cages*.png".format(sname, tname),
                    ", ".join([render_dirs[i] for i in indices])))

        cages = [
            "null:" if len(im_found) == 0 else im_found[0]
            for im_found in cages
        ]

        mylabels = labels[:]

        assert (len(images) == len(mylabels))
        image_strs = " ".join(images)
        cage_str = " ".join(
            ["-label {} {}".format(l, i) for l, i in zip(mylabels, cages)])

        num_cols = len(images) + 1
        target = source.replace("Sa", "Sb")
        results.append(
            pool.apply_async(
                call_proc,
                ("montage -geometry \'420x400>+0+0\' -tile {}x2 {} "
                 "{} {} {} {}".format(num_cols, source, image_strs, target,
                                      cage_str, output_file), )))
    # Close the pool and wait for each running task to complete
    pool.close()
    pool.join()
    for result in results:
        out, err = result.get()
        if len(err) > 0:
            print("err: {}".format(err))

    for cur_dir in render_dirs:
        call_proc("rm {}".format(os.path.join(cur_dir, "*.cages*.png")))
示例#11
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=tolerating_collate,
        num_workers=2,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    if opt.dim == 3:
        # cage (1,N,3)
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        cage_edge_points_list = []
        cage_edges_list = []
        for F in init_cage_Fs:
            mesh = Mesh(vertices=init_cage_V[0], faces=F[0])
            build_gemm(mesh, F[0])
            cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
            cage_edge_points_list.append(cage_edge_points)
            cage_edges_list = [edge_vertex_indices(F[0])]
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    # network
    net = networks.NetworkFull(
        opt,
        dim=opt.dim,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_V_t,
        template_faces=init_cage_Fs[-1],
    ).cuda()

    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)

    all_losses = losses.AllLosses(opt)
    # optimizer
    optimizer = torch.optim.Adam([{
        "params": net.encoder.parameters()
    }, {
        "params": net.nd_decoder.parameters()
    }, {
        "params": net.merger.parameters()
    }],
                                 lr=opt.lr)

    if opt.full_net:
        optimizer.add_param_group({
            'params': net.nc_decoder.parameters(),
            'lr': 0.1 * opt.lr
        })
    if opt.optimize_template:
        optimizer.add_param_group({
            'params': net.template_vertices,
            'lr': opt.lr
        })

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                int(opt.nepochs * 0.4),
                                                gamma=0.1,
                                                last_epoch=-1)

    # train
    net.train()
    start_epoch = 0
    t = 0

    steps_C = 20
    steps_D = 20

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "networks.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "option.py"),
                 opt.log_dir)
    print(net)

    log_file = open(os.path.join(opt.log_dir, "training_log.txt"), "a")
    log_file.write(str(net) + "\n")

    log_interval = max(len(dataloader) // 5, 50)
    save_interval = max(opt.nepochs // 10, 1)
    running_avg_loss = -1

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for t_epoch, data in enumerate(dataloader):
                warming_up = epoch < opt.warmup_epochs
                progress = t_epoch / len(dataloader) + epoch
                optimize_C = (t % (steps_C + steps_D)) > steps_D

                ############# get data ###########
                data = dataset.uncollate(data)
                data = crisscross_input(data)
                if opt.dim == 3:
                    data["cage_edge_points"] = cage_edge_points_list[-1]
                    data["cage_edges"] = cage_edges_list[-1]
                source_shape, target_shape = data["source_shape"], data[
                    "target_shape"]

                ############# blending ############
                if opt.blend_style:
                    blend_alpha = torch.rand(
                        (source_shape.shape[0], 1),
                        dtype=torch.float32).to(device=source_shape.device)
                else:
                    blend_alpha = 1.0
                data["alpha"] = blend_alpha

                ############# run network ###########
                optimizer.zero_grad()
                # optimizer_C.zero_grad()
                # optimizer_D.zero_grad()
                source_shape_t = source_shape.transpose(1, 2)
                target_shape_t = target_shape.transpose(1, 2)
                outputs = net(source_shape_t, target_shape_t, data["alpha"])

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 5 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        warming_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 5*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()
                if epoch < opt.warmup_epochs:
                    try:
                        net.nc_decoder.zero_grad()
                        net.encoder.zero_grad()
                    except AttributeError:
                        net.template_vertices.grad.zero_()

                if opt.alternate_cd:
                    optimize_C = (epoch > opt.warmup_epochs) and (
                        epoch % (opt.c_epoch + opt.d_epoch)) > opt.d_epoch
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                    else:
                        try:
                            net.encoder.zero_grad()
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                clamp_gradient(net, 0.1)
                optimizer.step()

                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()
            if opt.eval:
                try:
                    test(net=net, save_subdir="epoch_{}".format(epoch))
                except Exception as e:
                    traceback.print_exc(file=sys.stdout)
                    logger.warn("Failed to run test", str(e))

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test(net=net)
示例#12
0
def optimize(opt):
    """
    weights are the same with the original source mesh
    target=net(old_source)
    """
    # load new target
    if opt.is_poly:
        target_mesh = om.read_polymesh(opt.model)
    else:
        target_mesh = om.read_trimesh(opt.model)
    target_shape_arr = target_mesh.points()
    target_shape = target_shape_arr.copy()
    target_shape = torch.from_numpy(
        target_shape[:, :3].astype(np.float32)).cuda()
    target_shape.unsqueeze_(0)

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    cage_v = states["template_vertices"].transpose(1, 2).cuda()
    cage_f = states["template_faces"].cuda()
    shape_v = states["source_vertices"].transpose(1, 2).cuda()
    shape_f = states["source_faces"].cuda()

    if os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        new_label_path = opt.model.replace(os.path.splitext(opt.model)[1], ".picked")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {} and {}".format(orig_label_path, new_label_path))
        import pandas as pd
        new_label = pd.read_csv(new_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        new_label_name = new_label.iloc[:,5].tolist()
        new_to_orig_idx = []
        for i, name in enumerate(new_label_name):
            matched_idx = orig_label_name[orig_label_name==name].index
            if matched_idx.size == 1:
                new_to_orig_idx.append((i, matched_idx[0]))
        new_to_orig_idx = np.array(new_to_orig_idx)
        if new_label.shape[1] == 10:
            new_vidx = new_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,0]]
            target_points = target_shape[:, new_vidx, :]
        else:
            new_label_points = torch.from_numpy(new_label.iloc[:,6:9].to_numpy().astype(np.float32))
            target_points = new_label_points.unsqueeze(0).cuda()
            target_points, new_vidx, _ = faiss_knn(1, target_points, target_shape, NCHW=False)
            target_points = target_points.squeeze(2) # B,N,3
            new_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            new_label.to_csv(new_label_path, sep=" ", header=[str(new_label.shape[0])]+[""]*(new_label.shape[1]-1), index=False)
            target_points = target_points[:, new_to_orig_idx[:,0], :]

        target_points = target_points.cuda()
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).float()
        if orig_label.shape[1] == 10:
            orig_vidx = orig_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,1]]
            source_points = source_shape[:, orig_vidx, :]
        else:
            orig_label_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = orig_label_points.unsqueeze(0)
            # find the closest point on the original meshes
            source_points, new_vidx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            orig_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            orig_label.to_csv(orig_label_path, sep=" ", header=[str(orig_label.shape[0])]+[""]*(orig_label.shape[1]-1), index=False)
            source_points = source_points[:,new_to_orig_idx[:,1],:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
        source_points = source_points.cuda()
        # # shift target so that the belly match
        # try:
        #     orig_bellyUp_idx = orig_label_name[orig_label_name=="bellUp"].index[0]
        #     orig_bellyUp = orig_label_points[orig_bellyUp_idx, :]
        #     new_bellyUp_idx = [i for i, i2 in new_to_orig_idx if i2==orig_bellyUp_idx][0]
        #     new_bellyUp = new_label_points[new_bellyUp_idx,:]
        #     target_points += (orig_bellyUp - new_bellyUp)
        # except Exception as e:
        #     logger.warn("Couldn\'t match belly to belly")
        #     traceback.print_exc(file=sys.stdout)

        # source_points[0] = center_bounding_box(source_points[0])[0]
    elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        logger.info("Assuming Faust model")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {}".format(orig_label_path))
        import pandas as pd
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).cuda().float()
        if orig_label.shape[1] == 10:
            idx = torch.from_numpy(orig_label.iloc[:,9].to_numpy()).long()
            source_points = source_shape[:,idx,:]
            target_points = target_shape[:,idx,:]
        else:
            source_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = source_points.unsqueeze(0).cuda()
            # find the closest point on the original meshes
            source_points, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            idx = idx.squeeze(-1)
            target_points = target_shape[:,idx,:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
    elif opt.corres_idx is None and target_shape.shape[1] == shape_v.shape[1]:
        logger.info("No correspondence provided, assuming registered Faust models")
        # corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda()
        corresp_v = torch.unique(torch.randint(0, shape_v.shape[1], (4800,))).cuda()
        target_points = torch.index_select(target_shape, 1, corresp_v)
        source_points = torch.index_select(shape_v, 1, corresp_v)

    target_shape[0], target_center, target_scale = center_bounding_box(target_shape[0])
    _, _, source_scale = center_bounding_box(shape_v[0])
    target_scale_factor = (source_scale/target_scale)[1]
    target_shape *= target_scale_factor
    target_points -= target_center
    target_points = (target_points*target_scale_factor).detach()
    # make sure test use the normalized
    target_shape_arr[:] = target_shape[0].cpu().numpy()
    om.write_mesh(os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj"), target_mesh)
    opt.model = os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj")
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.obj"),
                         shape_v[0].cpu().numpy(), shape_f[0].cpu().numpy())
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "cage-initial.obj"),
                         cage_v[0].cpu().numpy(), cage_f[0].cpu().numpy())
    save_ply(target_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "target_points.ply"))
    save_ply(source_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "source_points.ply"))
    logger.info("Optimizing for {} corresponding vertices".format(
        target_points.shape[1]))

    cage_init = cage_v.clone().detach()
    lap_loss = MeshLaplacianLoss(torch.nn.MSELoss(reduction="none"), use_cot=True,
                                 use_norm=True, consistent_topology=True, precompute_L=True)
    mvc_reg_loss = MVCRegularizer(threshold=50, beta=1.0, alpha=0.0)
    cage_v.requires_grad_(True)
    optimizer = torch.optim.Adam([cage_v], lr=opt.lr, betas=(0.5, 0.9))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(opt.nepochs*0.4), gamma=0.5, last_epoch=-1)

    if opt.dim == 3:
        weights_ref = mean_value_coordinates_3D(
            source_points, cage_init, cage_f, verbose=False)
    else:
        raise NotImplementedError

    for t in range(opt.nepochs):
        optimizer.zero_grad()
        weights = mean_value_coordinates_3D(
            target_points, cage_v, cage_f, verbose=False)
        loss_mvc = torch.mean((weights-weights_ref)**2)
        # reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4
        reg = 0
        if opt.clap_weight > 0:
            reg = lap_loss(cage_init, cage_v, face=cage_f)*opt.clap_weight
            reg = reg.mean()
        if opt.mvc_weight > 0:
            reg += mvc_reg_loss(weights)*opt.mvc_weight

        # weight regularizer with the shape difference
        # dist = torch.sum((source_points - target_points)**2, dim=-1)
        # weights = torch.exp(-dist)
        # reg = reg*weights*0.1

        loss = loss_mvc + reg
        if (t+1) % 50 == 0:
            print("t {}/{} mvc_loss: {} reg: {}".format(t,
                                                        opt.nepochs, loss_mvc.item(), reg.item()))

        if loss_mvc.item() < 5e-6:
            break
        loss.backward()
        optimizer.step()
        scheduler.step()

    return cage_v, cage_f
示例#13
0
def test_all(opt, new_cage_shape):
    opt.phase = "test"
    opt.target_model = None
    print(opt.model)

    if opt.is_poly:
        source_mesh = om.read_polymesh(opt.model)
    else:
        source_mesh = om.read_trimesh(opt.model)

    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False,
                                             collate_fn=tolerating_collate,
                                             num_workers=0, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id))

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    # states["template_vertices"] = new_cage_shape.transpose(1, 2)
    # states["source_vertices"] = new_source.transpose(1,2)
    # states["source_faces"] = new_source_face
    # new_source_face = states["source_faces"]

    om.write_mesh(os.path.join(opt.log_dir, opt.subdir,
                               "template-Sa.ply"), source_mesh)

    net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size,
                                       template_vertices=states["template_vertices"], template_faces=states["template_faces"].cuda(),
                                       source_vertices=states["source_vertices"], source_faces=states["source_faces"]).cuda()
    load_network(net, states)

    source_points = torch.from_numpy(
        source_mesh.points().copy()).float().cuda().unsqueeze(0)
    with torch.no_grad():
        # source_face = net.source_faces.detach()
        for i, data in enumerate(dataloader):
            data = dataset.uncollate(data)

            target_shape, target_filename = data["target_shape"], data["target_file"]
            logger.info("", data["target_file"][0])

            sample_idx = None
            if "sample_idx" in data:
                sample_idx = data["sample_idx"]

            outputs = net(target_shape.transpose(1, 2), cage_only=True)
            if opt.d_residual:
                cage_offset = outputs["new_cage"]-outputs["cage"]
                outputs["cage"] = new_cage_shape
                outputs["new_cage"] = new_cage_shape+cage_offset

            deformed = deform_with_MVC(outputs["cage"], outputs["new_cage"], outputs["cage_face"].expand(
                outputs["cage"].shape[0], -1, -1), source_points)

            for b in range(deformed.shape[0]):
                t_filename = os.path.splitext(target_filename[b])[0]
                source_mesh_arr = source_mesh.points()
                source_mesh_arr[:] = deformed[0].cpu().detach().numpy()
                om.write_mesh(os.path.join(
                    opt.log_dir, opt.subdir, "template-{}-Sab.obj".format(t_filename)), source_mesh)
                # if data["target_face"] is not None and data["target_mesh"] is not None:
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sa.ply".format(t_filename)),
                #             source_mesh[0].detach().cpu(), source_face[b].detach().cpu())
                pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sb.ply".format(t_filename)),
                                     data["target_mesh"][b].detach().cpu(), data["target_face"][b].detach().cpu())
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab.ply".format(t_filename)),
                #             deformed[b].detach().cpu(), source_face[b].detach().cpu())

                # else:
                #     save_ply(source_mesh[0].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sa.ply".format(t_filename)))
                #     save_ply(target_shape[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sb.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())
                #     save_ply(deformed[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sab.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())

                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage1.ply".format(t_filename)),
                    outputs["cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )
                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage2.ply".format(t_filename)),
                    outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )

            # if opt.opt_lap and deformed.shape[1] == source_mesh.shape[1]:
            #     deformed = optimize_lap(opt, source_mesh, deformed, source_face)
            #     for b in range(deformed.shape[0]):
            #         pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab-optlap.ply".format(t_filename)),
            #                                 deformed[b].detach().cpu(), source_face[b].detach().cpu())

            if i % 20 == 0:
                logger.success("[{}/{}] Done".format(i, len(dataloader)))

    dataset.render_result(os.path.join(opt.log_dir, opt.subdir))
示例#14
0
def evaluate_deformation(result_dirs, resample, mse=False, overwrite_pts=False):
    CD_name = "MSE" if mse else "CD"
    if isinstance(result_dirs, str):
        result_dirs = [result_dirs]
    ########## initialize ############
    eval_result = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))  # eval_result[metric[folder[file]]]
    cotLap = CotLaplacian()
    uniLap = UniformLaplacian()
    if resample:
        for cur_dir in result_dirs:
            pts_dir = os.path.join(cur_dir, "eval_pts")
            os.makedirs(pts_dir, exist_ok=True)
            result = sample_pts(cur_dir, pts_dir, overwrite_pts)
            if not result:
                logger.warn("Failed to sample points in {}".format(cur_dir))

    ########## load results ###########
    # find Sa.ply, Sb.ply and a list of Sab.ply
    ###################################
    [print("dir{}: {}".format(i, name)) for i, name in enumerate(result_dirs)]
    files = glob(os.path.join(result_dirs[0], "*.ply"))+glob(os.path.join(result_dirs[0], "*.obj"))
    source_files = [p for p in files if "Sa." in p]
    target_files = [p.replace("Sa", "Sb") for p in source_files]
    assert(all([os.path.isfile(f) for f in target_files]))
    logger.info("Found {} source target pairs".format(len(source_files)))
    ########## evaluation ############
    print("{}: {}".format("filename".ljust(70), " | ".join(["dir{}".format(i).rjust(45) for i in range(len(result_dirs))])))
    print("{}: {}".format(" ".ljust(70), " | ".join([(CD_name+"/CotLap/CotLapNorm/UniLap/UniLapNorm").rjust(45) for i in range(len(result_dirs))])))
    cnt = 0
    for source, target in zip(source_files, target_files):
        source_filename = os.path.basename(source)
        target_filename = os.path.basename(target)
        try:
            if resample:
                source_pts_file = os.path.join(result_dirs[0], "eval_pts", source_filename[:-4]+".pts")
                target_pts_file = os.path.join(result_dirs[0], "eval_pts", target_filename[:-4]+".pts")
                if not os.path.isfile(source_pts_file):
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(source_pts_file))
                    continue
                if not os.path.isfile(target_pts_file):
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(target_pts_file))
                source_pts = load(source_pts_file)
                target_pts = load(target_pts_file)
                source_pts = torch.from_numpy(source_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                target_pts = torch.from_numpy(target_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()

            ext = os.path.splitext(source_filename)[1]
            sab_str = source_filename.replace("Sa"+ext, "Sab*")
            outputs = [glob( os.path.join(cur_dir, sab_str) ) for cur_dir in result_dirs]
            if not all([len(o) > 0 for o in outputs]):
                logger.warn("Couldn\'t find {} in all folders, skipping to process the next".format(sab_str))
                continue

            # read Sa, Sb
            source_shape, source_face = read_trimesh(source, clean=False)
            target_shape, _ = read_trimesh(target, clean=False)
            source_shape = torch.from_numpy(source_shape[:,:3].astype(np.float32)).unsqueeze(0).cuda()
            target_shape = torch.from_numpy(target_shape[:,:3].astype(np.float32)).unsqueeze(0).cuda()
            source_face = torch.from_numpy(source_face[:,:3].astype(np.int64)).unsqueeze(0).cuda()

            # laplacian for source (fixed)
            cotLap.L = None
            ref_lap = cotLap(source_shape, source_face)
            ref_lap_norm = torch.norm(ref_lap, dim=-1)

            uniLap.L = None
            ref_ulap = uniLap(source_shape, source_face)
            ref_ulap_norm = torch.norm(ref_ulap, dim=-1)

            filename = os.path.splitext(os.path.basename(source))[0]
            for output, cur_dir in zip(outputs, result_dirs):
                if len(output)>1:
                    logger.warn("Found multiple outputs {}. Using the last one".format(output))
                if len(output) == 0:
                    logger.warn("Found no outputs for {} in {}".format(sab_str, cur_dir))
                    continue
                output = output[-1]

                output_shape, output_face = read_trimesh(output, clean=False)
                output_shape = torch.from_numpy(output_shape[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                output_face = torch.from_numpy(output_face[:,:3].astype(np.int64)).unsqueeze(0).cuda()

                # chamfer
                if not mse:
                    if resample:
                        output_filename = os.path.basename(output)
                        output_pts_file = os.path.join(cur_dir, "eval_pts", output_filename[:-4]+".pts")
                        output_pts = load(output_pts_file)
                        output_pts = torch.from_numpy(output_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                        dist12, dist21, _, _ = nndistance(target_pts, output_pts)
                    else:
                        dist12, dist21, _, _ = nndistance(target_shape, output_shape)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1))
                    eval_result[cur_dir][CD_name][filename] = cd
                    eval_result[cur_dir][CD_name]["avg"] += (cd - eval_result[cur_dir][CD_name]["avg"])/(eval_result[cur_dir][CD_name]["cnt"]+1)
                    eval_result[cur_dir][CD_name]["cnt"] += 1
                else:
                    mse = torch.sum((output_shape-target_shape)**2, dim=-1).mean().item()
                    eval_result[cur_dir][CD_name][filename] = mse
                    eval_result[cur_dir][CD_name]["avg"] += (mse - eval_result[cur_dir][CD_name]["avg"])/(eval_result[cur_dir][CD_name]["cnt"]+1)
                    eval_result[cur_dir][CD_name]["cnt"] += 1


                lap = cotLap(output_shape)
                lap_loss = torch.mean((lap-ref_lap)**2).item()
                eval_result[cur_dir]["CotLap"][filename] = lap_loss
                eval_result[cur_dir]["CotLap"]["avg"] += (lap_loss - eval_result[cur_dir]["CotLap"]["avg"])/(eval_result[cur_dir]["CotLap"]["cnt"]+1)
                eval_result[cur_dir]["CotLap"]["cnt"] += 1

                lap_norm = torch.norm(lap, dim=-1)
                lap_norm_loss = torch.mean((lap_norm-ref_lap_norm).abs()).item()
                eval_result[cur_dir]["CotLapNorm"][filename] = lap_norm_loss
                eval_result[cur_dir]["CotLapNorm"]["avg"] += (lap_norm_loss - eval_result[cur_dir]["CotLapNorm"]["avg"])/(eval_result[cur_dir]["CotLapNorm"]["cnt"]+1)
                eval_result[cur_dir]["CotLapNorm"]["cnt"] += 1

                lap = uniLap(output_shape)
                lap_loss = torch.mean((lap-ref_ulap)**2).item()
                eval_result[cur_dir]["UniLap"][filename] = lap_loss
                eval_result[cur_dir]["UniLap"]["avg"] += (lap_loss - eval_result[cur_dir]["UniLap"]["avg"])/(eval_result[cur_dir]["UniLap"]["cnt"]+1)
                eval_result[cur_dir]["UniLap"]["cnt"] += 1

                lap_norm = torch.norm(lap, dim=-1)
                lap_norm_loss = torch.mean((lap_norm-ref_ulap_norm).abs()).item()
                eval_result[cur_dir]["UniLapNorm"][filename] = lap_norm_loss
                eval_result[cur_dir]["UniLapNorm"]["avg"] += (lap_norm_loss - eval_result[cur_dir]["UniLapNorm"]["avg"])/(eval_result[cur_dir]["UniLapNorm"]["cnt"]+1)
                eval_result[cur_dir]["UniLapNorm"]["cnt"] += 1


            print("{}: {}".format(filename.ljust(70), " | ".join(
                ["{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}".format(
                    eval_result[cur_dir][CD_name][filename],
                    eval_result[cur_dir]["CotLap"][filename], eval_result[cur_dir]["CotLapNorm"][filename],
                    eval_result[cur_dir]["UniLap"][filename], eval_result[cur_dir]["UniLapNorm"][filename]
                    )
                for cur_dir in result_dirs]
                ).ljust(30)))
        except Exception as e:
            traceback.print_exc(file=sys.stdout)
            logger.warn("Failed to evaluation {}. Skip to process the next.".format(source_filename))


    print("{}: {}".format("AVG".ljust(70), " | ".join(
        ["{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}".format(eval_result[cur_dir][CD_name]["avg"],
            eval_result[cur_dir]["CotLap"]["avg"], eval_result[cur_dir]["CotLapNorm"]["avg"],
            eval_result[cur_dir]["UniLap"]["avg"], eval_result[cur_dir]["UniLapNorm"]["avg"],
            )
            for cur_dir in result_dirs]
        ).ljust(30)))

    ########## write evaluation ############
    for cur_dir in result_dirs:
        for metric in eval_result[cur_dir]:
            output_file = os.path.join(cur_dir, "eval_{}.txt".format(metric))
            with open(output_file, "w") as eval_file:
                for name, value in eval_result[cur_dir][metric].items():
                    if (name != "avg" and name != "cnt"):
                        eval_file.write("{} {:8.4g}\n".format(name, value))

                eval_file.write("avg {:8.4g}".format(eval_result[cur_dir][metric]["avg"]))
示例#15
0
def evaluate_svr(result_dirs, resample, overwrite_pts=False):
    """ ours is the first in the result dirs """
    if isinstance(result_dirs, str):
        result_dirs = [result_dirs]
    ########## initialize ############
    eval_result = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 1e10)))  # eval_result[metric[folder[file]]]
    avg_result = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))  # eval_result[metric[folder[file]]]

    cotLap = CotLaplacian()
    uniLap = UniformLaplacian()
    if resample and not opt.mse:
        for cur_dir in result_dirs:
            pts_dir = os.path.join(cur_dir, "eval_pts")
            os.makedirs(pts_dir, exist_ok=True)
            result = svr_sample_pts(cur_dir, pts_dir, overwrite_pts)
            if not result:
                logger.warn("Failed to sample points in {}".format(cur_dir))

    ########## load results ###########
    # find Sa.ply, Sb.ply and a list of Sab.ply
    ###################################
    [print("dir{}: {}".format(i, name)) for i, name in enumerate(result_dirs)]
    files = find_files(result_dirs[0], ["ply", "obj"])
    target_files = [p for p in files if "Sb." in p]
    target_names = np.unique(np.array([os.path.basename(p).split("-")[1] for p in target_files])).tolist()
    logger.info("Found {} target files".format(len(target_names)))

    ########## evaluation ############
    print("{}: {}".format("filename".ljust(70), " | ".join(["dir{}".format(i).rjust(20) for i in range(len(result_dirs))])))
    print("{}: {}".format(" ".ljust(70), " | ".join(["CD/HD".rjust(20) for i in range(len(result_dirs))])))
    cnt = 0
    for target in target_names:
        # 1. load ground truth
        gt_path = glob(os.path.join(result_dirs[0], "*-{}-Sb.*".format(target)))[0]
        try:
            gt_shape, gt_face = read_trimesh(gt_path, clean=False)
            if resample:
                gt_pts_file = os.path.join(result_dirs[0], "eval_pts", "{}.pts".format(target))
                if not os.path.isfile(gt_pts_file):
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(gt_pts_file))
                    continue
                gt_pts = load(gt_pts_file)
                gt_pts = torch.from_numpy(gt_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()

            ours_paths = glob(os.path.join(result_dirs[0], "*-{}-Sab.*".format(target)))
            others_path = [glob( os.path.join(cur_dir, "{}.*".format(target)) ) for cur_dir in result_dirs[1:]]

            # 2. evaluate ours, all *-{target}-Sab
            if len(ours_paths) == 0:
                logger.warn("Cound\'t find {}. Skip to process the next.".format(os.path.join(result_dirs[0], "*-{}-Sab.*".format(target))))
                continue

            for ours in ours_paths:
                # load shape and points
                output_shape, output_face = read_trimesh(ours, clean=False)
                ours = os.path.basename(ours)
                cur_dir = result_dirs[0]
                if resample:
                    output_pts_file = os.path.join(cur_dir, "eval_pts", ours[:-4]+".pts")
                    if not os.path.isfile(output_pts_file):
                        logger.warn("Cound\'t find {}. Skip to process the next source.".format(output_pts_file))
                        continue
                    output_pts = load(output_pts_file)
                    output_pts = torch.from_numpy(output_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                    # compute chamfer
                    dist12, dist21, _, _ = nndistance(gt_pts, output_pts)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())
                else:
                    dist12, dist21, _, _ = nndistance(gt_shape, output_shape)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())

                eval_result[cur_dir]["CD"][target] = min(eval_result[cur_dir]["CD"][target], cd)
                avg_result[cur_dir]["CD"]["avg"] += (cd - avg_result[cur_dir]["CD"]["avg"])/(avg_result[cur_dir]["CD"]["cnt"]+1)
                avg_result[cur_dir]["CD"]["cnt"]+=1
                eval_result[cur_dir]["HD"][target] = min(eval_result[cur_dir]["HD"][target], hd)
                avg_result[cur_dir]["HD"]["avg"] += (hd - avg_result[cur_dir]["HD"]["avg"])/(avg_result[cur_dir]["HD"]["cnt"]+1)
                avg_result[cur_dir]["HD"]["cnt"]+=1

            # 3. evaluation others
            for cur_dir in result_dirs[1:]:
                result_path = glob(os.path.join(cur_dir, "{}.*".format(target)))
                if len(result_path) == 0:
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(result_path))
                    continue
                result_path = result_path[0]
                output_shape, output_face = read_trimesh(result_path, clean=False)
                result_name = os.path.splitext(os.path.basename(result_path))[0]
                if resample:
                    output_pts_file = os.path.join(cur_dir, "eval_pts", result_name+".pts")
                    if not os.path.isfile(output_pts_file):
                        logger.warn("Cound\'t find {}. Skip to process the next source.".format(output_pts_file))
                        continue
                    output_pts = load(output_pts_file)
                    output_pts = torch.from_numpy(output_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                    # compute chamfer
                    dist12, dist21, _, _ = nndistance(gt_pts, output_pts)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())
                else:
                    dist12, dist21, _, _ = nndistance(gt_shape, output_shape)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())

                eval_result[cur_dir]["CD"][target] = min(eval_result[cur_dir]["CD"][target], cd)
                avg_result[cur_dir]["CD"]["avg"] += (cd - avg_result[cur_dir]["CD"]["avg"])/(avg_result[cur_dir]["CD"]["cnt"]+1)
                avg_result[cur_dir]["CD"]["cnt"]+=1
                eval_result[cur_dir]["HD"][target] = min(eval_result[cur_dir]["HD"][target], hd)
                avg_result[cur_dir]["HD"]["avg"] += (hd - avg_result[cur_dir]["HD"]["avg"])/(avg_result[cur_dir]["HD"]["cnt"]+1)
                avg_result[cur_dir]["HD"]["cnt"]+=1

            print("{}: {}".format(target.ljust(70), " | ".join(
                ["{:8.4g}/{:8.4g}".format(
                    eval_result[cur_dir]["CD"][target],
                    eval_result[cur_dir]["HD"][target],
                    )
                for cur_dir in result_dirs]
                ).ljust(30)))
        except Exception as e:
            traceback.print_exc(file=sys.stdout)
            logger.warn("Failed to evaluation {}. Skip to process the next.".format(target))


    print("{}: {}".format("AVG".ljust(70), " | ".join(
        ["{:8.4g}/{:8.4g}".format(
            avg_result[cur_dir]["CD"]["avg"],
            avg_result[cur_dir]["HD"]["avg"],
            )
            for cur_dir in result_dirs]
        ).ljust(30)))

    ########## write evaluation ############
    for cur_dir in result_dirs:
        for metric in eval_result[cur_dir]:
            output_file = os.path.join(cur_dir, "eval_{}.txt".format(metric))
            with open(output_file, "w") as eval_file:
                for name, value in eval_result[cur_dir][metric].items():
                    if (name != "avg" and name != "cnt"):
                        eval_file.write("{} {:8.4g}\n".format(name, value))

                eval_file.write("avg {:8.4g}".format(eval_result[cur_dir][metric]["avg"]))