Ejemplo n.º 1
0
def in_volume_ncoll(points: np.ndarray,
                    volume: Volume,
                    n_rays: Optional[int] = 3) -> Sequence[bool]:
    """Use ncollpyde to test if points are within a given volume."""
    if isinstance(n_rays, type(None)):
        n_rays = 3

    if not isinstance(n_rays, (int, np.integer)):
        raise TypeError(f'n_rays must be integer, got "{type(n_rays)}"')

    if n_rays <= 0:
        raise ValueError('n_rays must be > 0')

    coll = ncollpyde.Volume(volume.vertices, volume.faces, n_rays=n_rays)

    return coll.contains(points)
Ejemplo n.º 2
0
def _test_single_edge(l1, l2, seg_id, vol):
    """Test single edge.

    Parameters
    ----------
    l1, l2 :    int | float
                Locations of the nodes connected by the edge
    seg_id :    int
                Segment ID of one of the nodes.
    vol :       cloudvolume.CloudVolume

    Returns
    -------
    bool

    """
    # Get the bounding box
    bbox = np.array([l1, l2])
    bbox = np.array([bbox.min(axis=0), bbox.max(axis=0)]).T

    # Get the mesh
    mesh = get_mesh(seg_id, bbox=bbox, vol=vol)

    # No mesh means that edge is most likely True
    if not mesh:
        return True

    # Prepare raycasting
    coll = ncollpyde.Volume(np.array(mesh.vertices, dtype=float, order='C'),
                            np.array(mesh.faces, dtype=np.int32, order='C'))

    # Get intersections
    l1 = l1.reshape(1, 3)
    l2 = l2.reshape(1, 3)
    inter_ix, inter_xyz, is_inside = coll.intersections(l1, l2)

    # If not intersections treat this edge as True
    if not inter_xyz.any():
        return True

    return True
Ejemplo n.º 3
0
def get_radius_ray(swc, mesh, n_rays=20, aggregate='mean', projection='sphere',
                   fallback='knn'):
    """Extract radii using ray casting.

    Parameters
    ----------
    swc :           pandas.DataFrame
                    SWC table
    mesh :          trimesh.Trimesh
    n_rays :        int
                    Number of rays to cast for each node.
    aggregate :     "mean" | "median" | "max" | "min" | "percentile75"
                    Function used to aggregate radii for over all intersections
                    for a given node.
    projection :    "sphere" | "tangents"
                    Whether to cast rays in a sphere around each node or in a
                    circle orthogonally to the node's tangent vector.
    fallback :      "knn" | None | number
                    If a point is outside or right on the surface of the mesh
                    the raycasting will return nonesense results. We can either
                    ignore those cases (``None``), assign a arbitrary number or
                    we can fall back to radii from k-nearest-neighbors (``knn``).

    Returns
    -------
    radii :     np.ndarray
                Corresponds to input coords.

    """
    agg_map = {'mean': np.mean, 'max': np.max, 'min': np.min,
               'median': np.median, 'percentile75': lambda x: np.percentile(x, 75)}
    assert aggregate in agg_map
    agg_func = agg_map[aggregate]

    assert projection in ['sphere', 'tangents']
    assert (fallback == 'knn') or isinstance(fallback, numbers.Number) or isinstance(fallback, type(None))

    # Get max dimension of mesh
    dim = (swc[['x', 'y', 'z']].max() - swc[['x', 'y', 'z']].min()).values
    radius = max(dim)

    # Vertices for each point on the circle
    points = swc[['x', 'y', 'z']].values
    sources = np.repeat(points, n_rays, axis=0)

    if projection == 'sphere':
        # Repeat points n_rays times
        sources = np.repeat(points, n_rays, axis=0)

        # Get (random) points on a sphere and scale by radius
        targets = fibonacci_sphere(n_rays, randomize=True) * radius
        # Reshape to match sources
        targets = np.tile(targets,
                          (points.shape[0], 1))
        # Offset onto sources
        targets += sources
    else:
        tangents, normals, binormals = frenet_frames(swc)

        v = np.arange(n_rays,
                      dtype=np.float) / n_rays * 2 * np.pi

        all_cx = (radius * -1. * np.tile(np.cos(v), points.shape[0]).reshape((n_rays, points.shape[0]), order='F')).T
        cx_norm = (all_cx[:, :, np.newaxis] * normals[:, np.newaxis, :]).reshape(sources.shape)

        all_cy = (radius * np.tile(np.sin(v), points.shape[0]).reshape((n_rays, points.shape[0]), order='F')).T
        cy_norm = (all_cy[:, :, np.newaxis] * binormals[:, np.newaxis, :]).reshape(sources.shape)

        targets = sources + cx_norm + cy_norm

    # Initialize ncollpyde Volume
    coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)

    # Get intersections: `ix` points to index of line segment; `loc` is the
    #  x/y/z coordinate of the intersection and `is_backface` is True if
    # intersection happened at the inside of a mesh
    ix, loc, is_backface = coll.intersections(sources, targets)

    # Remove intersections with front faces
    # For some reason this reduces the number of intersections to 0 for many
    # points
    #ix = ix[~is_backface]
    #loc = loc[~is_backface]

    # Calculate intersection distances
    dist = np.sqrt(np.sum((sources[ix] - loc)**2, axis=1))

    # Map from `ix` back to index of original point
    org_ix = (ix / n_rays).astype(int)

    # Split by original index
    split_ix = np.where(org_ix[:-1] - org_ix[1:])[0]
    split = np.split(dist, split_ix)

    # Aggregate over each original ix
    final_dist = np.zeros(points.shape[0])
    for l, i in zip(split, np.unique(org_ix)):
        final_dist[i] = agg_func(l)

    if not isinstance(fallback, type(None)):
        # See if any needs fixing
        inside = coll.contains(points)
        is_zero = final_dist == 0
        needs_fix = ~inside | is_zero

        if any(needs_fix):
            if isinstance(fallback, numbers.Number):
                final_dist[needs_fix] = fallback
            elif fallback == 'knn':
                final_dist[needs_fix] = get_radius_kkn(points[needs_fix], mesh, aggregate=aggregate)

    return final_dist
