Exemplo n.º 1
0
def matching_inputs_to_compartments(
        neuron_id: int,
        roi: navis.Volume,
        Ra: float,
        Rm: float,
        Cm: float):

    # Fetch the healed skeleton
    full_skeleton = nvneu.fetch_skeletons(neuron_id, heal=True)[0]
    # which of the whole neuron's nodes are in the roi
    skeleton_in_roi = navis.in_volume(full_skeleton.nodes[['x', 'y', 'z']], roi, mode='IN')

    # Fetch the neurons synapses
    postsyn = nvneu.fetch_synapse_connections(target_criteria=neuron_id)
    # match the connectors to nodes
    syn_to_node = match_connectors_to_nodes(postsyn, full_skeleton, synapse_type='post')
    # which of these are in the roi
    roi_syn_con = syn_to_node[syn_to_node.node.isin(full_skeleton.nodes[skeleton_in_roi].node_id.tolist())].copy()

    # Count how many synapses the upstream neurons have on your neuron of interest
    n_syn = dict(Counter(roi_syn_con.bodyId_pre.tolist()).most_common())

    # fetch these neurons
    a, b = nvneu.fetch_neurons(roi_syn_con.bodyId_pre.unique())
    a['n_syn'] = [n_syn[i] for i in a.bodyId.tolist()]

    # adding the instance names to the synapses in the roi
    bid_to_instance = dict(zip(a.bodyId.tolist(), a.instance.tolist()))
    roi_syn_con['instance'] = [bid_to_instance[i] for i in roi_syn_con.bodyId_pre]

    # compartmentalise the neuron and find the nodes in each compartment for the prepared neuron
    compartments_in_roi, nodes_in_roi = find_compartments_in_roi(full_skeleton, roi=roi, min_samples=6, Cm=Cm, Ra=Ra, Rm=Rm)
    clusters = compartments_in_roi.nodes.node_cluster.unique()

    cluster_dict = {}

    # find the nodes that make up each compartment in the full neuron
    for i in clusters:

        clust_nodes = cluster_to_all_nodes(full_skeleton, start_end_node_pairs=permute_start_end(compartments_in_roi, nodes_in_roi, cluster=i))
        cluster_dict[i] = clust_nodes

    # cluster_dict = {k : s for s, k in cluster_dict.items()}
    # roi_syn_con['compartment'] = [cluster_dict[i] for i in roi_syn_con.node.tolist()]

    return(cluster_dict)
Exemplo n.º 2
0
def find_compartments_in_roi(neuron,
                            Rm,
                            Ra,
                            Cm,
                            roi,
                            min_samples):

    n = neuron.copy()
    n_prep = prepare_neuron(n, change_units=True, factor=1e3)
    n_m, n_memcap = calculate_M_mat(n_prep, Rm, Ra, Cm, solve=False)
    n_m_solved = np.linalg.inv(sparse.csr_matrix.todense(n_m))

    phate_operator = phate.PHATE(n_components=3, n_jobs=-2, verbose=False)

    mat_red, labels, n_clusters, n_noise = find_clusters(n_m_solved, phate_operator, eps=1e-02, min_samples=min_samples)

    node_to_label = dict(zip(range(0, len(labels)), labels))

    n_prep.nodes['node_cluster'] = n_prep.nodes.index.map(node_to_label)

    n_in_roi = navis.in_volume(n_prep.nodes[['x', 'y', 'z']] * 1e3, roi, mode='IN', inplace=False)

    return(n_prep, n_in_roi)
