def run(in_files, out_dir, ref_file):
    ref_v, ref_f = read_trimesh(ref_file)
    ref_v = torch.from_numpy(ref_v[:, :3]).float()
    ref_f = torch.from_numpy(ref_f).long()
    _, ref_area = compute_face_normals_and_areas(ref_v, ref_f)
    ref_area = torch.sum(ref_area, dim=-1)
    for in_file in in_files:
        v, f = read_trimesh(in_file)
        v = torch.from_numpy(v[:, :3]).float()
        f = torch.from_numpy(f).long()
        v, _, _ = center_bounding_box(v)
        _, area = compute_face_normals_and_areas(v, f)
        area = torch.sum(area, dim=-1)
        ratio = torch.sqrt(ref_area/area)
        ratio = ratio.unsqueeze(-1).unsqueeze(-1)
        v = v * ratio
        out_path = os.path.join(out_dir, os.path.basename(in_file))
        pymesh.save_mesh_raw(out_path, v.numpy(), f.numpy())
        print("saved to {}".format(out_path))
Beispiel #2
0
def test(net=None, subdir="test"):
    opt.phase = "test"
    if isinstance(opt.target_model, str):
        opt.target_model = [opt.target_model]

    if net is None:
        states = torch.load(opt.ckpt)
        if "states" in states:
            states = states["states"]
        if opt.template:
            cage_shape, cage_face = read_trimesh(opt.template)
            cage_shape = torch.from_numpy(
                cage_shape[:, :3]).unsqueeze(0).float()
            cage_face = torch.from_numpy(cage_face).unsqueeze(0).long()
            states["template_vertices"] = cage_shape.transpose(1, 2)
            states["template_faces"] = cage_face

        if opt.source_model:
            source_shape, source_face = read_trimesh(opt.source_model)
            source_shape = torch.from_numpy(
                source_shape[:, :3]).unsqueeze(0).float()
            source_face = torch.from_numpy(source_face).unsqueeze(0).long()
            states["source_vertices"] = source_shape.transpose(1, 2)
            states["source_faces"] = source_shape

        net = networks.FixedSourceDeformer(
            opt,
            3,
            opt.num_point,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=states["template_vertices"],
            template_faces=states["template_faces"],
            source_vertices=states["source_vertices"],
            source_faces=states["source_faces"]).cuda()

        load_network(net, states)
        net = net.cuda()
        net.eval()
    else:
        net.eval()

    print(net)

    test_output_dir = os.path.join(opt.log_dir, subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with torch.no_grad():
        for target_model in opt.target_model:
            assert (os.path.isfile(target_model))
            target_face = None
            target_shape, target_face = read_trimesh(target_model)
            # target_shape = read_ply(target_model)[:,:3]
            # target_shape, _, scale = normalize_to_box(target_shape)
            # normalize acording to height y axis
            # target_shape = target_shape/2*1.7
            target_shape = torch.from_numpy(
                target_shape[:, :3]).cuda().float().unsqueeze(0)
            if target_face is None:
                target_face = net.source_faces
            else:
                target_face = torch.from_numpy(
                    target_face).cuda().long().unsqueeze(0)
            t_filename = os.path.splitext(os.path.basename(target_model))[0]

            source_mesh = net.source_vertices.transpose(1, 2).detach()
            source_face = net.source_faces.detach()

            # furthest sampling
            target_shape_sampled = furthest_point_sample(
                target_shape, net.source_vertices.shape[2], NCHW=False)[1]
            # target_shape_sampled = (target_shape[:, np.random.permutation(target_shape.shape[1]), :]).contiguous()
            outputs = net(target_shape_sampled.transpose(1, 2),
                          None,
                          cage_only=True)
            # deformed = outputs["deformed"]

            deformed = deform_with_MVC(
                outputs["cage"], outputs["new_cage"],
                outputs["cage_face"].expand(outputs["cage"].shape[0], -1,
                                            -1), source_mesh)

            b = 0

            save_ply(
                target_shape_sampled[b].cpu().numpy(),
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.pts".format(t_filename)))
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sa.ply".format(t_filename)),
                source_mesh[0].detach().cpu(), source_face[0].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.ply".format(t_filename)),
                target_shape[b].detach().cpu(), target_face[b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sab.ply".format(t_filename)),
                deformed[b].detach().cpu(), source_face[b].detach().cpu())

            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage1.ply".format(t_filename)),
                outputs["cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage2.ply".format(t_filename)),
                outputs["new_cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())

    PairedSurreal.render_result(test_output_dir)
Beispiel #3
0
def optimize(opt):
    """
    weights are the same with the original source mesh
    target=net(old_source)
    """
    # load new target
    if opt.is_poly:
        target_mesh = om.read_polymesh(opt.model)
    else:
        target_mesh = om.read_trimesh(opt.model)
    target_shape_arr = target_mesh.points()
    target_shape = target_shape_arr.copy()
    target_shape = torch.from_numpy(
        target_shape[:, :3].astype(np.float32)).cuda()
    target_shape.unsqueeze_(0)

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    cage_v = states["template_vertices"].transpose(1, 2).cuda()
    cage_f = states["template_faces"].cuda()
    shape_v = states["source_vertices"].transpose(1, 2).cuda()
    shape_f = states["source_faces"].cuda()

    if os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        new_label_path = opt.model.replace(os.path.splitext(opt.model)[1], ".picked")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {} and {}".format(orig_label_path, new_label_path))
        import pandas as pd
        new_label = pd.read_csv(new_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        new_label_name = new_label.iloc[:,5].tolist()
        new_to_orig_idx = []
        for i, name in enumerate(new_label_name):
            matched_idx = orig_label_name[orig_label_name==name].index
            if matched_idx.size == 1:
                new_to_orig_idx.append((i, matched_idx[0]))
        new_to_orig_idx = np.array(new_to_orig_idx)
        if new_label.shape[1] == 10:
            new_vidx = new_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,0]]
            target_points = target_shape[:, new_vidx, :]
        else:
            new_label_points = torch.from_numpy(new_label.iloc[:,6:9].to_numpy().astype(np.float32))
            target_points = new_label_points.unsqueeze(0).cuda()
            target_points, new_vidx, _ = faiss_knn(1, target_points, target_shape, NCHW=False)
            target_points = target_points.squeeze(2) # B,N,3
            new_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            new_label.to_csv(new_label_path, sep=" ", header=[str(new_label.shape[0])]+[""]*(new_label.shape[1]-1), index=False)
            target_points = target_points[:, new_to_orig_idx[:,0], :]

        target_points = target_points.cuda()
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).float()
        if orig_label.shape[1] == 10:
            orig_vidx = orig_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,1]]
            source_points = source_shape[:, orig_vidx, :]
        else:
            orig_label_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = orig_label_points.unsqueeze(0)
            # find the closest point on the original meshes
            source_points, new_vidx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            orig_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            orig_label.to_csv(orig_label_path, sep=" ", header=[str(orig_label.shape[0])]+[""]*(orig_label.shape[1]-1), index=False)
            source_points = source_points[:,new_to_orig_idx[:,1],:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
        source_points = source_points.cuda()
        # # shift target so that the belly match
        # try:
        #     orig_bellyUp_idx = orig_label_name[orig_label_name=="bellUp"].index[0]
        #     orig_bellyUp = orig_label_points[orig_bellyUp_idx, :]
        #     new_bellyUp_idx = [i for i, i2 in new_to_orig_idx if i2==orig_bellyUp_idx][0]
        #     new_bellyUp = new_label_points[new_bellyUp_idx,:]
        #     target_points += (orig_bellyUp - new_bellyUp)
        # except Exception as e:
        #     logger.warn("Couldn\'t match belly to belly")
        #     traceback.print_exc(file=sys.stdout)

        # source_points[0] = center_bounding_box(source_points[0])[0]
    elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        logger.info("Assuming Faust model")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {}".format(orig_label_path))
        import pandas as pd
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).cuda().float()
        if orig_label.shape[1] == 10:
            idx = torch.from_numpy(orig_label.iloc[:,9].to_numpy()).long()
            source_points = source_shape[:,idx,:]
            target_points = target_shape[:,idx,:]
        else:
            source_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = source_points.unsqueeze(0).cuda()
            # find the closest point on the original meshes
            source_points, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            idx = idx.squeeze(-1)
            target_points = target_shape[:,idx,:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
    elif opt.corres_idx is None and target_shape.shape[1] == shape_v.shape[1]:
        logger.info("No correspondence provided, assuming registered Faust models")
        # corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda()
        corresp_v = torch.unique(torch.randint(0, shape_v.shape[1], (4800,))).cuda()
        target_points = torch.index_select(target_shape, 1, corresp_v)
        source_points = torch.index_select(shape_v, 1, corresp_v)

    target_shape[0], target_center, target_scale = center_bounding_box(target_shape[0])
    _, _, source_scale = center_bounding_box(shape_v[0])
    target_scale_factor = (source_scale/target_scale)[1]
    target_shape *= target_scale_factor
    target_points -= target_center
    target_points = (target_points*target_scale_factor).detach()
    # make sure test use the normalized
    target_shape_arr[:] = target_shape[0].cpu().numpy()
    om.write_mesh(os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj"), target_mesh)
    opt.model = os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj")
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.obj"),
                         shape_v[0].cpu().numpy(), shape_f[0].cpu().numpy())
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "cage-initial.obj"),
                         cage_v[0].cpu().numpy(), cage_f[0].cpu().numpy())
    save_ply(target_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "target_points.ply"))
    save_ply(source_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "source_points.ply"))
    logger.info("Optimizing for {} corresponding vertices".format(
        target_points.shape[1]))

    cage_init = cage_v.clone().detach()
    lap_loss = MeshLaplacianLoss(torch.nn.MSELoss(reduction="none"), use_cot=True,
                                 use_norm=True, consistent_topology=True, precompute_L=True)
    mvc_reg_loss = MVCRegularizer(threshold=50, beta=1.0, alpha=0.0)
    cage_v.requires_grad_(True)
    optimizer = torch.optim.Adam([cage_v], lr=opt.lr, betas=(0.5, 0.9))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(opt.nepochs*0.4), gamma=0.5, last_epoch=-1)

    if opt.dim == 3:
        weights_ref = mean_value_coordinates_3D(
            source_points, cage_init, cage_f, verbose=False)
    else:
        raise NotImplementedError

    for t in range(opt.nepochs):
        optimizer.zero_grad()
        weights = mean_value_coordinates_3D(
            target_points, cage_v, cage_f, verbose=False)
        loss_mvc = torch.mean((weights-weights_ref)**2)
        # reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4
        reg = 0
        if opt.clap_weight > 0:
            reg = lap_loss(cage_init, cage_v, face=cage_f)*opt.clap_weight
            reg = reg.mean()
        if opt.mvc_weight > 0:
            reg += mvc_reg_loss(weights)*opt.mvc_weight

        # weight regularizer with the shape difference
        # dist = torch.sum((source_points - target_points)**2, dim=-1)
        # weights = torch.exp(-dist)
        # reg = reg*weights*0.1

        loss = loss_mvc + reg
        if (t+1) % 50 == 0:
            print("t {}/{} mvc_loss: {} reg: {}".format(t,
                                                        opt.nepochs, loss_mvc.item(), reg.item()))

        if loss_mvc.item() < 5e-6:
            break
        loss.backward()
        optimizer.step()
        scheduler.step()

    return cage_v, cage_f