Ejemplo n.º 4
0
def remove_hairs(s, mesh=None, inplace=False):
    """Remove "hairs" that sometimes occurr along the backbone.

    Works by finding terminal twigs that consist of only a single node. We will
    then remove those that are within line of sight of their parent.

    Note that this is currently not used for clean up as it does not work very
    well: removes as many correct hairs as genuine small branches.

    Parameters
    ----------
    s :         skeletor.Skeleton
                Skeleton to clean up.
    mesh :      trimesh.Trimesh, optional
                Original mesh (e.g. before contraction). If not provided will
                use the mesh associated with ``s``.
    max_dist :  "auto" | int | float
                Maximum Eucledian distance allowed between leaf nodes for them
                to be considered for collapsing. If "auto", will use the length
                of the longest edge in skeleton as limit.
    inplace :   bool
                If False will make and return a copy of the skeleton. If True,
                will modify the `s` inplace.

    Returns
    -------
    SWC :       pandas.DataFrame
                SWC with line-of-sight twigs removed.

    """
    if isinstance(mesh, type(None)):
        mesh = s.mesh

    # Make a copy of the skeleton
    if not inplace:
        s = s.copy()

    # Find branch points
    pcount = s.swc[s.swc.parent_id >= 0].groupby('parent_id').size()
    bp = pcount[pcount > 1].index

    # Find terminal twigs
    twigs = s.swc[~s.swc.node_id.isin(s.swc.parent_id)]
    twigs = twigs[twigs.parent_id.isin(bp)]

    if twigs.empty:
        return s

    # Initialize ncollpyde Volume
    coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)

    # Remove twigs that aren't inside the volume
    twigs = twigs[coll.contains(twigs[['x', 'y', 'z']].values)]

    # Generate rays between all pairs and their parents
    sources = twigs[['x', 'y', 'z']].values
    targets = s.swc.set_index('node_id').loc[twigs.parent_id,
                                             ['x', 'y', 'z']].values

    # Get intersections: `ix` points to index of line segment; `loc` is the
    #  x/y/z coordinate of the intersection and `is_backface` is True if
    # intersection happened at the inside of a mesh
    ix, loc, is_backface = coll.intersections(sources, targets)

    # Find pairs of twigs with no intersection - i.e. with line of sight
    los = ~np.isin(np.arange(sources.shape[0]), ix)

    # To remove: have line of sight
    to_remove = twigs[los]

    s.swc = s.swc[~s.swc.node_id.isin(to_remove.node_id)].copy()

    # Update the mesh map
    mesh_map = getattr(s, 'mesh_map', None)
    if not isinstance(mesh_map, type(None)):
        for t in to_remove.itertuples():
            mesh_map[mesh_map == t.node_id] = t.parent_id

    # Reindex nodes
    s.reindex(inplace=True)

    return s
