Example #1
0
def by_vertex_clusters(mesh, sampling_dist, cluster_pos='median',
                       output='swc', vertex_map=False,
                       drop_disconnected=False, progress=True):
    """Skeletonize a contracted mesh by clustering vertices.

    Notes
    -----
    This algorithm traverses the graph and groups vertices together that are
    within a given distance to each other. This uses the geodesic
    (along-the-mesh) distance, not simply the Eucledian distance. Subsequently
    these groups of vertices are collapsed and re-connected respecting the
    topology of the input mesh.

    The graph traversal is fast and scales well, so this method is well suited
    for meshes with lots of vertices. On the downside: this implementation is
    not very clever and you might have to play around with the parameters
    (mostly ``sampling_dist``) to get decent results.

    Parameters
    ----------
    mesh :          mesh obj
                    The mesh to be skeletonize. Can an object that has
                    ``.vertices`` and ``.faces`` properties  (e.g. a
                    trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a
                    dictionary ``{'vertices': vertices, 'faces': faces}``.
    sampling_dist : float | int
                    Maximal distance at which vertices are clustered. This
                    parameter should be tuned based on the resolution of your
                    mesh (see Examples).
    cluster_pos :   "median" | "center"
                    How to determine the x/y/z coordinates of the collapsed
                    vertex clusters (i.e. the skeleton's nodes)::

                      - "median": Use the vertex closest to cluster's center of
                        mass.
                      - "center": Use the center of mass. This makes for smoother
                        skeletons but can lead to nodes outside the mesh.
    vertex_map :    bool
                    If True, we will add a "vertex_id" property to the graph and
                    column to the SWC table that maps the cluster ID its first
                    vertex in the original mesh.
    output :        "swc" | "graph" | "both"
                    Determines the function's output. See ``Returns``.
    drop_disconnected : bool
                    If True, will drop disconnected nodes from the skeleton.
                    Note that this might result in empty skeletons.
    progress :      bool
                    If True, will show progress bar.

    Returns
    -------
    "swc" :         pandas.DataFrame
                    SWC representation of the skeleton.
    "graph" :       networkx.Graph
                    Graph representation of the skeleton.
    "both" :        tuple
                    Both of the above: ``(swc, graph)``.

    """
    assert output in ['swc', 'graph', 'both']
    assert cluster_pos in ['center', 'median']

    mesh = make_trimesh(mesh, validate=False)

    # Produce weighted edges
    edges = np.concatenate((mesh.edges_unique,
                            mesh.edges_unique_length.reshape(mesh.edges_unique.shape[0], 1)),
                           axis=1)

    # Generate Graph (must be undirected)
    G = nx.Graph()
    G.add_weighted_edges_from(edges)

    # Run the graph traversal that groups vertices into spatial clusters
    not_visited = set(G.nodes)
    seen = set()
    clusters = []
    to_visit = len(not_visited)
    with tqdm(desc='Clustering', total=len(not_visited), disable=progress is False) as pbar:
        while not_visited:
            # Pick a random node
            start = not_visited.pop()
            # Get all nodes in the geodesic vicinity
            cl, seen = dfs(G, n=start, dist_traveled=0,
                           max_dist=sampling_dist, seen=seen)
            cl = set(cl)

            # Append this cluster and track visited/not-visited nodes
            clusters.append(cl)
            not_visited = not_visited - cl

            # Update  progress bar
            pbar.update(to_visit - len(not_visited))
            to_visit = len(not_visited)

    # `clusters` is a list of sets -> let's turn it into list of arrays
    clusters = [np.array(list(c)).astype(int) for c in clusters]

    # Get positions of clusters
    if cluster_pos == 'center':
        # Get the center of each cluster
        cl_coords = np.array([np.mean(mesh.vertices[c], axis=0) for c in clusters])
    elif cluster_pos == 'median':
        # Get the node that's closest to to the clusters center
        cl_coords = []
        for c in clusters:
            cnt = np.mean(mesh.vertices[c], axis=0)
            cnt_dist = np.sum(np.fabs(mesh.vertices[c] - cnt), axis=1)
            median = mesh.vertices[c][np.argmin(cnt_dist)]
            cl_coords.append(median)
        cl_coords = np.array(cl_coords)

    # Generate edges
    cl_edges = np.array(mesh.edges_unique)
    if fastremap:
        mapping = {n: i for i, l in enumerate(clusters) for n in l}
        cl_edges = fastremap.remap(cl_edges, mapping, preserve_missing_labels=False, in_place=True)
    else:
        for i, c in enumerate(clusters):
            cl_edges[np.isin(cl_edges, c)] = i

    # Remove directionality from cluster edges
    cl_edges = np.sort(cl_edges, axis=1)

    # Get unique edges
    cl_edges = np.unique(cl_edges, axis=0)

    # Calculate edge lengths
    co1 = cl_coords[cl_edges[:, 0]]
    co2 = cl_coords[cl_edges[:, 1]]
    cl_edge_lengths = np.sqrt(np.sum((co1 - co2)**2, axis=1))

    # Produce adjacency matrix from edges and edge lengths
    n_clusters = len(clusters)
    adj = scipy.sparse.coo_matrix((cl_edge_lengths,
                                   (cl_edges[:, 0], cl_edges[:, 1])),
                                  shape=(n_clusters, n_clusters))

    # The cluster graph likely still contain cycles, let's get rid of them using
    # a minimum spanning tree
    mst = scipy.sparse.csgraph.minimum_spanning_tree(adj,
                                                     overwrite=True)

    # Turn into COO matrix
    coo = mst.tocoo()

    # Extract edge list
    edges = np.array([coo.row, coo.col]).T

    # Produce final graph - this also takes care of some fixing
    G = edges_to_graph(edges, nodes=np.unique(cl_edges.flatten()),
                       drop_disconnected=drop_disconnected, fix_tree=True)

    # At this point nodes are labeled by index of the cluster
    # Let's give them a "vertex_id" property mapping back to the
    # first vertex in that cluster
    if vertex_map:
        mapping = {i: l[0] for i, l in enumerate(clusters)}
        nx.set_node_attributes(G, mapping, name="vertex_id")

    if output == 'graph':
        return G

    # Generate SWC
    swc = make_swc(G, cl_coords)

    # Add vertex ID column if requested
    if vertex_map:
        swc['vertex_id'] = swc.node_id.map(mapping)

    if output == 'both':
        return swc, G

    return swc
