Exemple #1
0
    def swc_from_data(self, ims_data):
        def get_head_ims(segments):
            heads = []
            for i, segment in enumerate(segments):
                if not segment[0] in list(segments[:, 1]):
                    heads.append(segments[i][0])
            return np.unique(np.array(heads))

        gen_node_attr = lambda verts, col: {
            index: vert[col]
            for index, vert in enumerate(verts)
        }
        segments = np.array(
            ims_data['Scene']["Content"]['Filaments0']['Graphs']['Segments'])
        vertices = np.array(
            ims_data['Scene']["Content"]['Filaments0']['Graphs']['Vertices'])
        swc = nx.DiGraph()
        swc.add_edges_from(segments)
        swc.add_nodes_from(np.arange(len(vertices)))
        head = get_head_ims(segments)
        if len(head) == 1:
            head = head[0]
        swc = nx.dfs_tree(swc, head)
        attrs = ['x', 'y', 'z', 'radius', 'label']
        for i, attr in enumerate(attrs):
            attr_dict = gen_node_attr(vertices, i)
            nx.set_node_attributes(swc, attr_dict, attr)
        return navis.TreeNeuron(swc)
Exemple #2
0
def pymaid_to_navis(x: Union[pymaid.core.CatmaidNeuron,
                             pymaid.core.CatmaidNeuronList]):
    """
    Takes pymaid/CatmaidNeuron and topologically sorts the nodes

    Paramters
    ---------
    x:                a pymaid/Catmaid neuron object

    return_skdata:    bool
                    whether to return a list of node_ids topologically sorted
                    or to return a dict where the keys are treenodes and the values
                    are ranking in the topological sort

    Returns
    --------
    x: list or dict

    Examples
    --------
    """

    x = check_valid_pymaid_input(x)

    x.nodes["rank"] = x.nodes.treenode_id.map(
        pymaid_topological_sort(x, return_object="dict"))
    x.nodes.sort_values(by=["rank"], ascending=True, inplace=True)

    # Getting the topological sort of the pymaid neuron
    x_graph = x.graph
    navis_neuron = navis.TreeNeuron(x_graph)

    # Populating the xyz columns of nodes

    for i, j in enumerate(x.nodes[["x", "y", "z"]].values):

        navis_neuron.nodes.loc[i, "x"] = j[0]
        navis_neuron.nodes.loc[i, "y"] = j[1]
        navis_neuron.nodes.loc[i, "z"] = j[2]

    # adding type column
    navis_neuron.nodes["type"] = x.nodes.type.copy()

    # adding connectors
    navis_neuron.connectors = x.connectors.copy()
    navis_neuron.connectors = [
        "pre" if i == 0 else "post" for i in navis_neuron.connectors.type
    ]

    # adding soma & name

    navis_neuron.soma = x.soma
    navis_neuron.name = x.neuron_name

    return navis_neuron
