示例#1
0
def dist_triangle_probs_to_sampled_surface(surf_pos, surf_normals, gen_verts, gen_faces, gen_face_probs, n_sample_pts=5000):
    if gen_faces.shape[0] == 0:
        return torch.tensor(0., device=surf_verts.device)

    # get a characteristic length
    char_len = utils.norm(surf_pos - torch.mean(surf_pos, dim=0, keepdim=True)).mean()

    # Sample points on the generated triangulation
    samples, face_inds, _ = mesh_utils.sample_points_on_surface(
        gen_verts, gen_faces, n_sample_pts, return_inds_and_bary=True)

    # Likelihoods associated with each point
    point_probs = gen_face_probs[face_inds]

    # Measure the distance to the surface
    knn_dist, neigh = knn.find_knn(samples, surf_pos, k=1)
    neigh_pos = surf_pos[neigh.squeeze(1), :]

    if len(surf_normals) == 0 :
        dists = knn_dist
    else:
        neigh_normal = surf_normals[neigh.squeeze(1), :]
        vecs = neigh_pos - samples
        dists = torch.abs(utils.dot(vecs, neigh_normal))
    
    # Expected distance integral
    exp_dist = torch.mean(dists * point_probs)

    return exp_dist / char_len
示例#2
0
def dist_surface_to_triangle_probs(gen_verts, gen_faces, gen_face_probs, n_sample_pts=None, mesh=None, surf_samples=None):
    if gen_faces.shape[0] == 0:
        return torch.tensor(0., device=surf_verts.device)
    
    if surf_samples is not None:
        if mesh is not None or n_sample_pts is not None:
            raise ValueError("bad args!")

    if mesh is not None:
        surf_verts, surf_faces = mesh
        if surf_samples is not None:
            raise ValueError("bad args!")

        # Sample points on the known surfacse
        surf_samples = mesh_utils.sample_points_on_surface(surf_verts, surf_faces, n_sample_pts)


    
    # get a characteristic length
    char_len = utils.norm(surf_samples - torch.mean(surf_samples , dim=0, keepdim=True)).mean()

    # Find the distance to all triangles in the generated surface
    tri_dists = mesh_utils.point_triangle_distance(surf_samples, gen_verts, gen_faces)

    # Sort distances
    k_val = min(32, tri_dists.shape[-1])
    tri_dists_sorted, sorted_inds = torch.topk(tri_dists, largest=False, k=k_val, dim=-1)

    # Compute the likelihoods that each triangle is the nearest for that sample
    sorted_probs = gen_face_probs[sorted_inds]

    prob_none_closer = torch.cat(( # shift to the right, put 1 in first col
                    torch.ones_like(sorted_probs)[:,:1],
                    torch.cumprod(1. - sorted_probs, dim=-1)[:, :-1]  
                    ), dim=-1)

    prob_is_closest = prob_none_closer * sorted_probs

    # Append a last distance very far away, so you get high loss values if nothing is close
    last_prob = 1.0 - torch.sum(prob_is_closest, dim=-1)
    last_dist = char_len * torch.ones(tri_dists.shape[0], dtype=tri_dists.dtype, device=tri_dists.device)
    prob_is_closest = torch.cat((prob_is_closest, last_prob.unsqueeze(-1)), dim=-1)
    prob_is_closest = torch.clamp(prob_is_closest, 0., 1.) # for floating point reasons
    tri_dists_sorted = torch.cat((tri_dists_sorted, last_dist.unsqueeze(-1)), dim=-1)

    
    
    # Use these likelihoods to get expected distance
    expected_dist = torch.sum(prob_is_closest * tri_dists_sorted, dim=-1)
        
    result = torch.mean(expected_dist / char_len)
    return result