Example #2
0
def perform_remap(a, relabel_map):
  remapped_a = fastremap.remap(a, relabel_map, preserve_missing_labels=True) 
  return remapped_a
Example #3
0
def by_edge_collapse(mesh, shape_weight=1, sample_weight=0.1, output='swc',
                     drop_disconnected=False, progress=True):
    """Skeletonize a (contracted) mesh by collapsing edges.

    Notes
    -----
    This algorithm (described in [1]) iteratively collapses edges that are part
    of a face until no more faces are left. Edges are chosen based on a cost
    function that penalizes collapses that would change the shape of the object
    or would introduce long edges.

    This is somewhat sensitive to the dimensions of the input mesh: too large
    and you might experience slow-downs or numpy OverflowErrors; too low and
    you might get skeletons that don't quite match the mesh (e.g. too few nodes).
    If you experience either, try down- or up-scaling your mesh, respectively.

    Parameters
    ----------
    mesh :          mesh obj
                    The mesh to be skeletonize. Can an object that has
                    ``.vertices`` and ``.faces`` properties  (e.g. a
                    trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a
                    dictionary ``{'vertices': vertices, 'faces': faces}``.
    shape_weight :  float, optional
                    Weight for shape costs which penalize collapsing edges that
                    would drastically change the shape of the object.
    sample_weight : float, optional
                    Weight for sampling costs which penalize collapses that
                    would generate prohibitively long edges.
    output :        "swc" | "graph" | "both"
                    Determines the function's output. See ``Returns``.
    drop_disconnected : bool
                    If True, will drop disconnected nodes from the skeleton.
                    Note that this might result in empty skeletons.
    progress :      bool
                    If True, will show progress bar.

    Returns
    -------
    "swc" :         pandas.DataFrame
                    SWC representation of the skeleton.
    "graph" :       networkx.Graph
                    Graph representation of the skeleton.
    "both" :        tuple
                    Both of the above: ``(swc, graph)``.

    References
    ----------
    [1] Au OK, Tai CL, Chu HK, Cohen-Or D, Lee TY. Skeleton extraction by mesh
        contraction. ACM Transactions on Graphics (TOG). 2008 Aug 1;27(3):44.

    """
    assert output in ['swc', 'graph', 'both']

    mesh = make_trimesh(mesh, validate=False)

    # Shorthand faces and edges
    # We convert to arrays to (a) make a copy and (b) remove potential overhead
    # from these originally being trimesh TrackedArrays
    edges = np.array(mesh.edges_unique)
    verts = np.array(mesh.vertices)

    # For cost calculations we will normalise coordinates
    # This prevents getting ridiculuously large cost values ?e300
    # verts = (verts - verts.min()) / (verts.max() - verts.min())
    edge_lengths = np.sqrt(np.sum((verts[edges[:, 0]] - verts[edges[:, 1]])**2, axis=1))

    # Get a list of faces: [(edge1, edge2, edge3), ...]
    face_edges = np.array(mesh.faces_unique_edges)
    # Make sure these faces are unique, i.e. no [(e1, e2, e3), (e3, e2, e1)]
    face_edges = np.unique(np.sort(face_edges, axis=1), axis=0)

    # Shape cost initialisation:
    # Each vertex has a matrix Q which is used to determine the shape cost
    # of collapsing each node. We need to generate a matrix (Q) for each
    # vertex, then when we collapse two nodes, we can update using
    # Qj <- Qi + Qj, so the edges previously associated with vertex i
    # are now associated with vertex j.

    # For each edge, generate a matrix (K). K is made up of two sets of
    # coordinates in 3D space, a and b. a is the normalised edge vector
    # of edge(i,j) and b = a * <x/y/z coordinates of vertex i>
    #
    # The matrix K takes the form:
    #
    #        Kij = 0, -az, ay, -bx
    #              az, 0, -ax, -by
    #             -ay, ax, 0,  -bz

    edge_co0, edge_co1 = verts[edges[:, 0]], verts[edges[:, 1]]
    a = (edge_co1 - edge_co0) / edge_lengths.reshape(edges.shape[0], 1)
    # Note: It's a bit unclear to me whether the normalised edge vector should
    # be allowed to have negative values but I seem to be getting better
    # results if I use absolute values
    a = np.fabs(a)
    b = a * edge_co0

    # Bunch of zeros
    zero = np.zeros(a.shape[0])

    # Generate matrix K
    K = [[zero,    -a[:, 2], a[:, 1],    -b[:, 0]],
         [a[:, 2],  zero,    -a[:, 0],   -b[:, 1]],
         [-a[:, 1], a[:, 0], zero,       -b[:, 2]]]
    K = np.array(K)

    # Q for vertex i is then the sum of the products of (kT,k) for ALL edges
    # connected to vertex i:
    # Initialize matrix of correct shape
    Q_array = np.zeros((4, 4, verts.shape[0]), dtype=np.float128)

    # Generate (kT, K)
    kT = np.transpose(K, axes=(1, 0, 2))

    # To get the sum of the products in the correct format we have to
    # do some annoying transposes to get to (4, 4, len(edges))
    K_dot = np.matmul(K.T, kT.T).T

    # Iterate over all vertices
    for v in range(len(verts)):
        # Find edges that contain this vertex
        cond1 = edges[:, 0] == v
        cond2 = edges[:, 1] == v
        # Note that this does not take directionality of edges into account
        # Not sure if that's intended?

        # Get indices of these edges
        indices = np.where(cond1 | cond2)[0]

        # Get the products for all edges adjacent to mesh
        Q = K_dot[:, :, indices]
        # Sum over all edges
        Q = Q.sum(axis=2)
        # Add to Q array
        Q_array[:, :, v] = Q

    # Not sure if we are doing something wrong when calculating the Q array but
    # we end up having negative values which translate into negative scores.
    # This in turn is bad because we propagate that negative score when
    # collapsing edges which leads to a "zipper-effect" where nodes collapse
    # in sequence a->b->c->d until they hit some node with really high cost
    # Q_array -= Q_array.min()

    # Edge collapse:
    # Determining which edge to collapse is a weighted sum of the shape and
    # sampling cost. The shape cost of vertex i is Fa(p) = pT Qi p where p is
    # the coordinates of point p (vertex i here) in homogeneous representation.
    # The variable w from above is the value for the homogeneous 4th dimension.
    # T denotes transpose of matrix.
    # The shape cost of collapsing the edge Fa(i,j) = Fi(vj) + Fj(vj).
    # vi and vj being the coordinates of the vertex in homogeneous representation
    # (p in equation before)
    # The sampling cost penalises edge collapses that generate overly long edges,
    # based on the distance traveled by all edges to vi, when vi is merged with
    # vj. (Eq. 7 in paper)
    # You cannot collapse an edge (i -> j) if k is a common adjacent vertex of
    # both i and j, but (i/j/k) is not a face.
    # We will set the cost of these edges to infinity.

    # Now work out the shape cost of collapsing each node (eq. 7)
    # First get coordinates of the first node of each edge
    # Note that in Nik's implementation this was the second node
    p = verts[edges[:, 0]]

    # Append weight factor
    w = 1
    p = np.append(p, np.full((p.shape[0], 1), w), axis=1)

    this_Q1 = Q_array[:, :, edges[:, 0]]
    this_Q2 = Q_array[:, :, edges[:, 1]]

    F1 = np.einsum('ij,kji->ij', p, this_Q1)[:, [0, 1]]
    F2 = np.einsum('ij,kji->ij', p, this_Q2)[:, [0, 1]]

    # Calculate and append shape cost
    F = np.append(F1, F2, axis=1)
    shape_cost = np.sum(F, axis=1)

    # Sum lengths of all edges associated with a given vertex
    # This is easiest by generating a sparse matrix from the edges
    # and then summing by row
    adj = scipy.sparse.coo_matrix((edge_lengths,
                                   (edges[:, 0], edges[:, 1])),
                                  shape=(verts.shape[0], verts.shape[0]))

    # This makes sure the matrix is symmetrical, i.e. a->b == a<-b
    # Note that I'm not sure whether this is strictly necessary but it really
    # can't hurt
    adj = adj + adj.T

    # Get the lengths associated with each vertex
    verts_lengths = adj.sum(axis=1)

    # We need to flatten this (something funny with summing sparse matrices)
    verts_lengths = np.array(verts_lengths).flatten()

    # Map the sum of vertex lengths onto edges (as per first vertex in edge)
    ik_edge = verts_lengths[edges[:, 0]]

    # Calculate sampling cost
    sample_cost = edge_lengths * (ik_edge - edge_lengths)

    # Determine which edge to collapse and collapse it
    # Total Cost - weighted sum of shape and sample cost, equation 8 in paper
    F_T = shape_cost * shape_weight + sample_cost * sample_weight

    # Now start collapsing edges one at a time
    face_count = face_edges.shape[0]  # keep track of face counts for progress bar
    is_collapsed = np.full(edges.shape[0], False)
    keep = np.full(edges.shape[0], False)
    with tqdm(desc='Collapsing edges', total=face_count, disable=progress is False) as pbar:
        while face_edges.size:
            # Uncomment to get a more-or-less random edge collapse
            # F_T[:] = 0

            # Update progress bar
            pbar.update(face_count - face_edges.shape[0])
            face_count = face_edges.shape[0]

            # This has to come at the beginning of the loop
            # Set cost of collapsing edges without faces to infinite
            F_T[keep] = np.inf
            F_T[is_collapsed] = np.inf

            # Get the edge that we want to collapse
            collapse_ix = np.argmin(F_T)
            # Get the vertices this edge connects
            u, v = edges[collapse_ix]
            # Get all edges that contain these vertices:
            # First, edges that are (uv, x)
            connects_uv = np.isin(edges[:, 0], [u, v])
            # Second, check if any (uv, x) edges are (uv, uv)
            connects_uv[connects_uv] = np.isin(edges[connects_uv, 1], [u, v])

            # Remove uu and vv edges
            uuvv = edges[:, 0] == edges[:, 1]
            connects_uv = connects_uv & ~uuvv
            # Get the edge's indices
            clps_edges = np.where(connects_uv)[0]

            # Now find find the faces the collapsed edge is part of
            # Note: splitting this into three conditions is marginally faster than
            # np.any(np.isin(face_edges, clps_edges), axis=1)
            uv0 = np.isin(face_edges[:, 0], clps_edges)
            uv1 = np.isin(face_edges[:, 1], clps_edges)
            uv2 = np.isin(face_edges[:, 2], clps_edges)
            has_uv = uv0 | uv1 | uv2

            # If these edges do not have adjacent faces anymore
            if not np.any(has_uv):
                # Track this edge as a keeper
                keep[clps_edges] = True
                continue

            # Get the collapsed faces [(e1, e2, e3), ...] for this edge
            clps_faces = face_edges[has_uv]

            # Remove the collapsed faces
            face_edges = face_edges[~has_uv]

            # Track these edges as collapsed
            is_collapsed[clps_edges] = True

            # Get the adjacent edges (i.e. non-uv edges)
            adj_edges = clps_faces[~np.isin(clps_faces, clps_edges)].reshape(clps_faces.shape[0], 2)

            # We have to do some sorting and finding unique edges to make sure
            # remapping is done correctly further down
            # NOTE: Not sure we really need this, so leaving it out for now
            # adj_edges = np.unique(np.sort(adj_edges, axis=1), axis=0)

            # We need to keep track of changes to the adjacent faces
            # Basically each face in (i, j, k) will be reduced to one edge
            # which points from u -> v
            # -> replace occurrences of loosing edge with winning edge
            for win, loose in adj_edges:
                if fastremap:
                    face_edges = fastremap.remap(face_edges, {loose: win},
                                                 preserve_missing_labels=True,
                                                 in_place=True)
                else:
                    face_edges[face_edges == loose] = win
                is_collapsed[loose] = True

            # Replace occurrences of first node u with second node v
            if fastremap:
                edges = fastremap.remap(edges, {u: v},
                                        preserve_missing_labels=True,
                                        in_place=True)
            else:
                edges[edges == u] = v

            # Add shape cost of u to shape costs of v
            Q_array[:, :, v] += Q_array[:, :, u]

            # Determine which edges require update of costs:
            # In theory we only need to update costs for edges that are
            # associated with vertices v and u (which now also v)
            has_v = (edges[:, 0] == v) | (edges[:, 1] == v)

            # Uncomment to temporarily force updating costs for all edges
            # has_v[:] = True

            # Update shape costs
            this_Q1 = Q_array[:, :, edges[has_v, 0]]
            this_Q2 = Q_array[:, :, edges[has_v, 1]]

            F1 = np.einsum('ij,kji->ij', p[edges[has_v, 0]], this_Q1)[:, [0, 1]]
            F2 = np.einsum('ij,kji->ij', p[edges[has_v, 1]], this_Q2)[:, [0, 1]]

            F = np.append(F1, F2, axis=1)
            new_shape_cost = np.sum(F, axis=1)

            # Update sum of incoming edge lengths
            # Technically we would have to recalculate lengths of adjacent edges
            # every time but we will take the cheap way out and simply add them up
            verts_lengths[v] += verts_lengths[u]
            # Update sample costs for edges associated with v
            ik_edge = verts_lengths[edges[has_v, 0]]
            new_sample_cost = edge_lengths[has_v] * (ik_edge - edge_lengths[has_v])

            F_T[has_v] = new_shape_cost * shape_weight + new_sample_cost * sample_weight

    # After the edge collapse, the edges are garbled - I have yet to figure out
    # why and whether that can be prevented. However the vertices in those
    # edges are correct and so we just need to reconstruct their connectivity
    # by extracting a minimum spanning tree over the mesh.
    corrected_edges = mst_over_mesh(mesh, edges[keep].flatten())

    # Generate graph
    G = edges_to_graph(corrected_edges, vertices=mesh.vertices, fix_tree=True, weight=False,
                       drop_disconnected=True)

    if output == 'graph':
        return G

    swc = make_swc(G, mesh)

    if output == 'both':
        return (G, swc)

    return swc