Exemplo n.º 3
0
def interactive_dendrogram(
    z: Union[navis.TreeNeuron, navis.NeuronList],
    heal_neuron: bool = True,
    plot_nodes: bool = True,
    plot_connectors: bool = True,
    highlight_connectors: Optional = None,
    in_volume: Optional = None,
    prog: str = "dot",
    inscreen: bool = True,
    filename: Optional = None,
):
    """
    Takes a navis neuron and returns an interactive 2D dendrogram.
    In this dendrogram, nodes or connector locations can be highlighted


    Parameters
    ----------
    x:                    A navis neuron object

    heal_neuron: bool

                    Whether you want to heal the neuron or not. N.B. Navis neurons
                    should be healed on import, i.e. navis.fetch_skeletons(bodyid, heal = True)
                    see navis.fetch_skeletons and navis.heal_fragmented_neuron for more details

    plot_nodes:           bool

                    Whether treenodes should be plotted

    plot_connectors:      bool

                    Whether connectors should be plotted

    highlight_connectors: dict

                    A dictionary containing the treenodes of
                    the connectors you want to highlight as keys
                    and the colours you want to colour them as values.
                    This allows for multiple colours to be plotted.

                    N.B. Plotly colours are in the range of 0 - 255
                    whereas matplotlib colours are between 0-1. For the
                    interactive dendrogram colours need to be in the
                    plotly range, whereas in the static dendrogram
                    the colours need to be in the matplotlib range.

    in_volume:            navis.Volume object

                    A navis.Volume object corresponding to an ROI in the brain.
                    This will then highlight the nodes of
                    the neuron which are in that volume


    prog:                 str

                    The layout type used by navis.nx_agraph.graphviz_layout()
                    Valid programs include [dot, neato or fdp].
                    The dot program provides a hierarchical layout, this is the fastest program
                    The neato program creates edges between nodes proportional to their real length.
                    The neato program takes the longest amount of time, can be ~2hrs for a single neuron!

    inscreen:             bool

                    Whether to plot the graph inscreen (juptyer notebooks) or to plot it as a
                    separate HTML file that can be saved to file, opened in the browser and opened any time

    filename:             str

                    The filename of your interactive dendrogram html file. This parameter is only appropriate
                    when inscreen = False.



    Returns
    -------
    plotly.fig object containing the dendrogram - this can be either inscreen or as a separate html file
    with the filename specified by the filename parameter


    Examples
    --------
    from neuroboom.utils import create_graph_structure
    from neuroboom.dendrogram import interactive_dendrogram
    import navis.interfaces.neuprint as nvneu
    from matplotlib import pyplot as plt


    test_neuron = nvneu.fetch_skeletons(722817260)


    interactive_dendrogram(test_neuron, prog = 'dot', inscreen = True)


    """

    z = check_valid_neuron_input(z)

    if heal_neuron:
        z = navis.heal_fragmented_neuron(z)

    valid_progs = ["neato", "dot"]
    if prog not in valid_progs:
        raise ValueError("Unknown program parameter!")

    # save start time

    start = time.time()

    # Necessary for neato layouts for preservation of segment lengths

    if "parent_dist" not in z.nodes:
        z = calc_cable(z, return_skdata=True)

    # Generation of networkx diagram

    g, pos = create_graph_structure(z,
                                    returned_object="graph_and_positions",
                                    prog=prog)

    print("Now creating plotly graph...")

    # Convering networkx nodes for plotly
    # NODES

    x = []
    y = []
    node_info = []

    if plot_nodes:

        for n in g.nodes():
            x_, y_ = pos[n]

            x.append(x_)
            y.append(y_)
            node_info.append("Node ID: {}".format(n))

        node_trace = go.Scatter(
            x=x,
            y=y,
            mode="markers",
            text=node_info,
            hoverinfo="text",
            marker=go.scatter.Marker(showscale=False),
        )
    else:

        node_trace = go.Scatter()

    # EDGES

    xe = []
    ye = []

    for e in g.edges():

        x0, y0 = pos[e[0]]
        x1, y1 = pos[e[1]]

        xe += [x0, x1, None]
        ye += [y0, y1, None]

    edge_trace = go.Scatter(
        x=xe,
        y=ye,
        line=go.scatter.Line(width=1.0, color="#000"),
        hoverinfo="none",
        mode="lines",
    )

    # SOMA

    xs = []
    ys = []

    for n in g.nodes():
        if n != z.soma:
            continue
        elif n == z.soma:

            x__, y__ = pos[n]
            xs.append(x__)
            ys.append(y__)

        else:

            break

    soma_trace = go.Scatter(
        x=xs,
        y=ys,
        mode="markers",
        hoverinfo="text",
        marker=dict(size=20, color="rgb(0,0,0)"),
        text="Soma, node:{}".format(z.soma),
    )

    # CONNECTORS:
    # RELATION  = 0 ARE PRESYNAPSES, RELATION = 1 ARE POSTSYNAPSES

    if plot_connectors is False:

        presynapse_connector_trace = go.Scatter()

        postsynapse_connector_trace = go.Scatter()

    elif plot_connectors is True:

        presynapse_connector_list = list(
            z.connectors[z.connectors.type == "pre"].node_id.values)

        x_pre = []
        y_pre = []
        presynapse_connector_info = []

        for node in g.nodes():
            for tn in presynapse_connector_list:

                if node == tn:

                    x, y = pos[node]

                    x_pre.append(x)
                    y_pre.append(y)
                    presynapse_connector_info.append(
                        "Presynapse, connector_id: {}".format(tn))

        presynapse_connector_trace = go.Scatter(
            x=x_pre,
            y=y_pre,
            text=presynapse_connector_info,
            mode="markers",
            hoverinfo="text",
            marker=dict(size=10, color="rgb(0,255,0)"),
        )

        postsynapse_connectors_list = list(
            z.connectors[z.connectors.type == "post"].node_id.values)

        x_post = []
        y_post = []
        postsynapse_connector_info = []

        for node in g.nodes():
            for tn in postsynapse_connectors_list:

                if node == tn:

                    x, y = pos[node]

                    x_post.append(x)
                    y_post.append(y)

                    postsynapse_connector_info.append(
                        "Postsynapse, connector id: {}".format(tn))

        postsynapse_connector_trace = go.Scatter(
            x=x_post,
            y=y_post,
            text=postsynapse_connector_info,
            mode="markers",
            hoverinfo="text",
            marker=dict(size=10, color="rgb(0,0,255)"),
        )

    if highlight_connectors is None:

        HC_trace = go.Scatter()

    elif isinstance(highlight_connectors, dict):

        HC_nodes = []
        HC_color = []

        for i in list(highlight_connectors.keys()):

            HC_nodes.append(
                z.connectors[z.connectors.connector_id == i].node_id.values[0])
            HC_color.append("rgb({}, {}, {})".format(
                int(highlight_connectors[i][0]),
                int(highlight_connectors[i][1]),
                int(highlight_connectors[i][2]),
            ))

            HC_x = []
            HC_y = []

            HC_info = []

            for node in g.nodes():

                for tn in HC_nodes:

                    if node == tn:

                        x, y = pos[node]

                        HC_x.append(x)
                        HC_y.append(y)

                        HC_info.append(
                            "Connector of Interest, connector_id: {}, treenode_id: {}"
                            .format(
                                z.connectors[z.connectors.node_id ==
                                             node].connector_id.values[0],
                                node,
                            ))

            HC_trace = go.Scatter(
                x=HC_x,
                y=HC_y,
                text=HC_info,
                mode="markers",
                hoverinfo="text",
                marker=dict(size=12.5, color=HC_color),
            )

    # Highlight the nodes that are in a particular volume

    if in_volume is None:

        in_volume_trace = go.Scatter()

    elif in_volume is not None:

        # Volume of interest

        res = navis.in_volume(z.nodes, volume=in_volume, mode="IN")
        z.nodes["IN_VOLUME"] = res

        x_VOI = []
        y_VOI = []

        VOI_info = []

        for nodes in g.nodes():

            for tn in list(z.nodes[z.nodes.IN_VOLUME == True].node_id.values):

                if nodes == tn:

                    x, y = pos[tn]

                    x_VOI.append(x)
                    y_VOI.append(y)

                    VOI_info.append("Treenode {} is in {} volume".format(
                        tn, in_volume))

        in_volume_trace = go.Scatter(
            x=x_VOI,
            y=y_VOI,
            text=VOI_info,
            mode="markers",
            hoverinfo="text",
            marker=dict(size=5, color="rgb(35,119,0)"),
        )

    print("Creating Plotly Graph")

    fig = go.Figure(
        data=[
            edge_trace,
            node_trace,
            soma_trace,
            presynapse_connector_trace,
            postsynapse_connector_trace,
            HC_trace,
            in_volume_trace,
        ],
        layout=go.Layout(
            title="Plotly graph of {} with {} layout".format(z.name, prog),
            titlefont=dict(size=16),
            showlegend=False,
            hovermode="closest",
            margin=dict(b=20, l=50, r=5, t=40),
            annotations=[
                dict(showarrow=False,
                     xref="paper",
                     yref="paper",
                     x=0.005,
                     y=-0.002)
            ],
            xaxis=go.layout.XAxis(showgrid=False,
                                  zeroline=False,
                                  showticklabels=False),
            yaxis=go.layout.YAxis(showgrid=False,
                                  zeroline=False,
                                  showticklabels=False),
        ),
    )

    if inscreen is True:
        print("Finished in {} seconds".format(time.time() - start))
        return iplot(fig)

    else:
        print("Finished in {} seconds".format(time.time() - start))
        return plot(fig, filename=filename)