Beispiel #4
0

if __name__ == "__main__":
    parser = MyOptions()
    opt = parser.parse()

    opt.log_dir = os.path.dirname(opt.ckpt)

    os.makedirs(os.path.join(opt.log_dir, opt.subdir), exist_ok=True)
    if opt.use_cage is None:
        # optimize initial cage for the new target
        cage_v, cage_f = optimize(opt)
        pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "optimized_template_cage.ply"),
                             cage_v[0].detach().cpu(), cage_f[0].detach().cpu())
    else:
        cage_v, cage_f = read_trimesh(opt.use_cage)
        cage_v = torch.from_numpy(cage_v[:, :3].astype(np.float32)).cuda()
        cage_f = torch.from_numpy(cage_f[:, :3].astype(np.int64)).cuda()
        cage_v.unsqueeze_(0)
        cage_f.unsqueeze_(0)

    # # test using the new source and initial cage
    # target_shape_pose, target_face_pose, _ = read_trimesh("/home/mnt/points/data/MPI-FAUST/training/registrations/tr_reg_002.ply")
    # target_shape_pose = torch.from_numpy(target_shape_pose[:,:3].astype(np.float32)).cuda()
    # target_face_pose = torch.from_numpy(target_face_pose[:,:3].astype(np.int64)).cuda()
    # target_shape_pose, _, _ = center_bounding_box(target_shape_pose)
    # target_shape_pose.unsqueeze_(0)
    # target_face_pose.unsqueeze_(0)
    # test_one(opt, cage_v, target_shape, target_face, target_shape_pose, target_face_pose)
    test_all(opt, cage_v)