Example #4
0
def merge_vertices(mesh, dist='auto', inplace=False):
    """Merge vertices closer than a given distance.

    Parameters
    ----------
    mesh :      trimesh.Trimesh
                Mesh to merge vertices on.
    dist :      "auto" | number
                Distance at which to merge vertices. If "auto" will use
                ``mesh.edges_unique_length.mean() / 100``.
    inplace :   bool
                If True will modify the original mesh.

    Returns
    -------
    trimesh.Trimesh

    """
    assert isinstance(mesh, tm.Trimesh)

    if not inplace:
        mesh = mesh.copy()

    # Generate KDTree
    tree = sp.spatial.cKDTree(mesh.vertices)

    if dist == 'auto':
        dist = mesh.edges_unique_length.mean() / 100

    # Query tree
    pairs = tree.query_pairs(dist)

    # Facilitate remapping by removing extra steps: A->B->C to A->C, B->C
    G = nx.Graph()
    G.add_edges_from(pairs)
    mapping = {
        n: list(c)[0]
        for c in nx.connected_components(G) for n in list(c)[1:]
    }

    with mesh._cache:
        # Update faces
        if fastremap:
            mesh.faces = fastremap.remap(mesh.faces,
                                         mapping,
                                         preserve_missing_labels=True,
                                         in_place=True)
        else:
            for k, v in mapping.items():
                mesh.faces[mesh.faces == k] = v

    # Remove dropped vertices
    remove = np.isin(np.arange(mesh.vertices.shape[0]), list(mapping.keys()))
    mesh.update_vertices(~remove)

    # Remove degenerate and duplicate faces
    mesh.remove_degenerate_faces()
    mesh.remove_duplicate_faces()

    # Fix normals
    mesh.fix_normals()

    return mesh