Exemple #3
0
def l2_skeleton(root_id, refine=False, drop_missing=True,
                threads=10, progress=True, dataset='production', **kwargs):
    """Generate skeleton from L2 graph.

    Parameters
    ----------
    root_id  :          int | list of ints
                        Root ID(s) of the flywire neuron(s) you want to
                        skeletonize.
    refine :            bool
                        If True, will refine skeleton nodes by moving them in
                        the center of their corresponding chunk meshes.

    Only relevant if ``refine=True``:

    drop_missing :      bool
                        If True, will drop nodes that don't have a corresponding
                        chunk mesh. These are typically chunks that are very
                        small and dropping them might actually be benefitial.
    threads :           int
                        How many parallel threads to use for fetching the
                        chunk meshes. Reduce the number if you run into
                        ``HTTPErrors``. Only relevant if `use_flycache=False`.
    progress :          bool
                        Whether to show a progress bar.

    Returns
    -------
    skeleton :          navis.TreeNeuron
                        The extracted skeleton.

    Examples
    --------
    >>> from fafbseg import flywire
    >>> n = flywire.l2_skeleton(720575940614131061)

    """
    # TODO:
    # - drop duplicate nodes in unrefined skeleton
    # - use L2 graph to find soma: highest degree is typically the soma

    use_flycache = kwargs.get('use_flycache', False)

    if refine and use_flycache and dataset != 'production':
        raise ValueError('Unable to use fly cache to fetch L2 centroids for '
                         'sandbox dataset. Please set `use_flycache=False`.')

    if navis.utils.is_iterable(root_id):
        nl = []
        for id in navis.config.tqdm(root_id, desc='Skeletonizing',
                                    disable=not progress, leave=False):
            n = l2_skeleton(id, refine=refine, drop_missing=drop_missing,
                            threads=threads, progress=progress, dataset=dataset)
            nl.append(n)
        return navis.NeuronList(nl)

    # Get the cloudvolume
    vol = parse_volume(dataset)

    # Hard-coded datastack names
    ds = {"production": "flywire_fafb_production",
          "sandbox": "flywire_fafb_sandbox"}
    # Note that the default server url is https://global.daf-apis.com/info/
    client = FrameworkClient(ds.get(dataset, dataset))

    # Load the L2 graph for given root ID
    # This is a (N,2) array of edges
    l2_eg = np.array(client.chunkedgraph.level2_chunk_graph(root_id))

    # Drop duplicate edges
    l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0)

    # Unique L2 IDs
    l2_ids = np.unique(l2_eg)

    # ID to index
    l2dict = {l2: ii for ii, l2 in enumerate(l2_ids)}

    # Remap edge graph to indices
    eg_arr_rm = fastremap.remap(l2_eg, l2dict)

    coords = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids]
    coords = np.vstack(coords)

    # This turns the graph into a hierarchal tree by removing cycles and
    # ensuring all edges point towards a root
    if sk.__version_vector__[0] < 1:
        G = sk.skeletonizers.edges_to_graph(eg_arr_rm)
        swc = sk.skeletonizers.make_swc(G, coords=coords)
    else:
        G = sk.skeletonize.utils.edges_to_graph(eg_arr_rm)
        swc = sk.skeletonize.utils.make_swc(G, coords=coords, reindex=False)

    # Convert to Euclidian space
    # Dimension of a single chunk
    ch_dims = chunks_to_nm([1, 1, 1], vol) - chunks_to_nm([0, 0, 0], vol)
    ch_dims = np.squeeze(ch_dims)

    xyz = swc[['x', 'y', 'z']].values
    swc[['x', 'y', 'z']] = chunks_to_nm(xyz, vol) + ch_dims / 2

    if refine:
        if use_flycache:
            token = get_chunkedgraph_secret()
            centroids = spine.flycache.get_L2_centroids(l2_ids,
                                                        token=token,
                                                        progress=progress)

            # Drop missing (i.e. [0,0,0]) meshes
            centroids = {k: v for k, v in centroids.items() if v != [0, 0, 0]}
        else:
            # Get the centroids
            centroids = get_L2_centroids(l2_ids, vol, threads=threads, progress=progress)

        new_co = {l2dict[k]: v for k, v in centroids.items()}

        # Map refined coordinates onto the SWC
        has_new = swc.node_id.isin(new_co)
        swc.loc[has_new, 'x'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][0])
        swc.loc[has_new, 'y'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][1])
        swc.loc[has_new, 'z'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][2])

        # Turn into a proper neuron
        tn = navis.TreeNeuron(swc, id=root_id, units='1 nm')

        # Drop nodes that are still at their unrefined chunk position
        if drop_missing:
            tn = navis.remove_nodes(tn, swc.loc[~has_new, 'node_id'].values)
    else:
        tn = navis.TreeNeuron(swc, id=root_id, units='1 nm')

    return tn
