Esempio n. 1
0
    def plot_morpho(self, figsize, save_path=None, alpha=1, color=None, volume=None, vol_color = (250, 250, 250, .05), azim=-90, elev=-90, dist=6, xlim3d=(-4500, 110000), ylim3d=(-4500, 110000), linewidth=1.5, connectors=False):
        # recommended volume for L1 dataset, 'PS_Neuropil_manual'

        neurons = pymaid.get_neurons(self.skids)

        if(color==None):
            color = self.color

        if(volume!=None):
            neuropil = pymaid.get_volume(volume)
            neuropil.color = vol_color
            fig, ax = navis.plot2d([neurons, neuropil], method='3d_complex', color=color, linewidth=linewidth, connectors=connectors, cn_size=2, alpha=alpha)

        if(volume==None):
            fig, ax = navis.plot2d([neurons], method='3d_complex', color=color, linewidth=linewidth, connectors=connectors, cn_size=2, alpha=alpha)

        ax.azim = azim
        ax.elev = elev
        ax.dist = dist
        ax.set_xlim3d(xlim3d)
        ax.set_ylim3d(ylim3d)

        plt.show()

        if(save_path!=None):
            fig.savefig(f'{save_path}.png', format='png', dpi=300, transparent=True)
Esempio n. 2
0
def segments_to_neuron(seg_ids,
                       autoseg_instance,
                       name_pattern="Google: {id}",
                       verbose=True,
                       raise_none_found=True):
    """Retrieve autoseg neurons of given segmentation ID(s).

    If a given segmentation ID has been merged into another fragment, will try
    retrieving by annotation.

    Parameters
    ----------
    seg_ids :           int | list of int
                        Segmentation ID(s) of autoseg skeletons to retrieve.
    autoseg_instance :  pymaid.CatmaidInstance
                        Instance with autoseg skeletons.
    name_pattern :      str, optional
                        Segmentation IDs are encoded in the name. Use parameter
                        this to define that pattern.
    raise_none_found :  bool, optional
                        If True and none of the requested segments were found,
                        will raise ValueError

    Returns
    -------
    CatmaidNeuronList
                        Neurons representing given segmentation IDs.

    """
    assert isinstance(autoseg_instance, pymaid.CatmaidInstance)

    seg2skid = segments_to_skids(seg_ids,
                                 autoseg_instance=autoseg_instance,
                                 verbose=verbose)

    to_fetch = list(set([v for v in seg2skid.values() if v]))

    if not to_fetch:
        if raise_none_found:
            raise ValueError(
                "None of the provided segmentation IDs could be found")
        else:
            # Return empty list
            return pymaid.CatmaidNeuronList([])

    nl = pymaid.get_neurons(to_fetch, remote_instance=autoseg_instance)

    # Make sure we're dealing with a list of neurons
    if isinstance(nl, pymaid.CatmaidNeuron):
        nl = pymaid.CatmaidNeuronList(nl)

    # Invert seg2skid
    skid2seg = {}
    for k, v in seg2skid.items():
        skid2seg[v] = skid2seg.get(v, []) + [k]

    for n in nl:
        n.seg_ids = skid2seg[int(n.skeleton_id)]

    return nl
Esempio n. 3
0
def get_in_out(ids):
    nl = pymaid.get_neurons(ids)
    connectors = get_connectors(nl)
    outputs = connectors[connectors["presynaptic_to"].isin(ids)]
    # I hope this is a valid assumption?
    inputs = connectors[~connectors["presynaptic_to"].isin(ids)]
    return inputs, outputs
Esempio n. 4
0
def plot_celltype(path, pairids, n_rows, n_cols, celltypes, pairs_path, plot_pairs=True, connectors=False, cn_size=0.25, color=None, names=False, plot_padding=[0,0]):

    pairs = Promat.get_pairs(pairs_path)
    # pull specific cell type identities
    celltype_ct = [Celltype(f'{pairid}-ipsi-bi', Promat.get_paired_skids(pairid, pairs)) for pairid in pairids]
    celltype_ct = Celltype_Analyzer(celltype_ct)
    celltype_ct.set_known_types(celltypes)
    members = celltype_ct.memberships()

    # link identities to official celltype colors 
    celltype_identities = [np.where(members.iloc[:, i]==1.0)[0][0] for i in range(0, len(members.columns))]
    if(plot_pairs):
        celltype_ct = [Celltype(celltypes[celltype_identities[i]].name.replace('s', ''), Promat.get_paired_skids(pairid, pairs), celltypes[celltype_identities[i]].color) if celltype_identities[i]<17 else Celltype(f'{pairid}', Promat.get_paired_skids(pairid, pairs), '#7F7F7F') for i, pairid in enumerate(pairids)]
    if(plot_pairs==False):
        celltype_ct = [Celltype(celltypes[celltype_identities[i]].name.replace('s', ''), pairid, celltypes[celltype_identities[i]].color) if celltype_identities[i]<17 else Celltype('Other', pairid, '#7F7F7F') for i, pairid in enumerate(pairids)]

    # plot neuron morphologies
    neuropil = pymaid.get_volume('PS_Neuropil_manual')
    neuropil.color = (250, 250, 250, .05)

    n_rows = n_rows
    n_cols = n_cols
    alpha = 1

    fig = plt.figure(figsize=(n_cols*2, n_rows*2))
    gs = plt.GridSpec(n_rows, n_cols, figure=fig, wspace=plot_padding[0], hspace=plot_padding[1])
    axs = np.empty((n_rows, n_cols), dtype=object)

    for i, skids in enumerate([x.skids for x in celltype_ct]):
        if(color!=None):
            col = color
        else: col = celltype_ct[i].color
        neurons = pymaid.get_neurons(skids)

        inds = np.unravel_index(i, shape=(n_rows, n_cols))
        ax = fig.add_subplot(gs[inds], projection="3d")
        axs[inds] = ax
        navis.plot2d(x=[neurons, neuropil], connectors=connectors, cn_size=cn_size, color=col, alpha=alpha, ax=ax, method='3d_complex')

        ax.azim = -90
        ax.elev = -90
        ax.dist = 6
        ax.set_xlim3d((-4500, 110000))
        ax.set_ylim3d((-4500, 110000))
        if(names):
            ax.text(x=(ax.get_xlim()[0] + ax.get_xlim()[1])/2 - ax.get_xlim()[1]*0.05, y=ax.get_ylim()[1]*4/5, z=0, 
                    s=celltype_ct[i].name, transform=ax.transData, color=col, alpha=1)

    fig.savefig(f'{path}.png', format='png', dpi=300, transparent=True)