Example #5
0
def collapse_soma_skeleton(
    soma_verts,
    soma_pt,
    verts,
    edges,
    mesh_to_skeleton_map=None,
    collapse_index=None,
    return_filter=False,
    return_soma_ind=False,
):
    """function to adjust skeleton result to move root to soma_pt

    Parameters
    ----------
    soma_pt : numpy.array
        a 3 long vector of xyz locations of the soma (None to just remove duplicate )
    verts : numpy.array
        a Nx3 array of xyz vertex locations
    edges : numpy.array
        a Kx2 array of edges of the skeleton
    soma_d_thresh : float
        distance from soma_pt to collapse skeleton nodes
    mesh_to_skeleton_map : np.array
        a M long array of how each mesh index maps to a skeleton vertex
        (default None).  The function will update this as it collapses vertices to root.
    soma_mesh_indices : np.array
         a K long array of indices in the mesh that should be considered soma
         Any  skeleton vertex on these vertices will all be collapsed to root.
    return_filter : bool
        whether to return a list of which skeleton vertices were used in the end
        for the reduced set of skeleton vertices
    only_soma_component : bool
        whether to collapse only the skeleton connected component which is closest to the soma_pt
        (default True)
    return_soma_ind : bool
        whether to return which skeleton index that is the soma_pt

    Returns
    -------
    np.array
        verts, Px3 array of xyz skeleton vertices
    np.array
        edges, Qx2 array of skeleton edges
    (np.array)
        new_mesh_to_skeleton_map, returned if mesh_to_skeleton_map and soma_pt passed
    (np.array)
        used_vertices, if return_filter this contains the indices into the passed verts which the return verts is using
    int
        an index into the returned verts that is the root of the skeleton node, only returned if return_soma_ind is True

    """
    if soma_verts is not None:
        soma_pt_m = soma_pt.reshape(1, 3)
        if collapse_index is None:
            new_verts = np.vstack((verts, soma_pt_m))
            soma_i = verts.shape[0]
        else:
            new_verts = verts
            soma_i = collapse_index
            soma_verts = soma_verts[soma_verts != soma_i]

        edges_m = edges.copy()
        edges_m[np.isin(edges, soma_verts)] = soma_i

        simple_verts, simple_edges = utils.remove_unused_verts(new_verts, edges_m)
        good_edges = ~(simple_edges[:, 0] == simple_edges[:, 1])

        if mesh_to_skeleton_map is not None:
            consolidate_dict = {v: soma_i for v in soma_verts}
            new_index_dict, _ = utils.remap_dict(len(new_verts), consolidate_dict)
            new_index_dict[-1] = -1
            mesh_to_skeleton_map[np.isnan(mesh_to_skeleton_map)] = -1
            new_mesh_to_skeleton_map = fastremap.remap(
                mesh_to_skeleton_map, new_index_dict
            )

        output = [simple_verts, simple_edges[good_edges]]
        if mesh_to_skeleton_map is not None:
            output.append(new_mesh_to_skeleton_map)
        if return_filter:
            used_vertices = np.unique(edges_m.ravel())
            if collapse_index is None:
                # Remove the largest value which is soma_i
                used_vertices = used_vertices[:-1]
            output.append(used_vertices)
        if return_soma_ind:
            output.append(new_index_dict[soma_i])
        return output

    else:
        simple_verts, simple_edges = utils.remove_unused_verts(verts, edges)
        return simple_verts, simple_edges