Ejemplo n.º 5
0
def drop_line_of_sight_twigs(s, mesh=None, max_dist='auto', inplace=False):
    """Collapse twigs that are in line of sight to each other.

    Note that this only removes 1 layer of twigs (i.e. only the actual leaf
    nodes). Nothing is stopping you from running this function recursively
    though.

    Also note that this function needs a rework because it does not take
    connected components into account and hence collapses things that were
    not meant to be connected.

    Parameters
    ----------
    s :         skeletor.Skeleton
                Skeleton to clean up.
    mesh :      trimesh.Trimesh, optional
                Original mesh (e.g. before contraction). If not provided will
                use the mesh associated with ``s``.
    max_dist :  "auto" | int | float
                Maximum Eucledian distance allowed between leaf nodes for them
                to be considered for collapsing. If "auto", will use the length
                of the longest edge in skeleton as limit.
    inplace :   bool
                If False will make and return a copy of the skeleton. If True,
                will modify the `s` inplace.

    Returns
    -------
    SWC :       pandas.DataFrame
                SWC with line-of-sight twigs removed.

    """
    # Make a copy of the SWC
    if not inplace:
        s = s.copy()

    # Add distance to parents
    s.swc['parent_dist'] = 0
    not_root = s.swc.parent_id >= 0
    co1 = s.swc.loc[not_root, ['x', 'y', 'z']].values
    co2 = s.swc.set_index('node_id').loc[s.swc.loc[not_root, 'parent_id'],
                                         ['x', 'y', 'z']].values
    s.swc.loc[not_root, 'parent_dist'] = np.sqrt(np.sum((co1 - co2)**2, axis=1))

    # If max dist is 'auto', we will use the longest child->parent edge in the
    # skeleton as limit
    if max_dist == 'auto':
        max_dist = s.swc.parent_dist.max()

    # Initialize ncollpyde Volume
    coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)

    # Find twigs
    twigs = s.swc[~s.swc.node_id.isin(s.swc.parent_id)]

    # Remove twigs that aren't inside the volume
    twigs = twigs[coll.contains(twigs[['x', 'y', 'z']].values)]

    # Generate rays between all pairs of twigs
    twigs_co = twigs[['x', 'y', 'z']].values
    sources = np.repeat(twigs_co, twigs.shape[0], axis=0)
    targets = np.tile(twigs_co, (twigs.shape[0], 1))

    # Keep track of indices
    pairs = np.stack((np.repeat(twigs.node_id, twigs.shape[0]),
                      np.tile(twigs.node_id, twigs.shape[0]))).T

    # If max distance, drop pairs that are too far appart
    if max_dist:
        d = scipy.spatial.distance.pdist(twigs_co)
        d = scipy.spatial.distance.squareform(d)
        is_close = d.flatten() <= max_dist
        pairs = pairs[is_close]
        sources = sources[is_close]
        targets = targets[is_close]

    # Drop self rays
    not_self = pairs[:, 0] != pairs[:, 1]
    sources, targets = sources[not_self], targets[not_self]
    pairs = pairs[not_self]

    # Get intersections: `ix` points to index of line segment; `loc` is the
    #  x/y/z coordinate of the intersection and `is_backface` is True if
    # intersection happened at the inside of a mesh
    ix, loc, is_backface = coll.intersections(sources, targets)

    # Find pairs of twigs with no intersection - i.e. with line of sight
    los = ~np.isin(np.arange(pairs.shape[0]), ix)

    # To collapse: have line of sight
    to_collapse = pairs[los]

    # Group into cluster we need to collapse
    G = nx.Graph()
    G.add_edges_from(to_collapse)
    clusters = nx.connected_components(G)

    # When collapsing the clusters, we need to decide which should be the
    # winning twig. For this we will use the twig lengths. In theory we ought to
    # be more fancy and ask for the distance to the root but that's more
    # expensive and it's unclear if it'll work any better.
    seg_lengths = twigs.set_index('node_id').parent_dist.to_dict()
    to_remove = []
    seen = set()
    for nodes in clusters:
        # We have to be careful not to generate chains here, e.g. A sees B,
        # B sees C, C sees D, etc. To prevent this, we will break up these
        # clusters into cliques and collapse them by order of size of the
        # cliques
        for cl in nx.find_cliques(nx.subgraph(G, nodes)):
            # Turn into set
            cl = set(cl)
            # Drop any node that has already been visited
            cl = cl - seen
            # Skip if less than 2 nodes left in clique
            if len(cl) < 2:
                continue
            # Add nodes to list of visited nodes
            seen = seen | cl
            # Sort by segment lenghts to find the loosing nodes
            loosers = sorted(cl, key=lambda x: seg_lengths[x])[:-1]
            to_remove += loosers

    # Drop the tips we flagged for removal and the new column we added
    s.swc = s.swc[~s.swc.node_id.isin(to_remove)].drop('parent_dist', axis=1)

    # Clean up node/vertex order
    s.reindex(inplace=True)

    return s