Beispiel #5
0
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR + "/..")
from common import read_trimesh

work_dir = sys.argv[1]
output_dir = sys.argv[2]
sources = glob(os.path.join(work_dir, "*Sa.*"))
logger.info("Found {} source files".format(len(sources)))

os.makedirs(output_dir, exist_ok=True)

for source in sources:
    target = source.replace("Sa", "Sb")
    fn = os.path.basename(target)
    fn = fn.replace("Sb", "Sab")
    V_t, F_t = read_trimesh(target, clean=True)
    V_t = V_t[:,:3]
    V_s, F_s = read_trimesh(source, clean=True)
    V_s = V_s[:,:3]

    bb_max = np.max(V_t, axis=0)
    bb_min = np.min(V_t, axis=0)
    size_t = (bb_max - bb_min)

    bb_max = np.max(V_s, axis=0)
    bb_min = np.min(V_s, axis=0)
    size_s = (bb_max - bb_min)

    V_st = (V_s * size_t/size_s)
    pymesh.save_mesh(os.path.join(output_dir, fn.replace("Sab", "Sa")), pymesh.form_mesh(V_s, F_s))
    pymesh.save_mesh(os.path.join(output_dir, fn.replace("Sab", "Sb")), pymesh.form_mesh(V_t, F_t))