def remap_segmentation(cv,
                       chunk_x,
                       chunk_y,
                       chunk_z,
                       mip=2,
                       overlap_vx=1,
                       time_stamp=None,
                       progress=False):
    ws_cv = CloudVolume(cv.meta.cloudpath,
                        mip=mip,
                        progress=progress,
                        fill_missing=cv.fill_missing)
    mip_diff = mip - cv.meta.watershed_mip

    mip_chunk_size = np.array(cv.meta.graph_chunk_size,
                              dtype=np.int) / np.array(
                                  [2**mip_diff, 2**mip_diff, 1])
    mip_chunk_size = mip_chunk_size.astype(np.int)

    offset = Vec(chunk_x, chunk_y, chunk_z) * mip_chunk_size
    bbx = Bbox(offset, offset + mip_chunk_size + overlap_vx)
    if cv.meta.chunks_start_at_voxel_offset:
        bbx += ws_cv.voxel_offset
    bbx = Bbox.clamp(bbx, ws_cv.bounds)

    seg = ws_cv[bbx][..., 0]

    if not np.any(seg):
        return seg

    sv_remapping, unsafe_dict = get_lx_overlapping_remappings(
        cv,
        chunk_x,
        chunk_y,
        chunk_z,
        time_stamp=time_stamp,
        progress=progress)

    seg = fastremap.mask_except(seg, list(sv_remapping.keys()), in_place=True)
    fastremap.remap(seg,
                    sv_remapping,
                    preserve_missing_labels=True,
                    in_place=True)

    for unsafe_root_id in tqdm(unsafe_dict.keys(),
                               desc="Unsafe Relabel",
                               disable=(not progress)):
        bin_seg = seg == unsafe_root_id

        if np.sum(bin_seg) == 0:
            continue

        l2_edges = []
        cc_seg = cc3d.connected_components(bin_seg)
        for i_cc in range(1, np.max(cc_seg) + 1):
            bin_cc_seg = cc_seg == i_cc

            overlaps = []
            overlaps.extend(np.unique(seg[-2, :, :][bin_cc_seg[-1, :, :]]))
            overlaps.extend(np.unique(seg[:, -2, :][bin_cc_seg[:, -1, :]]))
            overlaps.extend(np.unique(seg[:, :, -2][bin_cc_seg[:, :, -1]]))
            overlaps = np.unique(overlaps)

            linked_l2_ids = overlaps[np.in1d(overlaps,
                                             unsafe_dict[unsafe_root_id])]

            if len(linked_l2_ids) == 0:
                seg[bin_cc_seg] = 0
            elif len(linked_l2_ids) == 1:
                seg[bin_cc_seg] = linked_l2_ids[0]
            else:
                seg[bin_cc_seg] = linked_l2_ids[0]

                for i_l2_id in range(len(linked_l2_ids) - 1):
                    for j_l2_id in range(i_l2_id + 1, len(linked_l2_ids)):
                        l2_edges.append(
                            [linked_l2_ids[i_l2_id], linked_l2_ids[j_l2_id]])

        if len(l2_edges) > 0:
            g = nx.Graph()
            g.add_edges_from(l2_edges)

            ccs = nx.connected_components(g)

            for cc in ccs:
                cc_ids = np.sort(list(cc))
                seg[np.in1d(seg, cc_ids[1:]).reshape(seg.shape)] = cc_ids[0]

    return seg