Esempio n. 5
0
def get_split_in_out(labels, splits):
    nl = pymaid.get_neurons(labels)

    all_axon_treenodes = []
    all_dend_treenodes = []
    for i, n in enumerate(nl):
        skid = int(n.skeleton_id)
        # order of output is axon, dendrite
        fragments = pymaid.cut_neuron(n, splits[skid])

        axon_treenodes = fragments[0].nodes.treenode_id.values
        dend_treenodes = fragments[1].nodes.treenode_id.values
        all_axon_treenodes.append(axon_treenodes)
        all_dend_treenodes.append(dend_treenodes)

    all_axon_treenodes = np.concatenate(all_axon_treenodes)
    all_dend_treenodes = np.concatenate(all_dend_treenodes)
    all_treenodes = np.concatenate((all_axon_treenodes, all_dend_treenodes))

    def filter_treenodes(nodes):
        if not isinstance(nodes, list):
            nodes = [nodes]
        relevant_nodes = [node for node in nodes if node in all_treenodes]
        is_axon = np.isin(relevant_nodes, all_axon_treenodes)
        is_dend = np.isin(relevant_nodes, all_dend_treenodes)
        if is_axon.all():
            return "axon"
        elif is_dend.all():
            return "dend"
        elif (is_axon | is_dend).any():
            # afaik can only happen if synapse is polyadic and some inputs to group are
            # onto dendrites and some are onto axons
            return "mixed"
        else:
            return ""

    inputs, outputs = get_in_out(labels)

    inputs["postsynaptic_type"] = inputs["postsynaptic_to_node"].map(
        filter_treenodes)
    outputs["presynaptic_type"] = outputs["presynaptic_to_node"].map(
        filter_treenodes)
    return inputs, outputs
Esempio n. 6
0
def autoreview_edges(x, conf_threshold=1, vol=None, remote_instance=None):
    """Automatically review (low-confidence) edges between nodes.

    The way this works:
      1. Fetch the live version of the neuron(s) from the CATMAID instance
      2. Use raycasting to test (low-confidence) edges
      3. Edge confidence is set to ``5`` if test is passed and to ``1`` if not

    You *can* use this function to test all edges in a neuron by increasing
    ``conf_threshold`` to 5. Please note that this could produce a lot of false
    positives (i.e. edges will be flagged as incorrect even though they
    aren't). Part of the problem is that mitochondria are segmented as
    separate entities and hence introduce membranes inside a neuron.

    Parameters
    ----------
    x :                 skeleton ID(s) | pymaid.CatmaidNeuron/List
                        Neuron(s) to review.
    conf_threshold :    int, optional
                        Confidence threshold for edges to be tested. By
                        default only reviews edges with confidence <= 1.
    vol :               cloudvolume.CloudVolume, optional
                        CloudVolume pointing to segmentation data.
    remote_instance :   pymaid.CatmaidInstance, optional
                        CATMAID instance. If ``None``, will use globally
                        define instance.

    Returns
    -------
    server response
                        CATMAID server response from updating node
                        confidences.

    See Also
    --------
    :func:`fafbseg.test_edges`
                        If you only need to test without changing confidences.

    Examples
    --------
    >>> # Set up CloudVolume from the publicly hosted FAFB segmentation data
    >>> # (if you have a local copy, use that instead)
    >>> from cloudvolume import CloudVolume
    >>> vol = CloudVolume('https://storage.googleapis.com/fafb-ffn1-20190805/segmentation',
    ...                   cache=True,
    ...                   progress=False)
    >>> # Autoreview edges
    >>> _ = fafbseg.autoreview_edges(14401884, vol=vol, remote_instance=manual)

    """
    # Fetch neuron(s)
    n = pymaid.get_neurons(x, remote_instance=remote_instance)

    # Extract low confidence edges
    not_root = ~n.nodes.parent_id.isnull()
    is_low_conf = n.nodes.confidence <= conf_threshold
    to_test = n.nodes[is_low_conf & not_root]

    if to_test.empty:
        print('No low-confidence edges to test in neuron(s) '
              '{} found'.format(n.skeleton_id))
        return

    # Test edges
    verdict = test_edges(n,
                         edges=to_test[['treenode_id', 'parent_id']].values,
                         vol=vol)

    # Update node confidences
    new_confidences = {n: 5 for n in to_test[verdict].treenode_id.values}
    new_confidences.update(
        {n: 1
         for n in to_test[~verdict].treenode_id.values})
    resp = pymaid.update_node_confidence(new_confidences,
                                         remote_instance=remote_instance)

    msg = '{} of {} tested low-confidence edges were found to be correct.'
    msg = msg.format(sum(verdict), to_test.shape[0])
    print(msg)

    return resp