def skeletonize_neuron(x,
                       shave_skeleton=True,
                       remove_soma_hairball=False,
                       assert_id_match=False,
                       dataset='production',
                       progress=True,
                       **kwargs):
    """Skeletonize FlyWire neuron.

    Note that this is optimized to be primarily fast which comes at the cost
    of (some) quality.

    Parameters
    ----------
    x  :                 int | trimesh.TriMesh | list thereof
                         ID(s) or trimesh of the FlyWire neuron(s) you want to
                         skeletonize.
    shave_skeleton :     bool
                         If True, we will "shave" the skeleton by removing all
                         single-node terminal twigs. This should get rid of
                         hairs on the backbone that can occur if the neurites
                         are very big.
    remove_soma_hairball : bool
                         If True, we will try to drop the hairball that is
                         typically created inside the soma. Note that while this
                         should work just fine for 99% of neurons, it's not very
                         smart and there is always a chance that we remove stuff
                         that should not have been removed. Also only works if
                         the neuron has a recognizable soma.
    assert_id_match :    bool
                         If True, will check if skeleton nodes map to the
                         correct segment ID and if not will move them back into
                         the segment. This is potentially very slow!
    dataset :            str | CloudVolume
                         Against which FlyWire dataset to query::
                           - "production" (current production dataset, fly_v31)
                           - "sandbox" (i.e. fly_v26)
    progress :           bool
                         Whether to show a progress bar or not.

    Return
    ------
    skeleton :          navis.TreeNeuron
                        The extracted skeleton.

    See Also
    --------
    :func:`fafbseg.flywire.skeletonize_neuron_parallel`
                        Use this if you want to skeletonize many neurons in
                        parallel.

    Examples
    --------
    >>> from fafbseg import flywire
    >>> n = flywire.skeletonize_neuron(720575940614131061)

    """

    if int(sk.__version__.split('.')[0]) < 1:
        raise ImportError('Please update skeletor to version >= 1.0.0: '
                          'pip3 install skeletor -U')

    vol = parse_volume(dataset)

    if navis.utils.is_iterable(x):
        return navis.NeuronList([
            skeletonize_neuron(n,
                               progress=False,
                               remove_soma_hairball=remove_soma_hairball,
                               assert_id_match=assert_id_match,
                               dataset=dataset,
                               **kwargs)
            for n in navis.config.tqdm(
                x, desc='Skeletonizing', disable=not progress, leave=False)
        ])

    if not navis.utils.is_mesh(x):
        vol = parse_volume(dataset)

        # Make sure this is a valid integer
        id = int(x)

        # Download the mesh
        mesh = vol.mesh.get(id,
                            deduplicate_chunk_boundaries=False,
                            remove_duplicate_vertices=True)[id]
    else:
        mesh = x
        id = getattr(mesh, 'segid', 0)

    mesh = sk.utilities.make_trimesh(mesh, validate=False)

    # Fix things before we skeletonize
    # This also drops fluff
    mesh = sk.pre.fix_mesh(mesh, inplace=True, remove_disconnected=100)

    # Skeletonize
    defaults = dict(waves=1, step_size=1)
    defaults.update(kwargs)
    s = sk.skeletonize.by_wavefront(mesh, progress=progress, **defaults)

    # Skeletor indexes node IDs at zero but to avoid potential issues we want
    # node IDs to start at 1
    s.swc['node_id'] += 1
    s.swc.loc[s.swc.parent_id >= 0, 'parent_id'] += 1

    # We will also round the radius and make it an integer to save some
    # memory. We could do the same with x/y/z coordinates but that could
    # potentially move nodes outside the mesh
    s.swc['radius'] = s.swc.radius.round().astype(int)

    # Turn into a neuron
    tn = navis.TreeNeuron(s.swc, units='1 nm', id=id, soma=None)

    if shave_skeleton:
        # Get branch points
        bp = tn.nodes.loc[tn.nodes.type == 'branch', 'node_id'].values

        # Get single-node twigs
        is_end = tn.nodes.type == 'end'
        parent_is_bp = tn.nodes.parent_id.isin(bp)
        twigs = tn.nodes.loc[is_end & parent_is_bp, 'node_id'].values

        # Drop terminal twigs
        tn._nodes = tn.nodes.loc[~tn.nodes.node_id.isin(twigs)].copy()
        tn._clear_temp_attr()

    # See if we can find a soma
    soma = detect_soma_skeleton(tn, min_rad=800, N=3)
    if soma:
        tn.soma = soma

        # Reroot to soma
        tn.reroot(tn.soma, inplace=True)

        if remove_soma_hairball:
            soma = tn.nodes.set_index('node_id').loc[soma]
            soma_loc = soma[['x', 'y', 'z']].values

            # Find all nodes within 2x the soma radius
            tree = navis.neuron2KDTree(tn)
            ix = tree.query_ball_point(soma_loc, max(4000, soma.radius * 2))

            # Translate indices into node IDs
            ids = tn.nodes.iloc[ix].node_id.values

            # Find segments that contain these nodes
            segs = [s for s in tn.segments if any(np.isin(ids, s))]

            # Sort segs by length
            segs = sorted(segs, key=lambda x: len(x))

            # Keep only the longest segment in that initial list
            to_drop = np.array([n for s in segs[:-1] for n in s])
            to_drop = to_drop[~np.isin(to_drop, segs[-1] + [soma.name])]

            navis.remove_nodes(tn, to_drop, inplace=True)

    if assert_id_match:
        if id == 0:
            raise ValueError('Segmentation ID must not be 0')
        new_locs = snap_to_id(tn.nodes[['x', 'y', 'z']].values,
                              id=id,
                              snap_zero=False,
                              dataset=dataset,
                              search_radius=160,
                              coordinates='nm',
                              max_workers=4,
                              verbose=True)
        tn.nodes[['x', 'y', 'z']] = new_locs

    return tn