Example #7
0
def l2_skeleton(root_id, refine=False, drop_missing=True,
                threads=10, progress=True, dataset='production', **kwargs):
    """Generate skeleton from L2 graph.

    Parameters
    ----------
    root_id  :          int | list of ints
                        Root ID(s) of the flywire neuron(s) you want to
                        skeletonize.
    refine :            bool
                        If True, will refine skeleton nodes by moving them in
                        the center of their corresponding chunk meshes.

    Only relevant if ``refine=True``:

    drop_missing :      bool
                        If True, will drop nodes that don't have a corresponding
                        chunk mesh. These are typically chunks that are very
                        small and dropping them might actually be benefitial.
    threads :           int
                        How many parallel threads to use for fetching the
                        chunk meshes. Reduce the number if you run into
                        ``HTTPErrors``. Only relevant if `use_flycache=False`.
    progress :          bool
                        Whether to show a progress bar.

    Returns
    -------
    skeleton :          navis.TreeNeuron
                        The extracted skeleton.

    Examples
    --------
    >>> from fafbseg import flywire
    >>> n = flywire.l2_skeleton(720575940614131061)

    """
    # TODO:
    # - drop duplicate nodes in unrefined skeleton
    # - use L2 graph to find soma: highest degree is typically the soma

    use_flycache = kwargs.get('use_flycache', False)

    if refine and use_flycache and dataset != 'production':
        raise ValueError('Unable to use fly cache to fetch L2 centroids for '
                         'sandbox dataset. Please set `use_flycache=False`.')

    if navis.utils.is_iterable(root_id):
        nl = []
        for id in navis.config.tqdm(root_id, desc='Skeletonizing',
                                    disable=not progress, leave=False):
            n = l2_skeleton(id, refine=refine, drop_missing=drop_missing,
                            threads=threads, progress=progress, dataset=dataset)
            nl.append(n)
        return navis.NeuronList(nl)

    # Get the cloudvolume
    vol = parse_volume(dataset)

    # Hard-coded datastack names
    ds = {"production": "flywire_fafb_production",
          "sandbox": "flywire_fafb_sandbox"}
    # Note that the default server url is https://global.daf-apis.com/info/
    client = FrameworkClient(ds.get(dataset, dataset))

    # Load the L2 graph for given root ID
    # This is a (N,2) array of edges
    l2_eg = np.array(client.chunkedgraph.level2_chunk_graph(root_id))

    # Drop duplicate edges
    l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0)

    # Unique L2 IDs
    l2_ids = np.unique(l2_eg)

    # ID to index
    l2dict = {l2: ii for ii, l2 in enumerate(l2_ids)}

    # Remap edge graph to indices
    eg_arr_rm = fastremap.remap(l2_eg, l2dict)

    coords = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids]
    coords = np.vstack(coords)

    # This turns the graph into a hierarchal tree by removing cycles and
    # ensuring all edges point towards a root
    if sk.__version_vector__[0] < 1:
        G = sk.skeletonizers.edges_to_graph(eg_arr_rm)
        swc = sk.skeletonizers.make_swc(G, coords=coords)
    else:
        G = sk.skeletonize.utils.edges_to_graph(eg_arr_rm)
        swc = sk.skeletonize.utils.make_swc(G, coords=coords, reindex=False)

    # Convert to Euclidian space
    # Dimension of a single chunk
    ch_dims = chunks_to_nm([1, 1, 1], vol) - chunks_to_nm([0, 0, 0], vol)
    ch_dims = np.squeeze(ch_dims)

    xyz = swc[['x', 'y', 'z']].values
    swc[['x', 'y', 'z']] = chunks_to_nm(xyz, vol) + ch_dims / 2

    if refine:
        if use_flycache:
            token = get_chunkedgraph_secret()
            centroids = spine.flycache.get_L2_centroids(l2_ids,
                                                        token=token,
                                                        progress=progress)

            # Drop missing (i.e. [0,0,0]) meshes
            centroids = {k: v for k, v in centroids.items() if v != [0, 0, 0]}
        else:
            # Get the centroids
            centroids = get_L2_centroids(l2_ids, vol, threads=threads, progress=progress)

        new_co = {l2dict[k]: v for k, v in centroids.items()}

        # Map refined coordinates onto the SWC
        has_new = swc.node_id.isin(new_co)
        swc.loc[has_new, 'x'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][0])
        swc.loc[has_new, 'y'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][1])
        swc.loc[has_new, 'z'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][2])

        # Turn into a proper neuron
        tn = navis.TreeNeuron(swc, id=root_id, units='1 nm')

        # Drop nodes that are still at their unrefined chunk position
        if drop_missing:
            tn = navis.remove_nodes(tn, swc.loc[~has_new, 'node_id'].values)
    else:
        tn = navis.TreeNeuron(swc, id=root_id, units='1 nm')

    return tn
Example #8
0
  def download(
    self, bbox, mip=None, 
    parallel=None, segids=None,
    preserve_zeros=False,
    agglomerate=False, timestamp=None,
    stop_layer=None
  ):
    """
    Downloads base segmentation and optionally agglomerates
    labels based on information in the graph server.

    bbox: specifies cutout to fetch
    mip: which resolution level to get (default self.mip)
    parallel: what parallel level to use (default self.parallel)

    agglomerate: if true, remap all watershed ids in the volume
      and return a flat segmentation.

    if agglomerate is true these options are available:

    timestamp: (agglomerate only) get the roots from this date and time
      formats accepted:
        int: unix timestamp
        datetime: self explainatory
        string: ISO 8601 date
    stop_layer: (agglomerate only) (int) if specified, return the lowest 
      parent at or above that layer. If not specified, go all the way 
      to the root id. 
        Layer 1: Watershed
        Layer 2: Within-Chunk Agglomeration
        Layer 2+: Between chunk interconnections (skip connections possible)

    if agglomerate is false, these other options come into play:

    segids: agglomerate the leaves of these segids from the graph 
      server and label them with the given segid.
    preserve_zeros: If segids is not None:
      False: mask other segids with zero
      True: mask other segids with the largest integer value
        contained by the image data type and leave zero as is.

    Returns: img as a VolumeCutout
    """
    if type(bbox) is Vec:
      bbox = Bbox(bbox, bbox+1)
    
    bbox = Bbox.create(
      bbox, context=self.bounds, 
      bounded=self.bounded, 
      autocrop=self.autocrop
    )
  
    if bbox.subvoxel():
      raise exceptions.EmptyRequestException("Requested {} is smaller than a voxel.".format(bbox))

    if mip is None:
      mip = self.mip

    mip0_bbox = self.bbox_to_mip(bbox, mip=mip, to_mip=0)
    # Only ever necessary to make requests within the bounding box
    # to the server. We can fill black in other situations.
    mip0_bbox = bbox.intersection(self.meta.bounds(0), mip0_bbox)

    img = super(CloudVolumeGraphene, self).download(bbox, mip=mip, parallel=parallel)

    if agglomerate:
      img = self.agglomerate_cutout(img, timestamp=timestamp, stop_layer=stop_layer)
      return VolumeCutout.from_volume(self.meta, mip, img, bbox)

    if segids is None:
      return img

    segids = list(toiter(segids))

    remapping = {}
    for segid in segids:
      leaves = self.get_leaves(segid, mip0_bbox, 0)
      remapping.update({ leaf: segid for leaf in leaves })
    
    img = fastremap.remap(img, remapping, preserve_missing_labels=True, in_place=True)

    mask_value = 0
    if preserve_zeros:
      mask_value = np.inf
      if np.issubdtype(self.dtype, np.integer):
        mask_value = np.iinfo(self.dtype).max

      segids.append(0)

    img = fastremap.mask_except(img, segids, in_place=True, value=mask_value)

    return VolumeCutout.from_volume(
      self.meta, mip, img, bbox 
    )