Esempio n. 7
0
def find_fragments(x,
                   remote_instance,
                   min_node_overlap=3,
                   min_nodes=1,
                   mesh=None):
    """Find manual tracings overlapping with given autoseg neuron.

    This function is a generalization of ``find_autoseg_fragments`` and is
    designed to not require overlapping neurons to have references (e.g.
    in their name) to segmentation IDs:

        1. Traverse neurites of ``x`` search within 2.5 microns radius for
           potentially overlapping fragments.
        2. Either collect segmentation IDs for the input neuron and all
           potentially overlapping fragments or (if provided) use the mesh
           to check if the candidates are inside that mesh.
        3. Return fragments that overlap with at least ``min_overlap`` nodes
           with input neuron.

    Parameters
    ----------
    x :                 pymaid.CatmaidNeuron | navis.TreeNeuron
                        Neuron to collect fragments for.
    remote_instance :   pymaid.CatmaidInstance
                        Catmaid instance in which to search for fragments.
    min_node_overlap :  int, optional
                        Minimal overlap between `x` and a fragment in nodes. If
                        the fragment has less total nodes than `min_overlap`,
                        the threshold will be lowered to:
                        ``min_overlap = min(min_overlap, fragment.n_nodes)``
    min_nodes :         int, optional
                        Minimum node count for returned neurons.
    mesh :              navis.Volume | trimesh.Trimesh | navis.MeshNeuron, optional
                        Mesh representation of ``x``. If provided will use the
                        mesh instead of querying the segmentation to determine
                        if fragments overlap. This is generally the faster.

    Return
    ------
    pymaid.CatmaidNeuronList
                        CatmaidNeurons of the overlapping fragments. Overlap
                        scores are attached to each neuron as ``.overlap_score``
                        attribute.

    Examples
    --------
    Setup:

    >>> import pymaid
    >>> import fafbseg

    >>> manual = pymaid.CatmaidInstance('MANUAL_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')
    >>> auto = pymaid.CatmaidInstance('AUTO_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')

    >>> # Set a source for segmentation data
    >>> fafbseg.use_google_storage("https://storage.googleapis.com/fafb-ffn1-20190805/segmentation")

    Find manually traced fragments overlapping with an autoseg neuron:

    >>> x = pymaid.get_neuron(204064470, remote_instance=auto)
    >>> frags_of_x = fafbseg.find_fragments(x, remote_instance=manual)

    See Also
    --------
    fafbseg.find_autoseg_fragments
                        Use this function if you are looking for autoseg
                        fragments overlapping with a given neuron. Because we
                        can use the reference to segment IDs (via names &
                        annotations), this function is much faster than
                        ``find_fragments``.

    """
    if not isinstance(x, navis.TreeNeuron):
        raise TypeError(
            'Expected navis.TreeNeuron or pymaid.CatmaidNeuron, got "{}"'.
            format(type(x)))

    meshtypes = (type(None), navis.MeshNeuron, navis.Volume, tm.Trimesh)
    if not isinstance(mesh, meshtypes):
        raise TypeError(f'Unexpected mesh of type "{type(mesh)}"')

    # Resample the autoseg neuron to 0.5 microns
    x_rs = x.resample(500, inplace=False)

    # For each node get skeleton IDs in a 0.25 micron radius
    r = 250
    # Generate bounding boxes around each node
    bboxes = [
        np.vstack([co - r, co + r]).T
        for co in x_rs.nodes[['x', 'y', 'z']].values
    ]

    # Query each bounding box
    urls = [
        remote_instance._get_skeletons_in_bbox(minx=min(b[0]),
                                               maxx=max(b[0]),
                                               miny=min(b[1]),
                                               maxy=max(b[1]),
                                               minz=min(b[2]),
                                               maxz=max(b[2]),
                                               min_nodes=min_nodes)
        for b in bboxes
    ]
    resp = remote_instance.fetch(urls,
                                 desc='Searching for overlapping neurons')
    skids = set([s for l in resp for s in l])

    # Return empty NeuronList if no skids found
    if not skids:
        return pymaid.CatmaidNeuronList([])

    # Get nodes for these candidates
    tn_table = pymaid.get_node_table(skids,
                                     include_details=False,
                                     convert_ts=False,
                                     remote_instance=remote_instance)
    # Keep track of total node counts
    node_counts = tn_table.groupby('skeleton_id').node_id.count().to_dict()

    # If no mesh, use segmentation
    if not mesh:
        # Get segment IDs for the input neuron
        x.nodes['seg_id'] = locs_to_segments(x.nodes[['x', 'y', 'z']].values,
                                             coordinates='nm',
                                             mip=0)

        # Count segment IDs
        x_seg_counts = x.nodes.groupby('seg_id').node_id.count().reset_index(
            drop=False)
        x_seg_counts.columns = ['seg_id', 'counts']

        # Remove seg IDs 0
        x_seg_counts = x_seg_counts[x_seg_counts.seg_id != 0]

        # Generate KDTree for nearest neighbor calculations
        tree = navis.neuron2KDTree(x)

        # Now remove nodes that aren't even close to our input neuron
        dist, ix = tree.query(tn_table[['x', 'y', 'z']].values,
                              distance_upper_bound=2500)
        tn_table = tn_table.loc[dist <= 2500]

        # Remove neurons that can't possibly have enough overlap
        node_counts2 = tn_table.groupby(
            'skeleton_id').node_id.count().to_dict()
        to_keep = [
            k for k, v in node_counts2.items()
            if v >= min(min_node_overlap, node_counts[k])
        ]
        tn_table = tn_table[tn_table.skeleton_id.isin(to_keep)]

        # Add segment IDs
        tn_table['seg_id'] = locs_to_segments(tn_table[['x', 'y', 'z']].values,
                                              coordinates='nm',
                                              mip=0)

        # Now group by neuron and by segment
        seg_counts = tn_table.groupby(
            ['skeleton_id', 'seg_id']).node_id.count().reset_index(drop=False)
        # Rename columns
        seg_counts.columns = ['skeleton_id', 'seg_id', 'counts']

        # Remove seg IDs 0
        seg_counts = seg_counts[seg_counts.seg_id != 0]

        # Remove segments IDs that are not overlapping with input neuron
        seg_counts = seg_counts[np.isin(seg_counts.seg_id.values,
                                        x_seg_counts.seg_id.values)]

        # Now go over each candidate and see if there is enough overlap
        ol = []
        scores = []
        for s in seg_counts.skeleton_id.unique():
            # Subset to nodes of this neurons
            this_counts = seg_counts[seg_counts.skeleton_id == s]

            # If the neuron is smaller than min_overlap, lower the threshold
            this_min_ol = min(min_node_overlap, node_counts[s])

            # Sum up counts for both input neuron and this candidate
            c_count = this_counts.counts.sum()
            x_count = x_seg_counts[x_seg_counts.seg_id.isin(
                this_counts.seg_id.values)].counts.sum()

            # If there is enough overlap, keep this neuron
            # and add score as `overlap_score`
            if (c_count >= this_min_ol) and (x_count >= this_min_ol):
                # The score is the minimal overlap
                scores.append(min(c_count, x_count))
                ol.append(s)
    else:
        # Check if nodes are inside or outside the mesh
        tn_table['in_mesh'] = navis.in_volume(tn_table[['x', 'y', 'z']].values,
                                              mesh).astype(bool)
        # Count the number of in-mesh nodes for each neuron
        # This also drops skeletons without
        ol_counts = tn_table.groupby('skeleton_id',
                                     as_index=False).in_mesh.sum()

        # Rename columns
        ol_counts.columns = ['skeleton_id', 'counts']

        # Now subset to those that are overlapping sufficiently
        # First drop all non-overlapping fragments
        ol_counts = ol_counts[ol_counts.counts > 0]

        # Add column with total node counts
        ol_counts['node_count'] = ol_counts.skeleton_id.map(node_counts)

        # Generate an threshold array such that the threshold is the minimum
        # between the total nodes and min_node_overlap
        mno_array = np.repeat([min_node_overlap], ol_counts.shape[0])
        mnc_array = ol_counts.skeleton_id.map(node_counts).values
        thr_array = np.vstack([mno_array, mnc_array]).min(axis=0)

        # Subset to candidates that meet the threshold
        ol_counts = ol_counts[ol_counts.counts >= thr_array]

        ol = ol_counts.skeleton_id.tolist()
        scores = ol_counts.counts.tolist()

    if ol:
        ol = pymaid.get_neurons(ol, remote_instance=remote_instance)

        # Make sure it's a neuronlist
        if not isinstance(ol, pymaid.CatmaidNeuronList):
            ol = pymaid.CatmaidNeuronList(ol)

        for n, s in zip(ol, scores):
            n.overlap_score = s
    else:
        ol = pymaid.CatmaidNeuronList([])

    return ol