Exemplo n.º 4
0
def find_fragments(x,
                   remote_instance,
                   min_node_overlap=3,
                   min_nodes=1,
                   mesh=None):
    """Find manual tracings overlapping with given autoseg neuron.

    This function is a generalization of ``find_autoseg_fragments`` and is
    designed to not require overlapping neurons to have references (e.g.
    in their name) to segmentation IDs:

        1. Traverse neurites of ``x`` search within 2.5 microns radius for
           potentially overlapping fragments.
        2. Either collect segmentation IDs for the input neuron and all
           potentially overlapping fragments or (if provided) use the mesh
           to check if the candidates are inside that mesh.
        3. Return fragments that overlap with at least ``min_overlap`` nodes
           with input neuron.

    Parameters
    ----------
    x :                 pymaid.CatmaidNeuron | navis.TreeNeuron
                        Neuron to collect fragments for.
    remote_instance :   pymaid.CatmaidInstance
                        Catmaid instance in which to search for fragments.
    min_node_overlap :  int, optional
                        Minimal overlap between `x` and a fragment in nodes. If
                        the fragment has less total nodes than `min_overlap`,
                        the threshold will be lowered to:
                        ``min_overlap = min(min_overlap, fragment.n_nodes)``
    min_nodes :         int, optional
                        Minimum node count for returned neurons.
    mesh :              navis.Volume | trimesh.Trimesh | navis.MeshNeuron, optional
                        Mesh representation of ``x``. If provided will use the
                        mesh instead of querying the segmentation to determine
                        if fragments overlap. This is generally the faster.

    Return
    ------
    pymaid.CatmaidNeuronList
                        CatmaidNeurons of the overlapping fragments. Overlap
                        scores are attached to each neuron as ``.overlap_score``
                        attribute.

    Examples
    --------
    Setup:

    >>> import pymaid
    >>> import fafbseg

    >>> manual = pymaid.CatmaidInstance('MANUAL_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')
    >>> auto = pymaid.CatmaidInstance('AUTO_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')

    >>> # Set a source for segmentation data
    >>> fafbseg.use_google_storage("https://storage.googleapis.com/fafb-ffn1-20190805/segmentation")

    Find manually traced fragments overlapping with an autoseg neuron:

    >>> x = pymaid.get_neuron(204064470, remote_instance=auto)
    >>> frags_of_x = fafbseg.find_fragments(x, remote_instance=manual)

    See Also
    --------
    fafbseg.find_autoseg_fragments
                        Use this function if you are looking for autoseg
                        fragments overlapping with a given neuron. Because we
                        can use the reference to segment IDs (via names &
                        annotations), this function is much faster than
                        ``find_fragments``.

    """
    if not isinstance(x, navis.TreeNeuron):
        raise TypeError(
            'Expected navis.TreeNeuron or pymaid.CatmaidNeuron, got "{}"'.
            format(type(x)))

    meshtypes = (type(None), navis.MeshNeuron, navis.Volume, tm.Trimesh)
    if not isinstance(mesh, meshtypes):
        raise TypeError(f'Unexpected mesh of type "{type(mesh)}"')

    # Resample the autoseg neuron to 0.5 microns
    x_rs = x.resample(500, inplace=False)

    # For each node get skeleton IDs in a 0.25 micron radius
    r = 250
    # Generate bounding boxes around each node
    bboxes = [
        np.vstack([co - r, co + r]).T
        for co in x_rs.nodes[['x', 'y', 'z']].values
    ]

    # Query each bounding box
    urls = [
        remote_instance._get_skeletons_in_bbox(minx=min(b[0]),
                                               maxx=max(b[0]),
                                               miny=min(b[1]),
                                               maxy=max(b[1]),
                                               minz=min(b[2]),
                                               maxz=max(b[2]),
                                               min_nodes=min_nodes)
        for b in bboxes
    ]
    resp = remote_instance.fetch(urls,
                                 desc='Searching for overlapping neurons')
    skids = set([s for l in resp for s in l])

    # Return empty NeuronList if no skids found
    if not skids:
        return pymaid.CatmaidNeuronList([])

    # Get nodes for these candidates
    tn_table = pymaid.get_node_table(skids,
                                     include_details=False,
                                     convert_ts=False,
                                     remote_instance=remote_instance)
    # Keep track of total node counts
    node_counts = tn_table.groupby('skeleton_id').node_id.count().to_dict()

    # If no mesh, use segmentation
    if not mesh:
        # Get segment IDs for the input neuron
        x.nodes['seg_id'] = locs_to_segments(x.nodes[['x', 'y', 'z']].values,
                                             coordinates='nm',
                                             mip=0)

        # Count segment IDs
        x_seg_counts = x.nodes.groupby('seg_id').node_id.count().reset_index(
            drop=False)
        x_seg_counts.columns = ['seg_id', 'counts']

        # Remove seg IDs 0
        x_seg_counts = x_seg_counts[x_seg_counts.seg_id != 0]

        # Generate KDTree for nearest neighbor calculations
        tree = navis.neuron2KDTree(x)

        # Now remove nodes that aren't even close to our input neuron
        dist, ix = tree.query(tn_table[['x', 'y', 'z']].values,
                              distance_upper_bound=2500)
        tn_table = tn_table.loc[dist <= 2500]

        # Remove neurons that can't possibly have enough overlap
        node_counts2 = tn_table.groupby(
            'skeleton_id').node_id.count().to_dict()
        to_keep = [
            k for k, v in node_counts2.items()
            if v >= min(min_node_overlap, node_counts[k])
        ]
        tn_table = tn_table[tn_table.skeleton_id.isin(to_keep)]

        # Add segment IDs
        tn_table['seg_id'] = locs_to_segments(tn_table[['x', 'y', 'z']].values,
                                              coordinates='nm',
                                              mip=0)

        # Now group by neuron and by segment
        seg_counts = tn_table.groupby(
            ['skeleton_id', 'seg_id']).node_id.count().reset_index(drop=False)
        # Rename columns
        seg_counts.columns = ['skeleton_id', 'seg_id', 'counts']

        # Remove seg IDs 0
        seg_counts = seg_counts[seg_counts.seg_id != 0]

        # Remove segments IDs that are not overlapping with input neuron
        seg_counts = seg_counts[np.isin(seg_counts.seg_id.values,
                                        x_seg_counts.seg_id.values)]

        # Now go over each candidate and see if there is enough overlap
        ol = []
        scores = []
        for s in seg_counts.skeleton_id.unique():
            # Subset to nodes of this neurons
            this_counts = seg_counts[seg_counts.skeleton_id == s]

            # If the neuron is smaller than min_overlap, lower the threshold
            this_min_ol = min(min_node_overlap, node_counts[s])

            # Sum up counts for both input neuron and this candidate
            c_count = this_counts.counts.sum()
            x_count = x_seg_counts[x_seg_counts.seg_id.isin(
                this_counts.seg_id.values)].counts.sum()

            # If there is enough overlap, keep this neuron
            # and add score as `overlap_score`
            if (c_count >= this_min_ol) and (x_count >= this_min_ol):
                # The score is the minimal overlap
                scores.append(min(c_count, x_count))
                ol.append(s)
    else:
        # Check if nodes are inside or outside the mesh
        tn_table['in_mesh'] = navis.in_volume(tn_table[['x', 'y', 'z']].values,
                                              mesh).astype(bool)
        # Count the number of in-mesh nodes for each neuron
        # This also drops skeletons without
        ol_counts = tn_table.groupby('skeleton_id',
                                     as_index=False).in_mesh.sum()

        # Rename columns
        ol_counts.columns = ['skeleton_id', 'counts']

        # Now subset to those that are overlapping sufficiently
        # First drop all non-overlapping fragments
        ol_counts = ol_counts[ol_counts.counts > 0]

        # Add column with total node counts
        ol_counts['node_count'] = ol_counts.skeleton_id.map(node_counts)

        # Generate an threshold array such that the threshold is the minimum
        # between the total nodes and min_node_overlap
        mno_array = np.repeat([min_node_overlap], ol_counts.shape[0])
        mnc_array = ol_counts.skeleton_id.map(node_counts).values
        thr_array = np.vstack([mno_array, mnc_array]).min(axis=0)

        # Subset to candidates that meet the threshold
        ol_counts = ol_counts[ol_counts.counts >= thr_array]

        ol = ol_counts.skeleton_id.tolist()
        scores = ol_counts.counts.tolist()

    if ol:
        ol = pymaid.get_neurons(ol, remote_instance=remote_instance)

        # Make sure it's a neuronlist
        if not isinstance(ol, pymaid.CatmaidNeuronList):
            ol = pymaid.CatmaidNeuronList(ol)

        for n, s in zip(ol, scores):
            n.overlap_score = s
    else:
        ol = pymaid.CatmaidNeuronList([])

    return ol