Ejemplo n.º 6
0
def recenter_vertices(s, mesh=None, inplace=False):
    """Move nodes that ended up outside the mesh back inside.

    Nodes can end up outside the original mesh e.g. if the mesh contraction
    didn't do a good job (most likely because of internal/degenerate faces that
    messed up the normals). This function rectifies this by snapping those nodes
    nodes back to the closest vertex and then tries to move them into the
    mesh's center. That second step is not guaranteed to work but at least you
    won't have any more nodes outside the mesh.

    Please note that if connected (!) nodes end up on the same position (i.e
    because they snapped to the same vertex), we will collapse them.

    Parameters
    ----------
    s :         skeletor.Skeleton
    mesh :      trimesh.Trimesh
                Original mesh.
    inplace :   bool
                If False will make and return a copy of the skeleton. If True,
                will modify the `s` inplace.

    Returns
    -------
    SWC :       pandas.DataFrame
                SWC with line-of-sight twigs removed.

    """
    if isinstance(mesh, type(None)):
        mesh = s.mesh

    # Copy skeleton
    if not inplace:
        s = s.copy()

    # Find nodes that are outside the mesh
    coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)
    outside = ~coll.contains(s.vertices)

    # Skip if all inside
    if not any(outside):
        return s

    # For each outside find the closest vertex
    tree = scipy.spatial.cKDTree(mesh.vertices)

    # Find nodes that are right on top of original vertices
    dist, ix = tree.query(s.vertices[outside])

    # We don't want to just snap them back to the closest vertex but try to find
    # the center. For this we will:
    # 1. Move each vertex inside the mesh by just a bit
    # 2. Cast a ray along the vertices' normals and find the opposite sides of the mesh
    # 3. Calculate the distance

    # Get the closest vertex...
    closest_vertex = mesh.vertices[ix]
    # .. and offset the vertex positions by just a bit so they "should" be
    # inside the mesh. In reality that doesn't always happen if the mesh is not
    # watertight
    vnormals = mesh.vertex_normals[ix]
    sources = closest_vertex - vnormals

    # Prepare rays to cast
    targets = sources - vnormals * 1e4

    # Cast rays
    ix, loc, is_backface = coll.intersections(sources, targets)

    # If no collisions
    if not any(loc):
        return s

    # Get half-vector
    halfvec = np.zeros(sources.shape)
    halfvec[ix] = (loc - closest_vertex[ix]) / 2

    # Offset vertices
    final_pos = closest_vertex + halfvec

    # Keep only those that are properly inside the mesh and fall back to the
    # closest vertex if that's not the case
    now_inside = coll.contains(final_pos)
    final_pos[~now_inside] = closest_vertex[~now_inside]

    # Replace coordinates
    s.swc.loc[outside, 'x'] = final_pos[:, 0]
    s.swc.loc[outside, 'y'] = final_pos[:, 1]
    s.swc.loc[outside, 'z'] = final_pos[:, 2]

    # At this point we may have nodes that snapped to the same vertex and
    # therefore end up at the same position. We will collapse those nodes
    # - but only if they are actually connected!
    # First find duplicate locations
    u, i, c = np.unique(s.vertices, return_counts=True, return_inverse=True, axis=0)

    # If any coordinates have counter higher 1
    if c.max() > 1:
        rewire = {}
        # Find out which unique coordinates are duplicated
        dupl = np.where(c > 1)[0]

        # Go over each duplicated coordinate
        for ix in dupl:
            # Find the nodes on this duplicate coordinate
            node_ix = np.where(i == ix)[0]

            # Get their edges
            edges = s.edges[np.all(np.isin(s.edges, node_ix), axis=1)]

            # We will work on the graph to collapse nodes sequentially A->B->C
            G = nx.DiGraph()
            G.add_edges_from(edges)
            for cc in nx.connected_components(G.to_undirected()):
                # Root is the node without any outdegree in this subgraph
                root = [n for n in cc if G.out_degree[n] == 0][0]
                # We don't want to collapse into `root` because it's not actually
                # among the nodes with the same coordinates but rather the "last"
                # nodes parent
                clps_into = next(G.predecessors(root))
                # Keep track of how we need to rewire
                rewire.update({c: clps_into for c in cc if c not in {root, clps_into}})

        # Only mess with the skeleton if there were nodes to be merged
        if rewire:
            # Rewire
            s.swc['parent_id'] = s.swc.parent_id.map(lambda x: rewire.get(x, x))

            # Drop nodes that were collapsed
            s.swc = s.swc.loc[~s.swc.node_id.isin(rewire)]

            # Update mesh map
            if not isinstance(s.mesh_map, type(None)):
                s.mesh_map = [rewire.get(x, x) for x in s.mesh_map]

            # Reindex to make vertex IDs continous again
            s.reindex(inplace=True)

            # This prevents future SettingsWithCopy Warnings:
            if not inplace:
                s.swc = s.swc.copy()

    return s