Esempio n. 8
0
        np.vectorize(CLASS_COLOR_DICT.get)(meta["merge_class"])))
connection_types = ["axon", "dendrite", "unsplittable"]
pal = sns.color_palette("deep", 5)
colors = [pal[1], pal[2], pal[4]]
connection_colors = dict(zip(connection_types, colors))

splits = pymaid.find_treenodes(tags="mw axon split")
splits = splits.set_index("skeleton_id")["treenode_id"].squeeze()

# plot paired neurons
pair_meta = meta[meta["pair_id"] != -1]
pairs = pair_meta["pair_id"].unique()

for p in pairs:
    temp_meta = pair_meta[pair_meta["pair_id"] == p]
    skids = temp_meta.index.values.astype(int)
    neuron_class = temp_meta["merge_class"].iloc[0]
    nl = pymaid.get_neurons(skids)
    plot_fragments(nl, splits, neuron_class=neuron_class)
    stashfig(get_savename(nl, neuron_class=neuron_class))

# plot unpaired neurons
unpair_meta = meta[meta["pair_id"] == -1]

for skid, row in unpair_meta.iterrows():
    neuron_class = row["merge_class"]
    nl = pymaid.get_neurons([skid])
    nl = pymaid.CatmaidNeuronList(nl)
    plot_fragments(nl, splits, neuron_class=neuron_class)
    stashfig(get_savename(nl, neuron_class=neuron_class))
Esempio n. 9
0
def find_fragments(x, remote_instance, min_node_overlap=3, min_nodes=1):
    """Find manual tracings overlapping with given autoseg neuron.

    This function is a generalization of ``find_autoseg_fragments`` and is
    designed to not require overlapping neurons to have references (e.g.
    in their name) to segmentation IDs:

        1. Traverse neurites of ``x`` search within 2.5 microns radius for
           potentially overlapping fragments.
        2. Collect segmentation IDs for the input neuron and all potentially
           overlapping fragments using the brainmaps API.
        3. Return fragments that overlap with at least ``min_overlap`` nodes
           with input neuron.

    Parameters
    ----------
    x :                 pymaid.CatmaidNeuron
                        Neuron to collect fragments for.
    remote_instance :   pymaid.CatmaidInstance
                        Catmaid instance in which to search for fragments.
    min_node_overlap :  int, optional
                        Minimal overlap between `x` and a fragment in nodes. If
                        the fragment has less total nodes than `min_overlap`,
                        the threshold will be lowered to:
                        ``min_overlap = min(min_overlap, fragment.n_nodes)``
    min_nodes :         int, optional
                        Minimum node count for returned neurons.

    Return
    ------
    pymaid.CatmaidNeuronList
                        CatmaidNeurons of the overlapping fragments. Overlap
                        scores are attached to each neuron as ``.overlap_score``
                        attribute.

    Examples
    --------
    Setup:

    >>> import pymaid
    >>> import fafbseg
    >>> import brainmappy as bm

    >>> manual = pymaid.CatmaidInstance('MANUAL_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')
    >>> auto = pymaid.CatmaidInstance('AUTO_SERVER_URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')

    >>> flow = bm.acquire_credentials()
    >>> # Note that volume ID must match with the autoseg CatmaidInstance!
    >>> bm.set_global_volume('some_volume_id')

    Find manually traced fragments overlapping with an autoseg neuron:

    >>> x = pymaid.get_neuron(204064470, remote_instance=auto)
    >>> frags_of_x = fafbseg.find_fragments(x, remote_instance=manual)

    See Also
    --------
    fafbseg.find_autoseg_fragments
                        Use this function if you are looking for autoseg
                        fragments overlapping with a given neuron. Because we
                        can use the reference to segment IDs (via names &
                        annotations), this function is much faster than
                        ``find_fragments``.

    """
    if not isinstance(x, pymaid.CatmaidNeuron):
        raise TypeError('Expected pymaid.CatmaidNeuron, got "{}"'.format(
            type(x)))

    # Resample the autoseg neuron to 0.5 microns
    x_rs = x.resample(500, inplace=False)

    # For each node get skeleton IDs in a 0.25 micron radius
    r = 250
    # Generate bounding boxes around each node
    bboxes = [
        np.vstack([co - r, co + r]).T
        for co in x_rs.nodes[['x', 'y', 'z']].values
    ]

    # Query each bounding box
    urls = [
        remote_instance._get_skeletons_in_bbox(minx=min(b[0]),
                                               maxx=max(b[0]),
                                               miny=min(b[1]),
                                               maxy=max(b[1]),
                                               minz=min(b[2]),
                                               maxz=max(b[2]),
                                               min_nodes=min_nodes)
        for b in bboxes
    ]
    resp = remote_instance.fetch(urls,
                                 desc='Searching for overlapping neurons')
    skids = set([s for l in resp for s in l])

    # Return empty NeuronList if no skids found
    if not skids:
        return pymaid.CatmaidNeuronList([])

    # Get nodes for these candidates
    tn_table = pymaid.get_treenode_table(skids,
                                         include_details=False,
                                         convert_ts=False,
                                         remote_instance=remote_instance)
    # Keep track of total node counts
    node_counts = tn_table.groupby('skeleton_id').treenode_id.count().to_dict()

    # Get segment IDs for the input neuron
    x.nodes['seg_id'] = segmentation.get_seg_ids(x.nodes[['x', 'y',
                                                          'z']].values)

    # Count segment IDs
    x_seg_counts = x.nodes.groupby('seg_id').treenode_id.count().reset_index(
        drop=False)
    x_seg_counts.columns = ['seg_id', 'counts']

    # Remove seg IDs 0
    x_seg_counts = x_seg_counts[x_seg_counts.seg_id != 0]

    # Generate KDTree for nearest neighbor calculations
    tree = pymaid.neuron2KDTree(x)

    # Now remove nodes that aren't even close to our input neuron
    dist, ix = tree.query(tn_table[['x', 'y', 'z']].values,
                          distance_upper_bound=2500)
    tn_table = tn_table.loc[dist <= 2500]

    # Remove neurons that can't possibly have enough overlap
    node_counts2 = tn_table.groupby(
        'skeleton_id').treenode_id.count().to_dict()
    to_keep = [
        k for k, v in node_counts2.items()
        if v >= min(min_node_overlap, node_counts[k])
    ]
    tn_table = tn_table[tn_table.skeleton_id.isin(to_keep)]

    # Add segment IDs
    tn_table['seg_id'] = segmentation.get_seg_ids(tn_table[['x', 'y',
                                                            'z']].values)

    # Now group by neuron and by segment
    seg_counts = tn_table.groupby(
        ['skeleton_id', 'seg_id']).treenode_id.count().reset_index(drop=False)
    # Rename columns
    seg_counts.columns = ['skeleton_id', 'seg_id', 'counts']

    # Remove seg IDs 0
    seg_counts = seg_counts[seg_counts.seg_id != 0]

    # Remove segments IDs that are not overlapping with input neuron
    seg_counts = seg_counts[np.isin(seg_counts.seg_id.values,
                                    x_seg_counts.seg_id.values)]

    # Now go over each candidate and see if there is enough overlap
    ol = []
    scores = []
    for s in seg_counts.skeleton_id.unique():
        # Subset to nodes of this neurons
        this_counts = seg_counts[seg_counts.skeleton_id == s]

        # If the neuron is smaller than min_overlap, lower the threshold
        this_min_ol = min(min_node_overlap, node_counts[s])

        # Sum up counts for both input neuron and this candidate
        c_count = this_counts.counts.sum()
        x_count = x_seg_counts[x_seg_counts.seg_id.isin(
            this_counts.seg_id.values)].counts.sum()

        # If there is enough overlap, keep this neuron
        # and add score as `overlap_score`
        if (c_count >= this_min_ol) and (x_count >= this_min_ol):
            # The score is the minimal overlap
            scores.append(min(c_count, x_count))
            ol.append(s)

    if ol:
        ol = pymaid.get_neurons(ol, remote_instance=remote_instance)

        # Make sure it's a neuronlist
        if not isinstance(ol, pymaid.CatmaidNeuronList):
            ol = pymaid.CatmaidNeuronList(ol)

        for n, s in zip(ol, scores):
            n.overlap_score = s
    else:
        ol = pymaid.CatmaidNeuronList([])

    return ol