Exemplo n.º 5
0
def compartmentalise_neuron(
        neuron_id: int,
        Rm: float,
        Ra: float,
        Cm: float,
        roi: navis.Volume,
        return_electromodel: bool):

    # Fetching the neuron
    ds_neuron = nvneu.fetch_skeletons(neuron_id, heal=True)[0]
    original_neuron = ds_neuron.copy()

    # Electrotonic model
    DS_NEURON = prepare_neuron(ds_neuron, change_units=True, factor=1e3)
    test_m, test_memcap = calculate_M_mat(DS_NEURON, Rm, Ra, Cm, solve=False)
    test_m_solved = np.linalg.inv(sparse.csr_matrix.todense(test_m))

    # running PHATE
    phate_operator = phate.PHATE(n_components=3, n_jobs=-2, verbose=False)

    mat_red, labels, n_clusters, n_noise = find_clusters(
            test_m_solved,
            phate_operator,
            eps=1e-02,
            min_samples=6)

    # index and labels
    index_to_label = dict(zip(range(0, len(labels)), labels))
    ds_neuron.nodes['node_cluster'] = ds_neuron.nodes.index.map(index_to_label)

    # node_to_compartment = dict(zip(ds_neuron.nodes.node_id.tolist(),
    #                                ds_neuron.nodes.node_cluster.tolist()))

    unique_compartments = ds_neuron.nodes.node_cluster.unique()

    # Finding the cluster of the nodes that were removed when downsampling the neuron
    whole_neuron_node_to_cluster = []

    for i in unique_compartments:

        nodes_to_permute = ds_neuron.nodes[ds_neuron.nodes.node_cluster == i].node_id.tolist()

        start_end = [i for i in itertools.permutations(nodes_to_permute, 2)]

        # start_end = [i for i in itertools.permutations(
        # ds_neuron.nodes[ds_neuron.nodes.node_cluster == i].node_id.tolist(), 2)]

        nodes_of_cluster = cluster_to_all_nodes(original_neuron, start_end)

        node_to_cluster_dictionary = dict(zip(nodes_of_cluster, [i] * len(nodes_of_cluster)))

        whole_neuron_node_to_cluster.append(node_to_cluster_dictionary)

    whole_neuron_node_to_cluster_dict = {k:v for d in whole_neuron_node_to_cluster for k, v in d.items()}

    # Fetching postsynapses
    ds_neuron_postsynapses = nvneu.fetch_synapse_connections(target_criteria=neuron_id)

    ds_neuron_synapse_to_node = match_connectors_to_nodes(ds_neuron_postsynapses,
                                                          original_neuron,
                                                          synapse_type='post')

    # Which nodes are in the CA?

    if roi is not None:

        skeleton_in_roi = navis.in_volume(original_neuron.nodes[['x', 'y', 'z']].values, roi, inplace=False)
        ds_isin = ds_neuron_synapse_to_node.node.isin(original_neuron.nodes[skeleton_in_roi].node_id.tolist())
        roi_syn_con = ds_neuron_synapse_to_node[ds_isin].copy()

    else:

        roi_syn_con = ds_neuron_synapse_to_node.copy()

    # roi_syn_con = ds_neuron_synapse_to_node[ds_neuron_synapse_to_node.node.isin(
    #                    original_neuron.nodes[skeleton_in_roi].node_id.tolist())].copy()

    a, b = nvneu.fetch_neurons(roi_syn_con.bodyId_pre.unique())
    bid_to_instance = dict(zip(a.bodyId.tolist(), a.instance.tolist()))
    roi_syn_con['instance'] = [bid_to_instance[i] for i in roi_syn_con.bodyId_pre]

    nodes_to_query = roi_syn_con[~roi_syn_con.node.isin(whole_neuron_node_to_cluster_dict.keys())].node.tolist()
    comp_of_missing_nodes = find_compartments_of_missing_nodes(roi_syn_con,
                                                               list(whole_neuron_node_to_cluster_dict.keys()),
                                                               original_neuron,
                                                               ds_neuron)

    whole_neuron_node_to_cluster_dict = {**whole_neuron_node_to_cluster_dict, **comp_of_missing_nodes}

    roi_syn_con['compartment'] = [whole_neuron_node_to_cluster_dict[i] for i in roi_syn_con.node.tolist()]

    original_neuron = node_to_compartment_full_neuron(original_neuron, ds_neuron)

    if return_electromodel:

        return(original_neuron, ds_neuron, roi_syn_con, test_m, test_memcap)

    else:

        return(original_neuron, ds_neuron, roi_syn_con)