Exemple #5
0
def skeletonize_neuron(x,
                       drop_soma_hairball=True,
                       contraction_kws={},
                       skeletonization_kws={},
                       radius_kws={},
                       assert_id_match=False,
                       dataset='production'):
    """Skeletonize flywire neuron.

    Parameters
    ----------
    x  :                 int | trimesh.TriMesh
                         ID or trimesh of the flywire neuron you want to
                         skeletonize.
    drop_soma_hairball : bool
                         If True, we will try to drop the hairball that is
                         typically created inside the soma.
    contraction_kws :    dict
                         Optional parameters for the contraction phase. See
                         ``skeletor.contract``.
    skeletonization_kws : dict
                         Optional parameters for the skeletonization phase. See
                         ``skeletor.skeletonize``.
    radius_kws :         dict
                         Optional parameters for the radius extraction phase.
                         See ``skeletor.radius``.
    assert_id_match :    bool
                         If True, will check if skeleton nodes map to the
                         correct segment ID and if not will move them back into
                         the segment. This is potentially very slow!
    dataset :            str | CloudVolume
                         Against which flywire dataset to query::
                           - "production" (current production dataset, fly_v31)
                           - "sandbox" (i.e. fly_v26)

    Return
    ------
    skeleton, simpified_mesh, contracted_mesh
                        The extraced skeleton, simplified and contracted mesh,
                        respectively.

    Examples
    --------
    >>> tn, simp, cntr = flywire.skeletonize_flywire_neuron(720575940614131061)

    """
    if not sk:
        raise ImportError('Must install skeletor: pip3 install skeletor')

    if not navis.utils.is_mesh(x):
        vol = parse_volume(dataset)

        # Make sure this is a valid integer
        id = int(x)

        # Download the mesh
        mesh = vol.mesh.get(id, deduplicate_chunk_boundaries=False)[id]
    else:
        mesh = x
        id = getattr(mesh, 'segid', 0)

    # Simplify
    simp = sk.simplify(mesh, ratio=.2)

    # Validate before we detect the soma verts
    simp = sk.utilities.fix_mesh(simp, inplace=True)

    # Try detecting the soma
    if drop_soma_hairball:
        soma_verts = detect_soma(simp)

    # Contract
    defaults = dict(SL=40, WH0=2, epsilon=0.1, precision=1e-7, validate=False)
    defaults.update(contraction_kws)
    cntr = sk.contract(simp, **defaults)

    # Generate skeleton
    defaults = dict(method='vertex_clusters',
                    sampling_dist=200,
                    vertex_map=True,
                    validate=False)
    defaults.update(skeletonization_kws)
    swc = sk.skeletonize(cntr, **defaults)

    # Clean up
    cleaned = sk.clean(swc, mesh=mesh, validate=False)

    # Extract radii
    defaults = dict(validate=False)
    defaults.update(radius_kws)
    cleaned['radius'] = sk.radii(cleaned, mesh=mesh, **defaults)

    # Convert to neuron
    n = navis.TreeNeuron(cleaned, id=id, units='nm', soma=None)

    # Drop any nodes that are soma vertices
    if drop_soma_hairball and soma_verts.shape[0] > 0:
        keep = n.nodes.loc[~n.nodes.vertex_id.isin(soma_verts),
                           'node_id'].values
        n = navis.subset_neuron(n, keep)

    if assert_id_match:
        if id == 0:
            raise ValueError('Segmentation ID must not be 0')
        new_locs = snap_to_id(n.nodes[['x', 'y', 'z']].values,
                              id=id,
                              snap_zero=False,
                              dataset=dataset,
                              search_radius=160,
                              coordinates='nm',
                              max_workers=4,
                              verbose=True)
        n.nodes[['x', 'y', 'z']] = new_locs

    return (n, simp, cntr)