Esempio n. 10
0
FNAME = os.path.basename(__file__)[:-3]
print(FNAME)


def stashfig(name, **kws):
    savefig(name, foldername=FNAME, save_on=True, **kws)


VERSION = "2020-03-09"
print(f"Using version {VERSION}")

mg = load_metagraph("G", version=VERSION)
start_instance()

nl = pymaid.get_neurons([mg.meta.index[2]])

# Plot using default settings
fig, ax = nl.plot2d()

# %% [markdown]
# #
mg = mg.sort_values("Pair ID")
nl = pymaid.get_neurons(mg.meta[mg.meta["Merge Class"] == "sens-ORN"].index.values)
fig, ax = nl.plot2d()


# %%
nl = pymaid.get_neurons(mg.meta.index.values)
print(len(nl))
# %% [markdown]
Esempio n. 11
0
    def get_cremi_score(self, score_thr=0, skel_ids=None):

        if skel_ids is None:
            assert self.skeleton_ids is not None
            skel_ids = csv_to_list(self.skeleton_ids, 0)
        else:
            assert self.skeleton_ids is None
            assert type(skel_ids) is list

        fpcountall, fncountall, predall, gtall, tpcountall, num_clustered_synapsesall = 0, 0, 0, 0, 0, 0

        pred_synapses_all = []
        for skel_id in skel_ids:
            logger.debug('evaluating skeleton {}'.format(skel_id))
            if not self.only_output_synapses and not self.only_input_synapses:
                pred_synapses = self.pred_df[
                    (self.pred_df.id_skel_pre == skel_id) |
                    (self.pred_df.id_skel_post == skel_id)]
                gt_synapses = self.gt_df[(self.gt_df.id_skel_pre == skel_id) |
                                         (self.gt_df.id_skel_post == skel_id)]

            elif self.only_input_synapses:
                pred_synapses = self.pred_df[self.pred_df.id_skel_post ==
                                             skel_id]
                gt_synapses = self.gt_df[self.gt_df.id_skel_post == skel_id]
            elif self.only_output_synapses:
                pred_synapses = self.pred_df[self.pred_df.id_skel_pre ==
                                             skel_id]
                gt_synapses = self.gt_df[self.gt_df.id_skel_pre == skel_id]
            else:
                raise Exception(
                    'Unclear parameter configuration: {}, {}'.format(
                        self.only_output_synapses, self.only_input_synapses))

            pred_synapses = [
                synapse.Synapse(**dic)
                for dic in pred_synapses.to_dict(orient='records')
            ]

            if not len(self.filter_seg_ids) == 0:
                pred_synapses = [
                    syn for syn in pred_synapses
                    if not (syn.id_segm_pre in self.filter_seg_ids
                            or syn.id_segm_post in self.filter_seg_ids)
                ]

            pred_synapses = [
                syn for syn in pred_synapses if syn.score >= score_thr
            ]
            if self.filter_same_id:
                if self.filter_same_id_type == 'seg':
                    pred_synapses = [
                        syn for syn in pred_synapses
                        if syn.id_segm_pre != syn.id_segm_post
                    ]
                elif self.filter_same_id_type == 'skel':
                    pred_synapses = [
                        syn for syn in pred_synapses
                        if syn.id_skel_pre != syn.id_skel_post
                    ]
            if self.filter_redundant:
                assert self.filter_redundant_dist_thr is not None
                num_synapses = len(pred_synapses)
                if self.filter_redundant_dist_type == 'geodesic':
                    # Get skeleton
                    skeleton = pymaid.get_neurons([skel_id])
                else:
                    skeleton = None
                __, removed_ids = synapse.cluster_synapses(
                    pred_synapses,
                    self.filter_redundant_dist_thr,
                    fuse_strategy='max_score',
                    id_type=self.filter_redundant_id_type,
                    skeleton=skeleton,
                    ignore_ids=self.filter_redundant_ignore_ids)
                pred_synapses = [
                    syn for syn in pred_synapses if not syn.id in removed_ids
                ]
                num_clustered_synapses = num_synapses - len(pred_synapses)
                logger.debug(
                    'num of clustered synapses: {}, skel id: {}'.format(
                        num_clustered_synapses, skel_id))
            else:
                num_clustered_synapses = 0

            logger.debug('found {} predicted synapses'.format(
                len(pred_synapses)))

            gt_synapses = [
                synapse.Synapse(**dic)
                for dic in gt_synapses.to_dict(orient='records')
            ]
            stats = evaluation.synaptic_partners_fscore(
                pred_synapses,
                gt_synapses,
                matching_threshold=self.matching_threshold,
                all_stats=True,
                use_only_pre=self.matching_threshold_only_pre,
                use_only_post=self.matching_threshold_only_post)
            fscore, precision, recall, fpcount, fncount, tp_fp_fn_syns = stats

            # tp_syns, fp_syns, fn_syns_gt, tp_syns_gt = evaluation.from_synapsematches_to_syns(
            #     matches, pred_synapses, gt_synapses)
            tp_syns, fp_syns, fn_syns_gt, tp_syns_gt = tp_fp_fn_syns
            fpcountall += fpcount
            fncountall += fncount
            tpcountall += len(tp_syns_gt)
            predall += len(pred_synapses)
            gtall += len(gt_synapses)
            num_clustered_synapsesall += num_clustered_synapses

            assert len(fp_syns) == fpcount
            pred_synapses_all.extend(pred_synapses)
            logger.info(
                f'skel id {skel_id} with fscore {float(fscore):0.2}, precision: {float(precision):0.2}, recall: {float(recall):0.2}'
            )
            logger.info(f'fp: {fpcount}, fn: {fncount}')
            logger.info(
                f'total predicted {len(pred_synapses)}; total gt: {len(gt_synapses)}\n'
            )

        pred_dic = {}
        for syn in pred_synapses_all:
            pred_dic[syn.id] = syn
        logger.debug('Number of duplicated syn ids: {} versus {}'.format(
            len(pred_synapses_all), len(pred_dic)))

        precision = float(tpcountall) / (tpcountall + fpcountall) if (
            tpcountall + fpcountall) > 0 else 0.
        recall = float(tpcountall) / (tpcountall + fncountall) if (
            tpcountall + fncountall) > 0 else 0.
        if (precision + recall) > 0:
            fscore = 2.0 * precision * recall / (precision + recall)
        else:
            fscore = 0.0

        # Collect all in a single document in order to enable quick queries.
        result_dic = {}
        result_dic['fscore'] = fscore
        result_dic['precision'] = precision
        result_dic['recall'] = recall
        result_dic['fpcount'] = fpcountall
        result_dic['fncount'] = fncountall
        result_dic['tpcount'] = tpcountall
        result_dic['predcount'] = predall
        result_dic['gtcount'] = gtall
        result_dic['score_thr'] = score_thr

        settings = {}
        settings['pred_synlinks'] = self.pred_synlinks
        settings['gt_synlinks'] = self.gt_synlinks

        settings['filter_same_id'] = self.filter_same_id
        settings['filter_same_id_type'] = self.filter_same_id_type
        settings['filter_redundant'] = self.filter_redundant
        settings['filter_redundant_id_type'] = self.filter_redundant_id_type
        settings['dist_thr'] = self.filter_redundant_dist_thr
        settings['skel_ids'] = self.skeleton_ids
        settings['matching_threshold'] = self.matching_threshold
        settings[
            'matching_threshold_only_post'] = self.matching_threshold_only_post
        settings[
            'matching_threshold_only_pre'] = self.matching_threshold_only_pre
        settings['only_output_synapses'] = self.only_output_synapses
        settings['only_input_synapses'] = self.only_input_synapses
        settings['num_clustered_synapses'] = num_clustered_synapsesall
        settings[
            'filter_redundant_dist_type'] = self.filter_redundant_dist_type
        settings['filter_seg_ids'] = str(self.filter_seg_ids)

        result_dic.update(settings)
        if self.results_dir is not None:
            resultsfile = self.results_dir + 'results_thr{}.json'.format(
                1000 * score_thr)
            logger.info('writing results to {}'.format(resultsfile))
            with open(resultsfile, 'w') as f:
                json.dump(result_dic, f)

        print('final fscore {:0.2}'.format(fscore))
        print('final precision {:0.2}, recall {:0.2}'.format(
            precision, recall))
        return result_dic