Exemplo n.º 6
0
def get_volume_pruned_neurons_by_skid(skids,
                                      volume_id,
                                      mode='fele',
                                      resample=0,
                                      only_keep_largest_fragment=False,
                                      verbose=False,
                                      remote_instance=None):
    """
    mode : 'fele' -   Keep all parts of the neuron between its primary
                      neurite's First Entry to and Last Exit from the volume.
                      So if a segment of the primary neurite leaves and then
                      re-enters the volume, that segment is not removed.
           'strict' - All nodes outside the volume are pruned.
    resample : If set to a positive value, the neuron will be resampled before
               pruning to have treenodes placed every `resample` nanometers. If
               left at 0, resampling is not performed.
    In both cases, if a branch point is encountered before the first entry or
    last exit, that branch point is used as the prune point.
    """
    if remote_instance is None:
        remote_instance = source_project

    #if exit_volume_id is None:
    #    exit_volume_id = entry_volume_id

    neurons = pymaid.get_neuron(skids, remote_instance=source_project)
    if volume_id not in volumes:
        try:
            print(f'Pulling volume {volume_id} from project'
                  f' {remote_instance.project_id}.')
            volumes[volume_id] = pymaid.get_volume(
                volume_id, remote_instance=remote_instance)
        except:
            print(f"Couldn't find volume {volume_id} in project_id"
                  f" {remote_instance.project_id}! Exiting.")
            raise
    else:
        print(f'Loading volume {volume_id} from cache.')
    volume = volumes[volume_id]

    if type(neurons) is pymaid.core.CatmaidNeuron:
        neurons = pymaid.core.CatmaidNeuronList(neurons)

    if resample > 0:
        #TODO find the last node on the primary neurite and store its position
        neurons.resample(
            resample)  # This throws out radius info except for root
        #TODO find the new node closest to the stored node and set all nodes
        #between that node and root to have radius 500

    for neuron in neurons:
        if 'pruned by vol' in neuron.neuron_name:
            raise Exception(
                'Volume pruning was requested for '
                f' "{neuron.neuron_name}". You probably didn\'t mean to do'
                ' this since it was already pruned. Exiting.')
            continue
        print(f'Pruning neuron {neuron.neuron_name}')
        if mode == 'fele':
            """
            First, find the most distal primary neurite node. Then, walk
            backward until either finding a node within the volume or a branch
            point. Prune distal to one distal to that point (so it only gets
            the primary neurite and not the offshoot).
            Then, start from the primary neurite node that's a child of the
            soma node, and walk forward (how?) until finding a node within the
            volume or a branch point. Prune proximal to that.
            """
            nodes = neuron.nodes.set_index('node_id')
            # Find end of the primary neurite
            nodes['has_fat_child'] = False
            for tid in nodes.index:
                if nodes.at[tid, 'radius'] == PRIMARY_NEURITE_RADIUS:
                    parent = nodes.at[tid, 'parent_id']
                    nodes.at[parent, 'has_fat_child'] = True
            is_prim_neurite_end = (~nodes['has_fat_child']) & (
                nodes['radius'] == PRIMARY_NEURITE_RADIUS)
            prim_neurite_end = nodes.index[is_prim_neurite_end]
            if len(prim_neurite_end) is 0:
                raise ValueError(f"{neuron.neuron_name} doesn't look like a"
                                 "  motor neuron. Exiting.")
            elif len(prim_neurite_end) is not 1:
                raise ValueError('Multiple primary neurite ends for'
                                 f' {neuron.neuron_name}: {prim_neurite_end}.'
                                 '\nExiting.')

            nodes['is_in_vol'] = navis.in_volume(nodes, volume)

            # Walk backwards until at a point inside the volume, or at a branch
            # point
            current_node = prim_neurite_end[0]
            parent_node = nodes.at[current_node, 'parent_id']
            while (nodes.at[parent_node, 'type'] != 'branch'
                   and not nodes.at[parent_node, 'is_in_vol']):
                current_node = parent_node
                if verbose: print(f'Walk back to {current_node}')
                parent_node = nodes.at[parent_node, 'parent_id']
            if verbose: print(f'Pruning distal to {current_node}')
            neuron.prune_distal_to(current_node, inplace=True)

            # Start at the first primary neurite node downstream of root
            current_node = nodes.index[
                (nodes.parent_id == neuron.root[0])
                #& (nodes.radius == PRIMARY_NEURITE_RADIUS)][0]
                & (nodes.radius > 0)][0]
            #Walking downstream is a bit slow, but probably acceptable
            while (not nodes.at[current_node, 'is_in_vol']
                   and nodes.at[current_node, 'type'] == 'slab'):
                current_node = nodes.index[nodes.parent_id == current_node][0]
                if verbose: print(f'Walk forward to {current_node}')
            if not nodes.at[current_node, 'is_in_vol']:
                input('WARNING: Hit a branch before hitting the volume for'
                      f' neuron {neuron.neuron_name}. This is unusual.'
                      ' Press enter to acknowledge.')

            if verbose: print(f'Pruning proximal to {current_node}')
            neuron.prune_proximal_to(current_node, inplace=True)

        elif mode == 'strict':
            neuron.prune_by_volume(volume)  #This does in-place pruning

        if neuron.n_skeletons > 1:
            if only_keep_largest_fragment:
                print('Neuron has multiple disconnected fragments after'
                      ' pruning. Will only keep the largest fragment.')
                frags = morpho.break_fragments(neuron)
                neuron.nodes = frags[frags.n_nodes == max(
                    frags.n_nodes)][0].nodes
                #print(neuron)
            #else, the neuron will get healed and print a message about being
            #healed during upload_neuron, so nothing needs to be done here.

        if mode == 'fele':
            neuron.annotations.append(
                f'pruned (first entry, last exit) by vol {volume_id}')
        elif mode == 'strict':
            neuron.annotations.append(f'pruned (strict) by vol {volume_id}')

        neuron.neuron_name = neuron.neuron_name + f' - pruned by vol {volume_id}'

        if verbose: print('\n')

    return neurons
