def skeletonize_neuron_parallel(ids, n_cores=os.cpu_count() // 2, **kwargs): """Skeletonization on parallel cores [WIP]. Parameters ---------- ids : iterable Root IDs of neurons you want to skeletonize. n_cores : int Number of cores to use. Don't go too crazy on this as the downloading of meshes becomes a bottle neck if you try to do too many at the same time. Keep your internet speed in mind. **kwargs Keyword arguments are passed on to `skeletonize_neuron`. Returns ------- navis.NeuronList """ if n_cores < 2 or n_cores > os.cpu_count(): raise ValueError( '`n_cores` must be between 2 and max number of cores.') sig = inspect.signature(skeletonize_neuron) for k in kwargs: if k not in sig.parameters: raise ValueError('unexpected keyword argument for ' f'`skeletonize_neuron`: {k}') # Make sure IDs are all integers ids = np.asarray(ids).astype(int) # Prepare the calls and parameters kwargs['progress'] = False funcs = [skeletonize_neuron] * len(ids) parsed_kwargs = [kwargs] * len(ids) combinations = list(zip(funcs, [[i] for i in ids], parsed_kwargs)) # Run the actual skeletonization with mp.Pool(n_cores) as pool: chunksize = 1 res = list( navis.config.tqdm(pool.imap(_worker_wrapper, combinations, chunksize=chunksize), total=len(combinations), desc='Skeletonizing', disable=False, leave=True)) # Check if any skeletonizations failed failed = np.array([r for r in res if not isinstance(r, navis.TreeNeuron)]).astype(str) if any(failed): print(f'{len(failed)} neurons failed to skeletonize: ' f'{". ".join(failed)}') return navis.NeuronList( [r for r in res if isinstance(r, navis.TreeNeuron)])
def neuron_to_segments(x, dataset='production', coordinates='voxel'): """Get root IDs overlapping with a given neuron. Parameters ---------- x : Neuron/List Neurons for which to return root IDs. Neurons must be in flywire (FAFB14.1) space. dataset : str | CloudVolume Against which flywire dataset to query:: - "production" (current production dataset, fly_v31) - "sandbox" (i.e. fly_v26) coordinates : "voxel" | "nm" Units the neuron(s) are in. "voxel" is assumed to be 4x4x40 (x/y/z) nanometers. Returns ------- overlap_matrix : pandas.DataFrame DataFrame of root IDs (rows) and IDs (columns) with overlap in nodes as values:: id id1 id2 root_id 10336680915 5 0 10336682132 0 1 """ if isinstance(x, navis.TreeNeuron): x = navis.NeuronList(x) assert isinstance(x, navis.NeuronList) # We must not perform this on x.nodes as this is a temporary property nodes = x.nodes # Get segmentation IDs nodes['root_id'] = locs_to_segments(nodes[['x', 'y', 'z']].values, coordinates=coordinates, root_ids=True, dataset=dataset) # Count segment IDs seg_counts = nodes.groupby(['neuron', 'root_id'], as_index=False).node_id.count() seg_counts.columns = ['id', 'root_id', 'counts'] # Remove seg IDs 0 seg_counts = seg_counts[seg_counts.root_id != 0] # Turn into matrix where columns are skeleton IDs, segment IDs are rows # and values are the overlap counts matrix = seg_counts.pivot(index='root_id', columns='id', values='counts') return matrix
def get_mesh_neuron(id, with_synapses=False, dataset='production'): """Fetch flywire neuron as navis.MeshNeuron. Parameters ---------- id : int | list of int Segment ID(s) to fetch meshes for. with_synapses : bool If True, will also load a connector table with synapse predicted by Buhmann et al. (2020). A "synapse score" (confidence) threshold of 30 is applied. dataset : str | CloudVolume Against which flywire dataset to query:: - "production" (currently fly_v31) - "sandbox" (currently fly_v26) Return ------ navis.MeshNeuron Examples -------- >>> from fafbseg import flywire >>> m = flywire.get_mesh_neuron(720575940614131061) >>> m.plot3d() # doctest: +SKIP """ vol = parse_volume(dataset) if navis.utils.is_iterable(id): return navis.NeuronList([ get_mesh_neuron(n, dataset=dataset, with_synapses=with_synapses) for n in navis.config.tqdm(id, desc='Fetching', leave=False) ]) # Make sure the ID is integer id = int(id) # Fetch mesh mesh = vol.mesh.get(id, remove_duplicate_vertices=True)[id] # Turn into meshneuron n = navis.MeshNeuron(mesh, id=id, units='nm', dataset=dataset) if with_synapses: _ = fetch_synapses(n, attach=True, min_score=30, dataset=dataset, progress=False) return n
def neuron_to_segments(x): """Get segment IDs overlapping with a given neuron. Parameters ---------- x : Neuron/List Neurons for which to return segment IDs. Returns ------- overlap_matrix : pandas.DataFrame DataFrame of segment IDs (rows) and IDs (columns) with overlap in nodes as values:: skeleton_id id 3245 seg_id 10336680915 5 0 10336682132 0 1 """ if isinstance(x, navis.TreeNeuron): x = navis.NeuronList(x) assert isinstance(x, navis.NeuronList) # We must not perform this on x.nodes as this is a temporary property nodes = x.nodes # Get segmentation IDs nodes['seg_id'] = locs_to_segments(nodes[['x', 'y', 'z']].values, coordinates='nm', mip=0) # Count segment IDs seg_counts = nodes.groupby(['neuron', 'seg_id'], as_index=False).node_id.count() seg_counts.columns = ['skeleton_id', 'seg_id', 'counts'] # Remove seg IDs 0 seg_counts = seg_counts[seg_counts.seg_id != 0] # Turn into matrix where columns are skeleton IDs, segment IDs are rows # and values are the overlap counts matrix = seg_counts.pivot(index='seg_id', columns='skeleton_id', values='counts') return matrix
def get_mesh_neuron(id, dataset='production'): """Fetch flywire neuron as navis.MeshNeuron. Parameters ---------- id : int | list of int Segment ID(s) to fetch meshes for. dataset : str | CloudVolume Against which flywire dataset to query:: - "production" (currently fly_v31) - "sandbox" (currently fly_v26) Return ------ navis.MeshNeuron Examples -------- >>> from fafbseg import flywire >>> m = flywire.get_mesh_neuron(720575940614131061) >>> m.plot3d() # doctest """ vol = parse_volume(dataset) if navis.utils.is_iterable(id): return navis.NeuronList([ get_mesh_neuron(n, dataset=dataset) for n in tqdm(id, desc='Fetching', leave=False) ]) # Make sure the ID is integer id = int(id) # Fetch mesh mesh = vol.mesh.get(id, remove_duplicate_vertices=True)[id] # Turn into meshneuron return navis.MeshNeuron(mesh, id=id, units='nm', dataset=dataset)
def merge_into_catmaid(x, target_instance, tag, min_node_overlap=4, min_overlap_size=1, merge_limit=1, min_upload_size=0, min_upload_nodes=1, update_radii=True, import_tags=False, label_joins=True, sid_from_nodes=True, mesh=None): """Merge neuron into target CATMAID instance. This function will attempt to: 1. Find fragments in ``target_instance`` that overlap with ``x`` using whatever segmentation data source you have set using ``fafbseg.use_...``. 2. Generate a union of these fragments and ``x``. 3. Make a differential upload of the union leaving existing nodes untouched. 4. Join uploaded and existing tracings into a single continuous neuron. This will also upload connectors but no node tags. Parameters ---------- x : pymaid.CatmaidNeuron/List | navis.TreeNeuron/List Neuron(s)/fragment(s) to commit to ``target_instance``. target_instance : pymaid.CatmaidInstance Target Catmaid instance to commit the neuron to. tag : str A tag to be added as part of a ``{URL} upload {tag}`` annotation. This should be something identifying your group - e.g. ``tag='WTCam'`` for the Cambridge Wellcome Trust group. min_node_overlap : int, optional Minimal overlap between `x` and a potentially overlapping neuron in ``target_instance``. If the fragment has less total nodes than `min_overlap`, the threshold will be lowered to: ``min_overlap = min(min_overlap, fragment.n_nodes)`` min_overlap_size : int, optional Minimum node count for potentially overlapping neurons in ``target_instance``. Use this to e.g. exclude single-node synapse orphans. merge_limit : int, optional Distance threshold [um] for collapsing nodes of ``x`` into overlapping fragments in target instance. Decreasing this will help if your neuron has complicated branching patterns (e.g. uPN dendrites) at the cost of potentially creating duplicate parallel tracings in the neuron's backbone. min_upload_size : float, optional Minimum size in microns for upload of new branches: branches found in ``x`` but not in the overlapping neuron(s) in ``target_instance`` are uploaded in fragments. Use this parameter to exclude small branches that might not be worth the additional review time. min_upload_nodes : int, optional As ``min_upload_size`` but for number of nodes instead of cable length. update_radii : bool, optional If True, will use radii in ``x`` to update radii of overlapping fragments if (and only if) the nodes do not currently have a radius (i.e. radius<=0). import_tags : bool, optional If True, will import node tags. Please note that this will NOT import tags of nodes that have been collapsed into manual tracings. label_joins : bool, optional If True, will label nodes at which old and new tracings have been joined with tags ("Joined from ..." and "Joined with ...") and with a lower confidence of 1. sid_from_nodes : bool, optional If True and the to-be-merged neuron has a "skeleton_id" column it will be used to set the ``source_id`` upon uploading new branches. This is relevant if your neuron is a virtual chimera of several neurons: in order to preserve provenance (i.e. correctly associating each node with a ``source_id`` origin). mesh : Volume | MeshNeuron | mesh-like object | list thereof Mesh representation of ``x``. If provided, will use to improve merging. If ``x`` is a list of neurons, must provide a mesh for each of them. Returns ------- Nothing If all went well. dict If something failed, returns server responses with error logs. Examples -------- Setup >>> import fafbseg >>> import pymaid >>> # Set up connections to manual and autoseg CATMAID >>> manual = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN') >>> auto = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN') >>> # Set a segmentation data source >>> fafbseg.use_google_storage("https://storage.googleapis.com/fafb-ffn1-20190805/segmentation") Merge a neuron from autoseg into v14 >>> # Fetch the autoseg neuron to transfer to v14 >>> x = pymaid.get_neuron(267355161, remote_instance=auto) >>> # Get the neuron's annotations so that they can be merged too >>> x.get_annotations(remote_instance=auto) >>> # Start the commit >>> # See online documentation for video of merge process >>> resp = fafbseg.merge_neuron(x, target_instance=manual) """ if not isinstance(x, navis.NeuronList): if not isinstance(x, navis.TreeNeuron): raise TypeError('Expected TreeNeuron/List, got "{}"'.format( type(x))) x = navis.NeuronList(x) if not isinstance(mesh, (np.ndarray, list)): if isinstance(mesh, type(None)): mesh = [mesh] * len(x) else: mesh = [mesh] if len(mesh) != len(x): raise ValueError(f'Got {len(mesh)} meshes for {len(x)} neurons.') # Make a copy - in case we make any changes to the neurons # (like changing duplicate skeleton IDs) x = x.copy() if not isinstance(tag, (str, type(None))): raise TypeError('Tag must be string, got "{}"'.format(type(tag))) # Check user permissions perm = target_instance.fetch(target_instance.make_url('permissions')) requ_perm = ['can_annotate', 'can_annotate_with_token', 'can_import'] miss_perm = [ p for p in requ_perm if target_instance.project_id not in perm[0].get(p, []) ] if miss_perm: msg = 'You lack permissions: {}. Please contact an administrator.' raise PermissionError(msg.format(', '.join(miss_perm))) pymaid.set_loggers('WARNING') # Throttle requests just to play it safe # On a bad connection one might have to decrease max_threads further target_instance.max_threads = min(target_instance.max_threads, 50) # For user convenience, we will do all the stuff that needs user # interaction first and then run the automatic merge: # Start by find all overlapping fragments overlapping = [] for n, m in tqdm(zip(x, mesh), desc='Pre-processing neuron(s)', leave=False, disable=not use_pbars, total=len(x)): ol = find_fragments(n, min_node_overlap=min_node_overlap, min_nodes=min_overlap_size, mesh=m, remote_instance=target_instance) if ol: # Add number of samplers to each neuron n_samplers = pymaid.get_sampler_counts( ol, remote_instance=target_instance) for nn in ol: nn.sampler_count = n_samplers[str(nn.id)] overlapping.append(ol) # Now have the user confirm merges before we actually make them viewer = navis.Viewer(title='Confirm merges') viewer.clear() overlap_cnf = [] base_neurons = [] try: for n, ol in zip(x, overlapping): # This asks user a bunch of questions prior to merge and upload ol, bn = confirm_overlap(n, ol, viewer=viewer) overlap_cnf.append(ol) base_neurons.append(bn) except BaseException: raise finally: viewer.close() for i, (n, ol, bn, m) in enumerate(zip(x, overlap_cnf, base_neurons, mesh)): print(f'Processing neuron "{n.name}" ({n.id}) [{i}/{len(x)}]', flush=True) # If no overlapping neurons proceed with just uploading. if not ol: print( 'No overlapping fragments found. Uploading without merging...', end='', flush=True) resp = pymaid.upload_neuron(n, import_tags=import_tags, import_annotations=True, import_connectors=True, remote_instance=target_instance) if 'error' in resp: return resp # Add annotations _ = __merge_annotations(n, resp['skeleton_id'], tag, target_instance) msg = '\nNeuron "{}" successfully uploaded to target instance as "{}" #{}' print(msg.format(n.name, n.name, resp['skeleton_id']), flush=True) continue # Check if there is a duplicate skeleton ID between the to-be-merged # neuron and the to-merge-into neurons original_skid = None if n.id in ol.id: print('Fixing duplicate skeleton IDs.', flush=True) # Keep track of old skid original_skid = n.id # Skeleton ID must stay convertable to integer n.id = str(random.randint(1, 1000000)) n._clear_temp_attr() # Check if there are any duplicate node IDs between neuron ``x`` and the # overlapping fragments and create new IDs for ``x`` if necessary duplicated = n.nodes[n.nodes.node_id.isin(ol.nodes.node_id.values)] if not duplicated.empty: print('Duplicate node IDs found. Regenerating node tables... ', end='', flush=True) max_ix = max(ol.nodes.node_id.max(), n.nodes.node_id.max()) + 1 new_ids = range(max_ix, max_ix + duplicated.shape[0]) id_map = { old: new for old, new in zip(duplicated.node_id, new_ids) } n.nodes['node_id'] = n.nodes.node_id.map( lambda n: id_map.get(n, n)) n.nodes['parent_id'] = n.nodes.parent_id.map( lambda n: id_map.get(n, n)) if n.has_connectors: n.connectors['node_id'] = n.connectors.node_id.map( lambda n: id_map.get(n, n)) n._clear_temp_attr() print('Done.', flush=True) # Combining the fragments into a single neuron is actually non-trivial: # 1. Collapse nodes of our input neuron `x` into within-distance nodes # in the overlapping fragments (never the other way around!) # 2. At the same time keep connectivity (i.e. edges) of the input-neuron # 3. Keep track of the nodes' provenance (i.e. the contractions) # # In addition there are a lot of edge-cases to consider. For example: # - multiple nodes collapsing onto the same node # - nodes of overlapping fragments that are close enough to be collapsed # (e.g. orphan synapse nodes) # Keep track of original skeleton IDs for a in ol + n: # Original skeleton of each node a.nodes['origin_skeletons'] = a.id if a.has_connectors: # Original skeleton of each connector a.connectors['origin_skeletons'] = a.id print('Generating union of all fragments... ', end='', flush=True) union, new_edges, collapsed_into = collapse_nodes(n, ol, limit=merge_limit, base_neuron=bn, mesh=m) print('Done.', flush=True) print('Extracting new nodes to upload... ', end='', flush=True) # Now we have to break the neuron into "new" fragments that we can upload # First get the new and old nodes new_nodes = union.nodes[union.nodes.origin_skeletons == n.id].node_id.values old_nodes = union.nodes[ union.nodes.origin_skeletons != n.id].node_id.values # Now remove the already existing nodes from the union only_new = navis.subset_neuron(union, new_nodes) # And then break into continuous fragments for upload frags = navis.break_fragments(only_new) print('Done.', flush=True) # Also get the new edges we need to generate to_stitch = new_edges[~new_edges.parent_id.isnull()] # We need this later -> no need to compute this for every uploaded fragment cond1b = to_stitch.node_id.isin(old_nodes) cond2b = to_stitch.parent_id.isin(old_nodes) # Now upload each fragment and keep track of new node IDs tn_map = {} for f in tqdm(frags, desc='Merging new arbors', leave=False, disable=not use_pbars): # In cases of complete merging into existing neurons, the fragment # will have no nodes if f.n_nodes < 1: continue # Check if fragment is a "linker" and as such can not be skipped lcond1 = np.isin(f.nodes.node_id.values, new_edges.node_id.values) lcond2 = np.isin(f.nodes.node_id.values, new_edges.parent_id.values) # If not linker, check skip conditions if sum(lcond1) + sum(lcond2) <= 1: if f.cable_length < min_upload_size: continue if f.n_nodes < min_upload_nodes: continue # Collect origin info for this neuron if it's a CatmaidNeuron if isinstance(n, pymaid.CatmaidNeuron): source_info = {'source_type': 'segmentation'} if not sid_from_nodes or 'origin_skeletons' not in f.nodes.columns: # If we had to change the skeleton ID due to duplication, make # sure to pass the original skid as source ID if original_skid: source_info['source_id'] = int(original_skid) else: source_info['source_id'] = int(n.id) else: if f.nodes.origin_skeletons.unique().shape[0] == 1: skid = f.nodes.origin_skeletons.unique()[0] else: print( 'Warning: uploading chimera fragment with multiple ' 'skeleton IDs! Using largest contributor ID.') # Use the skeleton ID that has the most nodes by_skid = f.nodes.groupby('origin_skeletons').x.count() skid = by_skid.sort_values( ascending=False).index.values[0] source_info['source_id'] = int(skid) if not isinstance(getattr(n, '_remote_instance', None), type(None)): source_info[ 'source_project_id'] = n._remote_instance.project_id source_info['source_url'] = n._remote_instance.server else: # Unknown source source_info = {} resp = pymaid.upload_neuron(f, import_tags=import_tags, import_annotations=False, import_connectors=True, remote_instance=target_instance, **source_info) # Stop if there was any error while uploading if 'error' in resp: return resp # Collect old -> new node IDs tn_map.update(resp['node_id_map']) # Now check if we can create any of the new edges by joining nodes # Both treenode and parent ID have to be either existing nodes or # newly uploaded cond1a = to_stitch.node_id.isin(tn_map) cond2a = to_stitch.parent_id.isin(tn_map) to_gen = to_stitch.loc[(cond1a | cond1b) & (cond2a | cond2b)] # Join nodes for node in to_gen.itertuples(): # Make sure our base_neuron always come out as winner on top if node.node_id in bn.nodes.node_id.values: winner, looser = node.node_id, node.parent_id else: winner, looser = node.parent_id, node.node_id # We need to map winner and looser to the new node IDs winner = tn_map.get(winner, winner) looser = tn_map.get(looser, looser) # And now do the join resp = pymaid.join_nodes(winner, looser, no_prompt=True, tag_nodes=label_joins, remote_instance=target_instance) # See if there was any error while uploading if 'error' in resp: print('Skipping joining nodes ' '{} and {}: {} - '.format(node.node_id, node.parent_id, resp['error'])) # Skip changing confidences continue # Pop this edge from new_edges and from condition new_edges.drop(node.Index, inplace=True) cond1b.drop(node.Index, inplace=True) cond2b.drop(node.Index, inplace=True) # Change node confidences at new join if label_joins: new_conf = {looser: 1} resp = pymaid.update_node_confidence( new_conf, remote_instance=target_instance) # Add annotations if n.has_annotations: _ = __merge_annotations(n, bn, tag, target_instance) # Update node radii if update_radii and 'radius' in n.nodes.columns and np.all( n.nodes.radius): print('Updating radii of existing nodes... ', end='', flush=True) resp = update_node_radii(source=n, target=ol, remote_instance=target_instance, limit=merge_limit, skip_existing=True) print('Done.', flush=True) print( 'Neuron "{}" successfully merged into target instance as "{}" #{}'. format(n.name, bn.name, bn.id), flush=True) return
def l2_skeleton(root_id, refine=False, drop_missing=True, threads=10, progress=True, dataset='production', **kwargs): """Generate skeleton from L2 graph. Parameters ---------- root_id : int | list of ints Root ID(s) of the flywire neuron(s) you want to skeletonize. refine : bool If True, will refine skeleton nodes by moving them in the center of their corresponding chunk meshes. Only relevant if ``refine=True``: drop_missing : bool If True, will drop nodes that don't have a corresponding chunk mesh. These are typically chunks that are very small and dropping them might actually be benefitial. threads : int How many parallel threads to use for fetching the chunk meshes. Reduce the number if you run into ``HTTPErrors``. Only relevant if `use_flycache=False`. progress : bool Whether to show a progress bar. Returns ------- skeleton : navis.TreeNeuron The extracted skeleton. Examples -------- >>> from fafbseg import flywire >>> n = flywire.l2_skeleton(720575940614131061) """ # TODO: # - drop duplicate nodes in unrefined skeleton # - use L2 graph to find soma: highest degree is typically the soma use_flycache = kwargs.get('use_flycache', False) if refine and use_flycache and dataset != 'production': raise ValueError('Unable to use fly cache to fetch L2 centroids for ' 'sandbox dataset. Please set `use_flycache=False`.') if navis.utils.is_iterable(root_id): nl = [] for id in navis.config.tqdm(root_id, desc='Skeletonizing', disable=not progress, leave=False): n = l2_skeleton(id, refine=refine, drop_missing=drop_missing, threads=threads, progress=progress, dataset=dataset) nl.append(n) return navis.NeuronList(nl) # Get the cloudvolume vol = parse_volume(dataset) # Hard-coded datastack names ds = {"production": "flywire_fafb_production", "sandbox": "flywire_fafb_sandbox"} # Note that the default server url is https://global.daf-apis.com/info/ client = FrameworkClient(ds.get(dataset, dataset)) # Load the L2 graph for given root ID # This is a (N,2) array of edges l2_eg = np.array(client.chunkedgraph.level2_chunk_graph(root_id)) # Drop duplicate edges l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0) # Unique L2 IDs l2_ids = np.unique(l2_eg) # ID to index l2dict = {l2: ii for ii, l2 in enumerate(l2_ids)} # Remap edge graph to indices eg_arr_rm = fastremap.remap(l2_eg, l2dict) coords = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids] coords = np.vstack(coords) # This turns the graph into a hierarchal tree by removing cycles and # ensuring all edges point towards a root if sk.__version_vector__[0] < 1: G = sk.skeletonizers.edges_to_graph(eg_arr_rm) swc = sk.skeletonizers.make_swc(G, coords=coords) else: G = sk.skeletonize.utils.edges_to_graph(eg_arr_rm) swc = sk.skeletonize.utils.make_swc(G, coords=coords, reindex=False) # Convert to Euclidian space # Dimension of a single chunk ch_dims = chunks_to_nm([1, 1, 1], vol) - chunks_to_nm([0, 0, 0], vol) ch_dims = np.squeeze(ch_dims) xyz = swc[['x', 'y', 'z']].values swc[['x', 'y', 'z']] = chunks_to_nm(xyz, vol) + ch_dims / 2 if refine: if use_flycache: token = get_chunkedgraph_secret() centroids = spine.flycache.get_L2_centroids(l2_ids, token=token, progress=progress) # Drop missing (i.e. [0,0,0]) meshes centroids = {k: v for k, v in centroids.items() if v != [0, 0, 0]} else: # Get the centroids centroids = get_L2_centroids(l2_ids, vol, threads=threads, progress=progress) new_co = {l2dict[k]: v for k, v in centroids.items()} # Map refined coordinates onto the SWC has_new = swc.node_id.isin(new_co) swc.loc[has_new, 'x'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][0]) swc.loc[has_new, 'y'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][1]) swc.loc[has_new, 'z'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][2]) # Turn into a proper neuron tn = navis.TreeNeuron(swc, id=root_id, units='1 nm') # Drop nodes that are still at their unrefined chunk position if drop_missing: tn = navis.remove_nodes(tn, swc.loc[~has_new, 'node_id'].values) else: tn = navis.TreeNeuron(swc, id=root_id, units='1 nm') return tn
def find_missed_branches(x, autoseg_instance, tag=False, tag_size_thresh=10, min_node_overlap=4, **kwargs): """Use autoseg to find (and annotate) potential missed branches. Parameters ---------- x : pymaid.CatmaidNeuron/List Neuron(s) to search for missed branches. autoseg_instance : pymaid.CatmaidInstance CATMAID instance containing the autoseg skeletons. tag : bool, optional If True, will tag nodes of ``x`` that might have missed branches with "missed branch?". tag_size_thresh : int, optional Size threshold in microns of cable for tagging potentially missed branches. min_node_overlap : int, optional Minimum number of nodes that input neuron(s) x must overlap with given segmentation ID for it to be included. **kwargs Keyword arguments passed to ``fafbseg.neuron_from_segments``. Returns ------- summary : pandas.DataFrame DataFrame containing a summary of potentially missed branches. If input is a single neuron: fragments : pymaid.CatmaidNeuronList Fragments found to be potentially overlapping with the input neuron. branches : pymaid.CatmaidNeuronList Potentially missed branches extracted from ``fragments``. Examples -------- Setup >>> import fafbseg >>> import pymaid >>> # Set up connections to manual and autoseg CATMAID >>> manual = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN') >>> auto = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN') >>> # Set a source for segmentation data >>> fafbseg.use_google_storage("https://storage.googleapis.com/fafb-ffn1-20190805/segmentation") Find missed branches and tag them >>> # Fetch a neuron >>> x = pymaid.get_neuron(16, remote_instance=manual) >>> # Find and tag missed branches >>> (summary, ... fragments, ... branches) = fafbseg.find_missed_branches(x, autoseg_instance=auto) >>> # Show summary of missed branches >>> summary.head() n_nodes cable_length node_id 0 110 28.297424 3306395 1 90 23.976504 20676047 2 64 15.851333 23419997 3 29 7.494350 6298769 4 16 3.509739 15307841 >>> # Co-visualize your neuron and potentially overlapping autoseg fragments >>> x.plot3d(color='w') >>> fragments.plot3d() >>> # Visualize the potentially missed branches >>> pymaid.clear3d() >>> x.plot3d(color='w') >>> branches.plot3d(color='r') """ if isinstance(x, navis.NeuronList): to_concat = [] for n in tqdm(x, desc='Processing neurons', disable=not use_pbars, leave=False): (summary, frags, branches) = find_missed_branches( n, autoseg_instance=autoseg_instance, tag=tag, tag_size_thresh=tag_size_thresh, **kwargs) summary['skeleton_id'] = n.id to_concat.append(summary) return pd.concat(to_concat, ignore_index=True) elif not isinstance(x, navis.TreeNeuron): raise TypeError(f'Input must be TreeNeuron/List, got "{type(x)}"') # Find autoseg neurons overlapping with input neuron nl = find_autoseg_fragments(x, autoseg_instance=autoseg_instance, min_node_overlap=min_node_overlap, verbose=False, raise_none_found=False) # Next create a union if not nl.empty: for n in nl: n.nodes['origin'] = 'autoseg' n.nodes['origin_skid'] = n.skeleton_id # Create a simple union union = navis.stitch_neurons(nl, method='NONE') # Merge into target neuron union, new_edges, clps_map = move.merge_utils.collapse_nodes(union, x, limit=2) # Subset to autoseg nodes autoseg_nodes = union.nodes[union.nodes.origin == 'autoseg'].node_id.values else: autoseg_nodes = np.empty((0, 5)) # Process fragments if any autoseg nodes left data = [] frags = navis.NeuronList([]) if autoseg_nodes.shape[0]: autoseg = navis.subset_neuron(union, autoseg_nodes) # Split into fragments frags = navis.break_fragments(autoseg) # Generate summary nodes = union.nodes.set_index('node_id') for n in frags: # Find parent node in union pn = nodes.loc[n.root[0], 'parent_id'] pn_co = nodes.loc[pn, ['x', 'y', 'z']].values org_skids = n.nodes.origin_skid.unique().tolist() data.append([n.n_nodes, n.cable_length, pn, pn_co, org_skids]) df = pd.DataFrame(data, columns=[ 'n_nodes', 'cable_length', 'node_id', 'node_loc', 'autoseg_skids' ]) df.sort_values('cable_length', ascending=False, inplace=True) if tag and not df.empty: to_tag = df[df.cable_length >= tag_size_thresh].node_id.values resp = pymaid.add_tags(to_tag, tags='missed branch?', node_type='TREENODE', remote_instance=x._remote_instance) if 'error' in resp: return df, resp return df, nl, frags
def get_neuron_connections(sources, targets=None, agglomerate=True, score_thresh=30, ol_thresh=5, dist_thresh=2000, drop_duplicates=True, drop_autapses=True, db=None, verbose=True): """Fetch connections between sets of neurons. Works by: 1. Fetch segment IDs corresponding to given neurons 2. Fetch Buhmann et al. synaptic connections between them 3. Do some clean-up. See parameters for details - defaults are those used by Buhmann et al Parameters ---------- sources : navis.Neuron | pymaid.CatmaidNeuron | NeuronList Presynaptic neurons to fetch connections for. targets : navis.Neuron | pymaid.CatmaidNeuron | NeuronList | None Postsynaptic neurons to fetch connections for. If ``None``, ``targets = sources``. agglomerate : bool If True, will agglomerate connectivity by ID and return a weighted edge list. If False, will return a list of individual synapses. ol_thresh : int, optional If provided, will required a minimum node overlap between a neuron and a segment for that segment to be included (see step 1 above). score_thresh : int, optional If provided will only return synapses with a cleft score of this or higher. dist_thresh : int Synapses further away than the given distance [nm] from any node in the neuron than given distance will be discarded. drop_duplicates : bool If True, will merge synapses which connect the same pair of pre-post segmentation IDs and are within less than 250nm. drop_autapses : bool If True, will automatically drop autapses. db : str, optional Must point to SQL database containing the synapse data. If not provided will look for a ```BUHMANN_SYNAPSE_DB`` environment variable. Return ------ pd.DataFrame Either edge list or list of synapses - see ``agglomerate`` parameter. """ # This is just so we throw a DB exception early and not wait for fetching # segment Ids first _ = get_connection(db) assert isinstance(sources, (navis.BaseNeuron, navis.NeuronList)) if isinstance(targets, type(None)): targets = sources assert isinstance(targets, (navis.BaseNeuron, navis.NeuronList)) if not isinstance(sources, navis.NeuronList): sources = navis.NeuronList([sources]) if not isinstance(targets, navis.NeuronList): targets = navis.NeuronList([targets]) # Get segments for this neuron(s) unique_neurons = (sources + targets).remove_duplicates(key='id') seg_ids = google.neuron_to_segments(unique_neurons) # Drop segments with overlap below threshold if ol_thresh: seg_ids = seg_ids.loc[seg_ids.max(axis=1) >= ol_thresh] # We need to make sure that every segment ID is only attributed to a single # neuron is_top = seg_ids.values != seg_ids.max(axis=1).values.reshape( (seg_ids.shape[0], 1)) where_top = np.where(is_top) seg_ids.values[where_top] = 0 # do not change this to None pre_ids = seg_ids.loc[seg_ids[sources.id].max(axis=1) > 0].index.values post_ids = seg_ids.loc[seg_ids[targets.id].max(axis=1) > 0].index.values # Fetch pre- and postsynapses associated with these segments # It's cheaper to get them all in one go syn = query_connections(pre_ids, post_ids, score_thresh=score_thresh, db=None) # Now associate synapses with neurons seg2neuron = dict( zip(seg_ids.index.values, seg_ids.columns[np.argmax(seg_ids.values, axis=1)])) syn['id_pre'] = syn.segmentid_pre.map(seg2neuron) syn['id_post'] = syn.segmentid_post.map(seg2neuron) # Let the clean-up BEGIN! # First drop nasty autapses if drop_autapses: syn = syn[syn.id_pre != syn.id_post] # Next drop synapses far away from our neurons if dist_thresh: syn['pre_close'] = False syn['post_close'] = False for id in np.unique(syn[['id_pre', 'id_post']].values.flatten()): neuron = unique_neurons.idx[id] tree = navis.neuron2KDTree(neuron) is_pre = syn.id_pre == id if np.any(is_pre): dist, ix = tree.query( syn.loc[is_pre, ['pre_x', 'pre_y', 'pre_z']].values, distance_upper_bound=dist_thresh) syn.loc[is_pre, 'pre_close'] = dist < float('inf') is_post = syn.id_post == id if np.any(is_post): dist, ix = tree.query( syn.loc[is_post, ['post_x', 'post_y', 'post_z']].values, distance_upper_bound=dist_thresh) syn.loc[is_post, 'post_close'] = dist < float('inf') # Drop connections where either pre- or postsynaptic site are too far # away from the neuron syn = syn[syn.pre_close & syn.post_close] # Drop duplicate connections, i.e. connections that connect the same pre- # and postsynaptic segmentation ID and are within a distance of 250nm dupl_thresh = 250 if drop_duplicates: # We are dealing with this from a presynaptic perspective while True: pre_tree = cKDTree(syn[['pre_x', 'pre_y', 'pre_z']].values) pairs = pre_tree.query_pairs(r=dupl_thresh, output_type='ndarray') same_pre = syn.iloc[pairs[:, 0]].segmentid_pre.values == syn.iloc[ pairs[:, 1]].segmentid_pre.values same_post = syn.iloc[pairs[:, 0]].segmentid_post.values == syn.iloc[ pairs[:, 1]].segmentid_post.values same_cn = same_pre & same_post # If no more pairs to collapse break if same_cn.sum() == 0: break # Generate a graph from pairs G = nx.Graph() G.add_edges_from(pairs[same_cn]) # Find the minimum number of nodes we need to remove # to separate the connectors to_rm = [] for cn in nx.connected_components(G): to_rm += list(nx.minimum_node_cut(nx.subgraph(G, cn))) syn = syn.drop(index=syn.index.values[to_rm]) if agglomerate: edges = syn.groupby(['id_pre', 'id_post'], as_index=False).cleft_id.count() edges.rename({'cleft_id': 'weight'}, axis=1, inplace=True) edges.sort_values('weight', ascending=False, inplace=True) return edges.reset_index(drop=True) return syn
def get_neuron_synapses(x, pre=True, post=True, collapse_connectors=False, score_thresh=30, ol_thresh=2, dist_thresh=1000, attach=True, drop_autapses=True, drop_duplicates=True, db=None, verbose=True, ret='catmaid', progress=True): """Fetch synapses for a given neuron. Works by: 1. Fetch segment IDs corresponding to a neuron 2. Fetch Buhmann et al. synapses associated with these segment IDs 3. Do some clean-up. See parameters for details. Notes ----- 1. If ``collapse_connectors=False`` the ``connector_id`` column is a simple enumeration and is effectively meaningless. 2. The x/y/z coordinates always correspond to the presynaptic site (like with CATMAID connectors). These columns are renamed from "pre_{x|y|z}" in the database. 3. Some of the clean-up assumes that the query neurons are unique neurons and not fragments of the same neuron. If that is the case, you might be better of running this function for each fragment individually. Parameters ---------- x : navis.Neuron | pymaid.CatmaidNeuron | NeuronList Neurons to fetch synapses for. pre/post : bool Whether to fetch pre- and/or postsynapses. collapse_connectors : bool If True, we will pool presynaptic connections into CATMAID-like connectors. If False each row represents a connections. ol_thresh : int, optional If provided, will required a minimum node overlap between a neuron and a segment for that segment to be included (see step 1 above). score_thresh : int, optional If provided will only return synapses with a cleft score of this or higher. dist_thresh : int Synapses further away than the given distance [nm] from any node in the neuron than given distance will be discarded. This is always with respect to the cleft site irrespective of whether our neuron is pre- or postsynaptic to it. attach : bool If True, will attach synapses as `.connectors` to neurons. If False, will return DataFrame with synapses. drop_autapses : bool If True, will drop synapses where pre- and postsynapse point to the same neuron. Autapses are wrong in 99.9% of all cases we've seen. drop_duplicates : bool If True, will merge synapses which connect the same pair of pre-post segmentation IDs and are within less than 2500nm. db : str, optional Must point to SQL database containing the synapse data. If not provided will look for a `BUHMANN_SYNAPSE_DB` environment variable. ret : "catmaid" | "brief" | "full" If "full" will return all synapse properties. If "brief" will return more relevant subset. If "catmaid" will return only CATMAID-like columns. progress : bool Whether to show progress bars or not. Return ------ pd.DataFrame Only if ``attach=False``. """ # This is just so we throw a DB exception early and not wait for fetching # segment Ids first _ = get_connection(db) assert isinstance(x, (navis.BaseNeuron, navis.NeuronList)) assert ret in ("catmaid", "brief", "full") if not isinstance(x, navis.NeuronList): x = navis.NeuronList([x]) # Get segments for this neuron(s) seg_ids = google.neuron_to_segments(x) # Drop segments with overlap below threshold if ol_thresh: seg_ids = seg_ids.loc[seg_ids.max(axis=1) >= ol_thresh] # We will make sure that every segment ID is only attributed to a single # neuron not_top = seg_ids.values != seg_ids.max(axis=1).values.reshape( (seg_ids.shape[0], 1)) where_not_top = np.where(not_top) if np.any(where_not_top[0]): seg_ids.values[where_not_top] = None # do not change this to 0 # Fetch pre- and postsynapses associated with these segments # It's cheaper to get them all in one go syn = query_synapses(seg_ids.index.values, score_thresh=score_thresh, ret=ret if ret != 'catmaid' else 'brief', db=None) # Drop autapses - they are most likely wrong if drop_autapses: syn = syn[syn.segmentid_pre != syn.segmentid_post] if drop_duplicates: dupl_thresh = 250 # Deal with pre- and postsynapses separately while True: # Generate pairs of pre- and postsynaptic coordinates that are # suspiciously close pre_tree = cKDTree(syn[['pre_x', 'pre_y', 'pre_z']].values) post_tree = cKDTree(syn[['post_x', 'post_y', 'post_z']].values) pre_pairs = pre_tree.query_pairs(r=dupl_thresh) post_pairs = post_tree.query_pairs(r=dupl_thresh) # We will consider pairs for removal where both pre- OR postsynapse # are close - this is easy to change via below operator pairs = pre_pairs | post_pairs # union of both pairs = np.array(list(pairs)) # For each pair check if they connect the same IDs same_pre = syn.iloc[pairs[:, 0]].segmentid_pre.values == syn.iloc[ pairs[:, 1]].segmentid_pre.values same_post = syn.iloc[pairs[:, 0]].segmentid_post.values == syn.iloc[ pairs[:, 1]].segmentid_post.values same_cn = same_pre & same_post # If no more pairs to collapse break if same_cn.sum() == 0: break # Generate a graph from pairs G = nx.Graph() G.add_edges_from(pairs[same_cn]) # Find the minimum number of nodes we need to remove # to separate the connectors to_rm = [] for cn in nx.connected_components(G): to_rm += list(nx.minimum_node_cut(nx.subgraph(G, cn))) # Drop those nodes syn = syn.drop(index=syn.index.values[to_rm]) # Reset index syn.reset_index(drop=True, inplace=True) if collapse_connectors: assign_connectors(syn) else: # Make fake IDs syn['connector_id'] = np.arange(syn.shape[0]).astype(np.int32) # Now associate synapses with neurons tables = [] for c in tqdm(seg_ids.columns, desc='Proc. neurons', disable=not progress or seg_ids.shape[1] == 1, leave=False): this_segs = seg_ids.loc[seg_ids[c].notnull(), c] is_pre = syn.segmentid_pre.isin(this_segs.index.values) is_post = syn.segmentid_post.isin(this_segs.index.values) # At this point we might see the exact same connection showing up in # `is_pre` and in `is_post`. This happens when we mapped both the # pre- and the postsynaptic segment to this neuron - likely an error. # In these cases we have to decide whether our neuron is truely pre- # or postsynaptic. For this we will use the overlap counts: # First find connections that would show up twice is_dupl = is_pre & is_post if any(is_dupl): dupl = syn[is_dupl] # Next get the overlap counts for the pre- and postsynaptic seg IDs dupl_pre_ol = seg_ids.loc[dupl.segmentid_pre, c].values dupl_post_ol = seg_ids.loc[dupl.segmentid_post, c].values # We go for the one with more overlap true_pre = dupl_pre_ol > dupl_post_ol # Propagate that decision is_pre[is_dupl] = true_pre is_post[is_dupl] = ~true_pre # Now get our synapses this_pre = syn[is_pre] this_post = syn[is_post] # Keep only one connector per presynapse # -> just like in CATMAID connector tables # Postsynaptic connectors will still show up multiple times if collapse_connectors: this_pre = this_pre.drop_duplicates('connector_id') # Combine pre- and postsynapses and keep track of the type connectors = pd.concat([this_pre, this_post], axis=0).reset_index(drop=True) connectors['type'] = 'post' connectors.iloc[:this_pre.shape[0], connectors.columns.get_loc('type')] = 'pre' connectors['type'] = connectors['type'].astype('category') # Rename columns such that x/y/z corresponds to presynaptic sites connectors.rename({ 'pre_x': 'x', 'pre_y': 'y', 'pre_z': 'z' }, axis=1, inplace=True) # For CATMAID-like connector tables subset to relevant columns if ret == 'catmaid': connectors = connectors[[ 'connector_id', 'x', 'y', 'z', 'cleft_scores', 'type' ]].copy() # Map connectors to nodes # Note that this is where we enforce `dist_thresh` neuron = x.idx[c] tree = navis.neuron2KDTree(neuron) dist, ix = tree.query(connectors[['x', 'y', 'z']].values, distance_upper_bound=dist_thresh) # Drop far away connectors connectors = connectors.loc[dist < np.inf] # Assign node IDs connectors['node_id'] = neuron.nodes.iloc[ix[ dist < np.inf]].node_id.values # Somas can end up having synapses, which we know is wrong and is # relatively easy to fix if np.any(neuron.soma): somata = navis.utils.make_iterable(neuron.soma) s_locs = neuron.nodes.loc[neuron.nodes.node_id.isin(somata), ['x', 'y', 'z']].values # Find all nodes within 2 micron around the somas soma_node_ix = tree.query_ball_point(s_locs, r=2000) soma_node_ix = [n for l in soma_node_ix for n in l] soma_node_id = neuron.nodes.iloc[soma_node_ix].node_id.values # Drop connectors attached to these soma nodes connectors = connectors[~connectors.node_id.isin(soma_node_id)] if attach: neuron.connectors = connectors.reset_index(drop=True) else: connectors['neuron'] = neuron.id # do NOT change the type of this tables.append(connectors) if not attach: connectors = pd.concat(tables, axis=0, sort=True).reset_index(drop=True) return connectors
def skeletonize_neuron(x, shave_skeleton=True, remove_soma_hairball=False, assert_id_match=False, dataset='production', progress=True, **kwargs): """Skeletonize FlyWire neuron. Note that this is optimized to be primarily fast which comes at the cost of (some) quality. Parameters ---------- x : int | trimesh.TriMesh | list thereof ID(s) or trimesh of the FlyWire neuron(s) you want to skeletonize. shave_skeleton : bool If True, we will "shave" the skeleton by removing all single-node terminal twigs. This should get rid of hairs on the backbone that can occur if the neurites are very big. remove_soma_hairball : bool If True, we will try to drop the hairball that is typically created inside the soma. Note that while this should work just fine for 99% of neurons, it's not very smart and there is always a chance that we remove stuff that should not have been removed. Also only works if the neuron has a recognizable soma. assert_id_match : bool If True, will check if skeleton nodes map to the correct segment ID and if not will move them back into the segment. This is potentially very slow! dataset : str | CloudVolume Against which FlyWire dataset to query:: - "production" (current production dataset, fly_v31) - "sandbox" (i.e. fly_v26) progress : bool Whether to show a progress bar or not. Return ------ skeleton : navis.TreeNeuron The extracted skeleton. See Also -------- :func:`fafbseg.flywire.skeletonize_neuron_parallel` Use this if you want to skeletonize many neurons in parallel. Examples -------- >>> from fafbseg import flywire >>> n = flywire.skeletonize_neuron(720575940614131061) """ if int(sk.__version__.split('.')[0]) < 1: raise ImportError('Please update skeletor to version >= 1.0.0: ' 'pip3 install skeletor -U') vol = parse_volume(dataset) if navis.utils.is_iterable(x): return navis.NeuronList([ skeletonize_neuron(n, progress=False, remove_soma_hairball=remove_soma_hairball, assert_id_match=assert_id_match, dataset=dataset, **kwargs) for n in navis.config.tqdm( x, desc='Skeletonizing', disable=not progress, leave=False) ]) if not navis.utils.is_mesh(x): vol = parse_volume(dataset) # Make sure this is a valid integer id = int(x) # Download the mesh mesh = vol.mesh.get(id, deduplicate_chunk_boundaries=False, remove_duplicate_vertices=True)[id] else: mesh = x id = getattr(mesh, 'segid', 0) mesh = sk.utilities.make_trimesh(mesh, validate=False) # Fix things before we skeletonize # This also drops fluff mesh = sk.pre.fix_mesh(mesh, inplace=True, remove_disconnected=100) # Skeletonize defaults = dict(waves=1, step_size=1) defaults.update(kwargs) s = sk.skeletonize.by_wavefront(mesh, progress=progress, **defaults) # Skeletor indexes node IDs at zero but to avoid potential issues we want # node IDs to start at 1 s.swc['node_id'] += 1 s.swc.loc[s.swc.parent_id >= 0, 'parent_id'] += 1 # We will also round the radius and make it an integer to save some # memory. We could do the same with x/y/z coordinates but that could # potentially move nodes outside the mesh s.swc['radius'] = s.swc.radius.round().astype(int) # Turn into a neuron tn = navis.TreeNeuron(s.swc, units='1 nm', id=id, soma=None) if shave_skeleton: # Get branch points bp = tn.nodes.loc[tn.nodes.type == 'branch', 'node_id'].values # Get single-node twigs is_end = tn.nodes.type == 'end' parent_is_bp = tn.nodes.parent_id.isin(bp) twigs = tn.nodes.loc[is_end & parent_is_bp, 'node_id'].values # Drop terminal twigs tn._nodes = tn.nodes.loc[~tn.nodes.node_id.isin(twigs)].copy() tn._clear_temp_attr() # See if we can find a soma soma = detect_soma_skeleton(tn, min_rad=800, N=3) if soma: tn.soma = soma # Reroot to soma tn.reroot(tn.soma, inplace=True) if remove_soma_hairball: soma = tn.nodes.set_index('node_id').loc[soma] soma_loc = soma[['x', 'y', 'z']].values # Find all nodes within 2x the soma radius tree = navis.neuron2KDTree(tn) ix = tree.query_ball_point(soma_loc, max(4000, soma.radius * 2)) # Translate indices into node IDs ids = tn.nodes.iloc[ix].node_id.values # Find segments that contain these nodes segs = [s for s in tn.segments if any(np.isin(ids, s))] # Sort segs by length segs = sorted(segs, key=lambda x: len(x)) # Keep only the longest segment in that initial list to_drop = np.array([n for s in segs[:-1] for n in s]) to_drop = to_drop[~np.isin(to_drop, segs[-1] + [soma.name])] navis.remove_nodes(tn, to_drop, inplace=True) if assert_id_match: if id == 0: raise ValueError('Segmentation ID must not be 0') new_locs = snap_to_id(tn.nodes[['x', 'y', 'z']].values, id=id, snap_zero=False, dataset=dataset, search_radius=160, coordinates='nm', max_workers=4, verbose=True) tn.nodes[['x', 'y', 'z']] = new_locs return tn
def collapse_nodes(A, B, limit=1, base_neuron=None, mesh=None): """Merge neuron A into neuron(s) B creating a union of both. This implementation uses edge contraction on the neurons' graph to ensure maximum connectivity. Only works if the fragments collectively form a continuous tree (i.e. you must be certain that they partially overlap). Parameters ---------- A : CatmaidNeuron Neuron to be collapsed into neurons B. B : CatmaidNeuronList Neurons to collapse neuron A into. limit : int, optional Max distance [microns] for nearest neighbour search. base_neuron : skeleton ID | CatmaidNeuron, optional Neuron from B to use as template for union. If not provided, the first neuron in the list is used as template! mesh : navis.Volume | navis.MeshNeuron | mesh-like object If provided, will use the mesh to check if nodes are in line of sight to each other before collapsing them. Returns ------- core.CatmaidNeuron Union of all input neurons. new_edges : pandas.DataFrame Subset of the ``.nodes`` table that represent newly added edges. collapsed_nodes : dict Map of collapsed nodes:: NodeA -collapsed-into-> NodeB """ if isinstance(A, navis.NeuronList): if len(A) == 1: A = A[0] else: A = navis.stitch_neurons(A, method="NONE") if not isinstance(A, navis.TreeNeuron): raise TypeError('`A` must be a TreeNeuron, got "{}"'.format(type(A))) if isinstance(B, navis.TreeNeuron): B = navis.NeuronList(B) if not isinstance(B, navis.NeuronList): raise TypeError('`B` must be a NeuronList, got "{}"'.format(type(B))) if isinstance(base_neuron, type(None)): base_neuron = B[0] # This is just check on the off-chance that skeleton IDs are not unique # (e.g. if neurons come from different projects) -> this is relevant because # we identify the master ("base_neuron") via it's skeleton ID skids = [n.id for n in B + A] if len(skids) > len(set(skids)): raise ValueError( 'Duplicate skeleton IDs found. Try manually assigning ' 'unique skeleton IDs.') # Convert distance threshold from microns to nanometres limit *= 1000 # Before we start messing around, let's make sure we can keep track of # the origin of each node for n in B + A: n.nodes['origin_skeletons'] = n.id # First make a weak union by simply combining the node tables B.neurons = sorted(B.neurons, key=lambda x: {base_neuron: 0}.get(x, 2)) union_simple = navis.stitch_neurons(B + A, method='NONE', master='FIRST') # Check for duplicate node IDs if any(union_simple.nodes.node_id.duplicated()): raise ValueError('Duplicate node IDs found.') # Find nodes in A to be merged into B tree = scipy.spatial.cKDTree(data=B.nodes[['x', 'y', 'z']].values) # For each node in A get the nearest neighbor in B coords = A.nodes[['x', 'y', 'z']].values nn_dist, nn_ix = tree.query(coords, k=1, distance_upper_bound=limit) # Find nodes that are close enough to collapse collapsed = A.nodes.loc[nn_dist <= limit].node_id.values clps_into = B.nodes.iloc[nn_ix[nn_dist <= limit]].node_id.values # If we have a mesh, check if those collapsed nodes are in sight of each # other if mesh: import ncollpyde coll = ncollpyde.Volume(mesh.vertices, mesh.faces) # Produce start and end coordinates for the to collapse nodes starts = A.nodes.set_index('node_id').loc[collapsed, ['x', 'y', 'z']].values ends = B.nodes.set_index('node_id').loc[clps_into, ['x', 'y', 'z']].values # Check if the line between start and end intersects the mesh intersects, _, _ = coll.intersections(starts, ends) # Indices that show up in `intersects` cross a membrane # -> we need to invert this to get those nodes that don't not_intersects = ~np.isin(np.arange(starts.shape[0]), intersects) # Keep only collapses that don't intersect collapsed = collapsed[not_intersects] clps_into = clps_into[not_intersects] # Generate a map of which node in A is to be collapsed into which node in B clps_map = {n1: n2 for n1, n2 in zip(collapsed, clps_into)} # The fastest way to collapse is to work on the edge list E = nx.to_pandas_edgelist(union_simple.graph) # Keep track of which edges were collapsed -> we will use this as weight # later on to prioritize existing edges over newly generated ones E['is_new'] = 1 source_in_B = E.source.isin(B.nodes.node_id.values) target_in_B = E.target.isin(B.nodes.node_id.values) E.loc[source_in_B | target_in_B, 'is_new'] = 0 # Now map collapsed nodes onto the nodes they collapsed into E['target'] = E.target.map(lambda x: clps_map.get(x, x)) E['source'] = E.source.map(lambda x: clps_map.get(x, x)) # Make sure no self loops after collapsing. This happens if two adjacent # nodes collapse onto the same target node E = E[E.source != E.target] # Remove duplicates. This happens e.g. when two adjaceny nodes merge into # two other adjaceny nodes: A->B C->D ----> A/B->C/D # By sorting first, we make sure original edges are kept first E.sort_values('is_new', ascending=True, inplace=True) # Because edges may exist in both directions (A->B and A<-B) we have to # generate a column that's agnostic to directionality using frozensets E['frozen_edge'] = E[['source', 'target']].apply(frozenset, axis=1) E.drop_duplicates(['frozen_edge'], keep='first', inplace=True) # Regenerate graph from these new edges G = nx.Graph() G.add_weighted_edges_from(E[['source', 'target', 'is_new']].values.astype(int)) # At this point there might still be disconnected pieces -> we will create # separate neurons for each tree props = union_simple.nodes.loc[union_simple.nodes.node_id.isin( G.nodes)].set_index('node_id') nx.set_node_attributes(G, props.to_dict(orient='index')) fragments = [] for n in nx.connected_components(G): c = G.subgraph(n) tree = nx.minimum_spanning_tree(c) fragments.append( navis.graph.nx2neuron(tree, name=base_neuron.name, id=base_neuron.id)) fragments = navis.NeuronList(fragments) if len(fragments) > 1: print('Union incomplete - watch out for disconnected fragments!') # Now heal those fragments using a minimum spanning tree union = navis.stitch_neurons(*fragments, method='ALL') else: union = fragments[0] # Reroot to base neuron's root union.reroot(base_neuron.root[0], inplace=True) # Add tags back on if union_simple.has_tags: if not union.has_tags: union.tags = {} union.tags.update(union_simple.tags) # Add connectors back on union.connectors = union_simple.connectors.drop_duplicates( subset='connector_id').copy() union.connectors.loc[:, 'node_id'] = union.connectors.node_id.map( lambda x: clps_map.get(x, x)) # Find the newly added edges (existing edges should not have been modified # - except for changing direction due to reroot) # The basic logic here is that new edges were only added between two # previously separate skeletons, i.e. where the skeleton ID changes between # parent and child node node2skid = union_simple.nodes.set_index( 'node_id').origin_skeletons.to_dict() union.nodes['parent_skeleton'] = union.nodes.parent_id.map(node2skid) new_edges = union.nodes[ union.nodes.origin_skeletons != union.nodes.parent_skeleton] # Remove root edges new_edges = new_edges[~new_edges.parent_id.isnull()] return union, new_edges, clps_map