Ejemplo n.º 7
0
def by_tangent_ball(mesh):
    """Skeletonize a mesh by finding the maximal tangent ball.

    This algorithm casts a ray from every mesh vertex along its inverse normals
    (requires `ncollpyde`). It then creates a sphere that is tangent to the
    vertex and to where the ray hit the inside of a face on the opposite side.
    Next it drops spheres that overlap with another, larger sphere. Modified
    from [1].

    The method works best on smooth meshes and is rather sensitive to errors in
    the mesh such as incorrect normals (see `skeletor.pre.fix_mesh`), internal
    faces, noisy surface (try smoothing or downsampling) or holes in the mesh.

    Parameters
    ----------
    mesh :              mesh obj
                        The mesh to be skeletonize. Can an object that has
                        ``.vertices`` and ``.faces`` properties  (e.g. a
                        trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a
                        dictionary ``{'vertices': vertices, 'faces': faces}``.

    Returns
    -------
    skeletor.Skeleton
                        Holds results of the skeletonization and enables quick
                        visualization.

    Examples
    --------
    >>> import skeletor as sk
    >>> mesh = sk.example_mesh()
    >>> fixed = sk.pre.fix_mesh(mesh, fix_normals=True, remove_disconnected=10)
    >>> skel = sk.skeletonize.by_tangent_ball(fixed)

    References
    ----------
    [1] Ma, J., Bae, S.W. & Choi, S. 3D medial axis point approximation using
        nearest neighbors and the normal field. Vis Comput 28, 7–19 (2012).
        https://doi.org/10.1007/s00371-011-0594-7

    """
    mesh = make_trimesh(mesh, validate=False)

    # Generate the KD tree
    tree = scipy.spatial.cKDTree(mesh.vertices)

    dist = tree.query(mesh.vertices, k=2)[0][:, 1]

    centers = np.zeros(mesh.vertices.shape)
    radii = np.zeros(mesh.vertices.shape[0])

    coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)
    sources = mesh.vertices - mesh.vertex_normals * 0.01
    targets = mesh.vertices - mesh.vertex_normals * (dist.max() * 10)
    ix, loc, is_backface = coll.intersections(sources, targets)

    # Now we need to invalidate centers
    intersects = np.zeros(mesh.vertices.shape[0]).astype(bool)
    intersects[ix[is_backface]] = True
    centers[ix] = mesh.vertices[ix] + (loc - mesh.vertices[ix]) / 2
    radii[ix] = np.linalg.norm(loc - mesh.vertices[ix], axis=1) / 2

    # Now we need to post processing
    inv = np.zeros(mesh.vertices.shape[0]).astype(bool)

    # Invalidate vertices that didn't intersect
    inv[~intersects] = True

    # Now invalidate any ball that is outside the mesh
    inv[~coll.contains(centers)] = True

    # Find tangent balls that are fully contained in another tangent ball
    # (those are not maximal inscribed)
    original_ind = np.arange(mesh.vertices.shape[0])
    while True:
        tree2 = scipy.spatial.cKDTree(centers[~inv])
        # For any not-yet-invalidated center find the closest other center
        dist, ix = tree2.query(centers[~inv], k=2)

        # Drop self-hits
        ix, dist = ix[:, 1], dist[:, 1]

        # In radius
        in_radius = dist < radii[~inv]

        # Stop if no more overlapping pairs
        if not in_radius.any():
            break

        # Collect radii to determine which of the overlapping ball survives
        pair_rad = np.vstack((radii[~inv][in_radius],
                              radii[~inv][ix[in_radius]])).T
        pair_ix = np.vstack((original_ind[~inv][in_radius],
                             original_ind[~inv][ix[in_radius]])).T

        # Invalidate the loosers
        looses = np.argmax(pair_rad, axis=1)
        looser_ix = np.unique(pair_ix[np.arange(pair_ix.shape[0]), looses])
        inv[looser_ix] = True

    # Now we need to collapse nodes into the remaining centers
    G = ig.Graph(n=mesh.vertices.shape[0],
                 edges=mesh.edges_unique,
                 directed=False)

    # Make sure that every connected component has at least one valid target
    for cc in G.clusters():
        if not np.isin(cc, original_ind[~inv]).any():
            inv[cc[0]] = False
            centers[cc[0]] = mesh.vertices[cc[0]]

    # For each invalidated vertex, find the closest vertex that is still valid
    # This works on unweighted edges but should be good enough - way faster
    # than a proper path search for sure
    pairs = find_closest(G, sources=original_ind[inv],
                         targets=original_ind[~inv])

    # Generate a mesh vertex to skeleton node map
    mesh_map = original_ind.copy()
    mesh_map[pairs[:, 0]] = pairs[:, 1]

    # Renumber the vertices from 0 -> N_vertices
    uni, ind, mesh_map = np.unique(mesh_map, return_inverse=True, return_index=True)

    # Make sure centers and radii match the new order
    centers = centers[uni]
    radii = radii[uni]

    # Contract vertices to nodes according to the mesh
    G.contract_vertices(mesh_map, combine_attrs=None)

    # This only drops duplicate and self-loop edges
    G = G.simplify()

    # Generate weights between remaining centers
    el = np.array(G.get_edgelist())
    weights = np.linalg.norm(centers[el[:, 0]] - centers[el[:, 1]], axis=1)

    # Generate hierarchical tree
    tree = G.spanning_tree(weights=weights)

    # Create a directed acyclic and hierarchical graph
    G_nx = edges_to_graph(edges=np.array(tree.get_edgelist()),
                          nodes=np.arange(0, len(G.vs)),
                          fix_tree=True,
                          drop_disconnected=False)

    # Generate the SWC table
    swc = make_swc(G_nx, coords=centers, reindex=False)
    swc['radius'] = radii[swc.node_id.values]
    _, new_ids = reindex_swc(swc, inplace=True)

    # Update vertex to node ID map
    mesh_map = np.array([new_ids[n] for n in mesh_map])

    return Skeleton(swc=swc, mesh=mesh, mesh_map=mesh_map,
                    method='tangent_ball')