Exemplo n.º 7
0
def adjx_from_syn_conn(
        x: List[int],
        presyn_postsyn: str = "pre",
        roi: Optional = None,
        ct: tuple = (0.0, 0.0),
        rename_index: bool = False,
):
    """
    Creates an adjacency matrix from synapse connections

    Parameters
    ----------
    presyn_postsyn:
    x :                 list
                        a list of bodyIds of which you want to find their synaptic (pre/post) connections

    presyn_postsyn :    str
                        a string of either 'pre' or 'post'
                        If 'pre', the function will search for the
                        presynaptic connections / downstream neurons of x

                        If 'post', the function will search for the
                        postsynaptic connections / upstream neurons of x

    roi :               navis.Volume
                        A region of interest within which you are filtering the connections for


    ct :                tuple
                        Confidence threshold tuple containing the confidence value to filter above
                        The first value is presynaptic confidence and the second value postsynaptic confidence
                        e.g. (0.9, 0.8) will filter for connections where the presynaptic confidence > 0.9
                        and the postsynaptic confidence > 0.8

    rename_index :      bool
                        Whether to rename the index using the type of the connected neuron


    Returns
    -------
    df :                a DataFrame where x are the columns, the connection type (pre/post) are the rows
                        and the values the number of connections

    partner_type_dict : a dictionary where the keys are bodyIds of the
                        upstream/downstream neurons and the values are their types

    Examples
    --------
    """

    if presyn_postsyn == "pre":

        con = nvneu.fetch_synapse_connections(source_criteria=x)

        if roi:

            tt = navis.in_volume(con[["x_pre", "y_pre", "z_pre"]].values, roi)
            con = con[tt].copy()

        if ct[0] or ct[1] > 0.0:

            con = con[(con.confidence_pre > ct[0])
                      & (con.confidence_post > ct[1])].copy()

        neurons = con.bodyId_post.unique()
        n, _ = nvneu.fetch_neurons(neurons)
        partner_type_dict = dict(zip(n.bodyId.tolist(), n.type.tolist()))

        count = Counter(con.bodyId_post)
        count = count.most_common()
        count_dict = dict(count)

        df = pd.DataFrame(columns=[x],
                          index=[i for i, j in count],
                          data=[count_dict[i] for i, j in count])

    elif presyn_postsyn == "post":

        con = nvneu.fetch_synapse_connections(target_criteria=x)

        if roi:

            tt = navis.in_volume(con[["x_post", "y_post", "z_post"]].values,
                                 roi)
            con = con[tt].copy()

        if ct[0] or ct[1] > 0.0:

            con = con[(con.confidence_pre > ct[0])
                      & (con.confidence_post > ct[1])].copy()

        neurons = con.bodyId_pre.unique()
        n, _ = nvneu.fetch_neurons(neurons)
        partner_type_dict = dict(zip(n.bodyId.tolist(), n.type.tolist()))

        count = Counter(con.bodyId_pre)
        count = count.most_common()
        count_dict = dict(count)

        df = pd.DataFrame(index=[i for i, j in count],
                          columns=[x],
                          data=[count_dict[i] for i, j in count])
        df = df.T.copy()

    # df = pd.DataFrame(
    #     columns=[x], index=[i for i, j in count], data=[count_dict[i] for i, j in count])

    if rename_index:

        df.index = [partner_type_dict[i] for i in df.index]

    return (df, partner_type_dict)