Example #9
0
 def agglomerate_cutout(self, img, timestamp=None, stop_layer=None):
   """Remap a graphene volume to its latest root ids. This creates a flat segmentation."""
   labels = fastremap.unique(img)
   roots = self.get_roots(labels, timestamp=timestamp, binary=True, stop_layer=stop_layer)
   mapping = { segid: root for segid, root in zip(labels, roots) }
   return fastremap.remap(img, mapping, preserve_missing_labels=True, in_place=True)
Example #10
0
def by_vertex_clusters(mesh, sampling_dist, cluster_pos='median', progress=True):
    """Skeletonize a (contracted) mesh by clustering vertices.

    The algorithm traverses the mesh graph and groups vertices together that
    are within a given distance to each other. This uses the geodesic
    (along-the-mesh) distance, not simply the Eucledian distance. Subsequently
    these groups of vertices are collapsed and re-connected respecting the
    topology of the input mesh.

    The graph traversal is fast and scales well, so this method is well suited
    for meshes with lots of vertices. On the downside: this implementation is
    not very clever and you might have to play around with the parameters
    (mostly ``sampling_dist``) to get decent results.

    Parameters
    ----------
    mesh :          mesh obj
                    The mesh to be skeletonize. Can an object that has
                    ``.vertices`` and ``.faces`` properties  (e.g. a
                    trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a
                    dictionary ``{'vertices': vertices, 'faces': faces}``.
    sampling_dist : float | int
                    Maximal distance at which vertices are clustered. This
                    parameter should be tuned based on the resolution of your
                    mesh (see Examples).
    cluster_pos :   "median" | "center"
                    How to determine the x/y/z coordinates of the collapsed
                    vertex clusters (i.e. the skeleton's nodes)::

                      - "median": Use the vertex closest to cluster's center of
                        mass.
                      - "center": Use the center of mass. This makes for smoother
                        skeletons but can lead to nodes outside the mesh.
    progress :      bool
                    If True, will show progress bar.

    Examples
    --------
    >>> import skeletor as sk
    >>> mesh = sk.example_mesh()
    >>> cont = sk.pre.contract(mesh, epsilon=0.1)
    >>> skel = sk.skeletonize.vertex_cluster(cont)
    >>> skel.mesh = mesh

    Returns
    -------
    skeletor.Skeleton
                    Holds results of the skeletonization and enables quick
                    visualization.

    """
    assert cluster_pos in ['center', 'median']

    mesh = make_trimesh(mesh, validate=False)

    # Produce weighted edges
    edges = np.concatenate((mesh.edges_unique,
                            mesh.edges_unique_length.reshape(mesh.edges_unique.shape[0], 1)),
                           axis=1)

    # Generate Graph (must be undirected)
    G = nx.Graph()
    G.add_weighted_edges_from(edges)

    # Run the graph traversal that groups vertices into spatial clusters
    not_visited = set(G.nodes)
    seen = set()
    clusters = []
    to_visit = len(not_visited)
    with tqdm(desc='Clustering', total=len(not_visited), disable=progress is False) as pbar:
        while not_visited:
            # Pick a random node
            start = not_visited.pop()
            # Get all nodes in the geodesic vicinity
            cl, seen = dfs(G, n=start, dist_traveled=0,
                           max_dist=sampling_dist, seen=seen)
            cl = set(cl)

            # Append this cluster and track visited/not-visited nodes
            clusters.append(cl)
            not_visited = not_visited - cl

            # Update  progress bar
            pbar.update(to_visit - len(not_visited))
            to_visit = len(not_visited)

    # `clusters` is a list of sets -> let's turn it into list of arrays
    clusters = [np.array(list(c)).astype(int) for c in clusters]

    # Get positions of clusters
    if cluster_pos == 'center':
        # Get the center of each cluster
        cl_coords = np.array([np.mean(mesh.vertices[c], axis=0) for c in clusters])
    elif cluster_pos == 'median':
        # Get the node that's closest to to the clusters center
        cl_coords = []
        for c in clusters:
            cnt = np.mean(mesh.vertices[c], axis=0)
            cnt_dist = np.sum(np.fabs(mesh.vertices[c] - cnt), axis=1)
            median = mesh.vertices[c][np.argmin(cnt_dist)]
            cl_coords.append(median)
        cl_coords = np.array(cl_coords)

    # Generate edges
    cl_edges = np.array(mesh.edges_unique)
    if fastremap:
        mapping = {n: i for i, l in enumerate(clusters) for n in l}
        cl_edges = fastremap.remap(cl_edges, mapping, preserve_missing_labels=False, in_place=True)
    else:
        for i, c in enumerate(clusters):
            cl_edges[np.isin(cl_edges, c)] = i

    # Remove directionality from cluster edges
    cl_edges = np.sort(cl_edges, axis=1)

    # Get unique edges
    cl_edges = np.unique(cl_edges, axis=0)

    # Calculate edge lengths
    co1 = cl_coords[cl_edges[:, 0]]
    co2 = cl_coords[cl_edges[:, 1]]
    cl_edge_lengths = np.sqrt(np.sum((co1 - co2)**2, axis=1))

    # Produce adjacency matrix from edges and edge lengths
    n_clusters = len(clusters)
    adj = scipy.sparse.coo_matrix((cl_edge_lengths,
                                   (cl_edges[:, 0], cl_edges[:, 1])),
                                  shape=(n_clusters, n_clusters))

    # The cluster graph likely still contain cycles, let's get rid of them using
    # a minimum spanning tree
    mst = scipy.sparse.csgraph.minimum_spanning_tree(adj,
                                                     overwrite=True)

    # Turn into COO matrix
    coo = mst.tocoo()

    # Extract edge list
    edges = np.array([coo.row, coo.col]).T

    # Produce final graph - this also takes care of some fixing
    G = edges_to_graph(edges, nodes=np.unique(cl_edges.flatten()),
                       drop_disconnected=False, fix_tree=True)

    # Generate a mesh vertex -> skeleton vertex map
    # Note that nodes are labeled by index of the cluster
    vertex_to_node_map = [i for i, cl in enumerate(clusters) for n in cl]

    # Generate SWC
    swc, new_ids = make_swc(G, cl_coords, reindex=True, validate=False)

    # Update mesh map
    vertex_to_node_map = np.array([new_ids[n] for n in vertex_to_node_map])

    return Skeleton(swc=swc, mesh=mesh, mesh_map=vertex_to_node_map,
                    method='vertex_clusters')