Ejemplo n.º 8
0
def collapse_nodes(A, B, limit=1, base_neuron=None, mesh=None):
    """Merge neuron A into neuron(s) B creating a union of both.

    This implementation uses edge contraction on the neurons' graph to ensure
    maximum connectivity. Only works if the fragments collectively form a
    continuous tree (i.e. you must be certain that they partially overlap).

    Parameters
    ----------
    A :                 CatmaidNeuron
                        Neuron to be collapsed into neurons B.
    B :                 CatmaidNeuronList
                        Neurons to collapse neuron A into.
    limit :             int, optional
                        Max distance [microns] for nearest neighbour search.
    base_neuron :       skeleton ID | CatmaidNeuron, optional
                        Neuron from B to use as template for union. If not
                        provided, the first neuron in the list is used as
                        template!
    mesh :              navis.Volume | navis.MeshNeuron | mesh-like object
                        If provided, will use the mesh to check if nodes are
                        in line of sight to each other before collapsing them.

    Returns
    -------
    core.CatmaidNeuron
                        Union of all input neurons.
    new_edges :         pandas.DataFrame
                        Subset of the ``.nodes`` table that represent newly
                        added edges.
    collapsed_nodes :   dict
                        Map of collapsed nodes::

                            NodeA -collapsed-into-> NodeB

    """
    if isinstance(A, navis.NeuronList):
        if len(A) == 1:
            A = A[0]
        else:
            A = navis.stitch_neurons(A, method="NONE")

    if not isinstance(A, navis.TreeNeuron):
        raise TypeError('`A` must be a TreeNeuron, got "{}"'.format(type(A)))

    if isinstance(B, navis.TreeNeuron):
        B = navis.NeuronList(B)

    if not isinstance(B, navis.NeuronList):
        raise TypeError('`B` must be a NeuronList, got "{}"'.format(type(B)))

    if isinstance(base_neuron, type(None)):
        base_neuron = B[0]

    # This is just check on the off-chance that skeleton IDs are not unique
    # (e.g. if neurons come from different projects) -> this is relevant because
    # we identify the master ("base_neuron") via it's skeleton ID
    skids = [n.id for n in B + A]
    if len(skids) > len(set(skids)):
        raise ValueError(
            'Duplicate skeleton IDs found. Try manually assigning '
            'unique skeleton IDs.')

    # Convert distance threshold from microns to nanometres
    limit *= 1000

    # Before we start messing around, let's make sure we can keep track of
    # the origin of each node
    for n in B + A:
        n.nodes['origin_skeletons'] = n.id

    # First make a weak union by simply combining the node tables
    B.neurons = sorted(B.neurons, key=lambda x: {base_neuron: 0}.get(x, 2))
    union_simple = navis.stitch_neurons(B + A, method='NONE', master='FIRST')

    # Check for duplicate node IDs
    if any(union_simple.nodes.node_id.duplicated()):
        raise ValueError('Duplicate node IDs found.')

    # Find nodes in A to be merged into B
    tree = scipy.spatial.cKDTree(data=B.nodes[['x', 'y', 'z']].values)

    # For each node in A get the nearest neighbor in B
    coords = A.nodes[['x', 'y', 'z']].values
    nn_dist, nn_ix = tree.query(coords, k=1, distance_upper_bound=limit)

    # Find nodes that are close enough to collapse
    collapsed = A.nodes.loc[nn_dist <= limit].node_id.values
    clps_into = B.nodes.iloc[nn_ix[nn_dist <= limit]].node_id.values

    # If we have a mesh, check if those collapsed nodes are in sight of each
    # other
    if mesh:
        import ncollpyde
        coll = ncollpyde.Volume(mesh.vertices, mesh.faces)

        # Produce start and end coordinates for the to collapse nodes
        starts = A.nodes.set_index('node_id').loc[collapsed,
                                                  ['x', 'y', 'z']].values
        ends = B.nodes.set_index('node_id').loc[clps_into,
                                                ['x', 'y', 'z']].values

        # Check if the line between start and end intersects the mesh
        intersects, _, _ = coll.intersections(starts, ends)

        # Indices that show up in `intersects` cross a membrane
        # -> we need to invert this to get those nodes that don't
        not_intersects = ~np.isin(np.arange(starts.shape[0]), intersects)

        # Keep only collapses that don't intersect
        collapsed = collapsed[not_intersects]
        clps_into = clps_into[not_intersects]

    # Generate a map of which node in A is to be collapsed into which node in B
    clps_map = {n1: n2 for n1, n2 in zip(collapsed, clps_into)}

    # The fastest way to collapse is to work on the edge list
    E = nx.to_pandas_edgelist(union_simple.graph)

    # Keep track of which edges were collapsed -> we will use this as weight
    # later on to prioritize existing edges over newly generated ones
    E['is_new'] = 1
    source_in_B = E.source.isin(B.nodes.node_id.values)
    target_in_B = E.target.isin(B.nodes.node_id.values)
    E.loc[source_in_B | target_in_B, 'is_new'] = 0

    # Now map collapsed nodes onto the nodes they collapsed into
    E['target'] = E.target.map(lambda x: clps_map.get(x, x))
    E['source'] = E.source.map(lambda x: clps_map.get(x, x))

    # Make sure no self loops after collapsing. This happens if two adjacent
    # nodes collapse onto the same target node
    E = E[E.source != E.target]

    # Remove duplicates. This happens e.g. when two adjaceny nodes merge into
    # two other adjaceny nodes: A->B C->D ----> A/B->C/D
    # By sorting first, we make sure original edges are kept first
    E.sort_values('is_new', ascending=True, inplace=True)

    # Because edges may exist in both directions (A->B and A<-B) we have to
    # generate a column that's agnostic to directionality using frozensets
    E['frozen_edge'] = E[['source', 'target']].apply(frozenset, axis=1)
    E.drop_duplicates(['frozen_edge'], keep='first', inplace=True)

    # Regenerate graph from these new edges
    G = nx.Graph()
    G.add_weighted_edges_from(E[['source', 'target',
                                 'is_new']].values.astype(int))

    # At this point there might still be disconnected pieces -> we will create
    # separate neurons for each tree
    props = union_simple.nodes.loc[union_simple.nodes.node_id.isin(
        G.nodes)].set_index('node_id')
    nx.set_node_attributes(G, props.to_dict(orient='index'))
    fragments = []
    for n in nx.connected_components(G):
        c = G.subgraph(n)
        tree = nx.minimum_spanning_tree(c)
        fragments.append(
            navis.graph.nx2neuron(tree,
                                  name=base_neuron.name,
                                  id=base_neuron.id))
    fragments = navis.NeuronList(fragments)

    if len(fragments) > 1:
        print('Union incomplete - watch out for disconnected fragments!')
        # Now heal those fragments using a minimum spanning tree
        union = navis.stitch_neurons(*fragments, method='ALL')
    else:
        union = fragments[0]

    # Reroot to base neuron's root
    union.reroot(base_neuron.root[0], inplace=True)

    # Add tags back on
    if union_simple.has_tags:
        if not union.has_tags:
            union.tags = {}
        union.tags.update(union_simple.tags)

    # Add connectors back on
    union.connectors = union_simple.connectors.drop_duplicates(
        subset='connector_id').copy()
    union.connectors.loc[:, 'node_id'] = union.connectors.node_id.map(
        lambda x: clps_map.get(x, x))

    # Find the newly added edges (existing edges should not have been modified
    # - except for changing direction due to reroot)
    # The basic logic here is that new edges were only added between two
    # previously separate skeletons, i.e. where the skeleton ID changes between
    # parent and child node
    node2skid = union_simple.nodes.set_index(
        'node_id').origin_skeletons.to_dict()
    union.nodes['parent_skeleton'] = union.nodes.parent_id.map(node2skid)
    new_edges = union.nodes[
        union.nodes.origin_skeletons != union.nodes.parent_skeleton]
    # Remove root edges
    new_edges = new_edges[~new_edges.parent_id.isnull()]

    return union, new_edges, clps_map