Exemplo n.º 8
0
def random_sample_from_volume(
        volume: navis.Volume,
        supervoxels: cloudvolume.frontends.precomputed.CloudVolumePrecomputed,
        segmentation: cloudvolume.frontends.graphene.CloudVolumeGraphene,
        amount_to_query: Union[float, int] = 10,
        mip_level: Tuple[int, int, int] = (64, 64, 40),
        n_points: int = int(1e6),
        bbox_dim: int = 20,
        disable_rand_point_progress: bool = False):
    """
    Takes a FAFB14 volume and generates random points within it.
    These points are then used to create a series of small bounding boxes.
    The supervoxels within these bounding boxes are then queried.
    These supervoxel ids are then mapped to root ids.

    Parameters
    ----------

    volume: a navis.Volume object retrieved either through FAFB14 or Janelia Hemibrain.
                If the latter, then you need to transform the volume into flywire
                space using a bridging transformation before using this function

    mip_level: the mip resolution level.
                Available resolutions are (16, 16, 40), (32, 32, 40), (64, 64, 40).
                I recommend the lowest voxel resolution (64, 64, 40).

    n_points: the number of randomly generated points to create.
                Note that not all of these will be in the volume, although those that are are kept.

    bbox_dim: dimension of the query cube - small values (<100) are encouraged.

    supervoxels: A supervoxel cloudvolume.Volume object
                specifying the supervoxel instance you are querying.

    e.g. cloudvolume.CloudVolume("precomputed://https://s3-hpcrc.rc.princeton.edu/fafbv14-ws/ws_190410_FAFB_v02_ws_size_threshold_200",
                                mip=mip_level)

    segmentation: A segmentation cloudvolume.Volume object specifying the segmentation instance you are querying.
                e.g. cloudvolume.CloudVolume("graphene://https://prodv1.flywire-daf.com/segmentation/table/fly_v31")

    amount_to_query: Either a float or an int.
                If a float, this is a percentage of the total number of randomly generated points that are in the volume you want to query
                If an int, the first n randomly generated points that are in the volume will be used to query.

    Returns
    ----------

    root_ids: an array of root ids that were within the randomly generated query boxes.


    Example
    ----------
    import pymaid
    import cloudvolume
    import numpy as np
    import navis
    from random import randrange
    import chain
    import time

    MB_R = pymaid.get_volume('MB_whole_R')
​
    mip_level = (64, 64, 40)
​
    cv_svs = cloudvolume.CloudVolume("precomputed://https://s3-hpcrc.rc.princeton.edu/fafbv14-ws/ws_190410_FAFB_v02_ws_size_threshold_200", mip=mip_level)

    cv_seg = cloudvolume.CloudVolume("graphene://https://prodv1.flywire-daf.com/segmentation/table/fly_v31")

    root_ids = random_sample_from_volume(volume = MB_R, mip_level = mip_level,
                                 n_points = int(1e6), bbox_dim = 20,
                                 supervoxels=cv_svs, segmentation=cv_seg)

    """
    start_time = time.time()

    assert isinstance(
        volume, navis.Volume
    ), f'You need to pass a navis.Volume object. You passed {type(volume)}.'
    assert isinstance(
        supervoxels, cloudvolume.frontends.precomputed.CloudVolumePrecomputed
    ), f'You need to pass a cloudvolume.frontends.precomputed.CloudVolumePrecomputed object. You passed {type(supervoxels)}.'
    assert isinstance(
        segmentation, cloudvolume.frontends.graphene.CloudVolumeGraphene
    ), f'You need to pass a cloudvolume.frontends.graphene.CloudVolumeGraphene object. You passed {type(segmentation)}.'

    if type(n_points) != int:

        n_points = int(n_points)

    volume.vertices /= mip_level

    vertices_array = np.array(volume.vertices)
    vertices_array = vertices_array.astype(int)

    # generating random points
    n_point_string = format(n_points, ',d')
    print(f'Generating {n_point_string} random points... \n')

    # finding the min and max values of each xyz dimension
    x_min = min(vertices_array[:, 0])
    x_max = max(vertices_array[:, 0])

    y_min = min(vertices_array[:, 1])
    y_max = max(vertices_array[:, 1])

    z_min = min(vertices_array[:, 2])
    z_max = max(vertices_array[:, 2])

    # randomly generating integers inbetween these max and min values
    rand_x = [
        randrange(x_min, x_max)
        for i in tqdm(range(n_points), disable=disable_rand_point_progress)
    ]
    rand_y = [
        randrange(y_min, y_max)
        for i in tqdm(range(n_points), disable=disable_rand_point_progress)
    ]
    rand_z = [
        randrange(z_min, z_max)
        for i in tqdm(range(n_points), disable=disable_rand_point_progress)
    ]

    xyz_arr = np.array([[rand_x, rand_y, rand_z]])
    xyz_arr = xyz_arr.T
    xyz_arr = xyz_arr.reshape(n_points, 3)

    # How many randomly generated points are in the volume?
    in_vol = navis.in_volume(xyz_arr, volume)
    print(
        f"""Of {n_point_string} random points generated, {xyz_arr[in_vol].shape[0] / n_points * 1e2 :.2f}% of them are in the volume."""
    )
    print(f"""This equals {xyz_arr[in_vol].shape[0]} points. \n""")
    xyz_in_vol = xyz_arr[in_vol]

    # generating query cubes
    xyz_start = xyz_in_vol
    xyz_end = xyz_in_vol + bbox_dim

    # querying flywire
    print('Querying flywire... \n')

    supervoxel_ids = []

    if isinstance(amount_to_query, float):

        assert amount_to_query <= 1.0, '''If using percentages of the total number of randomly generated points,
                                        you cannot use more than 100% of the points generated.
                                        Did you intend to use an integer?'''

        print(
            f'You are passing a float, so using {amount_to_query * 1e2}% of the total number ({len(xyz_start)}) of randomly points generated. \n'
        )

        if amount_to_query == 1.0:

            print(
                'You are using 100% of the randomly generated points - this can take a long time to complete.'
            )

        n_query = int(len(xyz_start) * amount_to_query)

        print(f'{amount_to_query * 1e2}% = {n_query} points')

        xyz_start = xyz_start[:n_query]
        xyz_end = xyz_end[:n_query]

        print(
            f'Coverage: {((bbox_dim ** 3) * n_query / volume.volume) * 1e2 :.2g} % of the total {volume.name} volume is covered with {n_query} bounding boxes of {bbox_dim} cubic dimensions. \n'
        )

        for i in range(n_query):

            q = supervoxels[xyz_start[i][0]:xyz_end[i][0],
                            xyz_start[i][1]:xyz_end[i][1],
                            xyz_start[i][2]:xyz_end[i][2]]

            supervoxel_ids.append(q)

    elif isinstance(amount_to_query, int):

        assert amount_to_query <= len(
            xyz_start
        ), f'''You cannot use the first {amount_to_query} randomly generated
                                                    points when you have only {len(xyz_start)} exist in the volume.
                                                    Increase the number of randomly generated points in n_points.'''

        print(
            f'You are passing an integer, so the first {amount_to_query} randomly generated points will be used to query. \n'
        )

        xyz_start = xyz_start[:amount_to_query]
        xyz_end = xyz_end[:amount_to_query]

        print(
            f'Coverage: {((bbox_dim ** 3) * amount_to_query / volume.volume) * 1e2 :.2g} % of the total {volume.name} volume is covered with {amount_to_query} bounding boxes of {bbox_dim} cubic dimensions. \n'
        )

        for i in range(amount_to_query):

            q = supervoxels[xyz_start[i][0]:xyz_end[i][0],
                            xyz_start[i][1]:xyz_end[i][1],
                            xyz_start[i][2]:xyz_end[i][2]]

            supervoxel_ids.append(q)

    sv_ids_unique = [np.unique(i) for i in supervoxel_ids]
    sv_ids_unique = np.unique(list(chain.from_iterable(sv_ids_unique)))
    print('Fetching root ids... \n')
    root_ids = segmentation.get_roots(sv_ids_unique)

    # Removing zeros
    root_ids = root_ids[~(root_ids == 0)]

    root_ids = np.unique(root_ids)

    print(
        f'Random sampling complete: {len(root_ids)} unique root ids found \n')
    print(
        f'This function took {(time.time() - start_time) / 60 :.2f} minutes to complete'
    )
    return (root_ids)