Esempio n. 12
0
    def get_cremi_score(self, score_thr=0):
        gt_db = database.SynapseDatabase(self.gt_db_name,
                                         db_host=self.gt_db_host,
                                         db_col_name=self.gt_db_col,
                                         mode='r')

        pred_db = database.SynapseDatabase(self.pred_db,
                                           db_host=self.pred_db_host,
                                           db_col_name=self.pred_db_col,
                                           mode='r')

        client_out = MongoClient(self.res_db_host)
        db_out = client_out[self.res_db_name]
        db_out.drop_collection(
            self.res_db_col + '.thr{}'.format(1000 * score_thr))

        skel_ids = csv_to_list(self.skeleton_ids, 0)

        fpcountall, fncountall, predall, gtall, tpcountall, num_clustered_synapsesall = 0, 0, 0, 0, 0, 0

        pred_synapses_all = []
        for skel_id in skel_ids:
            logger.debug('evaluating skeleton {}'.format(skel_id))
            if not self.only_output_synapses and not self.only_input_synapses:
                pred_synapses = pred_db.synapses.find(
                    {'$or': [{'pre_skel_id': skel_id},
                             {'post_skel_id': skel_id}]})
                gt_synapses = gt_db.synapses.find(
                    {'$or': [{'pre_skel_id': skel_id},
                             {'post_skel_id': skel_id}]})

            elif self.only_input_synapses:
                pred_synapses = pred_db.synapses.find({'post_skel_id': skel_id})
                gt_synapses = gt_db.synapses.find({'post_skel_id': skel_id})
            elif self.only_output_synapses:
                pred_synapses = pred_db.synapses.find({'pre_skel_id': skel_id})
                gt_synapses = gt_db.synapses.find({'pre_skel_id': skel_id})
            else:
                raise Exception(
                    'Unclear parameter configuration: {}, {}'.format(
                        self.only_output_synapses, self.only_input_synapses))





            pred_synapses = synapse.create_synapses_from_db(pred_synapses)
            if not len(self.filter_seg_ids) == 0:
                pred_synapses = [syn for syn in pred_synapses if not (
                            syn.id_segm_pre in self.filter_seg_ids or syn.id_segm_post in self.filter_seg_ids)]
            if self.syn_score_db is not None:
                score_host = self.syn_score_db['db_host']
                score_db = self.syn_score_db['db_name']
                score_col = self.syn_score_db['db_col_name']
                score_db = MongoClient(host=score_host)[score_db][score_col]
                score_cursor = score_db.find({'synful_id': {'$in': [syn.id for syn in pred_synapses]}})
                df = pd.DataFrame(score_cursor)
                for syn in pred_synapses:
                    if self.syn_score_db_comb is None:
                        syn.score = float(df[df.synful_id == syn.id].score)
                    elif self.syn_score_db_comb == 'multiplication':
                        syn.score *= float(df[df.synful_id == syn.id].score)
                    elif self.syn_score_db_comb == 'filter':
                        score = float(df[df.synful_id == syn.id].score)
                        if score == 0.:
                            syn.score = 0.
                    else:
                        raise Exception(f'Syn_score_db_comb incorrectly set: {self.syn_score_db_comb}')


            pred_synapses = [syn for syn in pred_synapses if
                             syn.score >= score_thr]
            if self.filter_same_id:
                if self.filter_same_id_type == 'seg':
                    pred_synapses = [syn for syn in pred_synapses if
                                     syn.id_segm_pre != syn.id_segm_post]
                elif self.filter_same_id_type == 'skel':
                    pred_synapses = [syn for syn in pred_synapses if
                                     syn.id_skel_pre != syn.id_skel_post]
            removed_ids = []
            if self.filter_redundant:
                assert self.filter_redundant_dist_thr is not None
                num_synapses = len(pred_synapses)
                if self.filter_redundant_dist_type == 'geodesic':
                    # Get skeleton
                    skeleton = pymaid.get_neurons([skel_id])
                else:
                    skeleton = None
                __, removed_ids = synapse.cluster_synapses(pred_synapses,
                                                           self.filter_redundant_dist_thr,
                                                           fuse_strategy='max_score',
                                                           id_type=self.filter_redundant_id_type,
                                                           skeleton=skeleton,
                                                           ignore_ids=self.filter_redundant_ignore_ids)
                pred_synapses = [syn for syn in pred_synapses if
                                 not syn.id in removed_ids]
                num_clustered_synapses = num_synapses - len(pred_synapses)
                logger.debug(
                    'num of clustered synapses: {}, skel id: {}'.format(
                        num_clustered_synapses, skel_id))
            else:
                num_clustered_synapses = 0


            logger.debug(
                'found {} predicted synapses'.format(len(pred_synapses)))

            gt_synapses = synapse.create_synapses_from_db(gt_synapses)
            stats = evaluation.synaptic_partners_fscore(pred_synapses,
                                                        gt_synapses,
                                                        matching_threshold=self.matching_threshold,
                                                        all_stats=True,
                                                        use_only_pre=self.matching_threshold_only_pre,
                                                        use_only_post=self.matching_threshold_only_post)
            fscore, precision, recall, fpcount, fncount, tp_fp_fn_syns = stats

            # tp_syns, fp_syns, fn_syns_gt, tp_syns_gt = evaluation.from_synapsematches_to_syns(
            #     matches, pred_synapses, gt_synapses)
            tp_syns, fp_syns, fn_syns_gt, tp_syns_gt = tp_fp_fn_syns
            tp_ids = [tp_syn.id for tp_syn in tp_syns]
            tp_ids_gt = [syn.id for syn in tp_syns_gt]
            matched_synapse_ids = [pair for pair in zip(tp_ids, tp_ids_gt)]
            fpcountall += fpcount
            fncountall += fncount
            tpcountall += len(tp_syns_gt)
            predall += len(pred_synapses)
            gtall += len(gt_synapses)
            num_clustered_synapsesall += num_clustered_synapses

            assert len(fp_syns) == fpcount
            db_dic = {
                'skel_id': skel_id,
                'tp_pred': [syn.id for syn in tp_syns],
                'tp_gt': [syn.id for syn in tp_syns_gt],
                'fp_pred': [syn.id for syn in fp_syns],
                'fn_gt': [syn.id for syn in fn_syns_gt],
                'gtcount': len(gt_synapses),
                'predcount': len(pred_synapses),
                'matched_synapse_ids': matched_synapse_ids,
                'fscore': stats[0],
                'precision': stats[1],
                'recall': stats[2],
                'fpcount': stats[3],
                'fncount': stats[4],
                'removed_ids': removed_ids,
            }

            db_out[self.res_db_col + '.thr{}'.format(1000 * score_thr)].insert(
                db_dic)
            pred_synapses_all.extend(pred_synapses)
            logger.info(f'skel id {skel_id} with fscore {fscore:0.2}, precision: {precision:0.2}, recall: {recall:0.2}')
            logger.info(f'fp: {fpcount}, fn: {fncount}')
            logger.info(f'total predicted {len(pred_synapses)}; total gt: {len(gt_synapses)}\n')

        # # Alsow write out synapses:
        pred_dic = {}
        for syn in pred_synapses_all:
            pred_dic[syn.id] = syn
        print('Number of duplicated syn ids: {} versus {}'.format(len(pred_synapses_all), len(pred_dic)))
        syn_out = database.SynapseDatabase(self.res_db_name,
                                           db_host=self.res_db_host,
                                           db_col_name=self.res_db_col + '.syn_thr{}'.format(
                                               1000 * score_thr),
                                           mode='w')
        syn_out.write_synapses(pred_dic.values())

        precision = float(tpcountall) / (tpcountall + fpcountall) if (
                                                                             tpcountall + fpcountall) > 0 else 0.
        recall = float(tpcountall) / (tpcountall + fncountall) if (
                                                                          tpcountall + fncountall) > 0 else 0.
        if (precision + recall) > 0:
            fscore = 2.0 * precision * recall / (precision + recall)
        else:
            fscore = 0.0

        # Collect all in a single document in order to enable quick queries.
        result_dic = {}
        result_dic['fscore'] = fscore
        result_dic['precision'] = precision
        result_dic['recall'] = recall
        result_dic['fpcount'] = fpcountall
        result_dic['fncount'] = fncountall
        result_dic['tpcount'] = tpcountall
        result_dic['predcount'] = predall
        result_dic['gtcount'] = gtall
        result_dic['score_thr'] = score_thr

        settings = {}
        settings['pred_db_col'] = self.pred_db_col
        settings['pred_db_name'] = self.pred_db_col
        settings['gt_db_col'] = self.gt_db_col
        settings['gt_db_name'] = self.gt_db_name
        settings['filter_same_id'] = self.filter_same_id
        settings['filter_same_id_type'] = self.filter_same_id_type
        settings['filter_redundant'] = self.filter_redundant
        settings['filter_redundant_id_type'] = self.filter_redundant_id_type
        settings['dist_thr'] = self.filter_redundant_dist_thr
        settings['skel_ids'] = self.skeleton_ids
        settings['matching_threshold'] = self.matching_threshold
        settings[
            'matching_threshold_only_post'] = self.matching_threshold_only_post
        settings[
            'matching_threshold_only_pre'] = self.matching_threshold_only_pre
        settings['only_output_synapses'] = self.only_output_synapses
        settings['only_input_synapses'] = self.only_input_synapses
        settings['num_clustered_synapses'] = num_clustered_synapsesall
        settings['filter_redundant_dist_type'] = self.filter_redundant_dist_type
        new_score_db_name = self.syn_score_db['db_name'] + \
                                   self.syn_score_db[
                                       'db_col_name'] if self.syn_score_db is not None else 'original score'
        if self.syn_score_db_comb is not None and new_score_db_name is not None:
            new_score_db_name += self.syn_score_db_comb
        settings['new_score_db'] = new_score_db_name
        settings['filter_seg_ids'] = str(self.filter_seg_ids)

        result_dic.update(settings)

        db_out[self.res_db_col_summary].insert_one(result_dic)

        print('final fscore {:0.2}'.format(fscore))
        print('final precision {:0.2}, recall {:0.2}'.format(precision, recall))
