def matching_inputs_to_compartments( neuron_id: int, roi: navis.Volume, Ra: float, Rm: float, Cm: float): # Fetch the healed skeleton full_skeleton = nvneu.fetch_skeletons(neuron_id, heal=True)[0] # which of the whole neuron's nodes are in the roi skeleton_in_roi = navis.in_volume(full_skeleton.nodes[['x', 'y', 'z']], roi, mode='IN') # Fetch the neurons synapses postsyn = nvneu.fetch_synapse_connections(target_criteria=neuron_id) # match the connectors to nodes syn_to_node = match_connectors_to_nodes(postsyn, full_skeleton, synapse_type='post') # which of these are in the roi roi_syn_con = syn_to_node[syn_to_node.node.isin(full_skeleton.nodes[skeleton_in_roi].node_id.tolist())].copy() # Count how many synapses the upstream neurons have on your neuron of interest n_syn = dict(Counter(roi_syn_con.bodyId_pre.tolist()).most_common()) # fetch these neurons a, b = nvneu.fetch_neurons(roi_syn_con.bodyId_pre.unique()) a['n_syn'] = [n_syn[i] for i in a.bodyId.tolist()] # adding the instance names to the synapses in the roi bid_to_instance = dict(zip(a.bodyId.tolist(), a.instance.tolist())) roi_syn_con['instance'] = [bid_to_instance[i] for i in roi_syn_con.bodyId_pre] # compartmentalise the neuron and find the nodes in each compartment for the prepared neuron compartments_in_roi, nodes_in_roi = find_compartments_in_roi(full_skeleton, roi=roi, min_samples=6, Cm=Cm, Ra=Ra, Rm=Rm) clusters = compartments_in_roi.nodes.node_cluster.unique() cluster_dict = {} # find the nodes that make up each compartment in the full neuron for i in clusters: clust_nodes = cluster_to_all_nodes(full_skeleton, start_end_node_pairs=permute_start_end(compartments_in_roi, nodes_in_roi, cluster=i)) cluster_dict[i] = clust_nodes # cluster_dict = {k : s for s, k in cluster_dict.items()} # roi_syn_con['compartment'] = [cluster_dict[i] for i in roi_syn_con.node.tolist()] return(cluster_dict)
def find_compartments_in_roi(neuron, Rm, Ra, Cm, roi, min_samples): n = neuron.copy() n_prep = prepare_neuron(n, change_units=True, factor=1e3) n_m, n_memcap = calculate_M_mat(n_prep, Rm, Ra, Cm, solve=False) n_m_solved = np.linalg.inv(sparse.csr_matrix.todense(n_m)) phate_operator = phate.PHATE(n_components=3, n_jobs=-2, verbose=False) mat_red, labels, n_clusters, n_noise = find_clusters(n_m_solved, phate_operator, eps=1e-02, min_samples=min_samples) node_to_label = dict(zip(range(0, len(labels)), labels)) n_prep.nodes['node_cluster'] = n_prep.nodes.index.map(node_to_label) n_in_roi = navis.in_volume(n_prep.nodes[['x', 'y', 'z']] * 1e3, roi, mode='IN', inplace=False) return(n_prep, n_in_roi)
def interactive_dendrogram( z: Union[navis.TreeNeuron, navis.NeuronList], heal_neuron: bool = True, plot_nodes: bool = True, plot_connectors: bool = True, highlight_connectors: Optional = None, in_volume: Optional = None, prog: str = "dot", inscreen: bool = True, filename: Optional = None, ): """ Takes a navis neuron and returns an interactive 2D dendrogram. In this dendrogram, nodes or connector locations can be highlighted Parameters ---------- x: A navis neuron object heal_neuron: bool Whether you want to heal the neuron or not. N.B. Navis neurons should be healed on import, i.e. navis.fetch_skeletons(bodyid, heal = True) see navis.fetch_skeletons and navis.heal_fragmented_neuron for more details plot_nodes: bool Whether treenodes should be plotted plot_connectors: bool Whether connectors should be plotted highlight_connectors: dict A dictionary containing the treenodes of the connectors you want to highlight as keys and the colours you want to colour them as values. This allows for multiple colours to be plotted. N.B. Plotly colours are in the range of 0 - 255 whereas matplotlib colours are between 0-1. For the interactive dendrogram colours need to be in the plotly range, whereas in the static dendrogram the colours need to be in the matplotlib range. in_volume: navis.Volume object A navis.Volume object corresponding to an ROI in the brain. This will then highlight the nodes of the neuron which are in that volume prog: str The layout type used by navis.nx_agraph.graphviz_layout() Valid programs include [dot, neato or fdp]. The dot program provides a hierarchical layout, this is the fastest program The neato program creates edges between nodes proportional to their real length. The neato program takes the longest amount of time, can be ~2hrs for a single neuron! inscreen: bool Whether to plot the graph inscreen (juptyer notebooks) or to plot it as a separate HTML file that can be saved to file, opened in the browser and opened any time filename: str The filename of your interactive dendrogram html file. This parameter is only appropriate when inscreen = False. Returns ------- plotly.fig object containing the dendrogram - this can be either inscreen or as a separate html file with the filename specified by the filename parameter Examples -------- from neuroboom.utils import create_graph_structure from neuroboom.dendrogram import interactive_dendrogram import navis.interfaces.neuprint as nvneu from matplotlib import pyplot as plt test_neuron = nvneu.fetch_skeletons(722817260) interactive_dendrogram(test_neuron, prog = 'dot', inscreen = True) """ z = check_valid_neuron_input(z) if heal_neuron: z = navis.heal_fragmented_neuron(z) valid_progs = ["neato", "dot"] if prog not in valid_progs: raise ValueError("Unknown program parameter!") # save start time start = time.time() # Necessary for neato layouts for preservation of segment lengths if "parent_dist" not in z.nodes: z = calc_cable(z, return_skdata=True) # Generation of networkx diagram g, pos = create_graph_structure(z, returned_object="graph_and_positions", prog=prog) print("Now creating plotly graph...") # Convering networkx nodes for plotly # NODES x = [] y = [] node_info = [] if plot_nodes: for n in g.nodes(): x_, y_ = pos[n] x.append(x_) y.append(y_) node_info.append("Node ID: {}".format(n)) node_trace = go.Scatter( x=x, y=y, mode="markers", text=node_info, hoverinfo="text", marker=go.scatter.Marker(showscale=False), ) else: node_trace = go.Scatter() # EDGES xe = [] ye = [] for e in g.edges(): x0, y0 = pos[e[0]] x1, y1 = pos[e[1]] xe += [x0, x1, None] ye += [y0, y1, None] edge_trace = go.Scatter( x=xe, y=ye, line=go.scatter.Line(width=1.0, color="#000"), hoverinfo="none", mode="lines", ) # SOMA xs = [] ys = [] for n in g.nodes(): if n != z.soma: continue elif n == z.soma: x__, y__ = pos[n] xs.append(x__) ys.append(y__) else: break soma_trace = go.Scatter( x=xs, y=ys, mode="markers", hoverinfo="text", marker=dict(size=20, color="rgb(0,0,0)"), text="Soma, node:{}".format(z.soma), ) # CONNECTORS: # RELATION = 0 ARE PRESYNAPSES, RELATION = 1 ARE POSTSYNAPSES if plot_connectors is False: presynapse_connector_trace = go.Scatter() postsynapse_connector_trace = go.Scatter() elif plot_connectors is True: presynapse_connector_list = list( z.connectors[z.connectors.type == "pre"].node_id.values) x_pre = [] y_pre = [] presynapse_connector_info = [] for node in g.nodes(): for tn in presynapse_connector_list: if node == tn: x, y = pos[node] x_pre.append(x) y_pre.append(y) presynapse_connector_info.append( "Presynapse, connector_id: {}".format(tn)) presynapse_connector_trace = go.Scatter( x=x_pre, y=y_pre, text=presynapse_connector_info, mode="markers", hoverinfo="text", marker=dict(size=10, color="rgb(0,255,0)"), ) postsynapse_connectors_list = list( z.connectors[z.connectors.type == "post"].node_id.values) x_post = [] y_post = [] postsynapse_connector_info = [] for node in g.nodes(): for tn in postsynapse_connectors_list: if node == tn: x, y = pos[node] x_post.append(x) y_post.append(y) postsynapse_connector_info.append( "Postsynapse, connector id: {}".format(tn)) postsynapse_connector_trace = go.Scatter( x=x_post, y=y_post, text=postsynapse_connector_info, mode="markers", hoverinfo="text", marker=dict(size=10, color="rgb(0,0,255)"), ) if highlight_connectors is None: HC_trace = go.Scatter() elif isinstance(highlight_connectors, dict): HC_nodes = [] HC_color = [] for i in list(highlight_connectors.keys()): HC_nodes.append( z.connectors[z.connectors.connector_id == i].node_id.values[0]) HC_color.append("rgb({}, {}, {})".format( int(highlight_connectors[i][0]), int(highlight_connectors[i][1]), int(highlight_connectors[i][2]), )) HC_x = [] HC_y = [] HC_info = [] for node in g.nodes(): for tn in HC_nodes: if node == tn: x, y = pos[node] HC_x.append(x) HC_y.append(y) HC_info.append( "Connector of Interest, connector_id: {}, treenode_id: {}" .format( z.connectors[z.connectors.node_id == node].connector_id.values[0], node, )) HC_trace = go.Scatter( x=HC_x, y=HC_y, text=HC_info, mode="markers", hoverinfo="text", marker=dict(size=12.5, color=HC_color), ) # Highlight the nodes that are in a particular volume if in_volume is None: in_volume_trace = go.Scatter() elif in_volume is not None: # Volume of interest res = navis.in_volume(z.nodes, volume=in_volume, mode="IN") z.nodes["IN_VOLUME"] = res x_VOI = [] y_VOI = [] VOI_info = [] for nodes in g.nodes(): for tn in list(z.nodes[z.nodes.IN_VOLUME == True].node_id.values): if nodes == tn: x, y = pos[tn] x_VOI.append(x) y_VOI.append(y) VOI_info.append("Treenode {} is in {} volume".format( tn, in_volume)) in_volume_trace = go.Scatter( x=x_VOI, y=y_VOI, text=VOI_info, mode="markers", hoverinfo="text", marker=dict(size=5, color="rgb(35,119,0)"), ) print("Creating Plotly Graph") fig = go.Figure( data=[ edge_trace, node_trace, soma_trace, presynapse_connector_trace, postsynapse_connector_trace, HC_trace, in_volume_trace, ], layout=go.Layout( title="Plotly graph of {} with {} layout".format(z.name, prog), titlefont=dict(size=16), showlegend=False, hovermode="closest", margin=dict(b=20, l=50, r=5, t=40), annotations=[ dict(showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002) ], xaxis=go.layout.XAxis(showgrid=False, zeroline=False, showticklabels=False), yaxis=go.layout.YAxis(showgrid=False, zeroline=False, showticklabels=False), ), ) if inscreen is True: print("Finished in {} seconds".format(time.time() - start)) return iplot(fig) else: print("Finished in {} seconds".format(time.time() - start)) return plot(fig, filename=filename)
def find_fragments(x, remote_instance, min_node_overlap=3, min_nodes=1, mesh=None): """Find manual tracings overlapping with given autoseg neuron. This function is a generalization of ``find_autoseg_fragments`` and is designed to not require overlapping neurons to have references (e.g. in their name) to segmentation IDs: 1. Traverse neurites of ``x`` search within 2.5 microns radius for potentially overlapping fragments. 2. Either collect segmentation IDs for the input neuron and all potentially overlapping fragments or (if provided) use the mesh to check if the candidates are inside that mesh. 3. Return fragments that overlap with at least ``min_overlap`` nodes with input neuron. Parameters ---------- x : pymaid.CatmaidNeuron | navis.TreeNeuron Neuron to collect fragments for. remote_instance : pymaid.CatmaidInstance Catmaid instance in which to search for fragments. min_node_overlap : int, optional Minimal overlap between `x` and a fragment in nodes. If the fragment has less total nodes than `min_overlap`, the threshold will be lowered to: ``min_overlap = min(min_overlap, fragment.n_nodes)`` min_nodes : int, optional Minimum node count for returned neurons. mesh : navis.Volume | trimesh.Trimesh | navis.MeshNeuron, optional Mesh representation of ``x``. If provided will use the mesh instead of querying the segmentation to determine if fragments overlap. This is generally the faster. Return ------ pymaid.CatmaidNeuronList CatmaidNeurons of the overlapping fragments. Overlap scores are attached to each neuron as ``.overlap_score`` attribute. Examples -------- Setup: >>> import pymaid >>> import fafbseg >>> manual = pymaid.CatmaidInstance('MANUAL_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN') >>> auto = pymaid.CatmaidInstance('AUTO_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN') >>> # Set a source for segmentation data >>> fafbseg.use_google_storage("https://storage.googleapis.com/fafb-ffn1-20190805/segmentation") Find manually traced fragments overlapping with an autoseg neuron: >>> x = pymaid.get_neuron(204064470, remote_instance=auto) >>> frags_of_x = fafbseg.find_fragments(x, remote_instance=manual) See Also -------- fafbseg.find_autoseg_fragments Use this function if you are looking for autoseg fragments overlapping with a given neuron. Because we can use the reference to segment IDs (via names & annotations), this function is much faster than ``find_fragments``. """ if not isinstance(x, navis.TreeNeuron): raise TypeError( 'Expected navis.TreeNeuron or pymaid.CatmaidNeuron, got "{}"'. format(type(x))) meshtypes = (type(None), navis.MeshNeuron, navis.Volume, tm.Trimesh) if not isinstance(mesh, meshtypes): raise TypeError(f'Unexpected mesh of type "{type(mesh)}"') # Resample the autoseg neuron to 0.5 microns x_rs = x.resample(500, inplace=False) # For each node get skeleton IDs in a 0.25 micron radius r = 250 # Generate bounding boxes around each node bboxes = [ np.vstack([co - r, co + r]).T for co in x_rs.nodes[['x', 'y', 'z']].values ] # Query each bounding box urls = [ remote_instance._get_skeletons_in_bbox(minx=min(b[0]), maxx=max(b[0]), miny=min(b[1]), maxy=max(b[1]), minz=min(b[2]), maxz=max(b[2]), min_nodes=min_nodes) for b in bboxes ] resp = remote_instance.fetch(urls, desc='Searching for overlapping neurons') skids = set([s for l in resp for s in l]) # Return empty NeuronList if no skids found if not skids: return pymaid.CatmaidNeuronList([]) # Get nodes for these candidates tn_table = pymaid.get_node_table(skids, include_details=False, convert_ts=False, remote_instance=remote_instance) # Keep track of total node counts node_counts = tn_table.groupby('skeleton_id').node_id.count().to_dict() # If no mesh, use segmentation if not mesh: # Get segment IDs for the input neuron x.nodes['seg_id'] = locs_to_segments(x.nodes[['x', 'y', 'z']].values, coordinates='nm', mip=0) # Count segment IDs x_seg_counts = x.nodes.groupby('seg_id').node_id.count().reset_index( drop=False) x_seg_counts.columns = ['seg_id', 'counts'] # Remove seg IDs 0 x_seg_counts = x_seg_counts[x_seg_counts.seg_id != 0] # Generate KDTree for nearest neighbor calculations tree = navis.neuron2KDTree(x) # Now remove nodes that aren't even close to our input neuron dist, ix = tree.query(tn_table[['x', 'y', 'z']].values, distance_upper_bound=2500) tn_table = tn_table.loc[dist <= 2500] # Remove neurons that can't possibly have enough overlap node_counts2 = tn_table.groupby( 'skeleton_id').node_id.count().to_dict() to_keep = [ k for k, v in node_counts2.items() if v >= min(min_node_overlap, node_counts[k]) ] tn_table = tn_table[tn_table.skeleton_id.isin(to_keep)] # Add segment IDs tn_table['seg_id'] = locs_to_segments(tn_table[['x', 'y', 'z']].values, coordinates='nm', mip=0) # Now group by neuron and by segment seg_counts = tn_table.groupby( ['skeleton_id', 'seg_id']).node_id.count().reset_index(drop=False) # Rename columns seg_counts.columns = ['skeleton_id', 'seg_id', 'counts'] # Remove seg IDs 0 seg_counts = seg_counts[seg_counts.seg_id != 0] # Remove segments IDs that are not overlapping with input neuron seg_counts = seg_counts[np.isin(seg_counts.seg_id.values, x_seg_counts.seg_id.values)] # Now go over each candidate and see if there is enough overlap ol = [] scores = [] for s in seg_counts.skeleton_id.unique(): # Subset to nodes of this neurons this_counts = seg_counts[seg_counts.skeleton_id == s] # If the neuron is smaller than min_overlap, lower the threshold this_min_ol = min(min_node_overlap, node_counts[s]) # Sum up counts for both input neuron and this candidate c_count = this_counts.counts.sum() x_count = x_seg_counts[x_seg_counts.seg_id.isin( this_counts.seg_id.values)].counts.sum() # If there is enough overlap, keep this neuron # and add score as `overlap_score` if (c_count >= this_min_ol) and (x_count >= this_min_ol): # The score is the minimal overlap scores.append(min(c_count, x_count)) ol.append(s) else: # Check if nodes are inside or outside the mesh tn_table['in_mesh'] = navis.in_volume(tn_table[['x', 'y', 'z']].values, mesh).astype(bool) # Count the number of in-mesh nodes for each neuron # This also drops skeletons without ol_counts = tn_table.groupby('skeleton_id', as_index=False).in_mesh.sum() # Rename columns ol_counts.columns = ['skeleton_id', 'counts'] # Now subset to those that are overlapping sufficiently # First drop all non-overlapping fragments ol_counts = ol_counts[ol_counts.counts > 0] # Add column with total node counts ol_counts['node_count'] = ol_counts.skeleton_id.map(node_counts) # Generate an threshold array such that the threshold is the minimum # between the total nodes and min_node_overlap mno_array = np.repeat([min_node_overlap], ol_counts.shape[0]) mnc_array = ol_counts.skeleton_id.map(node_counts).values thr_array = np.vstack([mno_array, mnc_array]).min(axis=0) # Subset to candidates that meet the threshold ol_counts = ol_counts[ol_counts.counts >= thr_array] ol = ol_counts.skeleton_id.tolist() scores = ol_counts.counts.tolist() if ol: ol = pymaid.get_neurons(ol, remote_instance=remote_instance) # Make sure it's a neuronlist if not isinstance(ol, pymaid.CatmaidNeuronList): ol = pymaid.CatmaidNeuronList(ol) for n, s in zip(ol, scores): n.overlap_score = s else: ol = pymaid.CatmaidNeuronList([]) return ol
def compartmentalise_neuron( neuron_id: int, Rm: float, Ra: float, Cm: float, roi: navis.Volume, return_electromodel: bool): # Fetching the neuron ds_neuron = nvneu.fetch_skeletons(neuron_id, heal=True)[0] original_neuron = ds_neuron.copy() # Electrotonic model DS_NEURON = prepare_neuron(ds_neuron, change_units=True, factor=1e3) test_m, test_memcap = calculate_M_mat(DS_NEURON, Rm, Ra, Cm, solve=False) test_m_solved = np.linalg.inv(sparse.csr_matrix.todense(test_m)) # running PHATE phate_operator = phate.PHATE(n_components=3, n_jobs=-2, verbose=False) mat_red, labels, n_clusters, n_noise = find_clusters( test_m_solved, phate_operator, eps=1e-02, min_samples=6) # index and labels index_to_label = dict(zip(range(0, len(labels)), labels)) ds_neuron.nodes['node_cluster'] = ds_neuron.nodes.index.map(index_to_label) # node_to_compartment = dict(zip(ds_neuron.nodes.node_id.tolist(), # ds_neuron.nodes.node_cluster.tolist())) unique_compartments = ds_neuron.nodes.node_cluster.unique() # Finding the cluster of the nodes that were removed when downsampling the neuron whole_neuron_node_to_cluster = [] for i in unique_compartments: nodes_to_permute = ds_neuron.nodes[ds_neuron.nodes.node_cluster == i].node_id.tolist() start_end = [i for i in itertools.permutations(nodes_to_permute, 2)] # start_end = [i for i in itertools.permutations( # ds_neuron.nodes[ds_neuron.nodes.node_cluster == i].node_id.tolist(), 2)] nodes_of_cluster = cluster_to_all_nodes(original_neuron, start_end) node_to_cluster_dictionary = dict(zip(nodes_of_cluster, [i] * len(nodes_of_cluster))) whole_neuron_node_to_cluster.append(node_to_cluster_dictionary) whole_neuron_node_to_cluster_dict = {k:v for d in whole_neuron_node_to_cluster for k, v in d.items()} # Fetching postsynapses ds_neuron_postsynapses = nvneu.fetch_synapse_connections(target_criteria=neuron_id) ds_neuron_synapse_to_node = match_connectors_to_nodes(ds_neuron_postsynapses, original_neuron, synapse_type='post') # Which nodes are in the CA? if roi is not None: skeleton_in_roi = navis.in_volume(original_neuron.nodes[['x', 'y', 'z']].values, roi, inplace=False) ds_isin = ds_neuron_synapse_to_node.node.isin(original_neuron.nodes[skeleton_in_roi].node_id.tolist()) roi_syn_con = ds_neuron_synapse_to_node[ds_isin].copy() else: roi_syn_con = ds_neuron_synapse_to_node.copy() # roi_syn_con = ds_neuron_synapse_to_node[ds_neuron_synapse_to_node.node.isin( # original_neuron.nodes[skeleton_in_roi].node_id.tolist())].copy() a, b = nvneu.fetch_neurons(roi_syn_con.bodyId_pre.unique()) bid_to_instance = dict(zip(a.bodyId.tolist(), a.instance.tolist())) roi_syn_con['instance'] = [bid_to_instance[i] for i in roi_syn_con.bodyId_pre] nodes_to_query = roi_syn_con[~roi_syn_con.node.isin(whole_neuron_node_to_cluster_dict.keys())].node.tolist() comp_of_missing_nodes = find_compartments_of_missing_nodes(roi_syn_con, list(whole_neuron_node_to_cluster_dict.keys()), original_neuron, ds_neuron) whole_neuron_node_to_cluster_dict = {**whole_neuron_node_to_cluster_dict, **comp_of_missing_nodes} roi_syn_con['compartment'] = [whole_neuron_node_to_cluster_dict[i] for i in roi_syn_con.node.tolist()] original_neuron = node_to_compartment_full_neuron(original_neuron, ds_neuron) if return_electromodel: return(original_neuron, ds_neuron, roi_syn_con, test_m, test_memcap) else: return(original_neuron, ds_neuron, roi_syn_con)
def get_volume_pruned_neurons_by_skid(skids, volume_id, mode='fele', resample=0, only_keep_largest_fragment=False, verbose=False, remote_instance=None): """ mode : 'fele' - Keep all parts of the neuron between its primary neurite's First Entry to and Last Exit from the volume. So if a segment of the primary neurite leaves and then re-enters the volume, that segment is not removed. 'strict' - All nodes outside the volume are pruned. resample : If set to a positive value, the neuron will be resampled before pruning to have treenodes placed every `resample` nanometers. If left at 0, resampling is not performed. In both cases, if a branch point is encountered before the first entry or last exit, that branch point is used as the prune point. """ if remote_instance is None: remote_instance = source_project #if exit_volume_id is None: # exit_volume_id = entry_volume_id neurons = pymaid.get_neuron(skids, remote_instance=source_project) if volume_id not in volumes: try: print(f'Pulling volume {volume_id} from project' f' {remote_instance.project_id}.') volumes[volume_id] = pymaid.get_volume( volume_id, remote_instance=remote_instance) except: print(f"Couldn't find volume {volume_id} in project_id" f" {remote_instance.project_id}! Exiting.") raise else: print(f'Loading volume {volume_id} from cache.') volume = volumes[volume_id] if type(neurons) is pymaid.core.CatmaidNeuron: neurons = pymaid.core.CatmaidNeuronList(neurons) if resample > 0: #TODO find the last node on the primary neurite and store its position neurons.resample( resample) # This throws out radius info except for root #TODO find the new node closest to the stored node and set all nodes #between that node and root to have radius 500 for neuron in neurons: if 'pruned by vol' in neuron.neuron_name: raise Exception( 'Volume pruning was requested for ' f' "{neuron.neuron_name}". You probably didn\'t mean to do' ' this since it was already pruned. Exiting.') continue print(f'Pruning neuron {neuron.neuron_name}') if mode == 'fele': """ First, find the most distal primary neurite node. Then, walk backward until either finding a node within the volume or a branch point. Prune distal to one distal to that point (so it only gets the primary neurite and not the offshoot). Then, start from the primary neurite node that's a child of the soma node, and walk forward (how?) until finding a node within the volume or a branch point. Prune proximal to that. """ nodes = neuron.nodes.set_index('node_id') # Find end of the primary neurite nodes['has_fat_child'] = False for tid in nodes.index: if nodes.at[tid, 'radius'] == PRIMARY_NEURITE_RADIUS: parent = nodes.at[tid, 'parent_id'] nodes.at[parent, 'has_fat_child'] = True is_prim_neurite_end = (~nodes['has_fat_child']) & ( nodes['radius'] == PRIMARY_NEURITE_RADIUS) prim_neurite_end = nodes.index[is_prim_neurite_end] if len(prim_neurite_end) is 0: raise ValueError(f"{neuron.neuron_name} doesn't look like a" " motor neuron. Exiting.") elif len(prim_neurite_end) is not 1: raise ValueError('Multiple primary neurite ends for' f' {neuron.neuron_name}: {prim_neurite_end}.' '\nExiting.') nodes['is_in_vol'] = navis.in_volume(nodes, volume) # Walk backwards until at a point inside the volume, or at a branch # point current_node = prim_neurite_end[0] parent_node = nodes.at[current_node, 'parent_id'] while (nodes.at[parent_node, 'type'] != 'branch' and not nodes.at[parent_node, 'is_in_vol']): current_node = parent_node if verbose: print(f'Walk back to {current_node}') parent_node = nodes.at[parent_node, 'parent_id'] if verbose: print(f'Pruning distal to {current_node}') neuron.prune_distal_to(current_node, inplace=True) # Start at the first primary neurite node downstream of root current_node = nodes.index[ (nodes.parent_id == neuron.root[0]) #& (nodes.radius == PRIMARY_NEURITE_RADIUS)][0] & (nodes.radius > 0)][0] #Walking downstream is a bit slow, but probably acceptable while (not nodes.at[current_node, 'is_in_vol'] and nodes.at[current_node, 'type'] == 'slab'): current_node = nodes.index[nodes.parent_id == current_node][0] if verbose: print(f'Walk forward to {current_node}') if not nodes.at[current_node, 'is_in_vol']: input('WARNING: Hit a branch before hitting the volume for' f' neuron {neuron.neuron_name}. This is unusual.' ' Press enter to acknowledge.') if verbose: print(f'Pruning proximal to {current_node}') neuron.prune_proximal_to(current_node, inplace=True) elif mode == 'strict': neuron.prune_by_volume(volume) #This does in-place pruning if neuron.n_skeletons > 1: if only_keep_largest_fragment: print('Neuron has multiple disconnected fragments after' ' pruning. Will only keep the largest fragment.') frags = morpho.break_fragments(neuron) neuron.nodes = frags[frags.n_nodes == max( frags.n_nodes)][0].nodes #print(neuron) #else, the neuron will get healed and print a message about being #healed during upload_neuron, so nothing needs to be done here. if mode == 'fele': neuron.annotations.append( f'pruned (first entry, last exit) by vol {volume_id}') elif mode == 'strict': neuron.annotations.append(f'pruned (strict) by vol {volume_id}') neuron.neuron_name = neuron.neuron_name + f' - pruned by vol {volume_id}' if verbose: print('\n') return neurons
def adjx_from_syn_conn( x: List[int], presyn_postsyn: str = "pre", roi: Optional = None, ct: tuple = (0.0, 0.0), rename_index: bool = False, ): """ Creates an adjacency matrix from synapse connections Parameters ---------- presyn_postsyn: x : list a list of bodyIds of which you want to find their synaptic (pre/post) connections presyn_postsyn : str a string of either 'pre' or 'post' If 'pre', the function will search for the presynaptic connections / downstream neurons of x If 'post', the function will search for the postsynaptic connections / upstream neurons of x roi : navis.Volume A region of interest within which you are filtering the connections for ct : tuple Confidence threshold tuple containing the confidence value to filter above The first value is presynaptic confidence and the second value postsynaptic confidence e.g. (0.9, 0.8) will filter for connections where the presynaptic confidence > 0.9 and the postsynaptic confidence > 0.8 rename_index : bool Whether to rename the index using the type of the connected neuron Returns ------- df : a DataFrame where x are the columns, the connection type (pre/post) are the rows and the values the number of connections partner_type_dict : a dictionary where the keys are bodyIds of the upstream/downstream neurons and the values are their types Examples -------- """ if presyn_postsyn == "pre": con = nvneu.fetch_synapse_connections(source_criteria=x) if roi: tt = navis.in_volume(con[["x_pre", "y_pre", "z_pre"]].values, roi) con = con[tt].copy() if ct[0] or ct[1] > 0.0: con = con[(con.confidence_pre > ct[0]) & (con.confidence_post > ct[1])].copy() neurons = con.bodyId_post.unique() n, _ = nvneu.fetch_neurons(neurons) partner_type_dict = dict(zip(n.bodyId.tolist(), n.type.tolist())) count = Counter(con.bodyId_post) count = count.most_common() count_dict = dict(count) df = pd.DataFrame(columns=[x], index=[i for i, j in count], data=[count_dict[i] for i, j in count]) elif presyn_postsyn == "post": con = nvneu.fetch_synapse_connections(target_criteria=x) if roi: tt = navis.in_volume(con[["x_post", "y_post", "z_post"]].values, roi) con = con[tt].copy() if ct[0] or ct[1] > 0.0: con = con[(con.confidence_pre > ct[0]) & (con.confidence_post > ct[1])].copy() neurons = con.bodyId_pre.unique() n, _ = nvneu.fetch_neurons(neurons) partner_type_dict = dict(zip(n.bodyId.tolist(), n.type.tolist())) count = Counter(con.bodyId_pre) count = count.most_common() count_dict = dict(count) df = pd.DataFrame(index=[i for i, j in count], columns=[x], data=[count_dict[i] for i, j in count]) df = df.T.copy() # df = pd.DataFrame( # columns=[x], index=[i for i, j in count], data=[count_dict[i] for i, j in count]) if rename_index: df.index = [partner_type_dict[i] for i in df.index] return (df, partner_type_dict)
def random_sample_from_volume( volume: navis.Volume, supervoxels: cloudvolume.frontends.precomputed.CloudVolumePrecomputed, segmentation: cloudvolume.frontends.graphene.CloudVolumeGraphene, amount_to_query: Union[float, int] = 10, mip_level: Tuple[int, int, int] = (64, 64, 40), n_points: int = int(1e6), bbox_dim: int = 20, disable_rand_point_progress: bool = False): """ Takes a FAFB14 volume and generates random points within it. These points are then used to create a series of small bounding boxes. The supervoxels within these bounding boxes are then queried. These supervoxel ids are then mapped to root ids. Parameters ---------- volume: a navis.Volume object retrieved either through FAFB14 or Janelia Hemibrain. If the latter, then you need to transform the volume into flywire space using a bridging transformation before using this function mip_level: the mip resolution level. Available resolutions are (16, 16, 40), (32, 32, 40), (64, 64, 40). I recommend the lowest voxel resolution (64, 64, 40). n_points: the number of randomly generated points to create. Note that not all of these will be in the volume, although those that are are kept. bbox_dim: dimension of the query cube - small values (<100) are encouraged. supervoxels: A supervoxel cloudvolume.Volume object specifying the supervoxel instance you are querying. e.g. cloudvolume.CloudVolume("precomputed://https://s3-hpcrc.rc.princeton.edu/fafbv14-ws/ws_190410_FAFB_v02_ws_size_threshold_200", mip=mip_level) segmentation: A segmentation cloudvolume.Volume object specifying the segmentation instance you are querying. e.g. cloudvolume.CloudVolume("graphene://https://prodv1.flywire-daf.com/segmentation/table/fly_v31") amount_to_query: Either a float or an int. If a float, this is a percentage of the total number of randomly generated points that are in the volume you want to query If an int, the first n randomly generated points that are in the volume will be used to query. Returns ---------- root_ids: an array of root ids that were within the randomly generated query boxes. Example ---------- import pymaid import cloudvolume import numpy as np import navis from random import randrange import chain import time MB_R = pymaid.get_volume('MB_whole_R') mip_level = (64, 64, 40) cv_svs = cloudvolume.CloudVolume("precomputed://https://s3-hpcrc.rc.princeton.edu/fafbv14-ws/ws_190410_FAFB_v02_ws_size_threshold_200", mip=mip_level) cv_seg = cloudvolume.CloudVolume("graphene://https://prodv1.flywire-daf.com/segmentation/table/fly_v31") root_ids = random_sample_from_volume(volume = MB_R, mip_level = mip_level, n_points = int(1e6), bbox_dim = 20, supervoxels=cv_svs, segmentation=cv_seg) """ start_time = time.time() assert isinstance( volume, navis.Volume ), f'You need to pass a navis.Volume object. You passed {type(volume)}.' assert isinstance( supervoxels, cloudvolume.frontends.precomputed.CloudVolumePrecomputed ), f'You need to pass a cloudvolume.frontends.precomputed.CloudVolumePrecomputed object. You passed {type(supervoxels)}.' assert isinstance( segmentation, cloudvolume.frontends.graphene.CloudVolumeGraphene ), f'You need to pass a cloudvolume.frontends.graphene.CloudVolumeGraphene object. You passed {type(segmentation)}.' if type(n_points) != int: n_points = int(n_points) volume.vertices /= mip_level vertices_array = np.array(volume.vertices) vertices_array = vertices_array.astype(int) # generating random points n_point_string = format(n_points, ',d') print(f'Generating {n_point_string} random points... \n') # finding the min and max values of each xyz dimension x_min = min(vertices_array[:, 0]) x_max = max(vertices_array[:, 0]) y_min = min(vertices_array[:, 1]) y_max = max(vertices_array[:, 1]) z_min = min(vertices_array[:, 2]) z_max = max(vertices_array[:, 2]) # randomly generating integers inbetween these max and min values rand_x = [ randrange(x_min, x_max) for i in tqdm(range(n_points), disable=disable_rand_point_progress) ] rand_y = [ randrange(y_min, y_max) for i in tqdm(range(n_points), disable=disable_rand_point_progress) ] rand_z = [ randrange(z_min, z_max) for i in tqdm(range(n_points), disable=disable_rand_point_progress) ] xyz_arr = np.array([[rand_x, rand_y, rand_z]]) xyz_arr = xyz_arr.T xyz_arr = xyz_arr.reshape(n_points, 3) # How many randomly generated points are in the volume? in_vol = navis.in_volume(xyz_arr, volume) print( f"""Of {n_point_string} random points generated, {xyz_arr[in_vol].shape[0] / n_points * 1e2 :.2f}% of them are in the volume.""" ) print(f"""This equals {xyz_arr[in_vol].shape[0]} points. \n""") xyz_in_vol = xyz_arr[in_vol] # generating query cubes xyz_start = xyz_in_vol xyz_end = xyz_in_vol + bbox_dim # querying flywire print('Querying flywire... \n') supervoxel_ids = [] if isinstance(amount_to_query, float): assert amount_to_query <= 1.0, '''If using percentages of the total number of randomly generated points, you cannot use more than 100% of the points generated. Did you intend to use an integer?''' print( f'You are passing a float, so using {amount_to_query * 1e2}% of the total number ({len(xyz_start)}) of randomly points generated. \n' ) if amount_to_query == 1.0: print( 'You are using 100% of the randomly generated points - this can take a long time to complete.' ) n_query = int(len(xyz_start) * amount_to_query) print(f'{amount_to_query * 1e2}% = {n_query} points') xyz_start = xyz_start[:n_query] xyz_end = xyz_end[:n_query] print( f'Coverage: {((bbox_dim ** 3) * n_query / volume.volume) * 1e2 :.2g} % of the total {volume.name} volume is covered with {n_query} bounding boxes of {bbox_dim} cubic dimensions. \n' ) for i in range(n_query): q = supervoxels[xyz_start[i][0]:xyz_end[i][0], xyz_start[i][1]:xyz_end[i][1], xyz_start[i][2]:xyz_end[i][2]] supervoxel_ids.append(q) elif isinstance(amount_to_query, int): assert amount_to_query <= len( xyz_start ), f'''You cannot use the first {amount_to_query} randomly generated points when you have only {len(xyz_start)} exist in the volume. Increase the number of randomly generated points in n_points.''' print( f'You are passing an integer, so the first {amount_to_query} randomly generated points will be used to query. \n' ) xyz_start = xyz_start[:amount_to_query] xyz_end = xyz_end[:amount_to_query] print( f'Coverage: {((bbox_dim ** 3) * amount_to_query / volume.volume) * 1e2 :.2g} % of the total {volume.name} volume is covered with {amount_to_query} bounding boxes of {bbox_dim} cubic dimensions. \n' ) for i in range(amount_to_query): q = supervoxels[xyz_start[i][0]:xyz_end[i][0], xyz_start[i][1]:xyz_end[i][1], xyz_start[i][2]:xyz_end[i][2]] supervoxel_ids.append(q) sv_ids_unique = [np.unique(i) for i in supervoxel_ids] sv_ids_unique = np.unique(list(chain.from_iterable(sv_ids_unique))) print('Fetching root ids... \n') root_ids = segmentation.get_roots(sv_ids_unique) # Removing zeros root_ids = root_ids[~(root_ids == 0)] root_ids = np.unique(root_ids) print( f'Random sampling complete: {len(root_ids)} unique root ids found \n') print( f'This function took {(time.time() - start_time) / 60 :.2f} minutes to complete' ) return (root_ids)