Example #11
0
def collapse_zero_length_edges(vertices,
                               edges,
                               root,
                               radius,
                               mesh_to_skel_map,
                               mesh_index,
                               node_mask,
                               vertex_properties={}):
    "Remove zero length edges from a skeleton"

    zl = np.linalg.norm(vertices[edges[:, 0]] - vertices[edges[:, 1]],
                        axis=1) == 0
    if not np.any(zl):
        return vertices, edges, root, radius, mesh_to_skel_map, mesh_index, node_mask, vertex_properties

    consolidate_dict = {x[0]: x[1] for x in edges[zl]}
    # Compress multiple zero edges in a row
    while np.any(
            np.isin(np.array(list(consolidate_dict.keys())),
                    np.array(list(consolidate_dict.values())))):
        all_keys = np.array(list(consolidate_dict.keys()))
        dup_keys = np.flatnonzero(
            np.isin(all_keys, np.array(list(consolidate_dict.values()))))
        first_key = all_keys[dup_keys[0]]
        first_val = consolidate_dict.pop(first_key)
        for ii, jj in consolidate_dict.items():
            if jj == first_key:
                consolidate_dict[ii] = first_val

    new_index_dict, node_filter = remap_dict(len(vertices), consolidate_dict)

    new_vertices = vertices[node_filter]
    new_edges = fastremap.remap(edges, new_index_dict)
    new_edges = new_edges[new_edges[:, 0] != new_edges[:, 1]]

    if mesh_to_skel_map is not None:
        new_index_dict[-1] = -1
        new_mesh_to_skel_map = fastremap.remap(mesh_to_skel_map,
                                               new_index_dict)
    else:
        new_mesh_to_skel_map = None

    new_root = new_index_dict.get(root, root)
    if radius is not None:
        new_radius = radius[node_filter]
    else:
        new_radius = None

    if mesh_index is not None:
        new_mesh_index = mesh_index[node_filter]
    else:
        new_mesh_index = None

    if node_mask is not None:
        new_node_mask = node_mask[node_filter]
    else:
        new_node_mask = None

    new_vp = {}
    for vp, val in vertex_properties.items():
        try:
            new_vp[vp] = val[node_filter]
        except:
            pass

    return new_vertices, new_edges, new_root, new_radius, new_mesh_to_skel_map, new_mesh_index, new_node_mask, new_vp
Example #12
0
    def download(self,
                 bbox,
                 mip=None,
                 parallel=None,
                 segids=None,
                 preserve_zeros=False):
        """
    Downloads base segmentation and optionally agglomerates
    labels based on information in the graph server.

    bbox: specifies cutout to fetch
    mip: which resolution level to get (default self.mip)
    parallel: what parallel level to use (default self.parallel)

    segids: agglomerate the leaves of these segids from the graph 
      server and label them with the given segid.
    preserve_zeros: If segids is not None:
      False: mask other segids with zero
      True: mask other segids with the largest integer value
        contained by the image data type and leave zero as is.

    Returns: img
    """
        if type(bbox) is Vec:
            bbox = Bbox(bbox, bbox + 1)

        bbox = Bbox.create(bbox,
                           context=self.bounds,
                           bounded=self.bounded,
                           autocrop=self.autocrop)

        if bbox.subvoxel():
            raise exceptions.EmptyRequestException(
                "Requested {} is smaller than a voxel.".format(bbox))

        if mip is None:
            mip = self.mip

        mip0_bbox = self.bbox_to_mip(bbox, mip=mip, to_mip=0)
        # Only ever necessary to make requests within the bounding box
        # to the server. We can fill black in other situations.
        mip0_bbox = bbox.intersection(self.meta.bounds(0), mip0_bbox)

        img = super(CloudVolumeGraphene, self).download(bbox,
                                                        mip=mip,
                                                        parallel=parallel)

        if segids is None:
            return img

        segids = list(toiter(segids))

        remapping = {}
        for segid in segids:
            leaves = self.get_leaves(segid, mip0_bbox, 0)
            remapping.update({leaf: segid for leaf in leaves})

        img = fastremap.remap(img,
                              remapping,
                              preserve_missing_labels=True,
                              in_place=True)

        mask_value = 0
        if preserve_zeros:
            mask_value = np.inf
            if np.issubdtype(self.dtype, np.integer):
                mask_value = np.iinfo(self.dtype).max

            segids.append(0)

        img = fastremap.mask_except(img,
                                    segids,
                                    in_place=True,
                                    value=mask_value)

        return VolumeCutout.from_volume(self.meta, mip, img, bbox)