def swc_from_data(self, ims_data): def get_head_ims(segments): heads = [] for i, segment in enumerate(segments): if not segment[0] in list(segments[:, 1]): heads.append(segments[i][0]) return np.unique(np.array(heads)) gen_node_attr = lambda verts, col: { index: vert[col] for index, vert in enumerate(verts) } segments = np.array( ims_data['Scene']["Content"]['Filaments0']['Graphs']['Segments']) vertices = np.array( ims_data['Scene']["Content"]['Filaments0']['Graphs']['Vertices']) swc = nx.DiGraph() swc.add_edges_from(segments) swc.add_nodes_from(np.arange(len(vertices))) head = get_head_ims(segments) if len(head) == 1: head = head[0] swc = nx.dfs_tree(swc, head) attrs = ['x', 'y', 'z', 'radius', 'label'] for i, attr in enumerate(attrs): attr_dict = gen_node_attr(vertices, i) nx.set_node_attributes(swc, attr_dict, attr) return navis.TreeNeuron(swc)
def pymaid_to_navis(x: Union[pymaid.core.CatmaidNeuron, pymaid.core.CatmaidNeuronList]): """ Takes pymaid/CatmaidNeuron and topologically sorts the nodes Paramters --------- x: a pymaid/Catmaid neuron object return_skdata: bool whether to return a list of node_ids topologically sorted or to return a dict where the keys are treenodes and the values are ranking in the topological sort Returns -------- x: list or dict Examples -------- """ x = check_valid_pymaid_input(x) x.nodes["rank"] = x.nodes.treenode_id.map( pymaid_topological_sort(x, return_object="dict")) x.nodes.sort_values(by=["rank"], ascending=True, inplace=True) # Getting the topological sort of the pymaid neuron x_graph = x.graph navis_neuron = navis.TreeNeuron(x_graph) # Populating the xyz columns of nodes for i, j in enumerate(x.nodes[["x", "y", "z"]].values): navis_neuron.nodes.loc[i, "x"] = j[0] navis_neuron.nodes.loc[i, "y"] = j[1] navis_neuron.nodes.loc[i, "z"] = j[2] # adding type column navis_neuron.nodes["type"] = x.nodes.type.copy() # adding connectors navis_neuron.connectors = x.connectors.copy() navis_neuron.connectors = [ "pre" if i == 0 else "post" for i in navis_neuron.connectors.type ] # adding soma & name navis_neuron.soma = x.soma navis_neuron.name = x.neuron_name return navis_neuron
def l2_skeleton(root_id, refine=False, drop_missing=True, threads=10, progress=True, dataset='production', **kwargs): """Generate skeleton from L2 graph. Parameters ---------- root_id : int | list of ints Root ID(s) of the flywire neuron(s) you want to skeletonize. refine : bool If True, will refine skeleton nodes by moving them in the center of their corresponding chunk meshes. Only relevant if ``refine=True``: drop_missing : bool If True, will drop nodes that don't have a corresponding chunk mesh. These are typically chunks that are very small and dropping them might actually be benefitial. threads : int How many parallel threads to use for fetching the chunk meshes. Reduce the number if you run into ``HTTPErrors``. Only relevant if `use_flycache=False`. progress : bool Whether to show a progress bar. Returns ------- skeleton : navis.TreeNeuron The extracted skeleton. Examples -------- >>> from fafbseg import flywire >>> n = flywire.l2_skeleton(720575940614131061) """ # TODO: # - drop duplicate nodes in unrefined skeleton # - use L2 graph to find soma: highest degree is typically the soma use_flycache = kwargs.get('use_flycache', False) if refine and use_flycache and dataset != 'production': raise ValueError('Unable to use fly cache to fetch L2 centroids for ' 'sandbox dataset. Please set `use_flycache=False`.') if navis.utils.is_iterable(root_id): nl = [] for id in navis.config.tqdm(root_id, desc='Skeletonizing', disable=not progress, leave=False): n = l2_skeleton(id, refine=refine, drop_missing=drop_missing, threads=threads, progress=progress, dataset=dataset) nl.append(n) return navis.NeuronList(nl) # Get the cloudvolume vol = parse_volume(dataset) # Hard-coded datastack names ds = {"production": "flywire_fafb_production", "sandbox": "flywire_fafb_sandbox"} # Note that the default server url is https://global.daf-apis.com/info/ client = FrameworkClient(ds.get(dataset, dataset)) # Load the L2 graph for given root ID # This is a (N,2) array of edges l2_eg = np.array(client.chunkedgraph.level2_chunk_graph(root_id)) # Drop duplicate edges l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0) # Unique L2 IDs l2_ids = np.unique(l2_eg) # ID to index l2dict = {l2: ii for ii, l2 in enumerate(l2_ids)} # Remap edge graph to indices eg_arr_rm = fastremap.remap(l2_eg, l2dict) coords = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids] coords = np.vstack(coords) # This turns the graph into a hierarchal tree by removing cycles and # ensuring all edges point towards a root if sk.__version_vector__[0] < 1: G = sk.skeletonizers.edges_to_graph(eg_arr_rm) swc = sk.skeletonizers.make_swc(G, coords=coords) else: G = sk.skeletonize.utils.edges_to_graph(eg_arr_rm) swc = sk.skeletonize.utils.make_swc(G, coords=coords, reindex=False) # Convert to Euclidian space # Dimension of a single chunk ch_dims = chunks_to_nm([1, 1, 1], vol) - chunks_to_nm([0, 0, 0], vol) ch_dims = np.squeeze(ch_dims) xyz = swc[['x', 'y', 'z']].values swc[['x', 'y', 'z']] = chunks_to_nm(xyz, vol) + ch_dims / 2 if refine: if use_flycache: token = get_chunkedgraph_secret() centroids = spine.flycache.get_L2_centroids(l2_ids, token=token, progress=progress) # Drop missing (i.e. [0,0,0]) meshes centroids = {k: v for k, v in centroids.items() if v != [0, 0, 0]} else: # Get the centroids centroids = get_L2_centroids(l2_ids, vol, threads=threads, progress=progress) new_co = {l2dict[k]: v for k, v in centroids.items()} # Map refined coordinates onto the SWC has_new = swc.node_id.isin(new_co) swc.loc[has_new, 'x'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][0]) swc.loc[has_new, 'y'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][1]) swc.loc[has_new, 'z'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][2]) # Turn into a proper neuron tn = navis.TreeNeuron(swc, id=root_id, units='1 nm') # Drop nodes that are still at their unrefined chunk position if drop_missing: tn = navis.remove_nodes(tn, swc.loc[~has_new, 'node_id'].values) else: tn = navis.TreeNeuron(swc, id=root_id, units='1 nm') return tn
def skeletonize_neuron(x, shave_skeleton=True, remove_soma_hairball=False, assert_id_match=False, dataset='production', progress=True, **kwargs): """Skeletonize FlyWire neuron. Note that this is optimized to be primarily fast which comes at the cost of (some) quality. Parameters ---------- x : int | trimesh.TriMesh | list thereof ID(s) or trimesh of the FlyWire neuron(s) you want to skeletonize. shave_skeleton : bool If True, we will "shave" the skeleton by removing all single-node terminal twigs. This should get rid of hairs on the backbone that can occur if the neurites are very big. remove_soma_hairball : bool If True, we will try to drop the hairball that is typically created inside the soma. Note that while this should work just fine for 99% of neurons, it's not very smart and there is always a chance that we remove stuff that should not have been removed. Also only works if the neuron has a recognizable soma. assert_id_match : bool If True, will check if skeleton nodes map to the correct segment ID and if not will move them back into the segment. This is potentially very slow! dataset : str | CloudVolume Against which FlyWire dataset to query:: - "production" (current production dataset, fly_v31) - "sandbox" (i.e. fly_v26) progress : bool Whether to show a progress bar or not. Return ------ skeleton : navis.TreeNeuron The extracted skeleton. See Also -------- :func:`fafbseg.flywire.skeletonize_neuron_parallel` Use this if you want to skeletonize many neurons in parallel. Examples -------- >>> from fafbseg import flywire >>> n = flywire.skeletonize_neuron(720575940614131061) """ if int(sk.__version__.split('.')[0]) < 1: raise ImportError('Please update skeletor to version >= 1.0.0: ' 'pip3 install skeletor -U') vol = parse_volume(dataset) if navis.utils.is_iterable(x): return navis.NeuronList([ skeletonize_neuron(n, progress=False, remove_soma_hairball=remove_soma_hairball, assert_id_match=assert_id_match, dataset=dataset, **kwargs) for n in navis.config.tqdm( x, desc='Skeletonizing', disable=not progress, leave=False) ]) if not navis.utils.is_mesh(x): vol = parse_volume(dataset) # Make sure this is a valid integer id = int(x) # Download the mesh mesh = vol.mesh.get(id, deduplicate_chunk_boundaries=False, remove_duplicate_vertices=True)[id] else: mesh = x id = getattr(mesh, 'segid', 0) mesh = sk.utilities.make_trimesh(mesh, validate=False) # Fix things before we skeletonize # This also drops fluff mesh = sk.pre.fix_mesh(mesh, inplace=True, remove_disconnected=100) # Skeletonize defaults = dict(waves=1, step_size=1) defaults.update(kwargs) s = sk.skeletonize.by_wavefront(mesh, progress=progress, **defaults) # Skeletor indexes node IDs at zero but to avoid potential issues we want # node IDs to start at 1 s.swc['node_id'] += 1 s.swc.loc[s.swc.parent_id >= 0, 'parent_id'] += 1 # We will also round the radius and make it an integer to save some # memory. We could do the same with x/y/z coordinates but that could # potentially move nodes outside the mesh s.swc['radius'] = s.swc.radius.round().astype(int) # Turn into a neuron tn = navis.TreeNeuron(s.swc, units='1 nm', id=id, soma=None) if shave_skeleton: # Get branch points bp = tn.nodes.loc[tn.nodes.type == 'branch', 'node_id'].values # Get single-node twigs is_end = tn.nodes.type == 'end' parent_is_bp = tn.nodes.parent_id.isin(bp) twigs = tn.nodes.loc[is_end & parent_is_bp, 'node_id'].values # Drop terminal twigs tn._nodes = tn.nodes.loc[~tn.nodes.node_id.isin(twigs)].copy() tn._clear_temp_attr() # See if we can find a soma soma = detect_soma_skeleton(tn, min_rad=800, N=3) if soma: tn.soma = soma # Reroot to soma tn.reroot(tn.soma, inplace=True) if remove_soma_hairball: soma = tn.nodes.set_index('node_id').loc[soma] soma_loc = soma[['x', 'y', 'z']].values # Find all nodes within 2x the soma radius tree = navis.neuron2KDTree(tn) ix = tree.query_ball_point(soma_loc, max(4000, soma.radius * 2)) # Translate indices into node IDs ids = tn.nodes.iloc[ix].node_id.values # Find segments that contain these nodes segs = [s for s in tn.segments if any(np.isin(ids, s))] # Sort segs by length segs = sorted(segs, key=lambda x: len(x)) # Keep only the longest segment in that initial list to_drop = np.array([n for s in segs[:-1] for n in s]) to_drop = to_drop[~np.isin(to_drop, segs[-1] + [soma.name])] navis.remove_nodes(tn, to_drop, inplace=True) if assert_id_match: if id == 0: raise ValueError('Segmentation ID must not be 0') new_locs = snap_to_id(tn.nodes[['x', 'y', 'z']].values, id=id, snap_zero=False, dataset=dataset, search_radius=160, coordinates='nm', max_workers=4, verbose=True) tn.nodes[['x', 'y', 'z']] = new_locs return tn
def skeletonize_neuron(x, drop_soma_hairball=True, contraction_kws={}, skeletonization_kws={}, radius_kws={}, assert_id_match=False, dataset='production'): """Skeletonize flywire neuron. Parameters ---------- x : int | trimesh.TriMesh ID or trimesh of the flywire neuron you want to skeletonize. drop_soma_hairball : bool If True, we will try to drop the hairball that is typically created inside the soma. contraction_kws : dict Optional parameters for the contraction phase. See ``skeletor.contract``. skeletonization_kws : dict Optional parameters for the skeletonization phase. See ``skeletor.skeletonize``. radius_kws : dict Optional parameters for the radius extraction phase. See ``skeletor.radius``. assert_id_match : bool If True, will check if skeleton nodes map to the correct segment ID and if not will move them back into the segment. This is potentially very slow! dataset : str | CloudVolume Against which flywire dataset to query:: - "production" (current production dataset, fly_v31) - "sandbox" (i.e. fly_v26) Return ------ skeleton, simpified_mesh, contracted_mesh The extraced skeleton, simplified and contracted mesh, respectively. Examples -------- >>> tn, simp, cntr = flywire.skeletonize_flywire_neuron(720575940614131061) """ if not sk: raise ImportError('Must install skeletor: pip3 install skeletor') if not navis.utils.is_mesh(x): vol = parse_volume(dataset) # Make sure this is a valid integer id = int(x) # Download the mesh mesh = vol.mesh.get(id, deduplicate_chunk_boundaries=False)[id] else: mesh = x id = getattr(mesh, 'segid', 0) # Simplify simp = sk.simplify(mesh, ratio=.2) # Validate before we detect the soma verts simp = sk.utilities.fix_mesh(simp, inplace=True) # Try detecting the soma if drop_soma_hairball: soma_verts = detect_soma(simp) # Contract defaults = dict(SL=40, WH0=2, epsilon=0.1, precision=1e-7, validate=False) defaults.update(contraction_kws) cntr = sk.contract(simp, **defaults) # Generate skeleton defaults = dict(method='vertex_clusters', sampling_dist=200, vertex_map=True, validate=False) defaults.update(skeletonization_kws) swc = sk.skeletonize(cntr, **defaults) # Clean up cleaned = sk.clean(swc, mesh=mesh, validate=False) # Extract radii defaults = dict(validate=False) defaults.update(radius_kws) cleaned['radius'] = sk.radii(cleaned, mesh=mesh, **defaults) # Convert to neuron n = navis.TreeNeuron(cleaned, id=id, units='nm', soma=None) # Drop any nodes that are soma vertices if drop_soma_hairball and soma_verts.shape[0] > 0: keep = n.nodes.loc[~n.nodes.vertex_id.isin(soma_verts), 'node_id'].values n = navis.subset_neuron(n, keep) if assert_id_match: if id == 0: raise ValueError('Segmentation ID must not be 0') new_locs = snap_to_id(n.nodes[['x', 'y', 'z']].values, id=id, snap_zero=False, dataset=dataset, search_radius=160, coordinates='nm', max_workers=4, verbose=True) n.nodes[['x', 'y', 'z']] = new_locs return (n, simp, cntr)