コード例 #1
def skeletonize_neuron_parallel(ids, n_cores=os.cpu_count() // 2, **kwargs):
    """Skeletonization on parallel cores [WIP].

    ids :       iterable
                Root IDs of neurons you want to skeletonize.
    n_cores :   int
                Number of cores to use. Don't go too crazy on this as the
                downloading of meshes becomes a bottle neck if you try to do
                too many at the same time. Keep your internet speed in
                Keyword arguments are passed on to `skeletonize_neuron`.


    if n_cores < 2 or n_cores > os.cpu_count():
        raise ValueError(
            '`n_cores` must be between 2 and max number of cores.')

    sig = inspect.signature(skeletonize_neuron)
    for k in kwargs:
        if k not in sig.parameters:
            raise ValueError('unexpected keyword argument for '
                             f'`skeletonize_neuron`: {k}')

    # Make sure IDs are all integers
    ids = np.asarray(ids).astype(int)

    # Prepare the calls and parameters
    kwargs['progress'] = False
    funcs = [skeletonize_neuron] * len(ids)
    parsed_kwargs = [kwargs] * len(ids)
    combinations = list(zip(funcs, [[i] for i in ids], parsed_kwargs))

    # Run the actual skeletonization
    with mp.Pool(n_cores) as pool:
        chunksize = 1
        res = list(

    # Check if any skeletonizations failed
    failed = np.array([r for r in res
                       if not isinstance(r, navis.TreeNeuron)]).astype(str)
    if any(failed):
        print(f'{len(failed)} neurons failed to skeletonize: '
              f'{". ".join(failed)}')

    return navis.NeuronList(
        [r for r in res if isinstance(r, navis.TreeNeuron)])
コード例 #2
def neuron_to_segments(x, dataset='production', coordinates='voxel'):
    """Get root IDs overlapping with a given neuron.

    x :                 Neuron/List
                        Neurons for which to return root IDs. Neurons must be
                        in flywire (FAFB14.1) space.
    dataset :           str | CloudVolume
                        Against which flywire dataset to query::
                            - "production" (current production dataset, fly_v31)
                            - "sandbox" (i.e. fly_v26)

    coordinates :       "voxel" | "nm"
                        Units the neuron(s) are in. "voxel" is assumed to be
                        4x4x40 (x/y/z) nanometers.

    overlap_matrix :    pandas.DataFrame
                        DataFrame of root IDs (rows) and IDs
                        (columns) with overlap in nodes as values::

                                 id     id1   id2
                            10336680915   5     0
                            10336682132   0     1

    if isinstance(x, navis.TreeNeuron):
        x = navis.NeuronList(x)

    assert isinstance(x, navis.NeuronList)

    # We must not perform this on x.nodes as this is a temporary property
    nodes = x.nodes

    # Get segmentation IDs
    nodes['root_id'] = locs_to_segments(nodes[['x', 'y', 'z']].values,

    # Count segment IDs
    seg_counts = nodes.groupby(['neuron', 'root_id'],
    seg_counts.columns = ['id', 'root_id', 'counts']

    # Remove seg IDs 0
    seg_counts = seg_counts[seg_counts.root_id != 0]

    # Turn into matrix where columns are skeleton IDs, segment IDs are rows
    # and values are the overlap counts
    matrix = seg_counts.pivot(index='root_id', columns='id', values='counts')

    return matrix
コード例 #3
ファイル: meshes.py プロジェクト: flyconnectome/fafbseg-py
def get_mesh_neuron(id, with_synapses=False, dataset='production'):
    """Fetch flywire neuron as navis.MeshNeuron.

    id  :                int | list of int
                         Segment ID(s) to fetch meshes for.
    with_synapses :      bool
                         If True, will also load a connector table with
                         synapse predicted by Buhmann et al. (2020).
                         A "synapse score" (confidence) threshold of 30 is
    dataset :            str | CloudVolume
                         Against which flywire dataset to query::
                           - "production" (currently fly_v31)
                           - "sandbox" (currently fly_v26)


    >>> from fafbseg import flywire
    >>> m = flywire.get_mesh_neuron(720575940614131061)
    >>> m.plot3d()  # doctest: +SKIP

    vol = parse_volume(dataset)

    if navis.utils.is_iterable(id):
        return navis.NeuronList([
            get_mesh_neuron(n, dataset=dataset, with_synapses=with_synapses)
            for n in navis.config.tqdm(id, desc='Fetching', leave=False)

    # Make sure the ID is integer
    id = int(id)

    # Fetch mesh
    mesh = vol.mesh.get(id, remove_duplicate_vertices=True)[id]

    # Turn into meshneuron
    n = navis.MeshNeuron(mesh, id=id, units='nm', dataset=dataset)

    if with_synapses:
        _ = fetch_synapses(n,

    return n
コード例 #4
def neuron_to_segments(x):
    """Get segment IDs overlapping with a given neuron.

    x :                 Neuron/List
                        Neurons for which to return segment IDs.

    overlap_matrix :    pandas.DataFrame
                        DataFrame of segment IDs (rows) and IDs
                        (columns) with overlap in nodes as values::

                            skeleton_id  id  3245
                            10336680915   5     0
                            10336682132   0     1

    if isinstance(x, navis.TreeNeuron):
        x = navis.NeuronList(x)

    assert isinstance(x, navis.NeuronList)

    # We must not perform this on x.nodes as this is a temporary property
    nodes = x.nodes

    # Get segmentation IDs
    nodes['seg_id'] = locs_to_segments(nodes[['x', 'y', 'z']].values,

    # Count segment IDs
    seg_counts = nodes.groupby(['neuron', 'seg_id'],
    seg_counts.columns = ['skeleton_id', 'seg_id', 'counts']

    # Remove seg IDs 0
    seg_counts = seg_counts[seg_counts.seg_id != 0]

    # Turn into matrix where columns are skeleton IDs, segment IDs are rows
    # and values are the overlap counts
    matrix = seg_counts.pivot(index='seg_id',

    return matrix
コード例 #5
ファイル: meshes.py プロジェクト: tomka/fafbseg-py
def get_mesh_neuron(id, dataset='production'):
    """Fetch flywire neuron as navis.MeshNeuron.

    id  :                int | list of int
                         Segment ID(s) to fetch meshes for.
    dataset :            str | CloudVolume
                         Against which flywire dataset to query::
                           - "production" (currently fly_v31)
                           - "sandbox" (currently fly_v26)


    >>> from fafbseg import flywire
    >>> m = flywire.get_mesh_neuron(720575940614131061)
    >>> m.plot3d()  # doctest

    vol = parse_volume(dataset)

    if navis.utils.is_iterable(id):
        return navis.NeuronList([
            get_mesh_neuron(n, dataset=dataset)
            for n in tqdm(id, desc='Fetching', leave=False)

    # Make sure the ID is integer
    id = int(id)

    # Fetch mesh
    mesh = vol.mesh.get(id, remove_duplicate_vertices=True)[id]

    # Turn into meshneuron
    return navis.MeshNeuron(mesh, id=id, units='nm', dataset=dataset)
コード例 #6
ファイル: merge.py プロジェクト: tomka/fafbseg-py
def merge_into_catmaid(x,
    """Merge neuron into target CATMAID instance.

    This function will attempt to:

        1. Find fragments in ``target_instance`` that overlap with ``x``
           using whatever segmentation data source you have set using
        2. Generate a union of these fragments and ``x``.
        3. Make a differential upload of the union leaving existing nodes
        4. Join uploaded and existing tracings into a single continuous
           neuron. This will also upload connectors but no node tags.

    x :                 pymaid.CatmaidNeuron/List | navis.TreeNeuron/List
                        Neuron(s)/fragment(s) to commit to ``target_instance``.
    target_instance :   pymaid.CatmaidInstance
                        Target Catmaid instance to commit the neuron to.
    tag :               str
                        A tag to be added as part of a ``{URL} upload {tag}``
                        annotation. This should be something identifying your
                        group - e.g. ``tag='WTCam'`` for the Cambridge Wellcome
                        Trust group.
    min_node_overlap :  int, optional
                        Minimal overlap between `x` and a potentially
                        overlapping neuron in ``target_instance``. If
                        the fragment has less total nodes than `min_overlap`,
                        the threshold will be lowered to:
                        ``min_overlap = min(min_overlap, fragment.n_nodes)``
    min_overlap_size :  int, optional
                        Minimum node count for potentially overlapping neurons
                        in ``target_instance``. Use this to e.g. exclude
                        single-node synapse orphans.
    merge_limit :       int, optional
                        Distance threshold [um] for collapsing nodes of ``x``
                        into overlapping fragments in target instance. Decreasing
                        this will help if your neuron has complicated branching
                        patterns (e.g. uPN dendrites) at the cost of potentially
                        creating duplicate parallel tracings in the neuron's
    min_upload_size :   float, optional
                        Minimum size in microns for upload of new branches:
                        branches found in ``x`` but not in the overlapping
                        neuron(s) in ``target_instance`` are uploaded in
                        fragments. Use this parameter to exclude small branches
                        that might not be worth the additional review time.
    min_upload_nodes :  int, optional
                        As ``min_upload_size`` but for number of nodes instead
                        of cable length.
    update_radii :      bool, optional
                        If True, will use radii in ``x`` to update radii of
                        overlapping fragments if (and only if) the nodes
                        do not currently have a radius (i.e. radius<=0).
    import_tags :       bool, optional
                        If True, will import node tags. Please note that this
                        will NOT import tags of nodes that have been collapsed
                        into manual tracings.
    label_joins :       bool, optional
                        If True, will label nodes at which old and new
                        tracings have been joined with tags ("Joined from ..."
                        and "Joined with ...") and with a lower confidence of
    sid_from_nodes :    bool, optional
                        If True and the to-be-merged neuron has a "skeleton_id"
                        column it will be used to set the ``source_id`` upon
                        uploading new branches. This is relevant if your neuron
                        is a virtual chimera of several neurons: in order to
                        preserve provenance (i.e. correctly associating each
                        node with a ``source_id`` origin).
    mesh :              Volume | MeshNeuron | mesh-like object | list thereof
                        Mesh representation of ``x``. If provided, will use to
                        improve merging. If ``x`` is a list of neurons, must
                        provide a mesh for each of them.

                        If all went well.
                        If something failed, returns server responses with
                        error logs.


    >>> import fafbseg
    >>> import pymaid

    >>> # Set up connections to manual and autoseg CATMAID
    >>> manual = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')
    >>> auto = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')

    >>> # Set a segmentation data source
    >>> fafbseg.use_google_storage("https://storage.googleapis.com/fafb-ffn1-20190805/segmentation")

    Merge a neuron from autoseg into v14

    >>> # Fetch the autoseg neuron to transfer to v14
    >>> x = pymaid.get_neuron(267355161, remote_instance=auto)

    >>> # Get the neuron's annotations so that they can be merged too
    >>> x.get_annotations(remote_instance=auto)

    >>> # Start the commit
    >>> # See online documentation for video of merge process
    >>> resp = fafbseg.merge_neuron(x, target_instance=manual)

    if not isinstance(x, navis.NeuronList):
        if not isinstance(x, navis.TreeNeuron):
            raise TypeError('Expected TreeNeuron/List, got "{}"'.format(
        x = navis.NeuronList(x)

    if not isinstance(mesh, (np.ndarray, list)):
        if isinstance(mesh, type(None)):
            mesh = [mesh] * len(x)
            mesh = [mesh]

    if len(mesh) != len(x):
        raise ValueError(f'Got {len(mesh)} meshes for {len(x)} neurons.')

    # Make a copy - in case we make any changes to the neurons
    # (like changing duplicate skeleton IDs)
    x = x.copy()

    if not isinstance(tag, (str, type(None))):
        raise TypeError('Tag must be string, got "{}"'.format(type(tag)))

    # Check user permissions
    perm = target_instance.fetch(target_instance.make_url('permissions'))
    requ_perm = ['can_annotate', 'can_annotate_with_token', 'can_import']
    miss_perm = [
        p for p in requ_perm
        if target_instance.project_id not in perm[0].get(p, [])

    if miss_perm:
        msg = 'You lack permissions: {}. Please contact an administrator.'
        raise PermissionError(msg.format(', '.join(miss_perm)))


    # Throttle requests just to play it safe
    # On a bad connection one might have to decrease max_threads further
    target_instance.max_threads = min(target_instance.max_threads, 50)

    # For user convenience, we will do all the stuff that needs user
    # interaction first and then run the automatic merge:

    # Start by find all overlapping fragments
    overlapping = []
    for n, m in tqdm(zip(x, mesh),
                     desc='Pre-processing neuron(s)',
                     disable=not use_pbars,
        ol = find_fragments(n,

        if ol:
            # Add number of samplers to each neuron
            n_samplers = pymaid.get_sampler_counts(
                ol, remote_instance=target_instance)

            for nn in ol:
                nn.sampler_count = n_samplers[str(nn.id)]


    # Now have the user confirm merges before we actually make them
    viewer = navis.Viewer(title='Confirm merges')
    overlap_cnf = []
    base_neurons = []
        for n, ol in zip(x, overlapping):
            # This asks user a bunch of questions prior to merge and upload
            ol, bn = confirm_overlap(n, ol, viewer=viewer)
    except BaseException:

    for i, (n, ol, bn, m) in enumerate(zip(x, overlap_cnf, base_neurons,
        print(f'Processing neuron "{n.name}" ({n.id}) [{i}/{len(x)}]',
        # If no overlapping neurons proceed with just uploading.
        if not ol:
                'No overlapping fragments found. Uploading without merging...',
            resp = pymaid.upload_neuron(n,
            if 'error' in resp:
                return resp

            # Add annotations
            _ = __merge_annotations(n, resp['skeleton_id'], tag,

            msg = '\nNeuron "{}" successfully uploaded to target instance as "{}" #{}'
            print(msg.format(n.name, n.name, resp['skeleton_id']), flush=True)

        # Check if there is a duplicate skeleton ID between the to-be-merged
        # neuron and the to-merge-into neurons
        original_skid = None
        if n.id in ol.id:
            print('Fixing duplicate skeleton IDs.', flush=True)
            # Keep track of old skid
            original_skid = n.id
            # Skeleton ID must stay convertable to integer
            n.id = str(random.randint(1, 1000000))

        # Check if there are any duplicate node IDs between neuron ``x`` and the
        # overlapping fragments and create new IDs for ``x`` if necessary
        duplicated = n.nodes[n.nodes.node_id.isin(ol.nodes.node_id.values)]
        if not duplicated.empty:
            print('Duplicate node IDs found. Regenerating node tables... ',
            max_ix = max(ol.nodes.node_id.max(), n.nodes.node_id.max()) + 1
            new_ids = range(max_ix, max_ix + duplicated.shape[0])
            id_map = {
                old: new
                for old, new in zip(duplicated.node_id, new_ids)
            n.nodes['node_id'] = n.nodes.node_id.map(
                lambda n: id_map.get(n, n))
            n.nodes['parent_id'] = n.nodes.parent_id.map(
                lambda n: id_map.get(n, n))
            if n.has_connectors:
                n.connectors['node_id'] = n.connectors.node_id.map(
                    lambda n: id_map.get(n, n))
            print('Done.', flush=True)

        # Combining the fragments into a single neuron is actually non-trivial:
        # 1. Collapse nodes of our input neuron `x` into within-distance nodes
        #    in the overlapping fragments (never the other way around!)
        # 2. At the same time keep connectivity (i.e. edges) of the input-neuron
        # 3. Keep track of the nodes' provenance (i.e. the contractions)
        # In addition there are a lot of edge-cases to consider. For example:
        # - multiple nodes collapsing onto the same node
        # - nodes of overlapping fragments that are close enough to be collapsed
        #   (e.g. orphan synapse nodes)

        # Keep track of original skeleton IDs
        for a in ol + n:
            # Original skeleton of each node
            a.nodes['origin_skeletons'] = a.id
            if a.has_connectors:
                # Original skeleton of each connector
                a.connectors['origin_skeletons'] = a.id

        print('Generating union of all fragments... ', end='', flush=True)
        union, new_edges, collapsed_into = collapse_nodes(n,
        print('Done.', flush=True)

        print('Extracting new nodes to upload... ', end='', flush=True)
        # Now we have to break the neuron into "new" fragments that we can upload
        # First get the new and old nodes
        new_nodes = union.nodes[union.nodes.origin_skeletons ==
        old_nodes = union.nodes[
            union.nodes.origin_skeletons != n.id].node_id.values

        # Now remove the already existing nodes from the union
        only_new = navis.subset_neuron(union, new_nodes)

        # And then break into continuous fragments for upload
        frags = navis.break_fragments(only_new)
        print('Done.', flush=True)

        # Also get the new edges we need to generate
        to_stitch = new_edges[~new_edges.parent_id.isnull()]

        # We need this later -> no need to compute this for every uploaded fragment
        cond1b = to_stitch.node_id.isin(old_nodes)
        cond2b = to_stitch.parent_id.isin(old_nodes)

        # Now upload each fragment and keep track of new node IDs
        tn_map = {}
        for f in tqdm(frags,
                      desc='Merging new arbors',
                      disable=not use_pbars):
            # In cases of complete merging into existing neurons, the fragment
            # will have no nodes
            if f.n_nodes < 1:

            # Check if fragment is a "linker" and as such can not be skipped
            lcond1 = np.isin(f.nodes.node_id.values, new_edges.node_id.values)
            lcond2 = np.isin(f.nodes.node_id.values,

            # If not linker, check skip conditions
            if sum(lcond1) + sum(lcond2) <= 1:
                if f.cable_length < min_upload_size:
                if f.n_nodes < min_upload_nodes:

            # Collect origin info for this neuron if it's a CatmaidNeuron
            if isinstance(n, pymaid.CatmaidNeuron):
                source_info = {'source_type': 'segmentation'}

                if not sid_from_nodes or 'origin_skeletons' not in f.nodes.columns:
                    # If we had to change the skeleton ID due to duplication, make
                    # sure to pass the original skid as source ID
                    if original_skid:
                        source_info['source_id'] = int(original_skid)
                        source_info['source_id'] = int(n.id)
                    if f.nodes.origin_skeletons.unique().shape[0] == 1:
                        skid = f.nodes.origin_skeletons.unique()[0]
                            'Warning: uploading chimera fragment with multiple '
                            'skeleton IDs! Using largest contributor ID.')
                        # Use the skeleton ID that has the most nodes
                        by_skid = f.nodes.groupby('origin_skeletons').x.count()
                        skid = by_skid.sort_values(

                    source_info['source_id'] = int(skid)

                if not isinstance(getattr(n, '_remote_instance', None),
                        'source_project_id'] = n._remote_instance.project_id
                    source_info['source_url'] = n._remote_instance.server
                # Unknown source
                source_info = {}

            resp = pymaid.upload_neuron(f,

            # Stop if there was any error while uploading
            if 'error' in resp:
                return resp

            # Collect old -> new node IDs

            # Now check if we can create any of the new edges by joining nodes
            # Both treenode and parent ID have to be either existing nodes or
            # newly uploaded
            cond1a = to_stitch.node_id.isin(tn_map)
            cond2a = to_stitch.parent_id.isin(tn_map)

            to_gen = to_stitch.loc[(cond1a | cond1b) & (cond2a | cond2b)]

            # Join nodes
            for node in to_gen.itertuples():
                # Make sure our base_neuron always come out as winner on top
                if node.node_id in bn.nodes.node_id.values:
                    winner, looser = node.node_id, node.parent_id
                    winner, looser = node.parent_id, node.node_id

                # We need to map winner and looser to the new node IDs
                winner = tn_map.get(winner, winner)
                looser = tn_map.get(looser, looser)

                # And now do the join
                resp = pymaid.join_nodes(winner,

                # See if there was any error while uploading
                if 'error' in resp:
                    print('Skipping joining nodes '
                          '{} and {}: {} - '.format(node.node_id,
                    # Skip changing confidences

                # Pop this edge from new_edges and from condition
                new_edges.drop(node.Index, inplace=True)
                cond1b.drop(node.Index, inplace=True)
                cond2b.drop(node.Index, inplace=True)

                # Change node confidences at new join
                if label_joins:
                    new_conf = {looser: 1}
                    resp = pymaid.update_node_confidence(
                        new_conf, remote_instance=target_instance)

        # Add annotations
        if n.has_annotations:
            _ = __merge_annotations(n, bn, tag, target_instance)

        # Update node radii
        if update_radii and 'radius' in n.nodes.columns and np.all(
            print('Updating radii of existing nodes... ', end='', flush=True)
            resp = update_node_radii(source=n,
            print('Done.', flush=True)

            'Neuron "{}" successfully merged into target instance as "{}" #{}'.
            format(n.name, bn.name, bn.id),

コード例 #7
def l2_skeleton(root_id, refine=False, drop_missing=True,
                threads=10, progress=True, dataset='production', **kwargs):
    """Generate skeleton from L2 graph.

    root_id  :          int | list of ints
                        Root ID(s) of the flywire neuron(s) you want to
    refine :            bool
                        If True, will refine skeleton nodes by moving them in
                        the center of their corresponding chunk meshes.

    Only relevant if ``refine=True``:

    drop_missing :      bool
                        If True, will drop nodes that don't have a corresponding
                        chunk mesh. These are typically chunks that are very
                        small and dropping them might actually be benefitial.
    threads :           int
                        How many parallel threads to use for fetching the
                        chunk meshes. Reduce the number if you run into
                        ``HTTPErrors``. Only relevant if `use_flycache=False`.
    progress :          bool
                        Whether to show a progress bar.

    skeleton :          navis.TreeNeuron
                        The extracted skeleton.

    >>> from fafbseg import flywire
    >>> n = flywire.l2_skeleton(720575940614131061)

    # TODO:
    # - drop duplicate nodes in unrefined skeleton
    # - use L2 graph to find soma: highest degree is typically the soma

    use_flycache = kwargs.get('use_flycache', False)

    if refine and use_flycache and dataset != 'production':
        raise ValueError('Unable to use fly cache to fetch L2 centroids for '
                         'sandbox dataset. Please set `use_flycache=False`.')

    if navis.utils.is_iterable(root_id):
        nl = []
        for id in navis.config.tqdm(root_id, desc='Skeletonizing',
                                    disable=not progress, leave=False):
            n = l2_skeleton(id, refine=refine, drop_missing=drop_missing,
                            threads=threads, progress=progress, dataset=dataset)
        return navis.NeuronList(nl)

    # Get the cloudvolume
    vol = parse_volume(dataset)

    # Hard-coded datastack names
    ds = {"production": "flywire_fafb_production",
          "sandbox": "flywire_fafb_sandbox"}
    # Note that the default server url is https://global.daf-apis.com/info/
    client = FrameworkClient(ds.get(dataset, dataset))

    # Load the L2 graph for given root ID
    # This is a (N,2) array of edges
    l2_eg = np.array(client.chunkedgraph.level2_chunk_graph(root_id))

    # Drop duplicate edges
    l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0)

    # Unique L2 IDs
    l2_ids = np.unique(l2_eg)

    # ID to index
    l2dict = {l2: ii for ii, l2 in enumerate(l2_ids)}

    # Remap edge graph to indices
    eg_arr_rm = fastremap.remap(l2_eg, l2dict)

    coords = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids]
    coords = np.vstack(coords)

    # This turns the graph into a hierarchal tree by removing cycles and
    # ensuring all edges point towards a root
    if sk.__version_vector__[0] < 1:
        G = sk.skeletonizers.edges_to_graph(eg_arr_rm)
        swc = sk.skeletonizers.make_swc(G, coords=coords)
        G = sk.skeletonize.utils.edges_to_graph(eg_arr_rm)
        swc = sk.skeletonize.utils.make_swc(G, coords=coords, reindex=False)

    # Convert to Euclidian space
    # Dimension of a single chunk
    ch_dims = chunks_to_nm([1, 1, 1], vol) - chunks_to_nm([0, 0, 0], vol)
    ch_dims = np.squeeze(ch_dims)

    xyz = swc[['x', 'y', 'z']].values
    swc[['x', 'y', 'z']] = chunks_to_nm(xyz, vol) + ch_dims / 2

    if refine:
        if use_flycache:
            token = get_chunkedgraph_secret()
            centroids = spine.flycache.get_L2_centroids(l2_ids,

            # Drop missing (i.e. [0,0,0]) meshes
            centroids = {k: v for k, v in centroids.items() if v != [0, 0, 0]}
            # Get the centroids
            centroids = get_L2_centroids(l2_ids, vol, threads=threads, progress=progress)

        new_co = {l2dict[k]: v for k, v in centroids.items()}

        # Map refined coordinates onto the SWC
        has_new = swc.node_id.isin(new_co)
        swc.loc[has_new, 'x'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][0])
        swc.loc[has_new, 'y'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][1])
        swc.loc[has_new, 'z'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][2])

        # Turn into a proper neuron
        tn = navis.TreeNeuron(swc, id=root_id, units='1 nm')

        # Drop nodes that are still at their unrefined chunk position
        if drop_missing:
            tn = navis.remove_nodes(tn, swc.loc[~has_new, 'node_id'].values)
        tn = navis.TreeNeuron(swc, id=root_id, units='1 nm')

    return tn
コード例 #8
def find_missed_branches(x,
    """Use autoseg to find (and annotate) potential missed branches.

    x :                 pymaid.CatmaidNeuron/List
                        Neuron(s) to search for missed branches.
    autoseg_instance :  pymaid.CatmaidInstance
                        CATMAID instance containing the autoseg skeletons.
    tag :               bool, optional
                        If True, will tag nodes of ``x`` that might have missed
                        branches with "missed branch?".
    tag_size_thresh :   int, optional
                        Size threshold in microns of cable for tagging
                        potentially missed branches.
    min_node_overlap :  int, optional
                        Minimum number of nodes that input neuron(s) x must
                        overlap with given segmentation ID for it to be
                        Keyword arguments passed to

    summary :           pandas.DataFrame
                        DataFrame containing a summary of potentially missed

                        If input is a single neuron:

    fragments :         pymaid.CatmaidNeuronList
                        Fragments found to be potentially overlapping with the
                        input neuron.
    branches :          pymaid.CatmaidNeuronList
                        Potentially missed branches extracted from ``fragments``.


    >>> import fafbseg
    >>> import pymaid

    >>> # Set up connections to manual and autoseg CATMAID
    >>> manual = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')
    >>> auto = pymaid.CatmaidInstance('URL', 'HTTP_USER', 'HTTP_PW', 'API_TOKEN')

    >>> # Set a source for segmentation data
    >>> fafbseg.use_google_storage("https://storage.googleapis.com/fafb-ffn1-20190805/segmentation")

    Find missed branches and tag them

    >>> # Fetch a neuron
    >>> x = pymaid.get_neuron(16, remote_instance=manual)
    >>> # Find and tag missed branches
    >>> (summary,
    ...  fragments,
    ...  branches) = fafbseg.find_missed_branches(x, autoseg_instance=auto)

    >>> # Show summary of missed branches
    >>> summary.head()
       n_nodes  cable_length   node_id
    0      110     28.297424   3306395
    1       90     23.976504  20676047
    2       64     15.851333  23419997
    3       29      7.494350   6298769
    4       16      3.509739  15307841

    >>> # Co-visualize your neuron and potentially overlapping autoseg fragments
    >>> x.plot3d(color='w')
    >>> fragments.plot3d()

    >>> # Visualize the potentially missed branches
    >>> pymaid.clear3d()
    >>> x.plot3d(color='w')
    >>> branches.plot3d(color='r')

    if isinstance(x, navis.NeuronList):
        to_concat = []
        for n in tqdm(x,
                      desc='Processing neurons',
                      disable=not use_pbars,
            (summary, frags, branches) = find_missed_branches(
            summary['skeleton_id'] = n.id

        return pd.concat(to_concat, ignore_index=True)
    elif not isinstance(x, navis.TreeNeuron):
        raise TypeError(f'Input must be TreeNeuron/List, got "{type(x)}"')

    # Find autoseg neurons overlapping with input neuron
    nl = find_autoseg_fragments(x,

    # Next create a union
    if not nl.empty:
        for n in nl:
            n.nodes['origin'] = 'autoseg'
            n.nodes['origin_skid'] = n.skeleton_id

        # Create a simple union
        union = navis.stitch_neurons(nl, method='NONE')

        # Merge into target neuron
        union, new_edges, clps_map = move.merge_utils.collapse_nodes(union,

        # Subset to autoseg nodes
        autoseg_nodes = union.nodes[union.nodes.origin ==
        autoseg_nodes = np.empty((0, 5))

    # Process fragments if any autoseg nodes left
    data = []
    frags = navis.NeuronList([])
    if autoseg_nodes.shape[0]:
        autoseg = navis.subset_neuron(union, autoseg_nodes)

        # Split into fragments
        frags = navis.break_fragments(autoseg)

        # Generate summary
        nodes = union.nodes.set_index('node_id')
        for n in frags:
            # Find parent node in union
            pn = nodes.loc[n.root[0], 'parent_id']
            pn_co = nodes.loc[pn, ['x', 'y', 'z']].values
            org_skids = n.nodes.origin_skid.unique().tolist()
            data.append([n.n_nodes, n.cable_length, pn, pn_co, org_skids])

    df = pd.DataFrame(data,
                          'n_nodes', 'cable_length', 'node_id', 'node_loc',
    df.sort_values('cable_length', ascending=False, inplace=True)

    if tag and not df.empty:
        to_tag = df[df.cable_length >= tag_size_thresh].node_id.values

        resp = pymaid.add_tags(to_tag,
                               tags='missed branch?',

        if 'error' in resp:
            return df, resp

    return df, nl, frags
コード例 #9
ファイル: synapses.py プロジェクト: tomka/fafbseg-py
def get_neuron_connections(sources,
    """Fetch connections between sets of neurons.

    Works by:
     1. Fetch segment IDs corresponding to given neurons
     2. Fetch Buhmann et al. synaptic connections between them
     3. Do some clean-up. See parameters for details - defaults are those used
        by Buhmann et al

    sources :       navis.Neuron | pymaid.CatmaidNeuron | NeuronList
                    Presynaptic neurons to fetch connections for.
    targets :       navis.Neuron | pymaid.CatmaidNeuron | NeuronList | None
                    Postsynaptic neurons to fetch connections for. If ``None``,
                    ``targets = sources``.
    agglomerate :   bool
                    If True, will agglomerate connectivity by ID and return a
                    weighted edge list. If False, will return a list of
                    individual synapses.
    ol_thresh :     int, optional
                    If provided, will required a minimum node overlap between
                    a neuron and a segment for that segment to be included
                    (see step 1 above).
    score_thresh :  int, optional
                    If provided will only return synapses with a cleft score of
                    this or higher.
    dist_thresh :   int
                    Synapses further away than the given distance [nm] from any
                    node in the neuron than given distance will be discarded.
    drop_duplicates :   bool
                        If True, will merge synapses which connect the same
                        pair of pre-post segmentation IDs and are within
                        less than 250nm.
    drop_autapses :     bool
                        If True, will automatically drop autapses.
    db :            str, optional
                    Must point to SQL database containing the synapse data. If
                    not provided will look for a ```BUHMANN_SYNAPSE_DB``
                    environment variable.

                    Either edge list or list of synapses - see ``agglomerate``

    # This is just so we throw a DB exception early and not wait for fetching
    # segment Ids first
    _ = get_connection(db)

    assert isinstance(sources, (navis.BaseNeuron, navis.NeuronList))

    if isinstance(targets, type(None)):
        targets = sources

    assert isinstance(targets, (navis.BaseNeuron, navis.NeuronList))

    if not isinstance(sources, navis.NeuronList):
        sources = navis.NeuronList([sources])
    if not isinstance(targets, navis.NeuronList):
        targets = navis.NeuronList([targets])

    # Get segments for this neuron(s)
    unique_neurons = (sources + targets).remove_duplicates(key='id')
    seg_ids = google.neuron_to_segments(unique_neurons)

    # Drop segments with overlap below threshold
    if ol_thresh:
        seg_ids = seg_ids.loc[seg_ids.max(axis=1) >= ol_thresh]

    # We need to make sure that every segment ID is only attributed to a single
    # neuron
    is_top = seg_ids.values != seg_ids.max(axis=1).values.reshape(
        (seg_ids.shape[0], 1))
    where_top = np.where(is_top)
    seg_ids.values[where_top] = 0  # do not change this to None

    pre_ids = seg_ids.loc[seg_ids[sources.id].max(axis=1) > 0].index.values
    post_ids = seg_ids.loc[seg_ids[targets.id].max(axis=1) > 0].index.values

    # Fetch pre- and postsynapses associated with these segments
    # It's cheaper to get them all in one go
    syn = query_connections(pre_ids,

    # Now associate synapses with neurons
    seg2neuron = dict(
        zip(seg_ids.index.values, seg_ids.columns[np.argmax(seg_ids.values,

    syn['id_pre'] = syn.segmentid_pre.map(seg2neuron)
    syn['id_post'] = syn.segmentid_post.map(seg2neuron)

    # Let the clean-up BEGIN!

    # First drop nasty autapses
    if drop_autapses:
        syn = syn[syn.id_pre != syn.id_post]

    # Next drop synapses far away from our neurons
    if dist_thresh:
        syn['pre_close'] = False
        syn['post_close'] = False
        for id in np.unique(syn[['id_pre', 'id_post']].values.flatten()):
            neuron = unique_neurons.idx[id]
            tree = navis.neuron2KDTree(neuron)

            is_pre = syn.id_pre == id
            if np.any(is_pre):
                dist, ix = tree.query(
                    syn.loc[is_pre, ['pre_x', 'pre_y', 'pre_z']].values,
                syn.loc[is_pre, 'pre_close'] = dist < float('inf')

            is_post = syn.id_post == id
            if np.any(is_post):
                dist, ix = tree.query(
                    syn.loc[is_post, ['post_x', 'post_y', 'post_z']].values,
                syn.loc[is_post, 'post_close'] = dist < float('inf')

        # Drop connections where either pre- or postsynaptic site are too far
        # away from the neuron
        syn = syn[syn.pre_close & syn.post_close]

    # Drop duplicate connections, i.e. connections that connect the same pre-
    # and postsynaptic segmentation ID and are within a distance of 250nm
    dupl_thresh = 250
    if drop_duplicates:
        # We are dealing with this from a presynaptic perspective
        while True:
            pre_tree = cKDTree(syn[['pre_x', 'pre_y', 'pre_z']].values)
            pairs = pre_tree.query_pairs(r=dupl_thresh, output_type='ndarray')

            same_pre = syn.iloc[pairs[:, 0]].segmentid_pre.values == syn.iloc[
                pairs[:, 1]].segmentid_pre.values
            same_post = syn.iloc[pairs[:,
                                       0]].segmentid_post.values == syn.iloc[
                                           pairs[:, 1]].segmentid_post.values
            same_cn = same_pre & same_post

            # If no more pairs to collapse break
            if same_cn.sum() == 0:

            # Generate a graph from pairs
            G = nx.Graph()

            # Find the minimum number of nodes we need to remove
            # to separate the connectors
            to_rm = []
            for cn in nx.connected_components(G):
                to_rm += list(nx.minimum_node_cut(nx.subgraph(G, cn)))
            syn = syn.drop(index=syn.index.values[to_rm])

    if agglomerate:
        edges = syn.groupby(['id_pre', 'id_post'],
        edges.rename({'cleft_id': 'weight'}, axis=1, inplace=True)
        edges.sort_values('weight', ascending=False, inplace=True)
        return edges.reset_index(drop=True)

    return syn
コード例 #10
ファイル: synapses.py プロジェクト: tomka/fafbseg-py
def get_neuron_synapses(x,
    """Fetch synapses for a given neuron.

    Works by:
     1. Fetch segment IDs corresponding to a neuron
     2. Fetch Buhmann et al. synapses associated with these segment IDs
     3. Do some clean-up. See parameters for details.

    1. If ``collapse_connectors=False`` the ``connector_id`` column is a
       simple enumeration and is effectively meaningless.
    2. The x/y/z coordinates always correspond to the presynaptic site (like
       with CATMAID connectors). These columns are renamed from "pre_{x|y|z}"
       in the database.
    3. Some of the clean-up assumes that the query neurons are unique neurons
       and not fragments of the same neuron. If that is the case, you might be
       better of running this function for each fragment individually.

    x :             navis.Neuron | pymaid.CatmaidNeuron | NeuronList
                    Neurons to fetch synapses for.
    pre/post :      bool
                    Whether to fetch pre- and/or postsynapses.
    collapse_connectors : bool
                    If True, we will pool presynaptic connections into
                    CATMAID-like connectors. If False each row represents a
    ol_thresh :     int, optional
                    If provided, will required a minimum node overlap between
                    a neuron and a segment for that segment to be included
                    (see step 1 above).
    score_thresh :  int, optional
                    If provided will only return synapses with a cleft score of
                    this or higher.
    dist_thresh :   int
                    Synapses further away than the given distance [nm] from any
                    node in the neuron than given distance will be discarded.
                    This is always with respect to the cleft site
                    irrespective of whether our neuron is pre- or postsynaptic
                    to it.
    attach :        bool
                    If True, will attach synapses as `.connectors` to neurons.
                    If False, will return DataFrame with synapses.
    drop_autapses : bool
                    If True, will drop synapses where pre- and postsynapse point
                    to the same neuron. Autapses are wrong in 99.9% of all cases
                    we've seen.
    drop_duplicates :   bool
                        If True, will merge synapses which connect the same
                        pair of pre-post segmentation IDs and are within
                        less than 2500nm.
    db :            str, optional
                    Must point to SQL database containing the synapse data. If
                    not provided will look for a `BUHMANN_SYNAPSE_DB`
                    environment variable.
    ret :           "catmaid" | "brief" | "full"
                    If "full" will return all synapse properties. If "brief"
                    will return more relevant subset. If "catmaid" will return
                    only CATMAID-like columns.
    progress :      bool
                    Whether to show progress bars or not.

                    Only if ``attach=False``.

    # This is just so we throw a DB exception early and not wait for fetching
    # segment Ids first
    _ = get_connection(db)

    assert isinstance(x, (navis.BaseNeuron, navis.NeuronList))
    assert ret in ("catmaid", "brief", "full")

    if not isinstance(x, navis.NeuronList):
        x = navis.NeuronList([x])

    # Get segments for this neuron(s)
    seg_ids = google.neuron_to_segments(x)

    # Drop segments with overlap below threshold
    if ol_thresh:
        seg_ids = seg_ids.loc[seg_ids.max(axis=1) >= ol_thresh]

    # We will make sure that every segment ID is only attributed to a single
    # neuron
    not_top = seg_ids.values != seg_ids.max(axis=1).values.reshape(
        (seg_ids.shape[0], 1))
    where_not_top = np.where(not_top)
    if np.any(where_not_top[0]):
        seg_ids.values[where_not_top] = None  # do not change this to 0

    # Fetch pre- and postsynapses associated with these segments
    # It's cheaper to get them all in one go
    syn = query_synapses(seg_ids.index.values,
                         ret=ret if ret != 'catmaid' else 'brief',

    # Drop autapses - they are most likely wrong
    if drop_autapses:
        syn = syn[syn.segmentid_pre != syn.segmentid_post]

    if drop_duplicates:
        dupl_thresh = 250
        # Deal with pre- and postsynapses separately
        while True:
            # Generate pairs of pre- and postsynaptic coordinates that are
            # suspiciously close
            pre_tree = cKDTree(syn[['pre_x', 'pre_y', 'pre_z']].values)
            post_tree = cKDTree(syn[['post_x', 'post_y', 'post_z']].values)
            pre_pairs = pre_tree.query_pairs(r=dupl_thresh)
            post_pairs = post_tree.query_pairs(r=dupl_thresh)

            # We will consider pairs for removal where both pre- OR postsynapse
            # are close - this is easy to change via below operator
            pairs = pre_pairs | post_pairs  # union of both
            pairs = np.array(list(pairs))

            # For each pair check if they connect the same IDs
            same_pre = syn.iloc[pairs[:, 0]].segmentid_pre.values == syn.iloc[
                pairs[:, 1]].segmentid_pre.values
            same_post = syn.iloc[pairs[:,
                                       0]].segmentid_post.values == syn.iloc[
                                           pairs[:, 1]].segmentid_post.values
            same_cn = same_pre & same_post

            # If no more pairs to collapse break
            if same_cn.sum() == 0:

            # Generate a graph from pairs
            G = nx.Graph()

            # Find the minimum number of nodes we need to remove
            # to separate the connectors
            to_rm = []
            for cn in nx.connected_components(G):
                to_rm += list(nx.minimum_node_cut(nx.subgraph(G, cn)))

            # Drop those nodes
            syn = syn.drop(index=syn.index.values[to_rm])
        # Reset index
        syn.reset_index(drop=True, inplace=True)

    if collapse_connectors:
        # Make fake IDs
        syn['connector_id'] = np.arange(syn.shape[0]).astype(np.int32)

    # Now associate synapses with neurons
    tables = []
    for c in tqdm(seg_ids.columns,
                  desc='Proc. neurons',
                  disable=not progress or seg_ids.shape[1] == 1,
        this_segs = seg_ids.loc[seg_ids[c].notnull(), c]
        is_pre = syn.segmentid_pre.isin(this_segs.index.values)
        is_post = syn.segmentid_post.isin(this_segs.index.values)

        # At this point we might see the exact same connection showing up in
        # `is_pre` and in `is_post`. This happens when we mapped both the
        # pre- and the postsynaptic segment to this neuron - likely an error.
        # In these cases we have to decide whether our neuron is truely pre-
        # or postsynaptic. For this we will use the overlap counts:
        # First find connections that would show up twice
        is_dupl = is_pre & is_post
        if any(is_dupl):
            dupl = syn[is_dupl]
            # Next get the overlap counts for the pre- and postsynaptic seg IDs
            dupl_pre_ol = seg_ids.loc[dupl.segmentid_pre, c].values
            dupl_post_ol = seg_ids.loc[dupl.segmentid_post, c].values

            # We go for the one with more overlap
            true_pre = dupl_pre_ol > dupl_post_ol

            # Propagate that decision
            is_pre[is_dupl] = true_pre
            is_post[is_dupl] = ~true_pre

        # Now get our synapses
        this_pre = syn[is_pre]
        this_post = syn[is_post]

        # Keep only one connector per presynapse
        # -> just like in CATMAID connector tables
        # Postsynaptic connectors will still show up multiple times
        if collapse_connectors:
            this_pre = this_pre.drop_duplicates('connector_id')

        # Combine pre- and postsynapses and keep track of the type
        connectors = pd.concat([this_pre, this_post],
        connectors['type'] = 'post'
                        connectors.columns.get_loc('type')] = 'pre'
        connectors['type'] = connectors['type'].astype('category')

        # Rename columns such that x/y/z corresponds to presynaptic sites
            'pre_x': 'x',
            'pre_y': 'y',
            'pre_z': 'z'

        # For CATMAID-like connector tables subset to relevant columns
        if ret == 'catmaid':
            connectors = connectors[[
                'connector_id', 'x', 'y', 'z', 'cleft_scores', 'type'

        # Map connectors to nodes
        # Note that this is where we enforce `dist_thresh`
        neuron = x.idx[c]
        tree = navis.neuron2KDTree(neuron)
        dist, ix = tree.query(connectors[['x', 'y', 'z']].values,

        # Drop far away connectors
        connectors = connectors.loc[dist < np.inf]

        # Assign node IDs
        connectors['node_id'] = neuron.nodes.iloc[ix[
            dist < np.inf]].node_id.values

        # Somas can end up having synapses, which we know is wrong and is
        # relatively easy to fix
        if np.any(neuron.soma):
            somata = navis.utils.make_iterable(neuron.soma)
            s_locs = neuron.nodes.loc[neuron.nodes.node_id.isin(somata),
                                      ['x', 'y', 'z']].values
            # Find all nodes within 2 micron around the somas
            soma_node_ix = tree.query_ball_point(s_locs, r=2000)
            soma_node_ix = [n for l in soma_node_ix for n in l]
            soma_node_id = neuron.nodes.iloc[soma_node_ix].node_id.values

            # Drop connectors attached to these soma nodes
            connectors = connectors[~connectors.node_id.isin(soma_node_id)]

        if attach:
            neuron.connectors = connectors.reset_index(drop=True)
            connectors['neuron'] = neuron.id  # do NOT change the type of this

    if not attach:
        connectors = pd.concat(tables, axis=0,
        return connectors
コード例 #11
def skeletonize_neuron(x,
    """Skeletonize FlyWire neuron.

    Note that this is optimized to be primarily fast which comes at the cost
    of (some) quality.

    x  :                 int | trimesh.TriMesh | list thereof
                         ID(s) or trimesh of the FlyWire neuron(s) you want to
    shave_skeleton :     bool
                         If True, we will "shave" the skeleton by removing all
                         single-node terminal twigs. This should get rid of
                         hairs on the backbone that can occur if the neurites
                         are very big.
    remove_soma_hairball : bool
                         If True, we will try to drop the hairball that is
                         typically created inside the soma. Note that while this
                         should work just fine for 99% of neurons, it's not very
                         smart and there is always a chance that we remove stuff
                         that should not have been removed. Also only works if
                         the neuron has a recognizable soma.
    assert_id_match :    bool
                         If True, will check if skeleton nodes map to the
                         correct segment ID and if not will move them back into
                         the segment. This is potentially very slow!
    dataset :            str | CloudVolume
                         Against which FlyWire dataset to query::
                           - "production" (current production dataset, fly_v31)
                           - "sandbox" (i.e. fly_v26)
    progress :           bool
                         Whether to show a progress bar or not.

    skeleton :          navis.TreeNeuron
                        The extracted skeleton.

    See Also
                        Use this if you want to skeletonize many neurons in

    >>> from fafbseg import flywire
    >>> n = flywire.skeletonize_neuron(720575940614131061)


    if int(sk.__version__.split('.')[0]) < 1:
        raise ImportError('Please update skeletor to version >= 1.0.0: '
                          'pip3 install skeletor -U')

    vol = parse_volume(dataset)

    if navis.utils.is_iterable(x):
        return navis.NeuronList([
            for n in navis.config.tqdm(
                x, desc='Skeletonizing', disable=not progress, leave=False)

    if not navis.utils.is_mesh(x):
        vol = parse_volume(dataset)

        # Make sure this is a valid integer
        id = int(x)

        # Download the mesh
        mesh = vol.mesh.get(id,
        mesh = x
        id = getattr(mesh, 'segid', 0)

    mesh = sk.utilities.make_trimesh(mesh, validate=False)

    # Fix things before we skeletonize
    # This also drops fluff
    mesh = sk.pre.fix_mesh(mesh, inplace=True, remove_disconnected=100)

    # Skeletonize
    defaults = dict(waves=1, step_size=1)
    s = sk.skeletonize.by_wavefront(mesh, progress=progress, **defaults)

    # Skeletor indexes node IDs at zero but to avoid potential issues we want
    # node IDs to start at 1
    s.swc['node_id'] += 1
    s.swc.loc[s.swc.parent_id >= 0, 'parent_id'] += 1

    # We will also round the radius and make it an integer to save some
    # memory. We could do the same with x/y/z coordinates but that could
    # potentially move nodes outside the mesh
    s.swc['radius'] = s.swc.radius.round().astype(int)

    # Turn into a neuron
    tn = navis.TreeNeuron(s.swc, units='1 nm', id=id, soma=None)

    if shave_skeleton:
        # Get branch points
        bp = tn.nodes.loc[tn.nodes.type == 'branch', 'node_id'].values

        # Get single-node twigs
        is_end = tn.nodes.type == 'end'
        parent_is_bp = tn.nodes.parent_id.isin(bp)
        twigs = tn.nodes.loc[is_end & parent_is_bp, 'node_id'].values

        # Drop terminal twigs
        tn._nodes = tn.nodes.loc[~tn.nodes.node_id.isin(twigs)].copy()

    # See if we can find a soma
    soma = detect_soma_skeleton(tn, min_rad=800, N=3)
    if soma:
        tn.soma = soma

        # Reroot to soma
        tn.reroot(tn.soma, inplace=True)

        if remove_soma_hairball:
            soma = tn.nodes.set_index('node_id').loc[soma]
            soma_loc = soma[['x', 'y', 'z']].values

            # Find all nodes within 2x the soma radius
            tree = navis.neuron2KDTree(tn)
            ix = tree.query_ball_point(soma_loc, max(4000, soma.radius * 2))

            # Translate indices into node IDs
            ids = tn.nodes.iloc[ix].node_id.values

            # Find segments that contain these nodes
            segs = [s for s in tn.segments if any(np.isin(ids, s))]

            # Sort segs by length
            segs = sorted(segs, key=lambda x: len(x))

            # Keep only the longest segment in that initial list
            to_drop = np.array([n for s in segs[:-1] for n in s])
            to_drop = to_drop[~np.isin(to_drop, segs[-1] + [soma.name])]

            navis.remove_nodes(tn, to_drop, inplace=True)

    if assert_id_match:
        if id == 0:
            raise ValueError('Segmentation ID must not be 0')
        new_locs = snap_to_id(tn.nodes[['x', 'y', 'z']].values,
        tn.nodes[['x', 'y', 'z']] = new_locs

    return tn
コード例 #12
ファイル: merge_utils.py プロジェクト: tomka/fafbseg-py
def collapse_nodes(A, B, limit=1, base_neuron=None, mesh=None):
    """Merge neuron A into neuron(s) B creating a union of both.

    This implementation uses edge contraction on the neurons' graph to ensure
    maximum connectivity. Only works if the fragments collectively form a
    continuous tree (i.e. you must be certain that they partially overlap).

    A :                 CatmaidNeuron
                        Neuron to be collapsed into neurons B.
    B :                 CatmaidNeuronList
                        Neurons to collapse neuron A into.
    limit :             int, optional
                        Max distance [microns] for nearest neighbour search.
    base_neuron :       skeleton ID | CatmaidNeuron, optional
                        Neuron from B to use as template for union. If not
                        provided, the first neuron in the list is used as
    mesh :              navis.Volume | navis.MeshNeuron | mesh-like object
                        If provided, will use the mesh to check if nodes are
                        in line of sight to each other before collapsing them.

                        Union of all input neurons.
    new_edges :         pandas.DataFrame
                        Subset of the ``.nodes`` table that represent newly
                        added edges.
    collapsed_nodes :   dict
                        Map of collapsed nodes::

                            NodeA -collapsed-into-> NodeB

    if isinstance(A, navis.NeuronList):
        if len(A) == 1:
            A = A[0]
            A = navis.stitch_neurons(A, method="NONE")

    if not isinstance(A, navis.TreeNeuron):
        raise TypeError('`A` must be a TreeNeuron, got "{}"'.format(type(A)))

    if isinstance(B, navis.TreeNeuron):
        B = navis.NeuronList(B)

    if not isinstance(B, navis.NeuronList):
        raise TypeError('`B` must be a NeuronList, got "{}"'.format(type(B)))

    if isinstance(base_neuron, type(None)):
        base_neuron = B[0]

    # This is just check on the off-chance that skeleton IDs are not unique
    # (e.g. if neurons come from different projects) -> this is relevant because
    # we identify the master ("base_neuron") via it's skeleton ID
    skids = [n.id for n in B + A]
    if len(skids) > len(set(skids)):
        raise ValueError(
            'Duplicate skeleton IDs found. Try manually assigning '
            'unique skeleton IDs.')

    # Convert distance threshold from microns to nanometres
    limit *= 1000

    # Before we start messing around, let's make sure we can keep track of
    # the origin of each node
    for n in B + A:
        n.nodes['origin_skeletons'] = n.id

    # First make a weak union by simply combining the node tables
    B.neurons = sorted(B.neurons, key=lambda x: {base_neuron: 0}.get(x, 2))
    union_simple = navis.stitch_neurons(B + A, method='NONE', master='FIRST')

    # Check for duplicate node IDs
    if any(union_simple.nodes.node_id.duplicated()):
        raise ValueError('Duplicate node IDs found.')

    # Find nodes in A to be merged into B
    tree = scipy.spatial.cKDTree(data=B.nodes[['x', 'y', 'z']].values)

    # For each node in A get the nearest neighbor in B
    coords = A.nodes[['x', 'y', 'z']].values
    nn_dist, nn_ix = tree.query(coords, k=1, distance_upper_bound=limit)

    # Find nodes that are close enough to collapse
    collapsed = A.nodes.loc[nn_dist <= limit].node_id.values
    clps_into = B.nodes.iloc[nn_ix[nn_dist <= limit]].node_id.values

    # If we have a mesh, check if those collapsed nodes are in sight of each
    # other
    if mesh:
        import ncollpyde
        coll = ncollpyde.Volume(mesh.vertices, mesh.faces)

        # Produce start and end coordinates for the to collapse nodes
        starts = A.nodes.set_index('node_id').loc[collapsed,
                                                  ['x', 'y', 'z']].values
        ends = B.nodes.set_index('node_id').loc[clps_into,
                                                ['x', 'y', 'z']].values

        # Check if the line between start and end intersects the mesh
        intersects, _, _ = coll.intersections(starts, ends)

        # Indices that show up in `intersects` cross a membrane
        # -> we need to invert this to get those nodes that don't
        not_intersects = ~np.isin(np.arange(starts.shape[0]), intersects)

        # Keep only collapses that don't intersect
        collapsed = collapsed[not_intersects]
        clps_into = clps_into[not_intersects]

    # Generate a map of which node in A is to be collapsed into which node in B
    clps_map = {n1: n2 for n1, n2 in zip(collapsed, clps_into)}

    # The fastest way to collapse is to work on the edge list
    E = nx.to_pandas_edgelist(union_simple.graph)

    # Keep track of which edges were collapsed -> we will use this as weight
    # later on to prioritize existing edges over newly generated ones
    E['is_new'] = 1
    source_in_B = E.source.isin(B.nodes.node_id.values)
    target_in_B = E.target.isin(B.nodes.node_id.values)
    E.loc[source_in_B | target_in_B, 'is_new'] = 0

    # Now map collapsed nodes onto the nodes they collapsed into
    E['target'] = E.target.map(lambda x: clps_map.get(x, x))
    E['source'] = E.source.map(lambda x: clps_map.get(x, x))

    # Make sure no self loops after collapsing. This happens if two adjacent
    # nodes collapse onto the same target node
    E = E[E.source != E.target]

    # Remove duplicates. This happens e.g. when two adjaceny nodes merge into
    # two other adjaceny nodes: A->B C->D ----> A/B->C/D
    # By sorting first, we make sure original edges are kept first
    E.sort_values('is_new', ascending=True, inplace=True)

    # Because edges may exist in both directions (A->B and A<-B) we have to
    # generate a column that's agnostic to directionality using frozensets
    E['frozen_edge'] = E[['source', 'target']].apply(frozenset, axis=1)
    E.drop_duplicates(['frozen_edge'], keep='first', inplace=True)

    # Regenerate graph from these new edges
    G = nx.Graph()
    G.add_weighted_edges_from(E[['source', 'target',

    # At this point there might still be disconnected pieces -> we will create
    # separate neurons for each tree
    props = union_simple.nodes.loc[union_simple.nodes.node_id.isin(
    nx.set_node_attributes(G, props.to_dict(orient='index'))
    fragments = []
    for n in nx.connected_components(G):
        c = G.subgraph(n)
        tree = nx.minimum_spanning_tree(c)
    fragments = navis.NeuronList(fragments)

    if len(fragments) > 1:
        print('Union incomplete - watch out for disconnected fragments!')
        # Now heal those fragments using a minimum spanning tree
        union = navis.stitch_neurons(*fragments, method='ALL')
        union = fragments[0]

    # Reroot to base neuron's root
    union.reroot(base_neuron.root[0], inplace=True)

    # Add tags back on
    if union_simple.has_tags:
        if not union.has_tags:
            union.tags = {}

    # Add connectors back on
    union.connectors = union_simple.connectors.drop_duplicates(
    union.connectors.loc[:, 'node_id'] = union.connectors.node_id.map(
        lambda x: clps_map.get(x, x))

    # Find the newly added edges (existing edges should not have been modified
    # - except for changing direction due to reroot)
    # The basic logic here is that new edges were only added between two
    # previously separate skeletons, i.e. where the skeleton ID changes between
    # parent and child node
    node2skid = union_simple.nodes.set_index(
    union.nodes['parent_skeleton'] = union.nodes.parent_id.map(node2skid)
    new_edges = union.nodes[
        union.nodes.origin_skeletons != union.nodes.parent_skeleton]
    # Remove root edges
    new_edges = new_edges[~new_edges.parent_id.isnull()]

    return union, new_edges, clps_map