Beispiel #6
0
def evaluate_deformation(result_dirs, resample, mse=False, overwrite_pts=False):
    CD_name = "MSE" if mse else "CD"
    if isinstance(result_dirs, str):
        result_dirs = [result_dirs]
    ########## initialize ############
    eval_result = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))  # eval_result[metric[folder[file]]]
    cotLap = CotLaplacian()
    uniLap = UniformLaplacian()
    if resample:
        for cur_dir in result_dirs:
            pts_dir = os.path.join(cur_dir, "eval_pts")
            os.makedirs(pts_dir, exist_ok=True)
            result = sample_pts(cur_dir, pts_dir, overwrite_pts)
            if not result:
                logger.warn("Failed to sample points in {}".format(cur_dir))

    ########## load results ###########
    # find Sa.ply, Sb.ply and a list of Sab.ply
    ###################################
    [print("dir{}: {}".format(i, name)) for i, name in enumerate(result_dirs)]
    files = glob(os.path.join(result_dirs[0], "*.ply"))+glob(os.path.join(result_dirs[0], "*.obj"))
    source_files = [p for p in files if "Sa." in p]
    target_files = [p.replace("Sa", "Sb") for p in source_files]
    assert(all([os.path.isfile(f) for f in target_files]))
    logger.info("Found {} source target pairs".format(len(source_files)))
    ########## evaluation ############
    print("{}: {}".format("filename".ljust(70), " | ".join(["dir{}".format(i).rjust(45) for i in range(len(result_dirs))])))
    print("{}: {}".format(" ".ljust(70), " | ".join([(CD_name+"/CotLap/CotLapNorm/UniLap/UniLapNorm").rjust(45) for i in range(len(result_dirs))])))
    cnt = 0
    for source, target in zip(source_files, target_files):
        source_filename = os.path.basename(source)
        target_filename = os.path.basename(target)
        try:
            if resample:
                source_pts_file = os.path.join(result_dirs[0], "eval_pts", source_filename[:-4]+".pts")
                target_pts_file = os.path.join(result_dirs[0], "eval_pts", target_filename[:-4]+".pts")
                if not os.path.isfile(source_pts_file):
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(source_pts_file))
                    continue
                if not os.path.isfile(target_pts_file):
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(target_pts_file))
                source_pts = load(source_pts_file)
                target_pts = load(target_pts_file)
                source_pts = torch.from_numpy(source_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                target_pts = torch.from_numpy(target_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()

            ext = os.path.splitext(source_filename)[1]
            sab_str = source_filename.replace("Sa"+ext, "Sab*")
            outputs = [glob( os.path.join(cur_dir, sab_str) ) for cur_dir in result_dirs]
            if not all([len(o) > 0 for o in outputs]):
                logger.warn("Couldn\'t find {} in all folders, skipping to process the next".format(sab_str))
                continue

            # read Sa, Sb
            source_shape, source_face = read_trimesh(source, clean=False)
            target_shape, _ = read_trimesh(target, clean=False)
            source_shape = torch.from_numpy(source_shape[:,:3].astype(np.float32)).unsqueeze(0).cuda()
            target_shape = torch.from_numpy(target_shape[:,:3].astype(np.float32)).unsqueeze(0).cuda()
            source_face = torch.from_numpy(source_face[:,:3].astype(np.int64)).unsqueeze(0).cuda()

            # laplacian for source (fixed)
            cotLap.L = None
            ref_lap = cotLap(source_shape, source_face)
            ref_lap_norm = torch.norm(ref_lap, dim=-1)

            uniLap.L = None
            ref_ulap = uniLap(source_shape, source_face)
            ref_ulap_norm = torch.norm(ref_ulap, dim=-1)

            filename = os.path.splitext(os.path.basename(source))[0]
            for output, cur_dir in zip(outputs, result_dirs):
                if len(output)>1:
                    logger.warn("Found multiple outputs {}. Using the last one".format(output))
                if len(output) == 0:
                    logger.warn("Found no outputs for {} in {}".format(sab_str, cur_dir))
                    continue
                output = output[-1]

                output_shape, output_face = read_trimesh(output, clean=False)
                output_shape = torch.from_numpy(output_shape[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                output_face = torch.from_numpy(output_face[:,:3].astype(np.int64)).unsqueeze(0).cuda()

                # chamfer
                if not mse:
                    if resample:
                        output_filename = os.path.basename(output)
                        output_pts_file = os.path.join(cur_dir, "eval_pts", output_filename[:-4]+".pts")
                        output_pts = load(output_pts_file)
                        output_pts = torch.from_numpy(output_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                        dist12, dist21, _, _ = nndistance(target_pts, output_pts)
                    else:
                        dist12, dist21, _, _ = nndistance(target_shape, output_shape)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1))
                    eval_result[cur_dir][CD_name][filename] = cd
                    eval_result[cur_dir][CD_name]["avg"] += (cd - eval_result[cur_dir][CD_name]["avg"])/(eval_result[cur_dir][CD_name]["cnt"]+1)
                    eval_result[cur_dir][CD_name]["cnt"] += 1
                else:
                    mse = torch.sum((output_shape-target_shape)**2, dim=-1).mean().item()
                    eval_result[cur_dir][CD_name][filename] = mse
                    eval_result[cur_dir][CD_name]["avg"] += (mse - eval_result[cur_dir][CD_name]["avg"])/(eval_result[cur_dir][CD_name]["cnt"]+1)
                    eval_result[cur_dir][CD_name]["cnt"] += 1


                lap = cotLap(output_shape)
                lap_loss = torch.mean((lap-ref_lap)**2).item()
                eval_result[cur_dir]["CotLap"][filename] = lap_loss
                eval_result[cur_dir]["CotLap"]["avg"] += (lap_loss - eval_result[cur_dir]["CotLap"]["avg"])/(eval_result[cur_dir]["CotLap"]["cnt"]+1)
                eval_result[cur_dir]["CotLap"]["cnt"] += 1

                lap_norm = torch.norm(lap, dim=-1)
                lap_norm_loss = torch.mean((lap_norm-ref_lap_norm).abs()).item()
                eval_result[cur_dir]["CotLapNorm"][filename] = lap_norm_loss
                eval_result[cur_dir]["CotLapNorm"]["avg"] += (lap_norm_loss - eval_result[cur_dir]["CotLapNorm"]["avg"])/(eval_result[cur_dir]["CotLapNorm"]["cnt"]+1)
                eval_result[cur_dir]["CotLapNorm"]["cnt"] += 1

                lap = uniLap(output_shape)
                lap_loss = torch.mean((lap-ref_ulap)**2).item()
                eval_result[cur_dir]["UniLap"][filename] = lap_loss
                eval_result[cur_dir]["UniLap"]["avg"] += (lap_loss - eval_result[cur_dir]["UniLap"]["avg"])/(eval_result[cur_dir]["UniLap"]["cnt"]+1)
                eval_result[cur_dir]["UniLap"]["cnt"] += 1

                lap_norm = torch.norm(lap, dim=-1)
                lap_norm_loss = torch.mean((lap_norm-ref_ulap_norm).abs()).item()
                eval_result[cur_dir]["UniLapNorm"][filename] = lap_norm_loss
                eval_result[cur_dir]["UniLapNorm"]["avg"] += (lap_norm_loss - eval_result[cur_dir]["UniLapNorm"]["avg"])/(eval_result[cur_dir]["UniLapNorm"]["cnt"]+1)
                eval_result[cur_dir]["UniLapNorm"]["cnt"] += 1


            print("{}: {}".format(filename.ljust(70), " | ".join(
                ["{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}".format(
                    eval_result[cur_dir][CD_name][filename],
                    eval_result[cur_dir]["CotLap"][filename], eval_result[cur_dir]["CotLapNorm"][filename],
                    eval_result[cur_dir]["UniLap"][filename], eval_result[cur_dir]["UniLapNorm"][filename]
                    )
                for cur_dir in result_dirs]
                ).ljust(30)))
        except Exception as e:
            traceback.print_exc(file=sys.stdout)
            logger.warn("Failed to evaluation {}. Skip to process the next.".format(source_filename))


    print("{}: {}".format("AVG".ljust(70), " | ".join(
        ["{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}/{:8.4g}".format(eval_result[cur_dir][CD_name]["avg"],
            eval_result[cur_dir]["CotLap"]["avg"], eval_result[cur_dir]["CotLapNorm"]["avg"],
            eval_result[cur_dir]["UniLap"]["avg"], eval_result[cur_dir]["UniLapNorm"]["avg"],
            )
            for cur_dir in result_dirs]
        ).ljust(30)))

    ########## write evaluation ############
    for cur_dir in result_dirs:
        for metric in eval_result[cur_dir]:
            output_file = os.path.join(cur_dir, "eval_{}.txt".format(metric))
            with open(output_file, "w") as eval_file:
                for name, value in eval_result[cur_dir][metric].items():
                    if (name != "avg" and name != "cnt"):
                        eval_file.write("{} {:8.4g}\n".format(name, value))

                eval_file.write("avg {:8.4g}".format(eval_result[cur_dir][metric]["avg"]))
Beispiel #7
0
def evaluate_svr(result_dirs, resample, overwrite_pts=False):
    """ ours is the first in the result dirs """
    if isinstance(result_dirs, str):
        result_dirs = [result_dirs]
    ########## initialize ############
    eval_result = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 1e10)))  # eval_result[metric[folder[file]]]
    avg_result = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))  # eval_result[metric[folder[file]]]

    cotLap = CotLaplacian()
    uniLap = UniformLaplacian()
    if resample and not opt.mse:
        for cur_dir in result_dirs:
            pts_dir = os.path.join(cur_dir, "eval_pts")
            os.makedirs(pts_dir, exist_ok=True)
            result = svr_sample_pts(cur_dir, pts_dir, overwrite_pts)
            if not result:
                logger.warn("Failed to sample points in {}".format(cur_dir))

    ########## load results ###########
    # find Sa.ply, Sb.ply and a list of Sab.ply
    ###################################
    [print("dir{}: {}".format(i, name)) for i, name in enumerate(result_dirs)]
    files = find_files(result_dirs[0], ["ply", "obj"])
    target_files = [p for p in files if "Sb." in p]
    target_names = np.unique(np.array([os.path.basename(p).split("-")[1] for p in target_files])).tolist()
    logger.info("Found {} target files".format(len(target_names)))

    ########## evaluation ############
    print("{}: {}".format("filename".ljust(70), " | ".join(["dir{}".format(i).rjust(20) for i in range(len(result_dirs))])))
    print("{}: {}".format(" ".ljust(70), " | ".join(["CD/HD".rjust(20) for i in range(len(result_dirs))])))
    cnt = 0
    for target in target_names:
        # 1. load ground truth
        gt_path = glob(os.path.join(result_dirs[0], "*-{}-Sb.*".format(target)))[0]
        try:
            gt_shape, gt_face = read_trimesh(gt_path, clean=False)
            if resample:
                gt_pts_file = os.path.join(result_dirs[0], "eval_pts", "{}.pts".format(target))
                if not os.path.isfile(gt_pts_file):
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(gt_pts_file))
                    continue
                gt_pts = load(gt_pts_file)
                gt_pts = torch.from_numpy(gt_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()

            ours_paths = glob(os.path.join(result_dirs[0], "*-{}-Sab.*".format(target)))
            others_path = [glob( os.path.join(cur_dir, "{}.*".format(target)) ) for cur_dir in result_dirs[1:]]

            # 2. evaluate ours, all *-{target}-Sab
            if len(ours_paths) == 0:
                logger.warn("Cound\'t find {}. Skip to process the next.".format(os.path.join(result_dirs[0], "*-{}-Sab.*".format(target))))
                continue

            for ours in ours_paths:
                # load shape and points
                output_shape, output_face = read_trimesh(ours, clean=False)
                ours = os.path.basename(ours)
                cur_dir = result_dirs[0]
                if resample:
                    output_pts_file = os.path.join(cur_dir, "eval_pts", ours[:-4]+".pts")
                    if not os.path.isfile(output_pts_file):
                        logger.warn("Cound\'t find {}. Skip to process the next source.".format(output_pts_file))
                        continue
                    output_pts = load(output_pts_file)
                    output_pts = torch.from_numpy(output_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                    # compute chamfer
                    dist12, dist21, _, _ = nndistance(gt_pts, output_pts)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())
                else:
                    dist12, dist21, _, _ = nndistance(gt_shape, output_shape)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())

                eval_result[cur_dir]["CD"][target] = min(eval_result[cur_dir]["CD"][target], cd)
                avg_result[cur_dir]["CD"]["avg"] += (cd - avg_result[cur_dir]["CD"]["avg"])/(avg_result[cur_dir]["CD"]["cnt"]+1)
                avg_result[cur_dir]["CD"]["cnt"]+=1
                eval_result[cur_dir]["HD"][target] = min(eval_result[cur_dir]["HD"][target], hd)
                avg_result[cur_dir]["HD"]["avg"] += (hd - avg_result[cur_dir]["HD"]["avg"])/(avg_result[cur_dir]["HD"]["cnt"]+1)
                avg_result[cur_dir]["HD"]["cnt"]+=1

            # 3. evaluation others
            for cur_dir in result_dirs[1:]:
                result_path = glob(os.path.join(cur_dir, "{}.*".format(target)))
                if len(result_path) == 0:
                    logger.warn("Cound\'t find {}. Skip to process the next.".format(result_path))
                    continue
                result_path = result_path[0]
                output_shape, output_face = read_trimesh(result_path, clean=False)
                result_name = os.path.splitext(os.path.basename(result_path))[0]
                if resample:
                    output_pts_file = os.path.join(cur_dir, "eval_pts", result_name+".pts")
                    if not os.path.isfile(output_pts_file):
                        logger.warn("Cound\'t find {}. Skip to process the next source.".format(output_pts_file))
                        continue
                    output_pts = load(output_pts_file)
                    output_pts = torch.from_numpy(output_pts[:,:3].astype(np.float32)).unsqueeze(0).cuda()
                    # compute chamfer
                    dist12, dist21, _, _ = nndistance(gt_pts, output_pts)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())
                else:
                    dist12, dist21, _, _ = nndistance(gt_shape, output_shape)
                    cd = torch.mean(torch.mean(dist12, dim=-1) + torch.mean(dist21, dim=-1)).item()
                    hd = max(torch.max(dist12).item(), torch.max(dist21).item())

                eval_result[cur_dir]["CD"][target] = min(eval_result[cur_dir]["CD"][target], cd)
                avg_result[cur_dir]["CD"]["avg"] += (cd - avg_result[cur_dir]["CD"]["avg"])/(avg_result[cur_dir]["CD"]["cnt"]+1)
                avg_result[cur_dir]["CD"]["cnt"]+=1
                eval_result[cur_dir]["HD"][target] = min(eval_result[cur_dir]["HD"][target], hd)
                avg_result[cur_dir]["HD"]["avg"] += (hd - avg_result[cur_dir]["HD"]["avg"])/(avg_result[cur_dir]["HD"]["cnt"]+1)
                avg_result[cur_dir]["HD"]["cnt"]+=1

            print("{}: {}".format(target.ljust(70), " | ".join(
                ["{:8.4g}/{:8.4g}".format(
                    eval_result[cur_dir]["CD"][target],
                    eval_result[cur_dir]["HD"][target],
                    )
                for cur_dir in result_dirs]
                ).ljust(30)))
        except Exception as e:
            traceback.print_exc(file=sys.stdout)
            logger.warn("Failed to evaluation {}. Skip to process the next.".format(target))


    print("{}: {}".format("AVG".ljust(70), " | ".join(
        ["{:8.4g}/{:8.4g}".format(
            avg_result[cur_dir]["CD"]["avg"],
            avg_result[cur_dir]["HD"]["avg"],
            )
            for cur_dir in result_dirs]
        ).ljust(30)))

    ########## write evaluation ############
    for cur_dir in result_dirs:
        for metric in eval_result[cur_dir]:
            output_file = os.path.join(cur_dir, "eval_{}.txt".format(metric))
            with open(output_file, "w") as eval_file:
                for name, value in eval_result[cur_dir][metric].items():
                    if (name != "avg" and name != "cnt"):
                        eval_file.write("{} {:8.4g}\n".format(name, value))

                eval_file.write("avg {:8.4g}".format(eval_result[cur_dir][metric]["avg"]))