예제 #1
0
def find_edges_to_link(mesh,
                       vert_ind_a,
                       vert_ind_b,
                       distance_upper_bound=2500,
                       verbose=False):
    '''Given a mesh and two points on that mesh
    find a way to add edges to the  mesh graph so that those indices
    are on the same connected component 
    
    Parameters
    ----------
    mesh: trimesh_io.Mesh
        a mesh to find edges on
    vert_ind_a: int
        one index into mesh.vertices, the first point
    vert_ind_b: int
        a second index into mesh.vertices, the second point
    distance_upper_bound: float
        a maximum distance to (default 2500 in units of mesh.vertices)
    verbose: bool
        whether to print debug info

    Returns
    -------
    np.array
        a Kx2 array of mesh indices that represent edges to add to the mesh to link the two points
        in a way that creates the shortest path between the points across mututally closest vertices
        from connected components.. not adding edges if they are larger than distance_upper_bound
        TODO: distance_upper_bound not presently implemented
    '''
    timings = {}
    start_time = time.time()

    # find the distance between the merge points and their center
    d = np.linalg.norm(mesh.vertices[vert_ind_a, :] -
                       mesh.vertices[vert_ind_b, :])
    c = np.mean(mesh.vertices[[vert_ind_a, vert_ind_b], :], axis=0)
    # cut down the mesh to only include mesh vertices near the center of this
    # merge edge and within 2x the euclidean length of the edge
    inds = mesh.kdtree.query_ball_point(c, d * 2)
    # convert this to a mask
    mask = np.zeros(len(mesh.vertices), dtype=np.bool)
    mask[inds] = True

    timings['create_mask'] = time.time() - start_time
    start_time = time.time()
    # create a masked version of the mesh
    mask_mesh = mesh.apply_mask(mask)

    timings['apply_mask'] = time.time() - start_time
    start_time = time.time()
    ccs, labels = sparse.csgraph.connected_components(mask_mesh.csgraph,
                                                      return_labels=True)

    # map the original indices into this masked space
    mask_inds = mask_mesh.filter_unmasked_indices(
        np.array([vert_ind_a, vert_ind_b]))

    timings['masked_ccs'] = time.time() - start_time
    start_time = time.time()

    # find all the multually closest edges between the linked components
    new_edges = find_close_edges_sym(mask_mesh.vertices, labels,
                                     labels[mask_inds[0]],
                                     labels[mask_inds[1]])
    timings['find_close_edges_sym'] = time.time() - start_time
    start_time = time.time()

    # if there is now way to do this, fall back to adding all
    # edges that are close
    if len(new_edges) == 0:
        if verbose:
            print('finding all close edges')
        new_edges = find_all_close_edges(mask_mesh.vertices, labels, ccs)
        if verbose:
            print(f'new_edges shape {new_edges.shape}')
    # if there are still not edges we have a problem
    if len(new_edges) == 0:
        raise Exception('no close edges found')

    # create a new mesh that has these added edges
    #new_mesh = make_new_mesh_with_added_edges(mask_mesh, new_edges)
    total_edges = np.vstack([mask_mesh.graph_edges, new_edges])
    graph = utils.create_csgraph(mask_mesh.vertices, total_edges)
    timings['make_new_mesh'] = time.time() - start_time
    start_time = time.time()

    # find the shortest path to one of the linking spots in this new mesh
    d_ais_to_all, pred = sparse.csgraph.dijkstra(graph,
                                                 indices=mask_inds[0],
                                                 unweighted=False,
                                                 directed=False,
                                                 return_predecessors=True)
    timings['find_close_edges_sym'] = time.time() - start_time
    start_time = time.time()
    # make sure we found a good path
    if np.isinf(d_ais_to_all[mask_inds[1]]):
        raise Exception(
            f"cannot find link between {vert_ind_a} and {vert_ind_b}")

    # turn this path back into a original mesh index edge list
    path = utils.get_path(mask_inds[0], mask_inds[1], pred)
    path_as_edges = utils.paths_to_edges([path])
    good_edges = np_shared_rows(path_as_edges, new_edges)
    good_edges = np.sort(path_as_edges[good_edges], axis=1)
    timings['remap answers'] = time.time() - start_time
    if verbose:
        print(timings)
    return mask_mesh.map_indices_to_unmasked(good_edges)