示例#3
0
def overlap_kernel(gen_verts, gen_faces, gen_face_probs, n_sample_pts=5000):
    if gen_faces.shape[0] == 0:
        return torch.tensor(0., device=gen_verts.device)

    # Sample points on the generated triangulation
    samples, face_inds, _ = mesh_utils.sample_points_on_surface(
        gen_verts, gen_faces, n_sample_pts, face_probs=(gen_face_probs), return_inds_and_bary=True)

    # Evaluate kernel
    sample_tri_kvals = mesh_utils.triangle_kernel(samples, gen_verts, gen_faces, kernel_height=0.5)

    # Incorporate weights and sum
    sample_tri_kvals_weight = sample_tri_kvals * gen_face_probs.unsqueeze(0)

    # Ideally, all samples should all have one entry with value 1 and 0 for all other entries, so
    # we ask that there be no kernel contribution from any other triangles.
    kernel_sums = torch.sum(sample_tri_kvals_weight, dim=-1)
    kernel_max = torch.max(sample_tri_kvals_weight, dim=-1).values
    scores = (kernel_sums - 1.)**2 + (kernel_max - 1.)**2

    # note that this corresponds to a normalization by the expected area of the surface
    return torch.mean(scores)
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('model_weights_path',
                        type=str,
                        help='path to the model checkpoint')
    parser.add_argument('input_path', type=str, help='path to the input')

    parser.add_argument('--disable_cuda',
                        action='store_true',
                        help='disable cuda')

    parser.add_argument('--sample_cloud',
                        type=int,
                        help='run on sampled points')

    parser.add_argument('--n_rounds',
                        type=int,
                        default=5,
                        help='number of rounds')
    parser.add_argument('--prob_thresh',
                        type=float,
                        default=.9,
                        help='threshold for final surface')

    parser.add_argument(
        '--output',
        type=str,
        help='path to save the resulting high prob mesh to. also disables viz')
    parser.add_argument('--output_trim_unused',
                        action='store_true',
                        help='trim unused vertices when outputting')

    # Parse arguments
    args = parser.parse_args()
    set_args_defaults(args)

    viz = not args.output
    args.polyscope = False

    # Initialize polyscope
    if viz:
        polyscope.init()

    # === Load the input

    if args.input_path.endswith(".npz"):
        record = np.load(args.input_path)
        verts = torch.tensor(record['vert_pos'],
                             dtype=args.dtype,
                             device=args.device)
        surf_samples = torch.tensor(record['surf_pos'],
                                    dtype=args.dtype,
                                    device=args.device)

        samples = verts.clone()
        faces = torch.zeros((0, 3), dtype=torch.int64, device=args.device)

        polyscope.register_point_cloud("surf samples", toNP(surf_samples))

    if args.input_path.endswith(".xyz"):
        raw_pts = np.loadtxt(args.input_path)
        verts = torch.tensor(raw_pts, dtype=args.dtype, device=args.device)

        samples = verts.clone()
        faces = torch.zeros((0, 3), dtype=torch.int64, device=args.device)

        polyscope.register_point_cloud("surf samples", toNP(verts))

    else:
        print("reading mesh")
        verts, faces = utils.read_mesh(args.input_path)
        print("  {} verts   {} faces".format(verts.shape[0], faces.shape[0]))
        verts = torch.tensor(verts, dtype=args.dtype, device=args.device)
        faces = torch.tensor(faces, dtype=torch.long, device=args.device)

        # verts = verts[::10,:]

        if args.sample_cloud:
            samples = mesh_utils.sample_points_on_surface(
                verts, faces, args.sample_cloud)
        else:
            samples = verts.clone()

    # === Load the model

    print("loading model weights")
    model = PointTriNet_Mesher()
    model.load_state_dict(torch.load(args.model_weights_path))

    model.eval()

    with torch.no_grad():

        # Sample lots of faces from the vertices
        print("predicting")
        candidate_triangles, candidate_probs = model.predict_mesh(
            samples.unsqueeze(0), n_rounds=args.n_rounds)
        candidate_triangles = candidate_triangles.squeeze(0)
        candidate_probs = candidate_probs.squeeze(0)
        print("done predicting")

        # Visualize
        high_prob = args.prob_thresh
        high_faces = candidate_triangles[candidate_probs > high_prob]
        closed_faces = mesh_utils.fill_holes_greedy(high_faces)

        if viz:
            polyscope.register_point_cloud("input points", toNP(samples))

            spmesh = polyscope.register_surface_mesh("all faces",
                                                     toNP(samples),
                                                     toNP(candidate_triangles),
                                                     enabled=False)
            spmesh.add_scalar_quantity("probs",
                                       toNP(candidate_probs),
                                       defined_on='faces')

            spmesh = polyscope.register_surface_mesh(
                "high prob mesh " + str(high_prob), toNP(samples),
                toNP(high_faces))
            spmesh.add_scalar_quantity(
                "probs",
                toNP(candidate_probs[candidate_probs > high_prob]),
                defined_on='faces')

            spmesh = polyscope.register_surface_mesh("hole-closed mesh " +
                                                     str(high_prob),
                                                     toNP(samples),
                                                     toNP(closed_faces),
                                                     enabled=False)

            polyscope.show()

        # Save output
        if args.output:

            high_prob = args.prob_thresh
            out_verts = toNP(samples)
            out_faces = toNP(high_faces)
            out_faces_closed = toNP(closed_faces)

            if args.output_trim_unused:
                out_verts, out_faces, _, _ = igl.remove_unreferenced(
                    out_verts, out_faces)

            igl.write_triangle_mesh(args.output + "_mesh.ply", out_verts,
                                    out_faces)
            write_ply_points(args.output + "_samples.ply", toNP(samples))

            igl.write_triangle_mesh(args.output + "_pred_mesh.ply", out_verts,
                                    out_faces)
            igl.write_triangle_mesh(args.output + "_pred_mesh_closed.ply",
                                    out_verts, out_faces_closed)
            write_ply_points(args.output + "_samples.ply", toNP(samples))