Esempio n. 13
0
    treenode_df = pd.DataFrame(treenode_info)
    # a split node is included in pre and post synaptic fragments
    # here i am just removing, i hope there is never a synapse on that node...
    treenode_df = treenode_df[~treenode_df["treenode_id"].duplicated(
        keep=False)]
    treenode_series = treenode_df.set_index("treenode_id")["treenode_type"]
    return treenode_series


# %% [markdown]
# ##

# params
ids = meta.index.values
ids = [int(i) for i in ids]
nl = pymaid.get_neurons(ids)

# %% [markdown]
# ##
print("Getting connectors...")
connectors = get_connectors(nl)

explode_cols = ["postsynaptic_to", "postsynaptic_to_node"]
index_cols = np.setdiff1d(connectors.columns, explode_cols)

print("Exploding connector DataFrame...")
# explode the lists within the connectors dataframe
connectors = (connectors.set_index(list(index_cols)).apply(
    pd.Series.explode).reset_index())
# TODO figure out these nans
connectors = connectors[~connectors.isnull().any(axis=1)]
Esempio n. 14
0

def set_view_params(ax, azim=-90, elev=0, dist=5):
    ax.azim = azim
    ax.elev = elev
    ax.dist = dist
    set_axes_equal(ax)


# params
label = "KC"
volume_names = ["PS_Neuropil_manual"]