예제 #2
0
def mesh_teasar(mesh,
                root=None,
                valid=None,
                root_ds=None,
                root_pred=None,
                soma_pt=None,
                soma_thresh=7500,
                invalidation_d=10000,
                return_timing=False,
                return_map=False):
    """core skeletonization function used to skeletonize a single component of a mesh"""
    # if no root passed, then calculation one
    if root is None:
        root, root_ds, root_pred, valid = setup_root(mesh,
                                                     soma_pt=soma_pt,
                                                     soma_thresh=soma_thresh)
    # if root_ds have not be precalculated do so
    if root_ds is None:
        root_ds, root_pred = sparse.csgraph.dijkstra(mesh.csgraph,
                                                     False,
                                                     root,
                                                     return_predecessors=True)
    # if certain vertices haven't been pre-invalidated start with just
    # the root vertex invalidated
    if valid is None:
        valid = np.ones(len(mesh.vertices), np.bool)
        valid[root] = False
    else:
        if (len(valid) != len(mesh.vertices)):
            raise Exception("valid must be length of vertices")

    if return_map == True:
        mesh_to_skeleton_dist = np.full(len(mesh.vertices), np.inf)
        mesh_to_skeleton_map = np.full(len(mesh.vertices), np.nan)

    total_to_visit = np.sum(valid)
    if np.sum(np.isinf(root_ds) & valid) != 0:
        print(np.where(np.isinf(root_ds) & valid))
        raise Exception("all valid vertices should be reachable from root")

    # vector to store each branch result
    paths = []

    # vector to store each path's total length
    path_lengths = []

    # keep track of the nodes that have been visited
    visited_nodes = [root]

    # counter to track how many branches have been counted
    k = 0

    # arrays to track timing
    start = time.time()
    time_arrays = [[], [], [], [], []]

    with tqdm(total=total_to_visit) as pbar:
        # keep looping till all vertices have been invalidated
        while (np.sum(valid) > 0):
            k += 1
            t = time.time()
            # find the next target, farthest vertex from root
            # that has not been invalidated
            target = np.nanargmax(root_ds * valid)
            if (np.isinf(root_ds[target])):
                raise Exception('target cannot be reached')
            time_arrays[0].append(time.time() - t)

            t = time.time()
            # figure out the longest this branch could be
            # by following the route from target to the root
            # and finding the first already visited node (max_branch)
            # The dist(root->target) - dist(root->max_branch)
            # is the maximum distance the shortest route to a branch
            # point from the target could possibly be,
            # use this bound to reduce the djisktra search radius for this target
            max_branch = target
            while max_branch not in visited_nodes:
                max_branch = root_pred[max_branch]
            max_path_length = root_ds[target] - root_ds[max_branch]

            # calculate the shortest path to that vertex
            # from all other vertices
            # up till the distance to the root
            ds, pred_t = sparse.csgraph.dijkstra(mesh.csgraph,
                                                 False,
                                                 target,
                                                 limit=max_path_length,
                                                 return_predecessors=True)

            # pick out the vertex that has already been visited
            # which has the shortest path to target
            min_node = np.argmin(ds[visited_nodes])
            # reindex to get its absolute index
            branch = visited_nodes[min_node]
            # this is in the index of the point on the skeleton
            # we want this branch to connect to
            time_arrays[1].append(time.time() - t)

            t = time.time()
            # get the path from the target to branch point
            path = utils.get_path(target, branch, pred_t)
            visited_nodes += path[0:-1]
            # record its length
            assert (~np.isinf(ds[branch]))
            path_lengths.append(ds[branch])
            # record the path
            paths.append(path)
            time_arrays[2].append(time.time() - t)

            t = time.time()
            # get the distance to all points along the new path
            # within the invalidation distance
            dm, _, sources = sparse.csgraph.dijkstra(mesh.csgraph,
                                                     False,
                                                     path,
                                                     limit=invalidation_d,
                                                     min_only=True,
                                                     return_predecessors=True)
            time_arrays[3].append(time.time() - t)

            t = time.time()
            # all such non infinite distances are within the invalidation
            # zone and should be marked invalid
            nodes_to_update = ~np.isinf(dm)
            marked = np.sum(valid & ~np.isinf(dm))
            if return_map == True:
                new_sources_closer = dm[
                    nodes_to_update] < mesh_to_skeleton_dist[nodes_to_update]
                mesh_to_skeleton_map[nodes_to_update] = np.where(
                    new_sources_closer, sources[nodes_to_update],
                    mesh_to_skeleton_map[nodes_to_update])
                mesh_to_skeleton_dist[nodes_to_update] = np.where(
                    new_sources_closer, dm[nodes_to_update],
                    mesh_to_skeleton_dist[nodes_to_update])

            valid[~np.isinf(dm)] = False

            # print out how many vertices are still valid
            pbar.update(marked)
            time_arrays[4].append(time.time() - t)
    # record the total time
    dt = time.time() - start

    out_tuple = (paths, path_lengths)
    if return_map:
        out_tuple = out_tuple + (mesh_to_skeleton_map, )
    if return_timing:
        out_tuple = out_tuple + (time_arrays, dt)

    return out_tuple