def by_vertex_clusters(mesh, sampling_dist, cluster_pos='median', output='swc', vertex_map=False, drop_disconnected=False, progress=True): """Skeletonize a contracted mesh by clustering vertices. Notes ----- This algorithm traverses the graph and groups vertices together that are within a given distance to each other. This uses the geodesic (along-the-mesh) distance, not simply the Eucledian distance. Subsequently these groups of vertices are collapsed and re-connected respecting the topology of the input mesh. The graph traversal is fast and scales well, so this method is well suited for meshes with lots of vertices. On the downside: this implementation is not very clever and you might have to play around with the parameters (mostly ``sampling_dist``) to get decent results. Parameters ---------- mesh : mesh obj The mesh to be skeletonize. Can an object that has ``.vertices`` and ``.faces`` properties (e.g. a trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a dictionary ``{'vertices': vertices, 'faces': faces}``. sampling_dist : float | int Maximal distance at which vertices are clustered. This parameter should be tuned based on the resolution of your mesh (see Examples). cluster_pos : "median" | "center" How to determine the x/y/z coordinates of the collapsed vertex clusters (i.e. the skeleton's nodes):: - "median": Use the vertex closest to cluster's center of mass. - "center": Use the center of mass. This makes for smoother skeletons but can lead to nodes outside the mesh. vertex_map : bool If True, we will add a "vertex_id" property to the graph and column to the SWC table that maps the cluster ID its first vertex in the original mesh. output : "swc" | "graph" | "both" Determines the function's output. See ``Returns``. drop_disconnected : bool If True, will drop disconnected nodes from the skeleton. Note that this might result in empty skeletons. progress : bool If True, will show progress bar. Returns ------- "swc" : pandas.DataFrame SWC representation of the skeleton. "graph" : networkx.Graph Graph representation of the skeleton. "both" : tuple Both of the above: ``(swc, graph)``. """ assert output in ['swc', 'graph', 'both'] assert cluster_pos in ['center', 'median'] mesh = make_trimesh(mesh, validate=False) # Produce weighted edges edges = np.concatenate((mesh.edges_unique, mesh.edges_unique_length.reshape(mesh.edges_unique.shape[0], 1)), axis=1) # Generate Graph (must be undirected) G = nx.Graph() G.add_weighted_edges_from(edges) # Run the graph traversal that groups vertices into spatial clusters not_visited = set(G.nodes) seen = set() clusters = [] to_visit = len(not_visited) with tqdm(desc='Clustering', total=len(not_visited), disable=progress is False) as pbar: while not_visited: # Pick a random node start = not_visited.pop() # Get all nodes in the geodesic vicinity cl, seen = dfs(G, n=start, dist_traveled=0, max_dist=sampling_dist, seen=seen) cl = set(cl) # Append this cluster and track visited/not-visited nodes clusters.append(cl) not_visited = not_visited - cl # Update progress bar pbar.update(to_visit - len(not_visited)) to_visit = len(not_visited) # `clusters` is a list of sets -> let's turn it into list of arrays clusters = [np.array(list(c)).astype(int) for c in clusters] # Get positions of clusters if cluster_pos == 'center': # Get the center of each cluster cl_coords = np.array([np.mean(mesh.vertices[c], axis=0) for c in clusters]) elif cluster_pos == 'median': # Get the node that's closest to to the clusters center cl_coords = [] for c in clusters: cnt = np.mean(mesh.vertices[c], axis=0) cnt_dist = np.sum(np.fabs(mesh.vertices[c] - cnt), axis=1) median = mesh.vertices[c][np.argmin(cnt_dist)] cl_coords.append(median) cl_coords = np.array(cl_coords) # Generate edges cl_edges = np.array(mesh.edges_unique) if fastremap: mapping = {n: i for i, l in enumerate(clusters) for n in l} cl_edges = fastremap.remap(cl_edges, mapping, preserve_missing_labels=False, in_place=True) else: for i, c in enumerate(clusters): cl_edges[np.isin(cl_edges, c)] = i # Remove directionality from cluster edges cl_edges = np.sort(cl_edges, axis=1) # Get unique edges cl_edges = np.unique(cl_edges, axis=0) # Calculate edge lengths co1 = cl_coords[cl_edges[:, 0]] co2 = cl_coords[cl_edges[:, 1]] cl_edge_lengths = np.sqrt(np.sum((co1 - co2)**2, axis=1)) # Produce adjacency matrix from edges and edge lengths n_clusters = len(clusters) adj = scipy.sparse.coo_matrix((cl_edge_lengths, (cl_edges[:, 0], cl_edges[:, 1])), shape=(n_clusters, n_clusters)) # The cluster graph likely still contain cycles, let's get rid of them using # a minimum spanning tree mst = scipy.sparse.csgraph.minimum_spanning_tree(adj, overwrite=True) # Turn into COO matrix coo = mst.tocoo() # Extract edge list edges = np.array([coo.row, coo.col]).T # Produce final graph - this also takes care of some fixing G = edges_to_graph(edges, nodes=np.unique(cl_edges.flatten()), drop_disconnected=drop_disconnected, fix_tree=True) # At this point nodes are labeled by index of the cluster # Let's give them a "vertex_id" property mapping back to the # first vertex in that cluster if vertex_map: mapping = {i: l[0] for i, l in enumerate(clusters)} nx.set_node_attributes(G, mapping, name="vertex_id") if output == 'graph': return G # Generate SWC swc = make_swc(G, cl_coords) # Add vertex ID column if requested if vertex_map: swc['vertex_id'] = swc.node_id.map(mapping) if output == 'both': return swc, G return swc
def perform_remap(a, relabel_map): remapped_a = fastremap.remap(a, relabel_map, preserve_missing_labels=True) return remapped_a
def by_edge_collapse(mesh, shape_weight=1, sample_weight=0.1, output='swc', drop_disconnected=False, progress=True): """Skeletonize a (contracted) mesh by collapsing edges. Notes ----- This algorithm (described in [1]) iteratively collapses edges that are part of a face until no more faces are left. Edges are chosen based on a cost function that penalizes collapses that would change the shape of the object or would introduce long edges. This is somewhat sensitive to the dimensions of the input mesh: too large and you might experience slow-downs or numpy OverflowErrors; too low and you might get skeletons that don't quite match the mesh (e.g. too few nodes). If you experience either, try down- or up-scaling your mesh, respectively. Parameters ---------- mesh : mesh obj The mesh to be skeletonize. Can an object that has ``.vertices`` and ``.faces`` properties (e.g. a trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a dictionary ``{'vertices': vertices, 'faces': faces}``. shape_weight : float, optional Weight for shape costs which penalize collapsing edges that would drastically change the shape of the object. sample_weight : float, optional Weight for sampling costs which penalize collapses that would generate prohibitively long edges. output : "swc" | "graph" | "both" Determines the function's output. See ``Returns``. drop_disconnected : bool If True, will drop disconnected nodes from the skeleton. Note that this might result in empty skeletons. progress : bool If True, will show progress bar. Returns ------- "swc" : pandas.DataFrame SWC representation of the skeleton. "graph" : networkx.Graph Graph representation of the skeleton. "both" : tuple Both of the above: ``(swc, graph)``. References ---------- [1] Au OK, Tai CL, Chu HK, Cohen-Or D, Lee TY. Skeleton extraction by mesh contraction. ACM Transactions on Graphics (TOG). 2008 Aug 1;27(3):44. """ assert output in ['swc', 'graph', 'both'] mesh = make_trimesh(mesh, validate=False) # Shorthand faces and edges # We convert to arrays to (a) make a copy and (b) remove potential overhead # from these originally being trimesh TrackedArrays edges = np.array(mesh.edges_unique) verts = np.array(mesh.vertices) # For cost calculations we will normalise coordinates # This prevents getting ridiculuously large cost values ?e300 # verts = (verts - verts.min()) / (verts.max() - verts.min()) edge_lengths = np.sqrt(np.sum((verts[edges[:, 0]] - verts[edges[:, 1]])**2, axis=1)) # Get a list of faces: [(edge1, edge2, edge3), ...] face_edges = np.array(mesh.faces_unique_edges) # Make sure these faces are unique, i.e. no [(e1, e2, e3), (e3, e2, e1)] face_edges = np.unique(np.sort(face_edges, axis=1), axis=0) # Shape cost initialisation: # Each vertex has a matrix Q which is used to determine the shape cost # of collapsing each node. We need to generate a matrix (Q) for each # vertex, then when we collapse two nodes, we can update using # Qj <- Qi + Qj, so the edges previously associated with vertex i # are now associated with vertex j. # For each edge, generate a matrix (K). K is made up of two sets of # coordinates in 3D space, a and b. a is the normalised edge vector # of edge(i,j) and b = a * <x/y/z coordinates of vertex i> # # The matrix K takes the form: # # Kij = 0, -az, ay, -bx # az, 0, -ax, -by # -ay, ax, 0, -bz edge_co0, edge_co1 = verts[edges[:, 0]], verts[edges[:, 1]] a = (edge_co1 - edge_co0) / edge_lengths.reshape(edges.shape[0], 1) # Note: It's a bit unclear to me whether the normalised edge vector should # be allowed to have negative values but I seem to be getting better # results if I use absolute values a = np.fabs(a) b = a * edge_co0 # Bunch of zeros zero = np.zeros(a.shape[0]) # Generate matrix K K = [[zero, -a[:, 2], a[:, 1], -b[:, 0]], [a[:, 2], zero, -a[:, 0], -b[:, 1]], [-a[:, 1], a[:, 0], zero, -b[:, 2]]] K = np.array(K) # Q for vertex i is then the sum of the products of (kT,k) for ALL edges # connected to vertex i: # Initialize matrix of correct shape Q_array = np.zeros((4, 4, verts.shape[0]), dtype=np.float128) # Generate (kT, K) kT = np.transpose(K, axes=(1, 0, 2)) # To get the sum of the products in the correct format we have to # do some annoying transposes to get to (4, 4, len(edges)) K_dot = np.matmul(K.T, kT.T).T # Iterate over all vertices for v in range(len(verts)): # Find edges that contain this vertex cond1 = edges[:, 0] == v cond2 = edges[:, 1] == v # Note that this does not take directionality of edges into account # Not sure if that's intended? # Get indices of these edges indices = np.where(cond1 | cond2)[0] # Get the products for all edges adjacent to mesh Q = K_dot[:, :, indices] # Sum over all edges Q = Q.sum(axis=2) # Add to Q array Q_array[:, :, v] = Q # Not sure if we are doing something wrong when calculating the Q array but # we end up having negative values which translate into negative scores. # This in turn is bad because we propagate that negative score when # collapsing edges which leads to a "zipper-effect" where nodes collapse # in sequence a->b->c->d until they hit some node with really high cost # Q_array -= Q_array.min() # Edge collapse: # Determining which edge to collapse is a weighted sum of the shape and # sampling cost. The shape cost of vertex i is Fa(p) = pT Qi p where p is # the coordinates of point p (vertex i here) in homogeneous representation. # The variable w from above is the value for the homogeneous 4th dimension. # T denotes transpose of matrix. # The shape cost of collapsing the edge Fa(i,j) = Fi(vj) + Fj(vj). # vi and vj being the coordinates of the vertex in homogeneous representation # (p in equation before) # The sampling cost penalises edge collapses that generate overly long edges, # based on the distance traveled by all edges to vi, when vi is merged with # vj. (Eq. 7 in paper) # You cannot collapse an edge (i -> j) if k is a common adjacent vertex of # both i and j, but (i/j/k) is not a face. # We will set the cost of these edges to infinity. # Now work out the shape cost of collapsing each node (eq. 7) # First get coordinates of the first node of each edge # Note that in Nik's implementation this was the second node p = verts[edges[:, 0]] # Append weight factor w = 1 p = np.append(p, np.full((p.shape[0], 1), w), axis=1) this_Q1 = Q_array[:, :, edges[:, 0]] this_Q2 = Q_array[:, :, edges[:, 1]] F1 = np.einsum('ij,kji->ij', p, this_Q1)[:, [0, 1]] F2 = np.einsum('ij,kji->ij', p, this_Q2)[:, [0, 1]] # Calculate and append shape cost F = np.append(F1, F2, axis=1) shape_cost = np.sum(F, axis=1) # Sum lengths of all edges associated with a given vertex # This is easiest by generating a sparse matrix from the edges # and then summing by row adj = scipy.sparse.coo_matrix((edge_lengths, (edges[:, 0], edges[:, 1])), shape=(verts.shape[0], verts.shape[0])) # This makes sure the matrix is symmetrical, i.e. a->b == a<-b # Note that I'm not sure whether this is strictly necessary but it really # can't hurt adj = adj + adj.T # Get the lengths associated with each vertex verts_lengths = adj.sum(axis=1) # We need to flatten this (something funny with summing sparse matrices) verts_lengths = np.array(verts_lengths).flatten() # Map the sum of vertex lengths onto edges (as per first vertex in edge) ik_edge = verts_lengths[edges[:, 0]] # Calculate sampling cost sample_cost = edge_lengths * (ik_edge - edge_lengths) # Determine which edge to collapse and collapse it # Total Cost - weighted sum of shape and sample cost, equation 8 in paper F_T = shape_cost * shape_weight + sample_cost * sample_weight # Now start collapsing edges one at a time face_count = face_edges.shape[0] # keep track of face counts for progress bar is_collapsed = np.full(edges.shape[0], False) keep = np.full(edges.shape[0], False) with tqdm(desc='Collapsing edges', total=face_count, disable=progress is False) as pbar: while face_edges.size: # Uncomment to get a more-or-less random edge collapse # F_T[:] = 0 # Update progress bar pbar.update(face_count - face_edges.shape[0]) face_count = face_edges.shape[0] # This has to come at the beginning of the loop # Set cost of collapsing edges without faces to infinite F_T[keep] = np.inf F_T[is_collapsed] = np.inf # Get the edge that we want to collapse collapse_ix = np.argmin(F_T) # Get the vertices this edge connects u, v = edges[collapse_ix] # Get all edges that contain these vertices: # First, edges that are (uv, x) connects_uv = np.isin(edges[:, 0], [u, v]) # Second, check if any (uv, x) edges are (uv, uv) connects_uv[connects_uv] = np.isin(edges[connects_uv, 1], [u, v]) # Remove uu and vv edges uuvv = edges[:, 0] == edges[:, 1] connects_uv = connects_uv & ~uuvv # Get the edge's indices clps_edges = np.where(connects_uv)[0] # Now find find the faces the collapsed edge is part of # Note: splitting this into three conditions is marginally faster than # np.any(np.isin(face_edges, clps_edges), axis=1) uv0 = np.isin(face_edges[:, 0], clps_edges) uv1 = np.isin(face_edges[:, 1], clps_edges) uv2 = np.isin(face_edges[:, 2], clps_edges) has_uv = uv0 | uv1 | uv2 # If these edges do not have adjacent faces anymore if not np.any(has_uv): # Track this edge as a keeper keep[clps_edges] = True continue # Get the collapsed faces [(e1, e2, e3), ...] for this edge clps_faces = face_edges[has_uv] # Remove the collapsed faces face_edges = face_edges[~has_uv] # Track these edges as collapsed is_collapsed[clps_edges] = True # Get the adjacent edges (i.e. non-uv edges) adj_edges = clps_faces[~np.isin(clps_faces, clps_edges)].reshape(clps_faces.shape[0], 2) # We have to do some sorting and finding unique edges to make sure # remapping is done correctly further down # NOTE: Not sure we really need this, so leaving it out for now # adj_edges = np.unique(np.sort(adj_edges, axis=1), axis=0) # We need to keep track of changes to the adjacent faces # Basically each face in (i, j, k) will be reduced to one edge # which points from u -> v # -> replace occurrences of loosing edge with winning edge for win, loose in adj_edges: if fastremap: face_edges = fastremap.remap(face_edges, {loose: win}, preserve_missing_labels=True, in_place=True) else: face_edges[face_edges == loose] = win is_collapsed[loose] = True # Replace occurrences of first node u with second node v if fastremap: edges = fastremap.remap(edges, {u: v}, preserve_missing_labels=True, in_place=True) else: edges[edges == u] = v # Add shape cost of u to shape costs of v Q_array[:, :, v] += Q_array[:, :, u] # Determine which edges require update of costs: # In theory we only need to update costs for edges that are # associated with vertices v and u (which now also v) has_v = (edges[:, 0] == v) | (edges[:, 1] == v) # Uncomment to temporarily force updating costs for all edges # has_v[:] = True # Update shape costs this_Q1 = Q_array[:, :, edges[has_v, 0]] this_Q2 = Q_array[:, :, edges[has_v, 1]] F1 = np.einsum('ij,kji->ij', p[edges[has_v, 0]], this_Q1)[:, [0, 1]] F2 = np.einsum('ij,kji->ij', p[edges[has_v, 1]], this_Q2)[:, [0, 1]] F = np.append(F1, F2, axis=1) new_shape_cost = np.sum(F, axis=1) # Update sum of incoming edge lengths # Technically we would have to recalculate lengths of adjacent edges # every time but we will take the cheap way out and simply add them up verts_lengths[v] += verts_lengths[u] # Update sample costs for edges associated with v ik_edge = verts_lengths[edges[has_v, 0]] new_sample_cost = edge_lengths[has_v] * (ik_edge - edge_lengths[has_v]) F_T[has_v] = new_shape_cost * shape_weight + new_sample_cost * sample_weight # After the edge collapse, the edges are garbled - I have yet to figure out # why and whether that can be prevented. However the vertices in those # edges are correct and so we just need to reconstruct their connectivity # by extracting a minimum spanning tree over the mesh. corrected_edges = mst_over_mesh(mesh, edges[keep].flatten()) # Generate graph G = edges_to_graph(corrected_edges, vertices=mesh.vertices, fix_tree=True, weight=False, drop_disconnected=True) if output == 'graph': return G swc = make_swc(G, mesh) if output == 'both': return (G, swc) return swc
def merge_vertices(mesh, dist='auto', inplace=False): """Merge vertices closer than a given distance. Parameters ---------- mesh : trimesh.Trimesh Mesh to merge vertices on. dist : "auto" | number Distance at which to merge vertices. If "auto" will use ``mesh.edges_unique_length.mean() / 100``. inplace : bool If True will modify the original mesh. Returns ------- trimesh.Trimesh """ assert isinstance(mesh, tm.Trimesh) if not inplace: mesh = mesh.copy() # Generate KDTree tree = sp.spatial.cKDTree(mesh.vertices) if dist == 'auto': dist = mesh.edges_unique_length.mean() / 100 # Query tree pairs = tree.query_pairs(dist) # Facilitate remapping by removing extra steps: A->B->C to A->C, B->C G = nx.Graph() G.add_edges_from(pairs) mapping = { n: list(c)[0] for c in nx.connected_components(G) for n in list(c)[1:] } with mesh._cache: # Update faces if fastremap: mesh.faces = fastremap.remap(mesh.faces, mapping, preserve_missing_labels=True, in_place=True) else: for k, v in mapping.items(): mesh.faces[mesh.faces == k] = v # Remove dropped vertices remove = np.isin(np.arange(mesh.vertices.shape[0]), list(mapping.keys())) mesh.update_vertices(~remove) # Remove degenerate and duplicate faces mesh.remove_degenerate_faces() mesh.remove_duplicate_faces() # Fix normals mesh.fix_normals() return mesh
def collapse_soma_skeleton( soma_verts, soma_pt, verts, edges, mesh_to_skeleton_map=None, collapse_index=None, return_filter=False, return_soma_ind=False, ): """function to adjust skeleton result to move root to soma_pt Parameters ---------- soma_pt : numpy.array a 3 long vector of xyz locations of the soma (None to just remove duplicate ) verts : numpy.array a Nx3 array of xyz vertex locations edges : numpy.array a Kx2 array of edges of the skeleton soma_d_thresh : float distance from soma_pt to collapse skeleton nodes mesh_to_skeleton_map : np.array a M long array of how each mesh index maps to a skeleton vertex (default None). The function will update this as it collapses vertices to root. soma_mesh_indices : np.array a K long array of indices in the mesh that should be considered soma Any skeleton vertex on these vertices will all be collapsed to root. return_filter : bool whether to return a list of which skeleton vertices were used in the end for the reduced set of skeleton vertices only_soma_component : bool whether to collapse only the skeleton connected component which is closest to the soma_pt (default True) return_soma_ind : bool whether to return which skeleton index that is the soma_pt Returns ------- np.array verts, Px3 array of xyz skeleton vertices np.array edges, Qx2 array of skeleton edges (np.array) new_mesh_to_skeleton_map, returned if mesh_to_skeleton_map and soma_pt passed (np.array) used_vertices, if return_filter this contains the indices into the passed verts which the return verts is using int an index into the returned verts that is the root of the skeleton node, only returned if return_soma_ind is True """ if soma_verts is not None: soma_pt_m = soma_pt.reshape(1, 3) if collapse_index is None: new_verts = np.vstack((verts, soma_pt_m)) soma_i = verts.shape[0] else: new_verts = verts soma_i = collapse_index soma_verts = soma_verts[soma_verts != soma_i] edges_m = edges.copy() edges_m[np.isin(edges, soma_verts)] = soma_i simple_verts, simple_edges = utils.remove_unused_verts(new_verts, edges_m) good_edges = ~(simple_edges[:, 0] == simple_edges[:, 1]) if mesh_to_skeleton_map is not None: consolidate_dict = {v: soma_i for v in soma_verts} new_index_dict, _ = utils.remap_dict(len(new_verts), consolidate_dict) new_index_dict[-1] = -1 mesh_to_skeleton_map[np.isnan(mesh_to_skeleton_map)] = -1 new_mesh_to_skeleton_map = fastremap.remap( mesh_to_skeleton_map, new_index_dict ) output = [simple_verts, simple_edges[good_edges]] if mesh_to_skeleton_map is not None: output.append(new_mesh_to_skeleton_map) if return_filter: used_vertices = np.unique(edges_m.ravel()) if collapse_index is None: # Remove the largest value which is soma_i used_vertices = used_vertices[:-1] output.append(used_vertices) if return_soma_ind: output.append(new_index_dict[soma_i]) return output else: simple_verts, simple_edges = utils.remove_unused_verts(verts, edges) return simple_verts, simple_edges
def remap_segmentation(cv, chunk_x, chunk_y, chunk_z, mip=2, overlap_vx=1, time_stamp=None, progress=False): ws_cv = CloudVolume(cv.meta.cloudpath, mip=mip, progress=progress, fill_missing=cv.fill_missing) mip_diff = mip - cv.meta.watershed_mip mip_chunk_size = np.array(cv.meta.graph_chunk_size, dtype=np.int) / np.array( [2**mip_diff, 2**mip_diff, 1]) mip_chunk_size = mip_chunk_size.astype(np.int) offset = Vec(chunk_x, chunk_y, chunk_z) * mip_chunk_size bbx = Bbox(offset, offset + mip_chunk_size + overlap_vx) if cv.meta.chunks_start_at_voxel_offset: bbx += ws_cv.voxel_offset bbx = Bbox.clamp(bbx, ws_cv.bounds) seg = ws_cv[bbx][..., 0] if not np.any(seg): return seg sv_remapping, unsafe_dict = get_lx_overlapping_remappings( cv, chunk_x, chunk_y, chunk_z, time_stamp=time_stamp, progress=progress) seg = fastremap.mask_except(seg, list(sv_remapping.keys()), in_place=True) fastremap.remap(seg, sv_remapping, preserve_missing_labels=True, in_place=True) for unsafe_root_id in tqdm(unsafe_dict.keys(), desc="Unsafe Relabel", disable=(not progress)): bin_seg = seg == unsafe_root_id if np.sum(bin_seg) == 0: continue l2_edges = [] cc_seg = cc3d.connected_components(bin_seg) for i_cc in range(1, np.max(cc_seg) + 1): bin_cc_seg = cc_seg == i_cc overlaps = [] overlaps.extend(np.unique(seg[-2, :, :][bin_cc_seg[-1, :, :]])) overlaps.extend(np.unique(seg[:, -2, :][bin_cc_seg[:, -1, :]])) overlaps.extend(np.unique(seg[:, :, -2][bin_cc_seg[:, :, -1]])) overlaps = np.unique(overlaps) linked_l2_ids = overlaps[np.in1d(overlaps, unsafe_dict[unsafe_root_id])] if len(linked_l2_ids) == 0: seg[bin_cc_seg] = 0 elif len(linked_l2_ids) == 1: seg[bin_cc_seg] = linked_l2_ids[0] else: seg[bin_cc_seg] = linked_l2_ids[0] for i_l2_id in range(len(linked_l2_ids) - 1): for j_l2_id in range(i_l2_id + 1, len(linked_l2_ids)): l2_edges.append( [linked_l2_ids[i_l2_id], linked_l2_ids[j_l2_id]]) if len(l2_edges) > 0: g = nx.Graph() g.add_edges_from(l2_edges) ccs = nx.connected_components(g) for cc in ccs: cc_ids = np.sort(list(cc)) seg[np.in1d(seg, cc_ids[1:]).reshape(seg.shape)] = cc_ids[0] return seg
def l2_skeleton(root_id, refine=False, drop_missing=True, threads=10, progress=True, dataset='production', **kwargs): """Generate skeleton from L2 graph. Parameters ---------- root_id : int | list of ints Root ID(s) of the flywire neuron(s) you want to skeletonize. refine : bool If True, will refine skeleton nodes by moving them in the center of their corresponding chunk meshes. Only relevant if ``refine=True``: drop_missing : bool If True, will drop nodes that don't have a corresponding chunk mesh. These are typically chunks that are very small and dropping them might actually be benefitial. threads : int How many parallel threads to use for fetching the chunk meshes. Reduce the number if you run into ``HTTPErrors``. Only relevant if `use_flycache=False`. progress : bool Whether to show a progress bar. Returns ------- skeleton : navis.TreeNeuron The extracted skeleton. Examples -------- >>> from fafbseg import flywire >>> n = flywire.l2_skeleton(720575940614131061) """ # TODO: # - drop duplicate nodes in unrefined skeleton # - use L2 graph to find soma: highest degree is typically the soma use_flycache = kwargs.get('use_flycache', False) if refine and use_flycache and dataset != 'production': raise ValueError('Unable to use fly cache to fetch L2 centroids for ' 'sandbox dataset. Please set `use_flycache=False`.') if navis.utils.is_iterable(root_id): nl = [] for id in navis.config.tqdm(root_id, desc='Skeletonizing', disable=not progress, leave=False): n = l2_skeleton(id, refine=refine, drop_missing=drop_missing, threads=threads, progress=progress, dataset=dataset) nl.append(n) return navis.NeuronList(nl) # Get the cloudvolume vol = parse_volume(dataset) # Hard-coded datastack names ds = {"production": "flywire_fafb_production", "sandbox": "flywire_fafb_sandbox"} # Note that the default server url is https://global.daf-apis.com/info/ client = FrameworkClient(ds.get(dataset, dataset)) # Load the L2 graph for given root ID # This is a (N,2) array of edges l2_eg = np.array(client.chunkedgraph.level2_chunk_graph(root_id)) # Drop duplicate edges l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0) # Unique L2 IDs l2_ids = np.unique(l2_eg) # ID to index l2dict = {l2: ii for ii, l2 in enumerate(l2_ids)} # Remap edge graph to indices eg_arr_rm = fastremap.remap(l2_eg, l2dict) coords = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids] coords = np.vstack(coords) # This turns the graph into a hierarchal tree by removing cycles and # ensuring all edges point towards a root if sk.__version_vector__[0] < 1: G = sk.skeletonizers.edges_to_graph(eg_arr_rm) swc = sk.skeletonizers.make_swc(G, coords=coords) else: G = sk.skeletonize.utils.edges_to_graph(eg_arr_rm) swc = sk.skeletonize.utils.make_swc(G, coords=coords, reindex=False) # Convert to Euclidian space # Dimension of a single chunk ch_dims = chunks_to_nm([1, 1, 1], vol) - chunks_to_nm([0, 0, 0], vol) ch_dims = np.squeeze(ch_dims) xyz = swc[['x', 'y', 'z']].values swc[['x', 'y', 'z']] = chunks_to_nm(xyz, vol) + ch_dims / 2 if refine: if use_flycache: token = get_chunkedgraph_secret() centroids = spine.flycache.get_L2_centroids(l2_ids, token=token, progress=progress) # Drop missing (i.e. [0,0,0]) meshes centroids = {k: v for k, v in centroids.items() if v != [0, 0, 0]} else: # Get the centroids centroids = get_L2_centroids(l2_ids, vol, threads=threads, progress=progress) new_co = {l2dict[k]: v for k, v in centroids.items()} # Map refined coordinates onto the SWC has_new = swc.node_id.isin(new_co) swc.loc[has_new, 'x'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][0]) swc.loc[has_new, 'y'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][1]) swc.loc[has_new, 'z'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][2]) # Turn into a proper neuron tn = navis.TreeNeuron(swc, id=root_id, units='1 nm') # Drop nodes that are still at their unrefined chunk position if drop_missing: tn = navis.remove_nodes(tn, swc.loc[~has_new, 'node_id'].values) else: tn = navis.TreeNeuron(swc, id=root_id, units='1 nm') return tn
def download( self, bbox, mip=None, parallel=None, segids=None, preserve_zeros=False, agglomerate=False, timestamp=None, stop_layer=None ): """ Downloads base segmentation and optionally agglomerates labels based on information in the graph server. bbox: specifies cutout to fetch mip: which resolution level to get (default self.mip) parallel: what parallel level to use (default self.parallel) agglomerate: if true, remap all watershed ids in the volume and return a flat segmentation. if agglomerate is true these options are available: timestamp: (agglomerate only) get the roots from this date and time formats accepted: int: unix timestamp datetime: self explainatory string: ISO 8601 date stop_layer: (agglomerate only) (int) if specified, return the lowest parent at or above that layer. If not specified, go all the way to the root id. Layer 1: Watershed Layer 2: Within-Chunk Agglomeration Layer 2+: Between chunk interconnections (skip connections possible) if agglomerate is false, these other options come into play: segids: agglomerate the leaves of these segids from the graph server and label them with the given segid. preserve_zeros: If segids is not None: False: mask other segids with zero True: mask other segids with the largest integer value contained by the image data type and leave zero as is. Returns: img as a VolumeCutout """ if type(bbox) is Vec: bbox = Bbox(bbox, bbox+1) bbox = Bbox.create( bbox, context=self.bounds, bounded=self.bounded, autocrop=self.autocrop ) if bbox.subvoxel(): raise exceptions.EmptyRequestException("Requested {} is smaller than a voxel.".format(bbox)) if mip is None: mip = self.mip mip0_bbox = self.bbox_to_mip(bbox, mip=mip, to_mip=0) # Only ever necessary to make requests within the bounding box # to the server. We can fill black in other situations. mip0_bbox = bbox.intersection(self.meta.bounds(0), mip0_bbox) img = super(CloudVolumeGraphene, self).download(bbox, mip=mip, parallel=parallel) if agglomerate: img = self.agglomerate_cutout(img, timestamp=timestamp, stop_layer=stop_layer) return VolumeCutout.from_volume(self.meta, mip, img, bbox) if segids is None: return img segids = list(toiter(segids)) remapping = {} for segid in segids: leaves = self.get_leaves(segid, mip0_bbox, 0) remapping.update({ leaf: segid for leaf in leaves }) img = fastremap.remap(img, remapping, preserve_missing_labels=True, in_place=True) mask_value = 0 if preserve_zeros: mask_value = np.inf if np.issubdtype(self.dtype, np.integer): mask_value = np.iinfo(self.dtype).max segids.append(0) img = fastremap.mask_except(img, segids, in_place=True, value=mask_value) return VolumeCutout.from_volume( self.meta, mip, img, bbox )
def agglomerate_cutout(self, img, timestamp=None, stop_layer=None): """Remap a graphene volume to its latest root ids. This creates a flat segmentation.""" labels = fastremap.unique(img) roots = self.get_roots(labels, timestamp=timestamp, binary=True, stop_layer=stop_layer) mapping = { segid: root for segid, root in zip(labels, roots) } return fastremap.remap(img, mapping, preserve_missing_labels=True, in_place=True)
def by_vertex_clusters(mesh, sampling_dist, cluster_pos='median', progress=True): """Skeletonize a (contracted) mesh by clustering vertices. The algorithm traverses the mesh graph and groups vertices together that are within a given distance to each other. This uses the geodesic (along-the-mesh) distance, not simply the Eucledian distance. Subsequently these groups of vertices are collapsed and re-connected respecting the topology of the input mesh. The graph traversal is fast and scales well, so this method is well suited for meshes with lots of vertices. On the downside: this implementation is not very clever and you might have to play around with the parameters (mostly ``sampling_dist``) to get decent results. Parameters ---------- mesh : mesh obj The mesh to be skeletonize. Can an object that has ``.vertices`` and ``.faces`` properties (e.g. a trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a dictionary ``{'vertices': vertices, 'faces': faces}``. sampling_dist : float | int Maximal distance at which vertices are clustered. This parameter should be tuned based on the resolution of your mesh (see Examples). cluster_pos : "median" | "center" How to determine the x/y/z coordinates of the collapsed vertex clusters (i.e. the skeleton's nodes):: - "median": Use the vertex closest to cluster's center of mass. - "center": Use the center of mass. This makes for smoother skeletons but can lead to nodes outside the mesh. progress : bool If True, will show progress bar. Examples -------- >>> import skeletor as sk >>> mesh = sk.example_mesh() >>> cont = sk.pre.contract(mesh, epsilon=0.1) >>> skel = sk.skeletonize.vertex_cluster(cont) >>> skel.mesh = mesh Returns ------- skeletor.Skeleton Holds results of the skeletonization and enables quick visualization. """ assert cluster_pos in ['center', 'median'] mesh = make_trimesh(mesh, validate=False) # Produce weighted edges edges = np.concatenate((mesh.edges_unique, mesh.edges_unique_length.reshape(mesh.edges_unique.shape[0], 1)), axis=1) # Generate Graph (must be undirected) G = nx.Graph() G.add_weighted_edges_from(edges) # Run the graph traversal that groups vertices into spatial clusters not_visited = set(G.nodes) seen = set() clusters = [] to_visit = len(not_visited) with tqdm(desc='Clustering', total=len(not_visited), disable=progress is False) as pbar: while not_visited: # Pick a random node start = not_visited.pop() # Get all nodes in the geodesic vicinity cl, seen = dfs(G, n=start, dist_traveled=0, max_dist=sampling_dist, seen=seen) cl = set(cl) # Append this cluster and track visited/not-visited nodes clusters.append(cl) not_visited = not_visited - cl # Update progress bar pbar.update(to_visit - len(not_visited)) to_visit = len(not_visited) # `clusters` is a list of sets -> let's turn it into list of arrays clusters = [np.array(list(c)).astype(int) for c in clusters] # Get positions of clusters if cluster_pos == 'center': # Get the center of each cluster cl_coords = np.array([np.mean(mesh.vertices[c], axis=0) for c in clusters]) elif cluster_pos == 'median': # Get the node that's closest to to the clusters center cl_coords = [] for c in clusters: cnt = np.mean(mesh.vertices[c], axis=0) cnt_dist = np.sum(np.fabs(mesh.vertices[c] - cnt), axis=1) median = mesh.vertices[c][np.argmin(cnt_dist)] cl_coords.append(median) cl_coords = np.array(cl_coords) # Generate edges cl_edges = np.array(mesh.edges_unique) if fastremap: mapping = {n: i for i, l in enumerate(clusters) for n in l} cl_edges = fastremap.remap(cl_edges, mapping, preserve_missing_labels=False, in_place=True) else: for i, c in enumerate(clusters): cl_edges[np.isin(cl_edges, c)] = i # Remove directionality from cluster edges cl_edges = np.sort(cl_edges, axis=1) # Get unique edges cl_edges = np.unique(cl_edges, axis=0) # Calculate edge lengths co1 = cl_coords[cl_edges[:, 0]] co2 = cl_coords[cl_edges[:, 1]] cl_edge_lengths = np.sqrt(np.sum((co1 - co2)**2, axis=1)) # Produce adjacency matrix from edges and edge lengths n_clusters = len(clusters) adj = scipy.sparse.coo_matrix((cl_edge_lengths, (cl_edges[:, 0], cl_edges[:, 1])), shape=(n_clusters, n_clusters)) # The cluster graph likely still contain cycles, let's get rid of them using # a minimum spanning tree mst = scipy.sparse.csgraph.minimum_spanning_tree(adj, overwrite=True) # Turn into COO matrix coo = mst.tocoo() # Extract edge list edges = np.array([coo.row, coo.col]).T # Produce final graph - this also takes care of some fixing G = edges_to_graph(edges, nodes=np.unique(cl_edges.flatten()), drop_disconnected=False, fix_tree=True) # Generate a mesh vertex -> skeleton vertex map # Note that nodes are labeled by index of the cluster vertex_to_node_map = [i for i, cl in enumerate(clusters) for n in cl] # Generate SWC swc, new_ids = make_swc(G, cl_coords, reindex=True, validate=False) # Update mesh map vertex_to_node_map = np.array([new_ids[n] for n in vertex_to_node_map]) return Skeleton(swc=swc, mesh=mesh, mesh_map=vertex_to_node_map, method='vertex_clusters')
def collapse_zero_length_edges(vertices, edges, root, radius, mesh_to_skel_map, mesh_index, node_mask, vertex_properties={}): "Remove zero length edges from a skeleton" zl = np.linalg.norm(vertices[edges[:, 0]] - vertices[edges[:, 1]], axis=1) == 0 if not np.any(zl): return vertices, edges, root, radius, mesh_to_skel_map, mesh_index, node_mask, vertex_properties consolidate_dict = {x[0]: x[1] for x in edges[zl]} # Compress multiple zero edges in a row while np.any( np.isin(np.array(list(consolidate_dict.keys())), np.array(list(consolidate_dict.values())))): all_keys = np.array(list(consolidate_dict.keys())) dup_keys = np.flatnonzero( np.isin(all_keys, np.array(list(consolidate_dict.values())))) first_key = all_keys[dup_keys[0]] first_val = consolidate_dict.pop(first_key) for ii, jj in consolidate_dict.items(): if jj == first_key: consolidate_dict[ii] = first_val new_index_dict, node_filter = remap_dict(len(vertices), consolidate_dict) new_vertices = vertices[node_filter] new_edges = fastremap.remap(edges, new_index_dict) new_edges = new_edges[new_edges[:, 0] != new_edges[:, 1]] if mesh_to_skel_map is not None: new_index_dict[-1] = -1 new_mesh_to_skel_map = fastremap.remap(mesh_to_skel_map, new_index_dict) else: new_mesh_to_skel_map = None new_root = new_index_dict.get(root, root) if radius is not None: new_radius = radius[node_filter] else: new_radius = None if mesh_index is not None: new_mesh_index = mesh_index[node_filter] else: new_mesh_index = None if node_mask is not None: new_node_mask = node_mask[node_filter] else: new_node_mask = None new_vp = {} for vp, val in vertex_properties.items(): try: new_vp[vp] = val[node_filter] except: pass return new_vertices, new_edges, new_root, new_radius, new_mesh_to_skel_map, new_mesh_index, new_node_mask, new_vp
def download(self, bbox, mip=None, parallel=None, segids=None, preserve_zeros=False): """ Downloads base segmentation and optionally agglomerates labels based on information in the graph server. bbox: specifies cutout to fetch mip: which resolution level to get (default self.mip) parallel: what parallel level to use (default self.parallel) segids: agglomerate the leaves of these segids from the graph server and label them with the given segid. preserve_zeros: If segids is not None: False: mask other segids with zero True: mask other segids with the largest integer value contained by the image data type and leave zero as is. Returns: img """ if type(bbox) is Vec: bbox = Bbox(bbox, bbox + 1) bbox = Bbox.create(bbox, context=self.bounds, bounded=self.bounded, autocrop=self.autocrop) if bbox.subvoxel(): raise exceptions.EmptyRequestException( "Requested {} is smaller than a voxel.".format(bbox)) if mip is None: mip = self.mip mip0_bbox = self.bbox_to_mip(bbox, mip=mip, to_mip=0) # Only ever necessary to make requests within the bounding box # to the server. We can fill black in other situations. mip0_bbox = bbox.intersection(self.meta.bounds(0), mip0_bbox) img = super(CloudVolumeGraphene, self).download(bbox, mip=mip, parallel=parallel) if segids is None: return img segids = list(toiter(segids)) remapping = {} for segid in segids: leaves = self.get_leaves(segid, mip0_bbox, 0) remapping.update({leaf: segid for leaf in leaves}) img = fastremap.remap(img, remapping, preserve_missing_labels=True, in_place=True) mask_value = 0 if preserve_zeros: mask_value = np.inf if np.issubdtype(self.dtype, np.integer): mask_value = np.iinfo(self.dtype).max segids.append(0) img = fastremap.mask_except(img, segids, in_place=True, value=mask_value) return VolumeCutout.from_volume(self.meta, mip, img, bbox)