class_ids = meta[meta["class1"] == label].index.values
ids = [int(i) for i in class_ids]
nl = pymaid.get_neurons(class_ids)
print(f"Plotting {len(nl)} neurons for label {label}.")
connectors = get_connectors(nl)
outputs = connectors[connectors["presynaptic_to"].isin(class_ids)]
# I hope this is a valid assumption?
inputs = connectors[~connectors["presynaptic_to"].isin(class_ids)]

# %% [markdown]
# ##
sns.set_context("talk", font_scale=1.5)
fig = plt.figure(figsize=(30, 30))
fig.suptitle(label, y=0.93)
gs = plt.GridSpec(3, 3, figure=fig, wspace=0, hspace=0)

views = ["front", "side", "top"]
view_params = [
Esempio n. 15
0
 def test_adjacency_matrix2(self):
     nl = pymaid.get_neurons(
         self.cn_table[self.cn_table.relation == 'upstream'].iloc[:10].skeleton_id.values)
     self.assertIsInstance(pymaid.adjacency_matrix(nl, use_connectors=True),
                           pd.DataFrame)
Esempio n. 16
0
 def test_adjacency_from_connectors(self):
     nl = pymaid.get_neurons(
         self.cn_table[self.cn_table.relation == 'upstream'].iloc[:10].skeleton_id.values)
     self.assertIsInstance(pymaid.adjacency_from_connectors(nl,
                                                            remote_instance=self.rm),
                           pd.DataFrame)
Esempio n. 17
0
    weight="weight",
)
meta = mg.meta


# %% [markdown]
# ##
start_instance()


# %% [markdown]
# #
import pandas as pd
import seaborn as sns

nl = pymaid.get_neurons(meta[meta["class1"] == "uPN"].index.values)
print(len(nl))

connectors = pymaid.get_connectors(nl)
connectors.set_index("connector_id", inplace=True)
connectors.drop(
    [
        "confidence",
        "creation_time",
        "edition_time",
        "tags",
        "creator",
        "editor",
        "type",
    ],
    inplace=True,