Esempio n. 1
0
    def prune_graph(self):
        import graspologic.utils as gu
        from pynets.statistics.individual.algorithms import defragment, \
            prune_small_components, most_important

        hardcoded_params = utils.load_runconfig()

        if int(self.prune) not in range(0, 4):
            raise ValueError(f"Pruning option {self.prune} invalid!")

        if self.prune != 0:
            # Remove isolates
            G_tmp = self.G.copy()
            self.G = defragment(G_tmp)[0]
            del G_tmp

        if int(self.prune) == 1:
            try:
                self.G = prune_small_components(
                    self.G, min_nodes=hardcoded_params["min_nodes"][0])
            except BaseException:
                print(
                    UserWarning(f"Warning: pruning {self.est_path} "
                                f"failed..."))
        elif int(self.prune) == 2:
            try:
                hub_detection_method = \
                hardcoded_params["hub_detection_method"][0]
                print(f"Filtering for hubs on the basis of "
                      f"{hub_detection_method}...\n")
                self.G = most_important(self.G, method=hub_detection_method)[0]
            except FileNotFoundError as e:
                import sys
                print(e, "Failed to parse advanced.yaml")

        elif int(self.prune) == 3:
            print("Pruning all but the largest connected "
                  "component subgraph...")
            self.G = gu.largest_connected_component(self.G)
        else:
            print("No graph defragmentation applied...")

        self.G = nx.from_numpy_array(self.in_mat)

        if nx.is_empty(self.G) is True or \
            (np.abs(self.in_mat) < 0.0000001).all() or \
                self.G.number_of_edges() == 0:
            print(
                UserWarning(f"Warning: {self.est_path} "
                            f"empty after pruning!"))
            return self.in_mat, None

        # Saved pruned
        if (self.prune != 0) and (self.prune is not None):
            final_mat_path = f"{self.est_path.split('.npy')[0]}{'_pruned'}"
            utils.save_mat(self.in_mat, final_mat_path, self.out_fmt)
            print(f"{'Source File: '}{final_mat_path}")

        return self.in_mat, final_mat_path
Esempio n. 2
0
def build_masetome(est_path_iterlist, ID):
    """
    Embeds structural-functional graph pairs into a common invariant subspace.

    Parameters
    ----------
    est_path_iterlist : list
        List of list of pairs of file paths (.npy) corresponding to
        structural and functional connectomes matched at a given node
        resolution.
    ID : str
        A subject id or other unique identifier.

    References
    ----------
    .. [1] Rosenthal, G., Váša, F., Griffa, A., Hagmann, P., Amico, E., Goñi, J.,
      Sporns, O. (2018). Mapping higher-order relations between brain structure
      and function with embedded vector representations of connectomes.
      Nature Communications. https://doi.org/10.1038/s41467-018-04614-w

    """
    import numpy as np
    from pynets.core.utils import prune_suffices
    from pynets.stats.embeddings import _mase_embed
    from pynets.core.utils import load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        n_components = hardcoded_params["gradients"]["n_components"][0]
    except KeyError:
        import sys
        print("ERROR: available gradient dimensionality presets not "
              "sucessfully extracted from runconfig.yaml")
        sys.exit(1)

    out_paths = []
    for pairs in est_path_iterlist:
        pop_list = []
        for _file in pairs:
            pop_list.append(np.load(_file))
        atlas = prune_suffices(pairs[0].split("/")[-3])
        res = prune_suffices("_".join(pairs[0].split("/")[-1].split("modality")
                                      [1].split("_")[1:]).split("_est")[0])
        if "rsn" in res:
            subgraph = res.split("rsn-")[1].split('_')[0]
        else:
            subgraph = "all_nodes"
        out_path = _mase_embed(pop_list,
                               atlas,
                               pairs[0],
                               ID,
                               subgraph_name=subgraph,
                               n_components=n_components)
        out_paths.append(out_path)

    return out_paths
Esempio n. 3
0
def sfm_mod_est(gtab, data, B0_mask, BACKEND='loky'):
    """
    Estimate a Sparse Fascicle Model (SFM) from dwi data.

    Parameters
    ----------
    gtab : Obj
        DiPy object storing diffusion gradient information.
    data : array
        4D numpy array of diffusion image data.
    B0_mask : str
        File path to B0 brain mask.

    Returns
    -------
    sf_mod : ndarray
        Coefficients of the sfm reconstruction.
    model : obj
        Fitted sf model.

    References
    ----------
    .. [1] Ariel Rokem, Jason D. Yeatman, Franco Pestilli, Kendrick
      N. Kay, Aviv Mezer, Stefan van der Walt, Brian A. Wandell
      (2015). Evaluating the accuracy of diffusion MRI models in white
      matter. PLoS ONE 10(4): e0123272. doi:10.1371/journal.pone.0123272
    .. [2] Ariel Rokem, Kimberly L. Chan, Jason D. Yeatman, Franco
      Pestilli,  Brian A. Wandell (2014). Evaluating the accuracy of diffusion
      models at multiple b-values with cross-validation. ISMRM 2014.

    """
    from dipy.data import get_sphere
    import dipy.reconst.sfm as sfm
    from pynets.core.utils import load_runconfig

    sphere = get_sphere("repulsion724")
    print("Reconstructing using SFM...")

    hardcoded_params = load_runconfig()
    nthreads = hardcoded_params["omp_threads"][0]

    model = sfm.SparseFascicleModel(gtab,
                                    sphere=sphere,
                                    l1_ratio=0.5,
                                    alpha=0.001)

    sf_mod = model.fit(data,
                       mask=np.nan_to_num(np.asarray(
                           nib.load(B0_mask).dataobj)).astype("bool"),
                       num_processes=nthreads,
                       parallel_backend=BACKEND)
    sf_odf = sf_mod.odf(sphere)
    sf_odf = np.clip(sf_odf, 0, np.max(sf_odf, -1)[..., None])
    return sf_odf.astype("float32"), model
Esempio n. 4
0
def build_asetomes(est_path_iterlist, ID):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.
    ID : str
        A subject id or other unique identifier.

    """
    import numpy as np
    from pynets.core.utils import prune_suffices, flatten
    from pynets.stats.embeddings import _ase_embed
    from pynets.core.utils import load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        n_components = hardcoded_params["gradients"]["n_components"][0]
    except KeyError:
        import sys
        print("ERROR: available gradient dimensionality presets not "
              "sucessfully extracted from runconfig.yaml")
        sys.exit(1)

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    out_paths = []
    for file_ in est_path_iterlist:
        mat = np.load(file_)
        atlas = prune_suffices(file_.split("/")[-3])
        res = prune_suffices("_".join(
            file_.split("/")[-1].split("modality")[1].split("_")[1:]).split(
                "_est")[0])
        if "rsn" in res:
            subgraph = res.split("rsn-")[1].split('_')[0]
        else:
            subgraph = "all_nodes"
        out_path = _ase_embed(mat,
                              atlas,
                              file_,
                              ID,
                              subgraph_name=subgraph,
                              n_components=n_components)
        out_paths.append(out_path)

    return out_paths
Esempio n. 5
0
    def __init__(
        self,
        net_parcels_nii_path,
        node_size,
        conf,
        func_file,
        roi,
        dir_path,
        ID,
        network,
        smooth,
        hpass,
        mask,
        extract_strategy,
    ):
        import sys
        import yaml
        import pkg_resources
        self.net_parcels_nii_path = net_parcels_nii_path
        self.node_size = node_size
        self.conf = conf
        self.func_file = func_file
        self.roi = roi
        self.dir_path = dir_path
        self.ID = ID
        self.network = network
        self.smooth = smooth
        self.mask = mask
        self.hpass = hpass
        self.extract_strategy = extract_strategy
        self.ts_within_nodes = None
        self._mask_img = None
        self._mask_path = None
        self._func_img = None
        self._t_r = None
        self._detrending = True
        self._net_parcels_nii_temp_path = None
        self._net_parcels_map_nifti = None
        self._parcel_masker = None

        from pynets.core.utils import load_runconfig
        hardcoded_params = load_runconfig()
        try:
            self.low_pass = hardcoded_params["low_pass"][0]
        except KeyError as e:
            print(e,
                  "ERROR: Plotting configuration not successfully extracted "
                  "from runconfig.yaml"
                  )
Esempio n. 6
0
    def __init__(
        self,
        net_parcels_nii_path,
        node_radius,
        conf,
        func_file,
        roi,
        dir_path,
        ID,
        subnet,
        smooth,
        hpass,
        mask,
        signal,
    ):
        self.net_parcels_nii_path = net_parcels_nii_path
        self.node_radius = node_radius
        self.conf = conf
        self.func_file = func_file
        self.roi = roi
        self.dir_path = dir_path
        self.ID = ID
        self.subnet = subnet
        self.smooth = smooth
        self.mask = mask
        self.hpass = hpass
        self.signal = signal
        self.ts_within_nodes = None
        self._mask_img = None
        self._mask_path = None
        self._func_img = None
        self._t_r = None
        self._detrending = True
        self._net_parcels_nii_temp_path = None
        self._net_parcels_map_nifti = None
        self._parcel_masker = None

        from pynets.core.utils import load_runconfig
        hardcoded_params = load_runconfig()
        try:
            self.low_pass = hardcoded_params["low_pass"][0]
        except KeyError as e:
            print(
                e, "ERROR: Plotting configuration not successfully "
                "extracted from advanced.yaml")
Esempio n. 7
0
File: track.py Progetto: dPys/PyNets
def track_ensemble(target_samples,
                   atlas_data_wm_gm_int,
                   labels_im_file,
                   recon_path,
                   sphere,
                   traversal,
                   curv_thr_list,
                   step_list,
                   track_type,
                   maxcrossing,
                   roi_neighborhood_tol,
                   min_length,
                   waymask,
                   B0_mask,
                   t1w2dwi,
                   gm_in_dwi,
                   vent_csf_in_dwi,
                   wm_in_dwi,
                   tiss_class,
                   BACKEND='threading'):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI
    masks.

    Parameters
    ----------
    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : str
        File path to Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a
        Nifti1Image in T1w-warped native diffusion space.
    recon_path : str
        File path to diffusion reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    traversal : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel
        while tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask_data : ndarray
        Tractography constraint mask array in native diffusion space.
    B0_mask_data : ndarray
        B0 brain mask data.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique
        ensemble combination. By default this is set to 250.
    max_length : int
        Maximum number of steps to restrict tracking.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is
        set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    min_separation_angle : float
        The minimum angle between directions [0, 90].

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.

    References
    ----------
    .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016).
      Ensemble Tractography. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.1004692
    """
    import os
    import gc
    import time
    import warnings
    import time
    import tempfile
    from joblib import Parallel, delayed, Memory
    import itertools
    import pickle5 as pickle
    from pynets.dmri.track import run_tracking
    from colorama import Fore, Style
    from pynets.dmri.utils import generate_sl
    from nibabel.streamlines.array_sequence import concatenate, ArraySequence
    from pynets.core.utils import save_3d_to_4d
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img
    from pynets.core.utils import load_runconfig
    from dipy.tracking import utils

    warnings.filterwarnings("ignore")

    pickle.HIGHEST_PROTOCOL = 5
    joblib_dir = tempfile.mkdtemp()
    os.makedirs(joblib_dir, exist_ok=True)

    hardcoded_params = load_runconfig()
    nthreads = hardcoded_params["omp_threads"][0]
    os.environ['MKL_NUM_THREADS'] = str(nthreads)
    os.environ['OPENBLAS_NUM_THREADS'] = str(nthreads)
    n_seeds_per_iter = \
        hardcoded_params['tracking']["n_seeds_per_iter"][0]
    max_length = \
        hardcoded_params['tracking']["max_length"][0]
    pft_back_tracking_dist = \
        hardcoded_params['tracking']["pft_back_tracking_dist"][0]
    pft_front_tracking_dist = \
        hardcoded_params['tracking']["pft_front_tracking_dist"][0]
    particle_count = \
        hardcoded_params['tracking']["particle_count"][0]
    min_separation_angle = \
        hardcoded_params['tracking']["min_separation_angle"][0]
    min_streams = \
        hardcoded_params['tracking']["min_streams"][0]
    seeding_mask_thr = hardcoded_params['tracking']["seeding_mask_thr"][0]
    timeout = hardcoded_params['tracking']["track_timeout"][0]

    all_combs = list(itertools.product(step_list, curv_thr_list))

    # Construct seeding mask
    seeding_mask = f"{os.path.dirname(labels_im_file)}/seeding_mask.nii.gz"
    if waymask is not None and os.path.isfile(waymask):
        waymask_img = math_img(f"img > {seeding_mask_thr}",
                               img=nib.load(waymask))
        waymask_img.to_filename(waymask)
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                waymask_img,
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)
    else:
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)

    tissues4d = save_3d_to_4d([
        B0_mask, labels_im_file, seeding_mask, t1w2dwi, gm_in_dwi,
        vent_csf_in_dwi, wm_in_dwi
    ])

    # Commence Ensemble Tractography
    start = time.time()
    stream_counter = 0

    all_streams = []
    ix = 0

    memory = Memory(location=joblib_dir, mmap_mode='r+', verbose=0)
    os.chdir(f"{memory.location}/joblib")

    @memory.cache
    def load_recon_data(recon_path):
        import h5py
        with h5py.File(recon_path, 'r') as hf:
            recon_data = hf['reconstruction'][:].astype('float32')
        hf.close()
        return recon_data

    recon_shelved = load_recon_data.call_and_shelve(recon_path)

    @memory.cache
    def load_tissue_data(tissues4d):
        return nib.load(tissues4d)

    tissue_shelved = load_tissue_data.call_and_shelve(tissues4d)

    try:
        while float(stream_counter) < float(target_samples) and \
                float(ix) < 0.50*float(len(all_combs)):
            with Parallel(n_jobs=nthreads,
                          backend=BACKEND,
                          mmap_mode='r+',
                          verbose=0) as parallel:

                out_streams = parallel(
                    delayed(run_tracking)
                    (i, recon_shelved, n_seeds_per_iter, traversal,
                     maxcrossing, max_length, pft_back_tracking_dist,
                     pft_front_tracking_dist, particle_count,
                     roi_neighborhood_tol, min_length, track_type,
                     min_separation_angle, sphere, tiss_class, tissue_shelved)
                    for i in all_combs)

                out_streams = list(filter(None, out_streams))

                if len(out_streams) > 1:
                    out_streams = concatenate(out_streams, axis=0)
                else:
                    continue

                if waymask is not None and os.path.isfile(waymask):
                    try:
                        out_streams = out_streams[utils.near_roi(
                            out_streams,
                            np.eye(4),
                            np.asarray(
                                nib.load(waymask).dataobj).astype("bool"),
                            tol=int(round(roi_neighborhood_tol * 0.50, 1)),
                            mode="all")]
                    except BaseException:
                        print(f"\n{Fore.RED}No streamlines generated in "
                              f"waymask vacinity\n")
                        print(Style.RESET_ALL)
                        return None

                if len(out_streams) < min_streams:
                    ix += 1
                    print(f"\n{Fore.YELLOW}Fewer than {min_streams} "
                          f"streamlines tracked "
                          f"on last iteration...\n")
                    print(Style.RESET_ALL)
                    if ix > 5:
                        print(f"\n{Fore.RED}No streamlines generated\n")
                        print(Style.RESET_ALL)
                        return None
                    continue
                else:
                    ix -= 1

                stream_counter += len(out_streams)
                all_streams.extend([generate_sl(i) for i in out_streams])
                del out_streams

                print("%s%s%s%s" % (
                    "\nCumulative Streamline Count: ",
                    Fore.CYAN,
                    stream_counter,
                    "\n",
                ))
                gc.collect()
                print(Style.RESET_ALL)

                if time.time() - start > timeout:
                    print(f"\n{Fore.RED}Warning: Tractography timed "
                          f"out: {time.time() - start}")
                    print(Style.RESET_ALL)
                    memory.clear(warn=False)
                    return None

    except RuntimeError as e:
        print(f"\n{Fore.RED}Error: Tracking failed due to:\n{e}\n")
        print(Style.RESET_ALL)
        memory.clear(warn=False)
        return None

    print("Tracking Complete: ", str(time.time() - start))

    memory.clear(warn=False)

    del parallel, all_combs
    gc.collect()

    if stream_counter != 0:
        print('Generating final ...')
        return ArraySequence([ArraySequence(i) for i in all_streams])
    else:
        print(f"\n{Fore.RED}No streamlines generated!")
        print(Style.RESET_ALL)
        return None
Esempio n. 8
0
def plot_conn_mat_struct(conn_matrix, conn_model, atlas, dir_path, ID, network,
                         labels, roi, thr, node_size, target_samples,
                         track_type, directget, min_length, error_margin):
    """
    API for selecting among various structural connectivity matrix plotting
    approaches.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    conn_model : str
       Connectivity estimation model (e.g. corr for correlation, cov for
       covariance, sps for precision covariance, partcorr for partial
       correlation). sps type is used by default.
    atlas : str
        Name of atlas parcellation used.
    dir_path : str
        Path to directory containing subject derivative data for given run.
    ID : str
        A subject id or other unique identifier.
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming
        (e.g. 'Default') used to filter nodes in the study of brain subgraphs.
    labels : list
        List of string labels corresponding to ROI nodes.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    thr : float
        A value, between 0 and 1, to threshold the graph using any variety of
        methods triggered through other options.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's.
    target_samples : int
        Total number of streamline samples specified to generate streams.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    directget : str
        The statistical approach to tracking. Options are:
        det (deterministic), closest (clos), boot (bootstrapped), and prob
        (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    """
    import matplotlib.pyplot as plt
    from pynets.core.utils import load_runconfig
    import sys
    from pynets.plotting import plot_graphs
    import networkx as nx
    import os.path as op

    out_path_fig = \
        "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % \
        (dir_path,
         "/adjacency_",
         ID,
         "_modality-dwi_",
         "%s" % ("%s%s%s" % ("rsn-",
                             network,
                             "_") if network is not None else ""),
         "%s" % ("%s%s%s" % ("roi-",
                             op.basename(roi).split(".")[0],
                             "_") if roi is not None else ""),
         "model-",
         conn_model,
         "_",
         "%s" % ("%s%s%s" % ("nodetype-spheres-",
                             node_size,
                             "mm_") if (
             (node_size != "parc") and (
                 node_size is not None)) else "nodetype-parc_"),
         "%s" % ("%s%s%s" % ("samples-",
                             int(target_samples),
                             "streams_") if float(target_samples) > 0
                 else "_"),
         "tracktype-",
         track_type,
         "_directget-",
         directget,
         "_minlength-",
         min_length,
         "_tol-",
         error_margin,
         "_thr-",
         thr,
         ".png",
         )

    hardcoded_params = load_runconfig()
    try:
        cmap_name = hardcoded_params["plotting"]["structural"]["adjacency"][
            "color_theme"][0]
    except KeyError as e:
        print(
            e, "Plotting configuration not successfully extracted from"
            " runconfig.yaml")

    plot_graphs.plot_conn_mat(conn_matrix,
                              labels,
                              out_path_fig,
                              cmap=plt.get_cmap(cmap_name))

    # Plot community adj. matrix
    try:
        from pynets.stats.netstats import community_resolution_selection

        G = nx.from_numpy_matrix(np.abs(conn_matrix))
        _, node_comm_aff_mat, resolution, num_comms = \
            community_resolution_selection(G)
        out_path_fig_comm = \
            "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" \
            % (dir_path,
               "/adjacency-communities_",
               ID,
               "_modality-dwi_",
               "%s" % ("%s%s%s" % ("rsn-",
                                   network,
                                   "_") if network is not None else ""),
               "%s" % ("%s%s%s" % ("roi-",
                                   op.basename(roi).split(".")[0],
                                   "_") if roi is not None else ""),
               "model-",
               conn_model,
               "_",
               "%s" % ("%s%s%s" % ("nodetype-spheres-",
                                   node_size,
                                   "mm_") if (
                   (node_size != "parc") and (
                       node_size is not None)) else "nodetype-parc_"),
               "%s" % ("%s%s%s" % ("samples-",
                                   int(target_samples),
                                   "streams_") if float(target_samples) > 0
                       else "_"),
               "tracktype-",
               track_type,
               "_directget-",
               directget,
               "_minlength-",
               min_length,
               "_tol-",
               error_margin,
               "_thr-",
               thr,
               ".png",
               )
        plot_graphs.plot_community_conn_mat(
            conn_matrix,
            labels,
            out_path_fig_comm,
            node_comm_aff_mat,
            cmap=plt.get_cmap(cmap_name),
        )
    except BaseException:
        print("\nWARNING: Louvain community detection failed. Cannot plot"
              " community matrix...")

    return
Esempio n. 9
0
def plot_conn_mat_func(
    conn_matrix,
    conn_model,
    atlas,
    dir_path,
    ID,
    network,
    labels,
    roi,
    thr,
    node_size,
    smooth,
    hpass,
    extract_strategy,
):
    """
    API for selecting among various functional connectivity matrix plotting
    approaches.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    conn_model : str
       Connectivity estimation model (e.g. corr for correlation, cov for
       covariance, sps for precision covariance, partcorr for partial
       correlation). sps type is used by default.
    atlas : str
        Name of atlas parcellation used.
    dir_path : str
        Path to directory containing subject derivative data for given run.
    ID : str
        A subject id or other unique identifier.
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g.
        'Default') used to filter nodes in the study of brain subgraphs.
    labels : list
        List of string labels corresponding to ROI nodes.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    thr : float
        A value, between 0 and 1, to threshold the graph using any variety of
        methods triggered through other options.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's.
    smooth : int
        Smoothing width (mm fwhm) to apply to time-series when extracting
        signal from ROI's.
    hpass : bool
        High-pass filter values (Hz) to apply to node-extracted time-series.
    extract_strategy : str
        The name of a valid function used to reduce the time-series region
        extraction.
    """
    import matplotlib.pyplot as plt
    from pynets.core.utils import load_runconfig
    import sys
    import networkx as nx
    import os.path as op
    from pynets.plotting import plot_graphs

    out_path_fig = \
        "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % \
        (dir_path,
         "/adjacency_",
         ID,
         "_modality-func_",
         "%s" % ("%s%s%s" % ("rsn-",
                             network,
                             "_") if network is not None else ""),
         "%s" % ("%s%s%s" % ("roi-",
                             op.basename(roi).split(".")[0],
                             "_") if roi is not None else ""),
         "model-",
         conn_model,
         "_",
         "%s" % ("%s%s%s" % ("nodetype-spheres-",
                             node_size,
                             "mm_") if (
             (node_size != "parc") and (
                 node_size is not None)) else "nodetype-parc_"),
         "%s" % ("%s%s%s" % ("smooth-",
                             smooth,
                             "fwhm_") if float(smooth) > 0 else ""),
         "%s" % ("%s%s%s" % ("hpass-",
                             hpass,
                             "Hz_") if hpass is not None else ""),
         "%s" % ("%s%s%s" % ("extract-",
                             extract_strategy,
                             "") if extract_strategy is not None else ""),
         "_thr-",
         thr,
         ".png",
         )

    hardcoded_params = load_runconfig()
    try:
        cmap_name = hardcoded_params["plotting"]["functional"]["adjacency"][
            "color_theme"][0]
    except KeyError as e:
        print(
            e, "Plotting configuration not successfully extracted from"
            " runconfig.yaml")

    plot_graphs.plot_conn_mat(conn_matrix,
                              labels,
                              out_path_fig,
                              cmap=plt.get_cmap(cmap_name))

    # Plot community adj. matrix
    try:
        from pynets.stats.netstats import community_resolution_selection

        G = nx.from_numpy_matrix(np.abs(conn_matrix))
        _, node_comm_aff_mat, resolution, num_comms = \
            community_resolution_selection(G)
        out_path_fig_comm = \
            "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % \
            (dir_path,
             "/adjacency-communities_",
             ID,
             "_modality-func_",
             "%s" % ("%s%s%s" % ("rsn-",
                                 network,
                                 "_") if network is not None else ""),
             "%s" % ("%s%s%s" % ("roi-",
                                 op.basename(roi).split(".")[0],
                                 "_") if roi is not None else ""),
             "model-",
             conn_model,
             "_",
             "%s" % ("%s%s%s" % ("nodetype-spheres-",
                                 node_size,
                                 "mm_") if (
                 (node_size != "parc") and (
                     node_size is not None)) else "nodetype-parc_"),
             "%s" % ("%s%s%s" % ("smooth-",
                                 smooth,
                                 "fwhm_") if float(smooth) > 0 else ""),
             "%s" % ("%s%s%s" % ("hpass-",
                                 hpass,
                                 "Hz_") if hpass is not None else ""),
             "%s" % ("%s%s%s" % ("extract-",
                                 extract_strategy,
                                 "") if extract_strategy is not None else ""),
             "_thr-",
             thr,
             ".png",
             )
        plot_graphs.plot_community_conn_mat(
            conn_matrix,
            labels,
            out_path_fig_comm,
            node_comm_aff_mat,
            cmap=plt.get_cmap(cmap_name),
        )
    except BaseException:
        print("\nWARNING: Louvain community detection failed. Cannot plot "
              "community matrix...")

    return
Esempio n. 10
0
def track_ensemble(target_samples, atlas_data_wm_gm_int, labels_im_file,
                   recon_path, sphere, directget, curv_thr_list, step_list,
                   track_type, maxcrossing, roi_neighborhood_tol, min_length,
                   waymask, B0_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                   wm_in_dwi, tiss_class, cache_dir):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI
    masks.

    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : str
        File path to Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a
        Nifti1Image in T1w-warped native diffusion space.
    recon_path : str
        File path to diffusion reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel
        while tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask_data : ndarray
        Tractography constraint mask array in native diffusion space.
    B0_mask_data : ndarray
        B0 brain mask data.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique
        ensemble combination. By default this is set to 250.
    max_length : int
        Maximum number of steps to restrict tracking.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is
        set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    min_separation_angle : float
        The minimum angle between directions [0, 90].

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.

    References
    ----------
    .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016).
      Ensemble Tractography. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.1004692

    """
    import os
    import gc
    import time
    import warnings
    from joblib import Parallel, delayed
    import itertools
    from pynets.dmri.track import run_tracking
    from colorama import Fore, Style
    from pynets.dmri.utils import generate_sl
    from nibabel.streamlines.array_sequence import concatenate, ArraySequence
    from pynets.core.utils import save_3d_to_4d
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img
    from pynets.core.utils import load_runconfig
    warnings.filterwarnings("ignore")

    tmp_files_dir = f"{cache_dir}/tmp_files"
    joblib_dir = f"{cache_dir}/joblib_tracking"
    os.makedirs(tmp_files_dir, exist_ok=True)
    os.makedirs(joblib_dir, exist_ok=True)

    hardcoded_params = load_runconfig()
    nthreads = hardcoded_params["nthreads"][0]
    n_seeds_per_iter = \
        hardcoded_params['tracking']["n_seeds_per_iter"][0]
    max_length = \
        hardcoded_params['tracking']["max_length"][0]
    pft_back_tracking_dist = \
        hardcoded_params['tracking']["pft_back_tracking_dist"][0]
    pft_front_tracking_dist = \
        hardcoded_params['tracking']["pft_front_tracking_dist"][0]
    particle_count = \
        hardcoded_params['tracking']["particle_count"][0]
    min_separation_angle = \
        hardcoded_params['tracking']["min_separation_angle"][0]
    min_streams = \
        hardcoded_params['tracking']["min_streams"][0]
    timeout = hardcoded_params['tracking']["track_timeout"][0]

    all_combs = list(itertools.product(step_list, curv_thr_list))

    # Construct seeding mask
    seeding_mask = f"{tmp_files_dir}/seeding_mask.nii.gz"
    if waymask is not None and os.path.isfile(waymask):
        waymask_img = math_img("img > 0.0075", img=nib.load(waymask))
        waymask_img.to_filename(waymask)
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                waymask_img,
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)
    else:
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)

    tissues4d = save_3d_to_4d([
        B0_mask, labels_im_file, seeding_mask, t1w2dwi, gm_in_dwi,
        vent_csf_in_dwi, wm_in_dwi
    ])

    # Commence Ensemble Tractography
    start = time.time()
    stream_counter = 0

    all_streams = []
    ix = 0

    try:
        while float(stream_counter) < float(target_samples) and \
                float(ix) < 0.50*float(len(all_combs)):
            with Parallel(n_jobs=nthreads,
                          backend='loky',
                          mmap_mode='r+',
                          temp_folder=joblib_dir,
                          verbose=0,
                          timeout=timeout) as parallel:
                out_streams = parallel(
                    delayed(run_tracking)
                    (i, recon_path, n_seeds_per_iter, directget, maxcrossing,
                     max_length, pft_back_tracking_dist,
                     pft_front_tracking_dist, particle_count,
                     roi_neighborhood_tol, waymask, min_length, track_type,
                     min_separation_angle, sphere, tiss_class, tissues4d,
                     tmp_files_dir) for i in all_combs)

                out_streams = [
                    i for i in out_streams if i is not None
                    and i is not ArraySequence() and len(i) > 0
                ]

                if len(out_streams) > 1:
                    out_streams = concatenate(out_streams, axis=0)

                if len(out_streams) < min_streams:
                    ix += 2
                    print(f"Fewer than {min_streams} streamlines tracked "
                          f"on last iteration with cache directory: "
                          f"{cache_dir}. Loosening tolerance and "
                          f"anatomical constraints. Check {tissues4d} or "
                          f"{recon_path} for errors...")
                    # if track_type != 'particle':
                    #     tiss_class = 'wb'
                    roi_neighborhood_tol = float(roi_neighborhood_tol) * 1.25
                    # min_length = float(min_length) * 0.9875
                    continue
                else:
                    ix -= 1

                # Append streamline generators to prevent exponential growth
                # in memory consumption
                all_streams.extend([generate_sl(i) for i in out_streams])
                stream_counter += len(out_streams)
                del out_streams

                print("%s%s%s%s" % (
                    "\nCumulative Streamline Count: ",
                    Fore.CYAN,
                    stream_counter,
                    "\n",
                ))
                gc.collect()
                print(Style.RESET_ALL)
        os.system(f"rm -rf {joblib_dir}/*")
    except BaseException:
        os.system(f"rm -rf {tmp_files_dir} &")
        return None

    if ix >= 0.75*len(all_combs) and \
            float(stream_counter) < float(target_samples):
        print(f"Tractography failed. >{len(all_combs)} consecutive sampling "
              f"iterations with few streamlines.")
        os.system(f"rm -rf {tmp_files_dir} &")
        return None
    else:
        os.system(f"rm -rf {tmp_files_dir} &")
        print("Tracking Complete: ", str(time.time() - start))

    del parallel, all_combs
    gc.collect()

    if stream_counter != 0:
        print('Generating final ArraySequence...')
        return ArraySequence([ArraySequence(i) for i in all_streams])
    else:
        print('No streamlines generated!')
        return None
Esempio n. 11
0
def build_multigraphs(est_path_iterlist, ID):
    """
    Constructs a multimodal multigraph for each available resolution of
    vertices.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy file containing graph.
    ID : str
        A subject id or other unique identifier.

    Returns
    -------
    multigraph_list_all : list
        List of multiplex graph dictionaries corresponding to
        each unique node resolution.
    graph_path_list_top : list
        List of lists consisting of pairs of most similar
        structural and functional connectomes for each unique node resolution.

    References
    ----------
    .. [1] Bullmore, E., & Sporns, O. (2009). Complex brain networks: Graph
      theoretical analysis of structural and functional systems.
      Nature Reviews Neuroscience. https://doi.org/10.1038/nrn2575
    .. [2] Vaiana, M., & Muldoon, S. F. (2018). Multilayer Brain Networks.
      Journal of Nonlinear Science. https://doi.org/10.1007/s00332-017-9436-8

    """
    import os
    import itertools
    import numpy as np
    from pathlib import Path
    from pynets.core.utils import flatten
    from pynets.stats.netmotifs import motif_matching
    from pynets.core.utils import load_runconfig

    raw_est_path_iterlist = list(
        set(
            [
                os.path.dirname(i) + '/raw' + os.path.basename(i).split(
                    "_thrtype")[0] + ".npy"
                for i in list(flatten(est_path_iterlist))
            ]
        )
    )

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        func_models = hardcoded_params["available_models"]["func_models"]
    except KeyError:
        print(
            "ERROR: available functional models not sucessfully extracted"
            " from runconfig.yaml"
        )
    try:
        struct_models = hardcoded_params["available_models"][
            "struct_models"]
    except KeyError:
        print(
            "ERROR: available structural models not sucessfully extracted"
            " from runconfig.yaml"
        )

    atlases = list(set([x.split("/")[-3].split("/")[0]
                        for x in raw_est_path_iterlist]))
    parcel_dict_func = dict.fromkeys(atlases)
    parcel_dict_dwi = dict.fromkeys(atlases)
    est_path_iterlist_dwi = list(
        set(
            [
                i
                for i in raw_est_path_iterlist
                if i.split("model-")[1].split("_")[0] in struct_models
            ]
        )
    )
    est_path_iterlist_func = list(
        set(
            [
                i
                for i in raw_est_path_iterlist
                if i.split("model-")[1].split("_")[0] in func_models
            ]
        )
    )

    if "_rsn" in ";".join(est_path_iterlist_func):
        func_subnets = list(
            set([i.split("_rsn-")[1].split("_")[0] for i in
                 est_path_iterlist_func])
        )
    else:
        func_subnets = []
    if "_rsn" in ";".join(est_path_iterlist_dwi):
        dwi_subnets = list(
            set([i.split("_rsn-")[1].split("_")[0] for i in
                 est_path_iterlist_dwi])
        )
    else:
        dwi_subnets = []

    dir_path = str(
        Path(
            os.path.dirname(
                est_path_iterlist_dwi[0])).parent.parent.parent)
    namer_dir = f"{dir_path}/graphs_multilayer"
    if not os.path.isdir(namer_dir):
        os.mkdir(namer_dir)

    name_list = []
    metadata_list = []
    multigraph_list_all = []
    graph_path_list_all = []
    for atlas in atlases:
        if len(func_subnets) >= 1:
            parcel_dict_func[atlas] = {}
            for sub_net in func_subnets:
                parcel_dict_func[atlas][sub_net] = []
        else:
            parcel_dict_func[atlas] = []

        if len(dwi_subnets) >= 1:
            parcel_dict_dwi[atlas] = {}
            for sub_net in dwi_subnets:
                parcel_dict_dwi[atlas][sub_net] = []
        else:
            parcel_dict_dwi[atlas] = []

        for graph_path in est_path_iterlist_dwi:
            if atlas in graph_path:
                if len(dwi_subnets) >= 1:
                    for sub_net in dwi_subnets:
                        if sub_net in graph_path:
                            parcel_dict_dwi[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_dwi[atlas].append(graph_path)

        for graph_path in est_path_iterlist_func:
            if atlas in graph_path:
                if len(func_subnets) >= 1:
                    for sub_net in func_subnets:
                        if sub_net in graph_path:
                            parcel_dict_func[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_func[atlas].append(graph_path)

        parcel_dict = {}
        # Create dictionary of all possible pairs of structural-functional
        # graphs for each unique resolution of vertices
        if len(dwi_subnets) >= 1 and len(func_subnets) >= 1:
            parcel_dict[atlas] = {}
            rsns = np.intersect1d(dwi_subnets, func_subnets).tolist()
            for rsn in rsns:
                parcel_dict[atlas][rsn] = list(set(itertools.product(
                    parcel_dict_dwi[atlas][rsn],
                    parcel_dict_func[atlas][rsn])))
                for paths in list(parcel_dict[atlas][rsn]):
                    [
                        name_list,
                        metadata_list,
                        multigraph_list_all,
                        graph_path_list_all,
                    ] = motif_matching(
                        paths,
                        ID,
                        atlas,
                        namer_dir,
                        name_list,
                        metadata_list,
                        multigraph_list_all,
                        graph_path_list_all,
                        rsn=rsn,
                    )
        else:
            parcel_dict[atlas] = list(set(itertools.product(
                parcel_dict_dwi[atlas], parcel_dict_func[atlas])))
            for paths in list(parcel_dict[atlas]):
                [
                    name_list,
                    metadata_list,
                    multigraph_list_all,
                    graph_path_list_all,
                ] = motif_matching(
                    paths,
                    ID,
                    atlas,
                    namer_dir,
                    name_list,
                    metadata_list,
                    multigraph_list_all,
                    graph_path_list_all,
                )

    graph_path_list_top = [list(i[0].values()) for i in graph_path_list_all]
    assert len(multigraph_list_all) == len(name_list) == len(metadata_list)

    return (
        multigraph_list_all,
        graph_path_list_top,
        len(name_list) * [namer_dir],
        name_list,
        metadata_list,
    )
Esempio n. 12
0
def build_masetome(est_path_iterlist, ID):
    """
    Embeds structural-functional graph pairs into a common invariant subspace.

    Parameters
    ----------
    est_path_iterlist : list
        List of list of pairs of file paths (.npy) corresponding to
        structural and functional connectomes matched at a given node
        resolution.
    ID : str
        A subject id or other unique identifier.

    References
    ----------
    .. [1] Rosenthal, G., Váša, F., Griffa, A., Hagmann, P., Amico, E., Goñi,
      J., Sporns, O. (2018). Mapping higher-order relations between brain
      structure and function with embedded vector representations of
      connectomes. Nature Communications.
      https://doi.org/10.1038/s41467-018-04614-w

    """
    from pathlib import Path
    import os
    import numpy as np
    from pynets.core.utils import prune_suffices
    from pynets.stats.embeddings import _mase_embed
    from pynets.core.utils import load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        n_components = hardcoded_params["gradients"][
            "n_components"][0]
    except KeyError:
        import sys
        print(
            "ERROR: available gradient dimensionality presets not "
            "sucessfully extracted from runconfig.yaml"
        )
        sys.exit(1)

    out_paths = []
    for pairs in est_path_iterlist:
        pop_list = []
        for _file in pairs:
            mat = np.load(_file)
            if np.isfinite(mat).all():
                pop_list.append(mat)
        if len(pop_list) != len(pairs):
            continue
        atlas = prune_suffices(pairs[0].split("/")[-3])
        res = prune_suffices("_".join(pairs[0].split(
            "/")[-1].split("modality")[1].split("_")[1:]).split("_est")[0])
        if "rsn" in res:
            subgraph = res.split("rsn-")[1].split('_')[0]
        else:
            subgraph = "all_nodes"
        out_path = _mase_embed(
            pop_list,
            atlas,
            pairs[0],
            ID,
            subgraph_name=subgraph, n_components=n_components)

        if out_path is not None:
            out_paths.append(out_path)
        else:
            # Add a null tmp file to prevent pool from breaking
            dir_path = str(Path(os.path.dirname(pairs[0])))
            namer_dir = f"{dir_path}/mplx_embeddings"
            if os.path.isdir(namer_dir) is False:
                os.makedirs(namer_dir, exist_ok=True)

            out_path = (
                f"{namer_dir}/gradient-MASE_{atlas}_{subgraph}"
                f"_{os.path.basename(pairs[0])}_NULL"
            )
            if not os.path.exists(out_path):
                os.mknod(out_path)
            out_paths.append(out_path)

    return out_paths
Esempio n. 13
0
    def tissue2dwi_align(self):
        """
        A function to perform alignment of ventricle ROI's from MNI
        space --> dwi and CSF from T1w space --> dwi. First generates and
        performs dwi space alignment of avoidance/waypoint masks for
        tractography. First creates ventricle ROI. Then creates transforms
        from stock MNI template to dwi space. For this to succeed, must first
        have called both t1w2dwi_align.
        """
        import sys
        import time
        import os.path as op
        import pkg_resources
        from pynets.core.utils import load_runconfig
        from nilearn.image import resample_to_img
        from nipype.utils.filemanip import fname_presuffix, copyfile
        from pynets.core.nodemaker import three_to_four_parcellation
        from nilearn.image import math_img, index_img

        hardcoded_params = load_runconfig()
        tiss_class = hardcoded_params['tracking']["tissue_classifier"][0]

        fa_template_path = pkg_resources.resource_filename(
            "pynets", f"templates/standard/FA_{self.vox_size}.nii.gz")

        if sys.platform.startswith('win') is False:
            try:
                fa_template_img = nib.load(fa_template_path)
            except indexed_gzip.ZranError as e:
                print(
                    e, f"\nCannot load FA template. Do you have git-lfs "
                    f"installed?")
        else:
            try:
                fa_template_img = nib.load(fa_template_path)
            except ImportError as e:
                print(e, f"\nCannot load FA template. Do you have git-lfs ")

        mni_template_img = nib.load(self.input_mni_brain)

        if not np.allclose(fa_template_img.affine, mni_template_img.affine) \
            or not \
            np.allclose(fa_template_img.shape, mni_template_img.shape):
            fa_template_img_res = resample_to_img(fa_template_img,
                                                  mni_template_img)
            nib.save(fa_template_img_res, self.fa_template_res)
        else:
            self.fa_template_res = fname_presuffix(fa_template_path,
                                                   suffix="_tmp",
                                                   newpath=op.dirname(
                                                       self.reg_path_img))
            copyfile(fa_template_path,
                     self.fa_template_res,
                     copy=True,
                     use_hardlink=False)

        # Register Lateral Ventricles and Corpus Callosum rois to t1w
        resample_to_img(nib.load(self.mni_atlas),
                        nib.load(self.input_mni_brain),
                        interpolation='nearest').to_filename(self.mni_roi_ref)

        roi_parcels = three_to_four_parcellation(self.mni_roi_ref)

        ventricle_roi = math_img("img1 + img2",
                                 img1=index_img(roi_parcels, 2),
                                 img2=index_img(roi_parcels, 13))

        self.mni_vent_loc = fname_presuffix(self.mni_vent_loc,
                                            suffix="_tmp",
                                            newpath=op.dirname(
                                                self.reg_path_img))
        ventricle_roi.to_filename(self.mni_vent_loc)
        del roi_parcels, ventricle_roi

        # Create transform from the HarvardOxford atlas in MNI to T1w.
        # This will be used to transform the ventricles to dwi space.
        regutils.align(
            self.mni_roi_ref,
            self.input_mni_brain,
            xfm=self.xfm_roi2mni_init,
            init=None,
            bins=None,
            dof=6,
            cost="mutualinfo",
            searchrad=True,
            interp="spline",
            out=None,
        )

        # Create transform to align roi to mni and T1w using flirt
        regutils.applyxfm(
            self.input_mni_brain,
            self.mni_vent_loc,
            self.xfm_roi2mni_init,
            self.vent_mask_mni,
        )
        time.sleep(0.5)
        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            regutils.apply_warp(
                self.t1w_brain,
                self.vent_mask_mni,
                self.vent_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )
            time.sleep(0.5)

            if sys.platform.startswith('win') is False:
                try:
                    nib.load(self.corpuscallosum)
                except indexed_gzip.ZranError as e:
                    print(
                        e, f"\nCannot load Corpus Callosum ROI. "
                        f"Do you have git-lfs installed?")
            else:
                try:
                    nib.load(self.corpuscallosum)
                except ImportError as e:
                    print(
                        e, f"\nCannot load Corpus Callosum ROI. "
                        f"Do you have git-lfs installed?")

            regutils.apply_warp(
                self.t1w_brain,
                self.corpuscallosum,
                self.corpuscallosum_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )
        else:
            regutils.applyxfm(self.vent_mask_mni, self.t1w_brain,
                              self.mni2t1_xfm, self.vent_mask_t1w)
            time.sleep(0.5)
            regutils.applyxfm(
                self.corpuscallosum,
                self.t1w_brain,
                self.mni2t1_xfm,
                self.corpuscallosum_mask_t1w,
            )
            time.sleep(0.5)

        # Applyxfm to map FA template image to T1w space
        regutils.applyxfm(self.t1w_brain, self.fa_template_res,
                          self.mni2t1_xfm, self.fa_template_t1w)
        time.sleep(0.5)

        # Applyxfm tissue maps to dwi space
        if self.t1w_brain_mask is not None:
            regutils.applyxfm(
                self.ap_path,
                self.t1w_brain_mask,
                self.t1wtissue2dwi_xfm,
                self.t1w_brain_mask_in_dwi,
            )
            time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.vent_mask_t1w,
                          self.t1wtissue2dwi_xfm, self.vent_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.csf_mask, self.t1wtissue2dwi_xfm,
                          self.csf_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.gm_mask, self.t1wtissue2dwi_xfm,
                          self.gm_in_dwi)
        time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.wm_mask, self.t1wtissue2dwi_xfm,
                          self.wm_in_dwi)
        time.sleep(0.5)

        regutils.applyxfm(
            self.ap_path,
            self.corpuscallosum_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.corpuscallosum_dwi,
        )
        time.sleep(0.5)

        csf_thr = 0.95
        wm_thr = 0.05
        gm_thr = 0.05

        # Threshold WM to binary in dwi space
        nib.save(math_img(f"img > {wm_thr}", img=nib.load(self.wm_in_dwi)),
                 self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        nib.save(math_img(f"img > {gm_thr}", img=nib.load(self.gm_in_dwi)),
                 self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        nib.save(math_img(f"img > {csf_thr}", img=nib.load(self.csf_mask_dwi)),
                 self.csf_mask_dwi_bin)

        # Threshold WM to binary in dwi space
        self.wm_in_dwi = regutils.apply_mask_to_image(self.wm_in_dwi,
                                                      self.wm_in_dwi_bin,
                                                      self.wm_in_dwi)

        # Threshold GM to binary in dwi space
        self.gm_in_dwi = regutils.apply_mask_to_image(self.gm_in_dwi,
                                                      self.gm_in_dwi_bin,
                                                      self.gm_in_dwi)

        # Threshold CSF to binary in dwi space
        self.csf_mask = regutils.apply_mask_to_image(self.csf_mask_dwi,
                                                     self.csf_mask_dwi_bin,
                                                     self.csf_mask_dwi)

        # Create ventricular CSF mask
        print("Creating Ventricular CSF mask...")
        math_img("(img1 + img2) > 0.0001",
                 img1=nib.load(self.csf_mask_dwi),
                 img2=nib.load(self.vent_mask_dwi)).to_filename(
                     self.vent_csf_in_dwi)

        print("Creating Corpus Callosum mask...")
        math_img("(img1*img2 - img3) > 0.0001",
                 img1=nib.load(self.corpuscallosum_dwi),
                 img2=nib.load(self.wm_in_dwi_bin),
                 img3=nib.load(self.vent_csf_in_dwi)).to_filename(
                     self.corpuscallosum_dwi)

        # Create GM-WM interface image
        math_img("((img1*img2 + img3)*img4) > 0.0001",
                 img1=nib.load(self.gm_in_dwi_bin),
                 img2=nib.load(self.wm_in_dwi_bin),
                 img3=nib.load(self.corpuscallosum_dwi),
                 img4=nib.load(self.B0_mask)).to_filename(
                     self.wm_gm_int_in_dwi)

        return
Esempio n. 14
0
    def tissue2dwi_align(self):
        """
        A function to perform alignment of ventricle ROI's from MNI
        space --> dwi and CSF from T1w space --> dwi. First generates and
        performs dwi space alignment of avoidance/waypoint masks for
        tractography. First creates ventricle ROI. Then creates transforms
        from stock MNI template to dwi space. For this to succeed, must first
        have called both t1w2dwi_align.
        """
        import sys
        import time
        import os.path as op
        import pkg_resources
        from pynets.core.utils import load_runconfig
        from nilearn.image import resample_to_img

        hardcoded_params = load_runconfig()
        tiss_class = hardcoded_params['tracking']["tissue_classifier"][0]

        fa_template_path = pkg_resources.resource_filename(
            "pynets", f"templates/FA_{self.vox_size}.nii.gz")

        if sys.platform.startswith('win') is False:
            try:
                fa_template_img = nib.load(fa_template_path)
            except indexed_gzip.ZranError as e:
                print(
                    e, f"\nCannot load FA template. Do you have git-lfs "
                    f"installed?")
        else:
            try:
                fa_template_img = nib.load(fa_template_path)
            except ImportError as e:
                print(e, f"\nCannot load FA template. Do you have git-lfs ")

        mni_template_img = nib.load(self.input_mni_brain)
        fa_template_img_res = resample_to_img(fa_template_img,
                                              mni_template_img)

        nib.save(fa_template_img_res, self.fa_template_res)

        # Register Lateral Ventricles and Corpus Callosum rois to t1w
        if not op.isfile(self.mni_atlas):
            raise FileNotFoundError("FSL atlas for ventricle reference not"
                                    " found!")

        # Create transform to MNI atlas to T1w using flirt. This will be use to
        # transform the ventricles to dwi space.
        regutils.align(
            self.mni_atlas,
            self.input_mni_brain,
            xfm=self.xfm_roi2mni_init,
            init=None,
            bins=None,
            dof=6,
            cost="mutualinfo",
            searchrad=True,
            interp="spline",
            out=None,
        )
        time.sleep(0.5)

        if sys.platform.startswith('win') is False:
            try:
                nib.load(self.mni_vent_loc)
            except indexed_gzip.ZranError as e:
                print(
                    e, f"\nCannot load ventricle ROI. Do you have git-lfs "
                    f"installed?")
        else:
            try:
                nib.load(self.mni_vent_loc)
            except ImportError as e:
                print(
                    e, f"\nCannot load ventricle ROI. Do you have git-lfs "
                    f"installed?")

        # Create transform to align roi to mni and T1w using flirt
        regutils.applyxfm(
            self.input_mni_brain,
            self.mni_vent_loc,
            self.xfm_roi2mni_init,
            self.vent_mask_mni,
        )
        time.sleep(0.5)
        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            regutils.apply_warp(
                self.t1w_brain,
                self.vent_mask_mni,
                self.vent_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )
            time.sleep(0.5)

            if sys.platform.startswith('win') is False:
                try:
                    nib.load(self.corpuscallosum)
                except indexed_gzip.ZranError as e:
                    print(
                        e, f"\nCannot load Corpus Callosum ROI. "
                        f"Do you have git-lfs installed?")
            else:
                try:
                    nib.load(self.corpuscallosum)
                except ImportError as e:
                    print(
                        e, f"\nCannot load Corpus Callosum ROI. "
                        f"Do you have git-lfs installed?")

            regutils.apply_warp(
                self.t1w_brain,
                self.corpuscallosum,
                self.corpuscallosum_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )
        else:
            regutils.applyxfm(self.vent_mask_mni, self.t1w_brain,
                              self.mni2t1_xfm, self.vent_mask_t1w)
            time.sleep(0.5)
            regutils.applyxfm(
                self.corpuscallosum,
                self.t1w_brain,
                self.mni2t1_xfm,
                self.corpuscallosum_mask_t1w,
            )
            time.sleep(0.5)

        # Applyxfm to map FA template image to T1w space
        regutils.applyxfm(self.t1w_brain, self.fa_template_res,
                          self.mni2t1_xfm, self.fa_template_t1w)
        time.sleep(0.5)

        # Applyxfm tissue maps to dwi space
        if self.t1w_brain_mask is not None:
            regutils.applyxfm(
                self.ap_path,
                self.t1w_brain_mask,
                self.t1wtissue2dwi_xfm,
                self.t1w_brain_mask_in_dwi,
            )
            time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.vent_mask_t1w,
                          self.t1wtissue2dwi_xfm, self.vent_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.csf_mask, self.t1wtissue2dwi_xfm,
                          self.csf_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.gm_mask, self.t1wtissue2dwi_xfm,
                          self.gm_in_dwi)
        time.sleep(0.5)
        regutils.applyxfm(self.ap_path, self.wm_mask, self.t1wtissue2dwi_xfm,
                          self.wm_in_dwi)
        time.sleep(0.5)

        regutils.applyxfm(
            self.ap_path,
            self.corpuscallosum_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.corpuscallosum_dwi,
        )
        time.sleep(0.5)

        if tiss_class == 'wb' or tiss_class == 'cmc':
            csf_thr = 0.50
            wm_thr = 0.15
            gm_thr = 0.10
        else:
            csf_thr = 0.99
            wm_thr = 0.10
            gm_thr = 0.075

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img = math_img(f"img > {wm_thr}", img=thr_img)
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img = math_img(f"img > {gm_thr}", img=thr_img)
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img = math_img(f"img > {csf_thr}", img=thr_img)
        nib.save(thr_img, self.csf_mask_dwi_bin)

        # Threshold WM to binary in dwi space
        self.wm_in_dwi = regutils.apply_mask_to_image(self.wm_in_dwi,
                                                      self.wm_in_dwi_bin,
                                                      self.wm_in_dwi)
        time.sleep(0.5)
        # Threshold GM to binary in dwi space
        self.gm_in_dwi = regutils.apply_mask_to_image(self.gm_in_dwi,
                                                      self.gm_in_dwi_bin,
                                                      self.gm_in_dwi)
        time.sleep(0.5)
        # Threshold CSF to binary in dwi space
        self.csf_mask = regutils.apply_mask_to_image(self.csf_mask_dwi,
                                                     self.csf_mask_dwi_bin,
                                                     self.csf_mask_dwi)
        time.sleep(0.5)
        # Create ventricular CSF mask
        print("Creating Ventricular CSF mask...")
        os.system(f"fslmaths {self.vent_mask_dwi} -kernel sphere 10 -ero "
                  f"-bin {self.vent_mask_dwi}")
        time.sleep(1)
        os.system(f"fslmaths {self.csf_mask_dwi} -add {self.vent_mask_dwi} "
                  f"-bin {self.vent_csf_in_dwi}")
        time.sleep(1)
        print("Creating Corpus Callosum mask...")
        os.system(
            f"fslmaths {self.corpuscallosum_dwi} -mas {self.wm_in_dwi_bin} "
            f"-sub {self.vent_csf_in_dwi} "
            f"-bin {self.corpuscallosum_dwi}")
        time.sleep(1)
        # Create gm-wm interface image
        os.system(f"fslmaths {self.gm_in_dwi_bin} -mul {self.wm_in_dwi_bin} "
                  f"-add {self.corpuscallosum_dwi} "
                  f"-mas {self.B0_mask} -bin {self.wm_gm_int_in_dwi}")
        time.sleep(1)
        return
Esempio n. 15
0
def direct_streamline_norm(streams, fa_path, ap_path, dir_path, track_type,
                           target_samples, conn_model, network, node_size,
                           dens_thresh, ID, roi, min_span_tree, disp_filt,
                           parc, prune, atlas, labels_im_file, uatlas, labels,
                           coords, norm, binary, atlas_t1w, basedir_path,
                           curv_thr_list, step_list, directget, min_length,
                           t1w_brain):
    """
    A Function to perform normalization of streamlines tracked in native
    diffusion space to an MNI-space template.

    Parameters
    ----------
    streams : str
        File path to save streamline array sequence in .trk format.
    fa_path : str
        File path to FA Nifti1Image.
    ap_path : str
        File path to the anisotropic power Nifti1Image.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    labels_im_file : str
        File path to atlas parcellation Nifti1Image aligned to dwi space.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_t1w : str
        File path to atlas parcellation Nifti1Image in T1w-conformed space.
    basedir_path : str
        Path to directory to output direct-streamline normalized temp files
        and outputs.
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    t1w_brain : str
        File path to the T1w Nifti1Image.

    Returns
    -------
    streams_warp : str
        File path to normalized streamline array sequence in .trk format.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_for_streams : str
        File path to atlas parcellation Nifti1Image in the same
        morphological space as the streamlines.
    directget : str
        The statistical approach to tracking. Options are: det
        (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.

    References
    ----------
    .. [1] Greene, C., Cieslak, M., & Grafton, S. T. (2017). Effect of
      different spatial normalization approaches on tractography and structural
      brain networks. Network Neuroscience, 1-19.
    """
    import sys
    import gc
    from dipy.tracking.streamline import transform_streamlines
    from pynets.registration import utils as regutils
    # from pynets.plotting import plot_gen
    import pkg_resources
    import os.path as op
    from pynets.registration.utils import vdc
    from nilearn.image import resample_to_img
    from dipy.io.streamline import load_tractogram
    from dipy.tracking import utils
    from dipy.tracking._utils import _mapping_to_voxel
    from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin
    from dipy.io.streamline import save_tractogram
    from pynets.core.utils import load_runconfig

    # from pynets.core.utils import missing_elements

    hardcoded_params = load_runconfig()
    try:
        run_dsn = hardcoded_params['tracking']["DSN"][0]
    except FileNotFoundError as e:
        print(e, "Failed to parse runconfig.yaml")

    if run_dsn is True:
        dsn_dir = f"{basedir_path}/dmri_reg/DSN"
        if not op.isdir(dsn_dir):
            os.mkdir(dsn_dir)

        namer_dir = f"{dir_path}/tractography"
        if not op.isdir(namer_dir):
            os.mkdir(namer_dir)

        atlas_img = nib.load(labels_im_file)

        # Run SyN and normalize streamlines
        fa_img = nib.load(fa_path)
        vox_size = fa_img.header.get_zooms()[0]

        atlas_for_streams = atlas_t1w

        atlas_t1w_img = nib.load(atlas_t1w)
        t1w_brain_img = nib.load(t1w_brain)
        brain_mask = np.asarray(t1w_brain_img.dataobj).astype("bool")

        streams_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/streamlines_t1w_",
            "%s" % (network + "_" if network is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "_",
            target_samples,
            "%s" % ("%s%s" % ("_" + str(node_size), "mm_") if
                    ((node_size != "parc") and
                     (node_size is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_directget-",
            directget,
            "_minlength-",
            min_length,
            ".trk",
        )

        density_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/density_map_t1w_",
            "%s" % (network + "_" if network is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "_",
            target_samples,
            "%s" % ("%s%s" % ("_" + str(node_size), "mm_") if
                    ((node_size != "parc") and
                     (node_size is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_directget-",
            directget,
            "_minlength-",
            min_length,
            ".nii.gz",
        )

        # streams_warp_png = '/tmp/dsn.png'

        # SyN FA->Template
        [mapping, affine_map,
         warped_fa] = regutils.wm_syn(t1w_brain, ap_path, dsn_dir)

        tractogram = load_tractogram(
            streams,
            fa_img,
            to_origin=Origin.NIFTI,
            to_space=Space.VOXMM,
            bbox_valid_check=False,
        )

        fa_img.uncache()
        streamlines = tractogram.streamlines
        warped_fa_img = nib.load(warped_fa)
        warped_fa_affine = warped_fa_img.affine
        warped_fa_shape = warped_fa_img.shape

        adjusted_affine = affine_map.affine.copy()
        adjusted_affine[1][3] = -adjusted_affine[1][3]
        adjusted_affine[2][3] = -adjusted_affine[2][3] * 0.95

        streams_in_curr_grid = transform_streamlines(streamlines,
                                                     warped_fa_affine)

        streams_final_filt = regutils.warp_streamlines(adjusted_affine,
                                                       fa_img.affine, mapping,
                                                       warped_fa_img,
                                                       streams_in_curr_grid,
                                                       brain_mask)

        # Remove streamlines with negative voxel indices
        lin_T, offset = _mapping_to_voxel(np.eye(4))
        streams_final_filt_final = []
        for sl in streams_final_filt:
            inds = np.dot(sl, lin_T)
            inds += offset
            if not inds.min().round(decimals=6) < 0:
                streams_final_filt_final.append(sl)

        # Save streamlines
        stf = StatefulTractogram(
            streams_final_filt_final,
            reference=atlas_t1w_img,
            space=Space.VOXMM,
            origin=Origin.NIFTI,
        )
        stf.remove_invalid_streamlines()
        streams_final_filt_final = stf.streamlines
        save_tractogram(stf, streams_t1w, bbox_valid_check=True)
        warped_fa_img.uncache()

        # DSN QC plotting
        # plot_gen.show_template_bundles(streams_final_filt_final, atlas_t1w,
        # streams_warp_png)

        # Create and save MNI density map
        nib.save(
            nib.Nifti1Image(
                utils.density_map(streams_final_filt_final,
                                  affine=np.eye(4),
                                  vol_dims=warped_fa_shape),
                warped_fa_affine,
            ),
            density_t1w,
        )

        # Map parcellation from native space back to MNI-space and create an
        # 'uncertainty-union' parcellation with original mni-space uatlas
        warped_uatlas = affine_map.transform_inverse(
            mapping.transform(
                np.asarray(atlas_img.dataobj).astype("int"),
                interpolation="nearestneighbour",
            ),
            interp="nearest",
        )
        atlas_img.uncache()
        warped_uatlas_img_res_data = np.asarray(
            resample_to_img(
                nib.Nifti1Image(warped_uatlas, affine=warped_fa_affine),
                atlas_t1w_img,
                interpolation="nearest",
                clip=False,
            ).dataobj)
        uatlas_t1w_data = np.asarray(atlas_t1w_img.dataobj)
        atlas_t1w_img.uncache()
        overlap_mask = np.invert(
            warped_uatlas_img_res_data.astype("bool") *
            uatlas_t1w_data.astype("bool"))
        os.makedirs(f"{dir_path}/parcellations", exist_ok=True)
        atlas_for_streams = f"{dir_path}/parcellations/" \
                            f"{op.basename(uatlas).split('.nii')[0]}" \
                            f"_t1w_liberal.nii.gz"

        nib.save(
            nib.Nifti1Image(
                warped_uatlas_img_res_data * overlap_mask.astype("int") +
                uatlas_t1w_data * overlap_mask.astype("int") +
                np.invert(overlap_mask).astype("int") *
                warped_uatlas_img_res_data.astype("int"),
                affine=atlas_t1w_img.affine,
            ),
            atlas_for_streams,
        )

        del (
            tractogram,
            streamlines,
            warped_uatlas_img_res_data,
            uatlas_t1w_data,
            overlap_mask,
            stf,
            streams_final_filt_final,
            streams_final_filt,
            streams_in_curr_grid,
            brain_mask,
        )

        gc.collect()

        assert len(coords) == len(labels)

    else:
        print(
            "Skipping Direct Streamline Normalization (DSN). Will proceed to "
            "define fiber connectivity in native diffusion space...")
        streams_t1w = streams
        warped_fa = fa_path
        atlas_for_streams = labels_im_file

    return (streams_t1w, dir_path, track_type, target_samples, conn_model,
            network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt,
            parc, prune, atlas, uatlas, labels, coords, norm, binary,
            atlas_for_streams, directget, warped_fa, min_length)
Esempio n. 16
0
def make_subject_dict(
    modalities, base_dir, thr_type, mets, embedding_types, template, sessions,
    rsns):
    from joblib.externals.loky import get_reusable_executor
    from joblib import Parallel, delayed
    from pynets.core.utils import mergedicts
    from pynets.core.utils import load_runconfig
    import tempfile
    import psutil
    import shutil
    import gc

    hardcoded_params = load_runconfig()
    embedding_methods = hardcoded_params["embed"]
    metaparams_func = hardcoded_params["metaparams_func"]
    metaparams_dwi = hardcoded_params["metaparams_dwi"]

    miss_frames_all = []
    subject_dict_all = {}
    modality_grids = {}
    for modality in modalities:
        print(f"MODALITY: {modality}")
        metaparams = eval(f"metaparams_{modality}")
        for alg in embedding_types:
            print(f"EMBEDDING TYPE: {alg}")
            for ses_name in sessions:
                ids = [
                    f"{os.path.basename(i)}_ses-{ses_name}"
                    for i in glob.glob(f"{base_dir}/pynets/*")
                    if os.path.basename(i).startswith("sub")
                ]

                if alg != "topology" and alg in embedding_methods:
                    df_top = None
                    ensembles = get_ensembles_embedding(modality, alg,
                                                        base_dir)
                    if ensembles is None:
                        print("No ensembles found.")
                        continue
                elif alg == "topology":
                    ensembles, df_top = get_ensembles_top(
                        modality, thr_type, f"{base_dir}/pynets"
                    )
                    if "missing" in df_top.columns:
                        df_top.drop(columns="missing", inplace=True)

                    if ensembles is None or df_top is None:
                        print("Missing topology outputs.")
                        continue
                else:
                    continue

                ensembles = list(set([i for i in ensembles if i is not None]))

                hyperparam_dict = {}

                grid = build_grid(
                    modality, hyperparam_dict, sorted(list(set(metaparams))),
                    ensembles)[1]

                grid = list(set([i for i in grid if i != () and
                                 len(list(i)) > 0]))

                modality_grids[modality] = grid

                par_dict = subject_dict_all.copy()
                cache_dir = tempfile.mkdtemp()

                with Parallel(
                    n_jobs=-1,
                    backend='loky',
                    verbose=1,
                    max_nbytes=f"{int(float(list(psutil.virtual_memory())[4]/len(ids)))}M",
                    temp_folder=cache_dir,
                ) as parallel:
                    outs_tup = parallel(
                        delayed(populate_subject_dict)(
                            id,
                            modality,
                            grid,
                            par_dict,
                            alg,
                            base_dir,
                            template,
                            thr_type,
                            embedding_methods,
                            mets,
                            df_top
                        )
                        for id in ids
                    )
                outs = [i[0] for i in outs_tup]
                miss_frames = [i[1] for i in outs_tup if not i[1].empty]
                if len(miss_frames) > 1:
                    miss_frames = pd.concat(miss_frames)
                miss_frames_all.append(miss_frames)
                for d in outs:
                    subject_dict_all = dict(mergedicts(subject_dict_all, d))
                shutil.rmtree(cache_dir, ignore_errors=True)
                get_reusable_executor().shutdown(wait=True)
                del par_dict, outs_tup, outs, df_top, miss_frames, ses_name, \
                    grid, hyperparam_dict, parallel
                gc.collect()
            del alg
        del metaparams
    del modality
    gc.collect()

    return subject_dict_all, modality_grids, miss_frames_all
Esempio n. 17
0
def create_feature_space(base_dir, df, grid_param, subject_dict, ses,
                         modality, alg, mets=None):
    from colorama import Fore, Style
    from pynets.core.utils import load_runconfig
    df_tmps = []

    hardcoded_params = load_runconfig()
    embedding_methods = hardcoded_params["embed"]

    for ID in df["participant_id"]:
        if ID not in subject_dict.keys():
            print(f"ID: {ID} not found...")
            continue

        if str(ses) not in subject_dict[ID].keys():
            print(f"Session: {ses} not found for ID {ID}...")
            continue

        if modality not in subject_dict[ID][str(ses)].keys():
            print(f"Modality: {modality} not found for ID {ID}, "
                  f"ses-{ses}...")
            continue

        if alg not in subject_dict[ID][str(ses)][modality].keys():
            print(
                f"Modality: {modality} not found for ID {ID}, ses-{ses}, "
                f"{alg}..."
            )
            continue

        if grid_param not in subject_dict[ID][str(ses)][modality][alg].keys():
            print(
                f"Grid param {grid_param} not found for ID {ID}, ses-{ses}, "
                f"{alg} and {modality}..."
            )
            continue

        if alg != "topology" and alg in embedding_methods:
            print(f"{Fore.GREEN}✓{Style.RESET_ALL} Grid Param: {grid_param} "
                  f"found for {ID}")
            df_lps = flatten_latent_positions(
                base_dir, subject_dict, ID, ses, modality, grid_param, alg
            )
        else:
            if grid_param in subject_dict[ID][str(ses)][modality][alg].keys():
                df_lps = pd.DataFrame(
                    subject_dict[ID][str(ses)][modality][alg][grid_param].T,
                    columns=mets,
                )
            else:
                df_lps = None

        if df_lps is not None:
            df_tmp = (
                df[df["participant_id"] == ID]
                .reset_index()
                .drop(columns="index")
                .join(df_lps, how="right")
            )
            df_tmps.append(df_tmp)
            del df_tmp
        else:
            print(f"Feature-space null for ID {ID} & ses-{ses}, modality: "
                  f"{modality}, embedding: {alg}...")
            continue

    if len(df_tmps) > 0:
        dfs = [dff.set_index("participant_id"
                             ) for dff in df_tmps if not dff.empty]
        df_all = pd.concat(dfs, axis=0)
        df_all = df_all.replace({0: np.nan})
        # df_all = df_all.apply(lambda x: np.where(x < 0.00001, np.nan, x))
        #print(len(df_all))
        del df_tmps
        return df_all, grid_param
    else:
        return pd.Series(np.nan), grid_param
Esempio n. 18
0
def streams2graph(atlas_for_streams, streams, dir_path, track_type, conn_model,
                  subnet, node_radius, dens_thresh, ID, roi, min_span_tree,
                  disp_filt, parc, prune, atlas, parcellation, labels, coords,
                  norm, binary, traversal, warped_fa, min_length,
                  error_margin):
    """
    Use tracked streamlines as a basis for estimating a structural connectome.

    Parameters
    ----------
    atlas_for_streams : str
        File path to atlas parcellation Nifti1Image in T1w-conformed space.
    streams : str
        File path to streamline array sequence in .trk format.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    subnet : str
        Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_radius : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone subnet' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    parcellation : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    traversal : str
        The statistical approach to tracking. Options are:
        det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    error_margin : int
        Euclidean margin of error for classifying a streamline as a connection
         to an ROI. Default is 2 voxels.

    Returns
    -------
    atlas_for_streams : str
        File path to atlas parcellation Nifti1Image in T1w-conformed space.
    streams : str
        File path to streamline array sequence in .trk format.
    conn_matrix : array
        Adjacency matrix stored as an m x n array of nodes and edges.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    dir_path : str
        Path to directory containing subject derivative data for given run.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    subnet : str
        Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_radius : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone subnet' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    parcellation : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    traversal : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    error_margin : int
        Euclidean margin of error for classifying a streamline as a connection
         to an ROI. Default is 2 voxels.

    References
    ----------
    .. [1] Sporns, O., Tononi, G., & Kötter, R. (2005). The human connectome:
      A structural description of the human brain. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.0010042
    .. [2] Sotiropoulos, S. N., & Zalesky, A. (2019). Building connectomes
      using diffusion MRI: why, how and but. NMR in Biomedicine.
      https://doi.org/10.1002/nbm.3752
    .. [3] Chung, M. K., Hanson, J. L., Adluru, N., Alexander, A. L., Davidson,
      R. J., & Pollak, S. D. (2017). Integrative Structural Brain subnet
      Analysis in Diffusion Tensor Imaging. Brain Connectivity.
      https://doi.org/10.1089/brain.2016.0481
    """
    import gc
    import time
    from dipy.tracking.streamline import Streamlines, values_from_volume
    from dipy.tracking._utils import _mapping_to_voxel, _to_voxel_coordinates
    import networkx as nx
    from itertools import combinations
    from collections import defaultdict
    from pynets.core import utils, nodemaker
    from pynets.dmri.utils import generate_sl
    from dipy.io.streamline import load_tractogram
    from dipy.io.stateful_tractogram import Space, Origin
    from pynets.core.utils import load_runconfig

    hardcoded_params = load_runconfig()
    fa_wei = hardcoded_params["StructuralNetworkWeighting"]["fa_weighting"][0]
    fiber_density = hardcoded_params["StructuralNetworkWeighting"][
        "fiber_density"][0]
    overlap_thr = hardcoded_params["StructuralNetworkWeighting"][
        "overlap_thr"][0]
    roi_neighborhood_tol = \
        hardcoded_params['tracking']["roi_neighborhood_tol"][0]

    start = time.time()

    if float(roi_neighborhood_tol) <= float(error_margin):
        raise ValueError('roi_neighborhood_tol preset cannot be less than '
                         'the value of the structural connectome error'
                         '_margin parameter.')
    else:
        print(f"Using fiber-roi intersection tolerance: {error_margin}...")

    # Load FA
    fa_img = nib.load(warped_fa)

    # Load parcellation
    roi_img = nib.load(atlas_for_streams)
    atlas_data = np.around(np.asarray(roi_img.dataobj))
    roi_zooms = roi_img.header.get_zooms()
    roi_shape = roi_img.shape

    # Read Streamlines
    if streams is not None:
        streamlines = [
            i.astype(np.float32) for i in Streamlines(
                load_tractogram(streams,
                                fa_img,
                                to_origin=Origin.NIFTI,
                                to_space=Space.VOXMM).streamlines)
        ]

        # Remove streamlines with negative voxel indices
        lin_T, offset = _mapping_to_voxel(np.eye(4))
        streams_filtered = []
        neg_vox = False
        for sl in streamlines:
            inds = np.dot(sl, lin_T)
            inds += offset
            if not inds.min().round(decimals=6) < 0:
                streams_filtered.append(sl)
            else:
                neg_vox = True

        if neg_vox is True:
            print(UserWarning("Negative voxel indices detected! " "Check FOV"))

        streamlines = streams_filtered
        del streams_filtered
        # from fury import actor, window, colormap
        # renderer = window.Renderer()
        # template_actor = actor.contour_from_roi(roi_img.get_fdata(),
        #                                         color=(50, 50, 50),
        #                                         opacity=1)
        # renderer.add(template_actor)
        # lines_actor = actor.line(streamlines,
        #                                colormap.line_colors(streamlines))
        # renderer.add(lines_actor)
        # window.show(renderer)
        #
        # roi_img.uncache()

        if fa_wei is True:
            fa_weights = values_from_volume(
                np.asarray(fa_img.dataobj, dtype=np.float32), streamlines,
                np.eye(4))
            global_fa_weights = list(utils.flatten(fa_weights))
            min_global_fa_wei = min([i for i in global_fa_weights if i > 0])
            max_global_fa_wei = max(global_fa_weights)
            fa_weights_norm = []
            # Here we normalize by global FA
            for val_list in fa_weights:
                fa_weights_norm.append(
                    np.nanmean((val_list - min_global_fa_wei) /
                               (max_global_fa_wei - min_global_fa_wei)))

        # Make streamlines into generators to keep memory at a minimum
        total_streamlines = len(streamlines)
        sl = [generate_sl(i) for i in streamlines]
        del streamlines
        gc.collect()

        # Instantiate empty networkX graph object & dictionary and create
        # voxel-affine mapping
        lin_T, offset = _mapping_to_voxel(np.eye(4))
        mx = len(np.unique(atlas_data.astype("uint16"))) - 1
        g = nx.Graph(ecount=0, vcount=mx)
        edge_dict = defaultdict(int)
        node_dict = dict(
            zip(np.unique(atlas_data.astype("uint16"))[1:],
                np.arange(mx) + 1))

        # Add empty vertices with label volume attributes
        for node in range(1, mx + 1):
            g.add_node(node,
                       roi_volume=np.sum(atlas_data.astype("uint16") == node))

        # Build graph
        pc = 0
        bad_idxs = []
        fiberlengths = {}
        fa_weights_dict = {}
        print(f"Quantifying fiber-ROI intersection for {atlas}:")
        for ix, s in enumerate(sl):
            # Percent counter
            pcN = int(round(100 * float(ix / total_streamlines)))
            if pcN % 10 == 0 and ix > 0 and pcN > pc:
                pc = pcN
                print(f"{pcN}%")

            # Map the streamlines coordinates to voxel coordinates and get
            # labels for label_volume
            s = Streamlines(s)
            if s.data.shape[0] == 0:
                continue
            vox_coords = _to_voxel_coordinates(s, lin_T, offset)

            [i, j, k] = np.vstack(
                np.array([
                    nodemaker.get_sphere(coord, error_margin, roi_zooms,
                                         roi_shape) for coord in vox_coords
                ])).T

            # get labels for label_volume
            lab_arr = atlas_data[i, j, k]
            # print(lab_arr)
            endlabels = []
            for jx, lab in enumerate(np.unique(lab_arr).astype("uint32")):
                if (lab > 0) and (np.sum(lab_arr == lab) >= overlap_thr):
                    try:
                        endlabels.append(node_dict[lab])
                    except BaseException:
                        bad_idxs.append(jx)
                        print(f"Label {lab} missing from parcellation. Check "
                              f"registration and ensure valid input "
                              f"parcellation file.")

            for edge in combinations(endlabels, 2):
                # Get fiber lengths along edge
                if fiber_density is True:
                    if not (edge[0], edge[1]) in fiberlengths.keys():
                        fiberlengths[(edge[0], edge[1])] = [len(vox_coords)]
                    else:
                        fiberlengths[(edge[0],
                                      edge[1])].append(len(vox_coords))

                # Get FA values along edge
                if fa_wei is True:
                    if not (edge[0], edge[1]) in fa_weights_dict.keys():
                        fa_weights_dict[(edge[0],
                                         edge[1])] = [fa_weights_norm[ix]]
                    else:
                        fa_weights_dict[(edge[0],
                                         edge[1])].append(fa_weights_norm[ix])

                edge_dict[tuple(sorted(tuple([int(node)
                                              for node in edge])))] += 1

            g.add_weighted_edges_from([(k[0], k[1], count)
                                       for k, count in edge_dict.items()])

            del lab_arr, endlabels
            gc.collect()

        del sl
        gc.collect()

        # Add fiber density attributes for each edge
        # Adapted from the nnormalized fiber-density estimation routines of
        # Sebastian Tourbier.
        if fiber_density is True:
            print("Redefining edges on the basis of fiber density...")
            # Summarize total fibers and total label volumes
            total_fibers = 0
            total_volume = 0
            u_start = -1
            for u, v, d in g.edges(data=True):
                total_fibers += len(d)
                if u != u_start:
                    total_volume += g.nodes[int(u)]['roi_volume']
                u_start = u

            ix = 0
            for u, v, d in g.edges(data=True):
                if d['weight'] > 0:
                    fiber_density = (float(
                        ((float(d['weight']) / float(total_fibers)) /
                         float(np.nanmean(fiberlengths[(u, v)]))) *
                        ((2.0 * float(total_volume)) /
                         (g.nodes[int(u)]['roi_volume'] +
                          g.nodes[int(v)]['roi_volume'])))) * 1000
                else:
                    fiber_density = 0
                g.edges[u, v].update({"fiber_density": fiber_density})
                ix += 1

        if fa_wei is True:
            print("Re-weighting edges by mean FA along each edge's associated "
                  "bundles...")
            # Add FA attributes for each edge
            ix = 0
            for u, v, d in g.edges(data=True):
                if d['weight'] > 0:
                    edge_average_fa = np.nanmean(fa_weights_dict[(u, v)])
                else:
                    edge_average_fa = np.nan
                g.edges[u, v].update({"fa_weight": edge_average_fa})
                ix += 1

        # Summarize weights
        if fa_wei is True and fiber_density is True:
            for u, v, d in g.edges(data=True):
                g.edges[u, v].update(
                    {"final_weight": (d['fa_weight']) * d['fiber_density']})
        elif fiber_density is True and fa_wei is False:
            for u, v, d in g.edges(data=True):
                g.edges[u, v].update({"final_weight": d['fiber_density']})
        elif fa_wei is True and fiber_density is False:
            for u, v, d in g.edges(data=True):
                g.edges[u, v].update(
                    {"final_weight": d['fa_weight'] * d['weight']})
        else:
            for u, v, d in g.edges(data=True):
                g.edges[u, v].update({"final_weight": d['weight']})

        # Convert weighted graph to numpy matrix
        conn_matrix_raw = nx.to_numpy_array(g, weight='final_weight')

        # Enforce symmetry
        conn_matrix = np.maximum(conn_matrix_raw, conn_matrix_raw.T)

        print("Structural graph completed:\n", str(time.time() - start))

        if len(bad_idxs) > 0:
            bad_idxs = sorted(list(set(bad_idxs)), reverse=True)
            for j in bad_idxs:
                del labels[j], coords[j]
    else:
        print(
            UserWarning('No valid streamlines detected. '
                        'Proceeding with an empty graph...'))
        mx = len(np.unique(atlas_data.astype("uint16"))) - 1
        conn_matrix = np.zeros((mx, mx))

    assert len(coords) == len(labels) == conn_matrix.shape[0]

    if subnet is not None:
        atlas_name = f"{atlas}_{subnet}_stage-rawgraph"
    else:
        atlas_name = f"{atlas}_stage-rawgraph"

    utils.save_coords_and_labels_to_json(coords,
                                         labels,
                                         dir_path,
                                         atlas_name,
                                         indices=None)

    coords = np.array(coords)
    labels = np.array(labels)

    if parc is True:
        node_radius = "parc"

    # Save unthresholded
    utils.save_mat(
        conn_matrix,
        utils.create_raw_path_diff(ID, subnet, conn_model, roi, dir_path,
                                   node_radius, track_type, parc, traversal,
                                   min_length, error_margin),
    )

    return (atlas_for_streams, streams, conn_matrix, track_type, dir_path,
            conn_model, subnet, node_radius, dens_thresh, ID, roi,
            min_span_tree, disp_filt, parc, prune, atlas, parcellation, labels,
            coords, norm, binary, traversal, min_length, error_margin)
Esempio n. 19
0
def build_asetomes(est_path_iterlist, ID):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.
    ID : str
        A subject id or other unique identifier.

    """
    from pathlib import Path
    import os
    import numpy as np
    from pynets.core.utils import prune_suffices, flatten
    from pynets.stats.embeddings import _ase_embed
    from pynets.core.utils import load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        n_components = hardcoded_params["gradients"][
            "n_components"][0]
    except KeyError:
        import sys
        print(
            "ERROR: available gradient dimensionality presets not "
            "sucessfully extracted from runconfig.yaml"
        )
        sys.exit(1)

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    out_paths = []
    for file_ in est_path_iterlist:
        mat = np.load(file_)
        if np.isfinite(mat).all() == False:
            continue

        atlas = prune_suffices(file_.split("/")[-3])
        res = prune_suffices("_".join(file_.split(
            "/")[-1].split("modality")[1].split("_")[1:]).split("_est")[0])
        if "rsn" in res:
            subgraph = res.split("rsn-")[1].split('_')[0]
        else:
            subgraph = "all_nodes"
        out_path = _ase_embed(mat, atlas, file_, ID, subgraph_name=subgraph,
                              n_components=n_components)
        if out_path is not None:
            out_paths.append(out_path)
        else:
            # Add a null tmp file to prevent pool from breaking
            dir_path = str(Path(os.path.dirname(file_)).parent)
            namer_dir = f"{dir_path}/embeddings"
            if os.path.isdir(namer_dir) is False:
                os.makedirs(namer_dir, exist_ok=True)
            out_path = f"{namer_dir}/gradient-ASE" \
                       f"_{atlas}_{subgraph}_{os.path.basename(file_)}_NULL"
            if not os.path.exists(out_path):
                os.mknod(out_path)
            out_paths.append(out_path)

    return out_paths
Esempio n. 20
0
def build_asetomes(est_path_iterlist):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.

    """
    from pathlib import Path
    import os
    import numpy as np
    from pynets.statistics.individual.spectral import _ase_embed
    from pynets.core.utils import prune_suffices, flatten, load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        n_components = hardcoded_params["gradients"][
            "n_components"][0]
    except KeyError:
        import sys
        print(
            "ERROR: available gradient dimensionality presets not "
            "sucessfully extracted from advanced.yaml"
        )
        sys.exit(1)

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    out_paths = []
    for file_ in est_path_iterlist:
        mat = np.load(file_)
        if np.isfinite(mat).all() == False:
            continue

        atlas = prune_suffices(file_.split("/")[-3])
        res = prune_suffices("_".join(file_.split(
            "/")[-1].split("modality")[1].split("_")[1:]).split("_est")[0])
        if "subnet" in res:
            subgraph = res.split("subnet-")[1].split('_')[0]
        else:
            subgraph = "all_nodes"

            out_path = _ase_embed(mat, atlas, file_, subgraph_name=subgraph,
                                  n_components=n_components, prune=0, norm=1)

        if out_path is not None:
            out_paths.append(out_path)
        else:
            # Add a null tmp file to prevent pool from breaking
            dir_path = str(Path(os.path.dirname(file_)).parent)
            namer_dir = f"{dir_path}/embeddings"
            if os.path.isdir(namer_dir) is False:
                os.makedirs(namer_dir, exist_ok=True)
            out_path = f"{namer_dir}/gradient-ASE" \
                       f"_subnet-{atlas}_granularity-{subgraph}_" \
                       f"{os.path.basename(file_)}_NULL"
            # TODO: Replace this band-aid solution with the real fix
            out_path = out_path.replace('subnet-subnet-',
                                        'subnet-').replace(
                'granularity-granularity-', 'granularity-')
            if not os.path.exists(out_path):
                open(out_path, 'w').close()
            out_paths.append(out_path)

    return out_paths
Esempio n. 21
0
def build_omnetome(est_path_iterlist, ID):
    """
    Embeds ensemble population of graphs into an embedded ensemble feature
    vector.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy file containing graph.
    ID : str
        A subject id or other unique identifier.

    References
    ----------
    .. [1] Liu, Y., He, L., Cao, B., Yu, P. S., Ragin, A. B., & Leow, A. D.
      (2018). Multi-view multi-graph embedding for brain network clustering
      analysis. 32nd AAAI Conference on Artificial Intelligence, AAAI 2018.
    .. [2] Levin, K., Athreya, A., Tang, M., Lyzinski, V., & Priebe, C. E.
      (2017, November). A central limit theorem for an omnibus embedding of
      multiple random dot product graphs. In Data Mining Workshops (ICDMW),
      2017 IEEE International Conference on (pp. 964-967). IEEE.

    """
    from pathlib import Path
    import sys
    import numpy as np
    from pynets.core.utils import flatten
    from pynets.stats.embeddings import _omni_embed
    from pynets.core.utils import load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()

    try:
        func_models = hardcoded_params["available_models"]["func_models"]
    except KeyError:
        print(
            "ERROR: available functional models not sucessfully extracted"
            " from runconfig.yaml"
        )
        sys.exit(1)
    try:
        struct_models = hardcoded_params["available_models"][
            "struct_models"]
    except KeyError:
        print(
            "ERROR: available structural models not sucessfully extracted"
            " from runconfig.yaml"
        )
        sys.exit(1)
    try:
        n_components = hardcoded_params["gradients"][
            "n_components"][0]
    except KeyError:
        print(
            "ERROR: available gradient dimensionality presets not "
            "sucessfully extracted from runconfig.yaml"
        )
        sys.exit(1)

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    if len(est_path_iterlist) > 1:
        atlases = list(set([x.split("/")[-3].split("/")[0]
                            for x in est_path_iterlist]))
        parcel_dict_func = dict.fromkeys(atlases)
        parcel_dict_dwi = dict.fromkeys(atlases)

        est_path_iterlist_dwi = list(
            set(
                [
                    i
                    for i in est_path_iterlist
                    if i.split("model-")[1].split("_")[0] in struct_models
                ]
            )
        )
        est_path_iterlist_func = list(
            set(
                [
                    i
                    for i in est_path_iterlist
                    if i.split("model-")[1].split("_")[0] in func_models
                ]
            )
        )

        if "_rsn" in ";".join(est_path_iterlist_func):
            func_subnets = list(
                set([i.split("_rsn-")[1].split("_")[0] for i in
                     est_path_iterlist_func])
            )
        else:
            func_subnets = []
        if "_rsn" in ";".join(est_path_iterlist_dwi):
            dwi_subnets = list(
                set([i.split("_rsn-")[1].split("_")[0] for i in
                     est_path_iterlist_dwi])
            )
        else:
            dwi_subnets = []

        out_paths_func = []
        out_paths_dwi = []
        for atlas in atlases:
            if len(func_subnets) >= 1:
                parcel_dict_func[atlas] = {}
                for sub_net in func_subnets:
                    parcel_dict_func[atlas][sub_net] = []
            else:
                parcel_dict_func[atlas] = []

            if len(dwi_subnets) >= 1:
                parcel_dict_dwi[atlas] = {}
                for sub_net in dwi_subnets:
                    parcel_dict_dwi[atlas][sub_net] = []
            else:
                parcel_dict_dwi[atlas] = []

            for graph_path in est_path_iterlist_dwi:
                if atlas in graph_path:
                    if len(dwi_subnets) >= 1:
                        for sub_net in dwi_subnets:
                            if sub_net in graph_path:
                                parcel_dict_dwi[atlas][sub_net].append(
                                    graph_path)
                    else:
                        parcel_dict_dwi[atlas].append(graph_path)

            for graph_path in est_path_iterlist_func:
                if atlas in graph_path:
                    if len(func_subnets) >= 1:
                        for sub_net in func_subnets:
                            if sub_net in graph_path:
                                parcel_dict_func[atlas][sub_net].append(
                                    graph_path)
                    else:
                        parcel_dict_func[atlas].append(graph_path)
            if len(parcel_dict_func[atlas]) > 0:
                if isinstance(parcel_dict_func[atlas], dict):
                    # RSN case
                    for rsn in parcel_dict_func[atlas]:
                        pop_rsn_list = []
                        graph_path_list = []
                        for graph in parcel_dict_func[atlas][rsn]:
                            pop_rsn_list.append(np.load(graph))
                            graph_path_list.append(graph)
                        if len(pop_rsn_list) > 1:
                            if len(
                                    list(set([i.shape for i in
                                              pop_rsn_list]))) > 1:
                                raise RuntimeWarning(
                                    "Inconsistent number of"
                                    " vertices in graph population "
                                    "that precludes embedding...")
                            out_path = _omni_embed(
                                pop_rsn_list, atlas, graph_path_list, ID, rsn,
                                n_components
                            )
                            out_paths_func.append(out_path)
                        else:
                            print(
                                "WARNING: Only one graph sampled, omnibus"
                                " embedding not appropriate."
                            )
                            pass
                else:
                    pop_list = []
                    graph_path_list = []
                    for pop_ref in parcel_dict_func[atlas]:
                        pop_list.append(np.load(pop_ref))
                        graph_path_list.append(pop_ref)
                    if len(pop_list) > 1:
                        if len(list(set([i.shape for i in pop_list]))) > 1:
                            raise RuntimeWarning(
                                "Inconsistent number of vertices in "
                                "graph population that precludes embedding")
                        out_path = _omni_embed(pop_list, atlas,
                                               graph_path_list, ID,
                                               n_components=n_components)
                        out_paths_func.append(out_path)
                    else:
                        print(
                            "WARNING: Only one graph sampled, omnibus "
                            "embedding not appropriate."
                        )
                        pass

            if len(parcel_dict_dwi[atlas]) > 0:
                if isinstance(parcel_dict_dwi[atlas], dict):
                    # RSN case
                    graph_path_list = []
                    for rsn in parcel_dict_dwi[atlas]:
                        pop_rsn_list = []
                        for graph in parcel_dict_dwi[atlas][rsn]:
                            pop_rsn_list.append(np.load(graph))
                            graph_path_list.append(graph)
                        if len(pop_rsn_list) > 1:
                            if len(
                                    list(set([i.shape for i in
                                              pop_rsn_list]))) > 1:
                                raise RuntimeWarning(
                                    "Inconsistent number of"
                                    " vertices in graph population "
                                    "that precludes embedding")
                            out_path = _omni_embed(
                                pop_rsn_list, atlas, graph_path_list, ID,
                                rsn, n_components
                            )
                            out_paths_dwi.append(out_path)
                        else:
                            print(
                                "WARNING: Only one graph sampled, omnibus"
                                " embedding not appropriate."
                            )
                            pass
                else:
                    pop_list = []
                    graph_path_list = []
                    for pop_ref in parcel_dict_dwi[atlas]:
                        pop_list.append(np.load(pop_ref))
                        graph_path_list.append(pop_ref)
                    if len(pop_list) > 1:
                        if len(list(set([i.shape for i in pop_list]))) > 1:
                            raise RuntimeWarning(
                                "Inconsistent number of vertices in graph"
                                " population that precludes embedding")
                        out_path = _omni_embed(pop_list, atlas,
                                               graph_path_list,
                                               ID, n_components=n_components)
                        out_paths_dwi.append(out_path)
                    else:
                        print(
                            "WARNING: Only one graph sampled, omnibus "
                            "embedding not appropriate."
                        )
                        pass
    else:
        print("At least two graphs required to build an omnetome...")
        out_paths_func = []
        out_paths_dwi = []
        pass

    return out_paths_dwi, out_paths_func
Esempio n. 22
0
    def _run_interface(self, runtime):
        import gc
        import os
        import time
        import os.path as op
        from dipy.io import load_pickle
        from colorama import Fore, Style
        from dipy.data import get_sphere
        from pynets.core import utils
        from pynets.core.utils import load_runconfig
        from pynets.dmri.estimation import reconstruction
        from pynets.dmri.track import (
            create_density_map,
            track_ensemble,
        )
        from dipy.io.stateful_tractogram import Space, StatefulTractogram, \
            Origin
        from dipy.io.streamline import save_tractogram
        from nipype.utils.filemanip import copyfile, fname_presuffix

        hardcoded_params = load_runconfig()
        use_life = hardcoded_params['tracking']["use_life"][0]
        roi_neighborhood_tol = hardcoded_params['tracking'][
            "roi_neighborhood_tol"][0]
        sphere = hardcoded_params['tracking']["sphere"][0]
        target_samples = hardcoded_params['tracking']["tracking_samples"][0]

        dir_path = utils.do_dir_path(self.inputs.atlas,
                                     os.path.dirname(self.inputs.dwi_file))

        namer_dir = "{}/tractography".format(dir_path)
        if not os.path.isdir(namer_dir):
            os.makedirs(namer_dir, exist_ok=True)

        # Load diffusion data
        dwi_file_tmp_path = fname_presuffix(self.inputs.dwi_file,
                                            suffix="_tmp",
                                            newpath=runtime.cwd)
        copyfile(self.inputs.dwi_file,
                 dwi_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        dwi_img = nib.load(dwi_file_tmp_path, mmap=True)
        dwi_data = dwi_img.get_fdata(dtype=np.float32)

        # Load FA data
        fa_file_tmp_path = fname_presuffix(self.inputs.fa_path,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.fa_path,
                 fa_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        fa_img = nib.load(fa_file_tmp_path, mmap=True)

        labels_im_file_tmp_path = fname_presuffix(self.inputs.labels_im_file,
                                                  suffix="_tmp",
                                                  newpath=runtime.cwd)
        copyfile(self.inputs.labels_im_file,
                 labels_im_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        # Load B0 mask
        B0_mask_tmp_path = fname_presuffix(self.inputs.B0_mask,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.B0_mask,
                 B0_mask_tmp_path,
                 copy=True,
                 use_hardlink=False)

        streams = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            runtime.cwd,
            "/streamlines_",
            "%s" % (self.inputs.subnet +
                    "_" if self.inputs.subnet is not None else ""),
            "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                    "_" if self.inputs.roi is not None else ""),
            self.inputs.conn_model,
            "_",
            target_samples,
            "_",
            "%s" % ("%s%s" % (self.inputs.node_radius, "mm_") if
                    ((self.inputs.node_radius != "parc") and
                     (self.inputs.node_radius is not None)) else "parc_"),
            "curv-",
            str(self.inputs.curv_thr_list).replace(", ", "_"),
            "_step-",
            str(self.inputs.step_list).replace(", ", "_"),
            "_traversal-",
            self.inputs.traversal,
            "_minlength-",
            self.inputs.min_length,
            ".trk",
        )

        if os.path.isfile(f"{namer_dir}/{op.basename(streams)}"):
            from dipy.io.streamline import load_tractogram
            copyfile(
                f"{namer_dir}/{op.basename(streams)}",
                streams,
                copy=True,
                use_hardlink=False,
            )
            tractogram = load_tractogram(
                streams,
                fa_img,
                bbox_valid_check=False,
            )

            streamlines = tractogram.streamlines

            # Create streamline density map
            try:
                [dir_path, dm_path] = create_density_map(
                    fa_img,
                    dir_path,
                    streamlines,
                    self.inputs.conn_model,
                    self.inputs.node_radius,
                    self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.subnet,
                    self.inputs.roi,
                    self.inputs.traversal,
                    self.inputs.min_length,
                    namer_dir,
                )
            except BaseException:
                print('Density map failed. Check tractography output.')
                dm_path = None

            del streamlines, tractogram
            fa_img.uncache()
            dwi_img.uncache()
            gc.collect()
            self._results["dm_path"] = dm_path
            self._results["streams"] = streams
            recon_path = None
        else:
            # Fit diffusion model
            # Save reconstruction to .npy
            recon_path = "%s%s%s%s%s%s%s%s" % (
                runtime.cwd,
                "/reconstruction_",
                "%s" % (self.inputs.subnet +
                        "_" if self.inputs.subnet is not None else ""),
                "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                        "_" if self.inputs.roi is not None else ""),
                self.inputs.conn_model,
                "_",
                "%s" % ("%s%s" % (self.inputs.node_radius, "mm") if
                        ((self.inputs.node_radius != "parc") and
                         (self.inputs.node_radius is not None)) else "parc"),
                ".hdf5",
            )

            gtab_file_tmp_path = fname_presuffix(self.inputs.gtab_file,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gtab_file,
                     gtab_file_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gtab = load_pickle(gtab_file_tmp_path)

            # Only re-run the reconstruction if we have to
            if not os.path.isfile(f"{namer_dir}/{op.basename(recon_path)}"):
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(2)
                del model
            elif os.path.getsize(f"{namer_dir}/{op.basename(recon_path)}") > 0:
                print(f"Found existing reconstruction with "
                      f"{self.inputs.conn_model}. Loading...")
                copyfile(
                    f"{namer_dir}/{op.basename(recon_path)}",
                    recon_path,
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
            else:
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
                del model
            dwi_img.uncache()
            del dwi_data

            # Load atlas wm-gm interface reduced version for seeding
            labels_im_file_tmp_path_wm_gm_int = fname_presuffix(
                self.inputs.labels_im_file_wm_gm_int,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.labels_im_file_wm_gm_int,
                     labels_im_file_tmp_path_wm_gm_int,
                     copy=True,
                     use_hardlink=False)

            t1w2dwi_tmp_path = fname_presuffix(self.inputs.t1w2dwi,
                                               suffix="_tmp",
                                               newpath=runtime.cwd)
            copyfile(self.inputs.t1w2dwi,
                     t1w2dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gm_in_dwi_tmp_path = fname_presuffix(self.inputs.gm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gm_in_dwi,
                     gm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            vent_csf_in_dwi_tmp_path = fname_presuffix(
                self.inputs.vent_csf_in_dwi,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.vent_csf_in_dwi,
                     vent_csf_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            wm_in_dwi_tmp_path = fname_presuffix(self.inputs.wm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.wm_in_dwi,
                     wm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            if self.inputs.waymask:
                waymask_tmp_path = fname_presuffix(self.inputs.waymask,
                                                   suffix="_tmp",
                                                   newpath=runtime.cwd)
                copyfile(self.inputs.waymask,
                         waymask_tmp_path,
                         copy=True,
                         use_hardlink=False)
            else:
                waymask_tmp_path = None

            # Iteratively build a list of streamlines for each ROI while
            # tracking
            print(f"{Fore.GREEN}Target streamlines per iteration: "
                  f"{Fore.BLUE} "
                  f"{target_samples}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Curvature threshold(s): {Fore.BLUE} "
                  f"{self.inputs.curv_thr_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Step size(s): {Fore.BLUE} "
                  f"{self.inputs.step_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Tracking type: {Fore.BLUE} "
                  f"{self.inputs.track_type}")
            print(Style.RESET_ALL)
            if self.inputs.traversal == "prob":
                print(f"{Fore.GREEN}Direction-getting type: {Fore.BLUE}"
                      f"Probabilistic")
            elif self.inputs.traversal == "clos":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Closest Peak")
            elif self.inputs.traversal == "det":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Deterministic Maximum")
            else:
                raise ValueError("Direction-getting type not recognized!")

            print(Style.RESET_ALL)

            # Commence Ensemble Tractography
            try:
                streamlines = track_ensemble(
                    target_samples, labels_im_file_tmp_path_wm_gm_int,
                    labels_im_file_tmp_path, recon_path, get_sphere(sphere),
                    self.inputs.traversal, self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.track_type, self.inputs.maxcrossing,
                    int(roi_neighborhood_tol), self.inputs.min_length,
                    waymask_tmp_path, B0_mask_tmp_path, t1w2dwi_tmp_path,
                    gm_in_dwi_tmp_path, vent_csf_in_dwi_tmp_path,
                    wm_in_dwi_tmp_path, self.inputs.tiss_class)
                gc.collect()
            except BaseException as w:
                print(f"\n{Fore.RED}Tractography failed: {w}")
                print(Style.RESET_ALL)
                streamlines = None

            if streamlines is not None:
                # import multiprocessing
                # from pynets.core.utils import kill_process_family
                # return kill_process_family(int(
                # multiprocessing.current_process().pid))

                # Linear Fascicle Evaluation (LiFE)
                if use_life is True:
                    print('Using LiFE to evaluate streamline plausibility...')
                    from pynets.dmri.utils import \
                        evaluate_streamline_plausibility
                    dwi_img = nib.load(dwi_file_tmp_path)
                    dwi_data = dwi_img.get_fdata(dtype=np.float32)
                    orig_count = len(streamlines)

                    if self.inputs.waymask:
                        mask_data = nib.load(waymask_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    else:
                        mask_data = nib.load(wm_in_dwi_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    try:
                        streamlines = evaluate_streamline_plausibility(
                            dwi_data,
                            gtab,
                            mask_data,
                            streamlines,
                            sphere=sphere)
                    except BaseException:
                        print(f"Linear Fascicle Evaluation failed. "
                              f"Visually checking streamlines output "
                              f"{namer_dir}/{op.basename(streams)} is "
                              f"recommended.")
                    if len(streamlines) < 0.5 * orig_count:
                        raise ValueError('LiFE revealed no plausible '
                                         'streamlines in the tractogram!')
                    del dwi_data, mask_data

                # Save streamlines to trk
                stf = StatefulTractogram(streamlines,
                                         fa_img,
                                         origin=Origin.NIFTI,
                                         space=Space.VOXMM)
                stf.remove_invalid_streamlines()

                save_tractogram(
                    stf,
                    streams,
                )

                del stf

                copyfile(
                    streams,
                    f"{namer_dir}/{op.basename(streams)}",
                    copy=True,
                    use_hardlink=False,
                )

                # Create streamline density map
                try:
                    [dir_path, dm_path] = create_density_map(
                        dwi_img,
                        dir_path,
                        streamlines,
                        self.inputs.conn_model,
                        self.inputs.node_radius,
                        self.inputs.curv_thr_list,
                        self.inputs.step_list,
                        self.inputs.subnet,
                        self.inputs.roi,
                        self.inputs.traversal,
                        self.inputs.min_length,
                        namer_dir,
                    )
                except BaseException:
                    print('Density map failed. Check tractography output.')
                    dm_path = None

                del streamlines
                dwi_img.uncache()
                gc.collect()
                self._results["dm_path"] = dm_path
                self._results["streams"] = streams
            else:
                self._results["streams"] = None
                self._results["dm_path"] = None
            tmp_files = [
                gtab_file_tmp_path, wm_in_dwi_tmp_path, gm_in_dwi_tmp_path,
                vent_csf_in_dwi_tmp_path, t1w2dwi_tmp_path
            ]

            for j in tmp_files:
                if j is not None:
                    if os.path.isfile(j):
                        os.system(f"rm -f {j} &")

        self._results["track_type"] = self.inputs.track_type
        self._results["conn_model"] = self.inputs.conn_model
        self._results["dir_path"] = dir_path
        self._results["subnet"] = self.inputs.subnet
        self._results["node_radius"] = self.inputs.node_radius
        self._results["dens_thresh"] = self.inputs.dens_thresh
        self._results["ID"] = self.inputs.ID
        self._results["roi"] = self.inputs.roi
        self._results["min_span_tree"] = self.inputs.min_span_tree
        self._results["disp_filt"] = self.inputs.disp_filt
        self._results["parc"] = self.inputs.parc
        self._results["prune"] = self.inputs.prune
        self._results["atlas"] = self.inputs.atlas
        self._results["parcellation"] = self.inputs.parcellation
        self._results["labels"] = self.inputs.labels
        self._results["coords"] = self.inputs.coords
        self._results["norm"] = self.inputs.norm
        self._results["binary"] = self.inputs.binary
        self._results["atlas_t1w"] = self.inputs.atlas_t1w
        self._results["curv_thr_list"] = self.inputs.curv_thr_list
        self._results["step_list"] = self.inputs.step_list
        self._results["fa_path"] = fa_file_tmp_path
        self._results["traversal"] = self.inputs.traversal
        self._results["labels_im_file"] = labels_im_file_tmp_path
        self._results["min_length"] = self.inputs.min_length

        tmp_files = [B0_mask_tmp_path, dwi_file_tmp_path]

        for j in tmp_files:
            if j is not None:
                if os.path.isfile(j):
                    os.system(f"rm -f {j} &")

        # Exercise caution when deleting copied recon_path
        # if recon_path is not None:
        #     if os.path.isfile(recon_path):
        #         os.remove(recon_path)

        return runtime
Esempio n. 23
0
    def _run_interface(self, runtime):
        import os
        import gc
        import time
        import nibabel as nib
        from pynets.core.utils import load_runconfig
        from nipype.utils.filemanip import fname_presuffix, copyfile
        from pynets.fmri import clustering
        from pynets.registration.utils import orient_reslice
        from joblib import Parallel, delayed
        from joblib.externals.loky.backend import resource_tracker
        from pynets.registration import utils as regutils
        from pynets.core.utils import decompress_nifti
        import pkg_resources
        import shutil
        import tempfile
        resource_tracker.warnings = None

        template = pkg_resources.resource_filename(
            "pynets", f"templates/standard/{self.inputs.template_name}_brain_"
            f"{self.inputs.vox_size}.nii.gz")

        template_tmp_path = fname_presuffix(template,
                                            suffix="_tmp",
                                            newpath=runtime.cwd)
        copyfile(template, template_tmp_path, copy=True, use_hardlink=False)

        hardcoded_params = load_runconfig()

        c_boot = hardcoded_params["c_boot"][0]
        nthreads = hardcoded_params["omp_threads"][0]

        clust_list = ["kmeans", "ward", "complete", "average", "ncut", "rena"]

        clust_mask_temp_path = orient_reslice(self.inputs.clust_mask,
                                              runtime.cwd,
                                              self.inputs.vox_size)
        cm_suf = os.path.basename(self.inputs.clust_mask).split('.nii')[0]
        clust_mask_in_t1w_path = f"{runtime.cwd}/clust_mask-" \
                                 f"{cm_suf}_in_t1w.nii.gz"

        t1w_brain_tmp_path = fname_presuffix(self.inputs.t1w_brain,
                                             suffix="_tmp",
                                             newpath=runtime.cwd)
        copyfile(self.inputs.t1w_brain,
                 t1w_brain_tmp_path,
                 copy=True,
                 use_hardlink=False)

        mni2t1w_warp_tmp_path = fname_presuffix(self.inputs.mni2t1w_warp,
                                                suffix="_tmp",
                                                newpath=runtime.cwd)
        copyfile(
            self.inputs.mni2t1w_warp,
            mni2t1w_warp_tmp_path,
            copy=True,
            use_hardlink=False,
        )

        mni2t1_xfm_tmp_path = fname_presuffix(self.inputs.mni2t1_xfm,
                                              suffix="_tmp",
                                              newpath=runtime.cwd)
        copyfile(self.inputs.mni2t1_xfm,
                 mni2t1_xfm_tmp_path,
                 copy=True,
                 use_hardlink=False)

        clust_mask_in_t1w = regutils.roi2t1w_align(
            clust_mask_temp_path,
            t1w_brain_tmp_path,
            mni2t1_xfm_tmp_path,
            mni2t1w_warp_tmp_path,
            clust_mask_in_t1w_path,
            template_tmp_path,
            self.inputs.simple,
        )
        time.sleep(0.5)

        if self.inputs.mask:
            out_name_mask = fname_presuffix(self.inputs.mask,
                                            suffix="_tmp",
                                            newpath=runtime.cwd)
            copyfile(self.inputs.mask,
                     out_name_mask,
                     copy=True,
                     use_hardlink=False)
        else:
            out_name_mask = None

        out_name_func_file = fname_presuffix(self.inputs.func_file,
                                             suffix="_tmp",
                                             newpath=runtime.cwd)
        copyfile(self.inputs.func_file,
                 out_name_func_file,
                 copy=True,
                 use_hardlink=False)
        out_name_func_file = decompress_nifti(out_name_func_file)

        if self.inputs.conf:
            out_name_conf = fname_presuffix(self.inputs.conf,
                                            suffix="_tmp",
                                            newpath=runtime.cwd)
            copyfile(self.inputs.conf,
                     out_name_conf,
                     copy=True,
                     use_hardlink=False)
        else:
            out_name_conf = None

        nip = clustering.NiParcellate(
            func_file=out_name_func_file,
            clust_mask=clust_mask_in_t1w,
            k=int(self.inputs.k),
            clust_type=self.inputs.clust_type,
            local_corr=self.inputs.local_corr,
            outdir=self.inputs.outdir,
            conf=out_name_conf,
            mask=out_name_mask,
        )

        atlas = nip.create_clean_mask()
        nip.create_local_clustering(overwrite=True, r_thresh=0.4)

        if self.inputs.clust_type in clust_list:
            if float(c_boot) > 1:
                import random
                from joblib import Memory
                from joblib.externals.loky import get_reusable_executor
                print(f"Performing circular block bootstrapping with {c_boot}"
                      f" iterations...")
                ts_data, block_size = nip.prep_boot()

                cache_dir = tempfile.mkdtemp()
                memory = Memory(cache_dir, verbose=0)
                ts_data = memory.cache(ts_data)

                def create_bs_imgs(ts_data, block_size, clust_mask_corr_img):
                    import nibabel as nib
                    from nilearn.masking import unmask
                    from pynets.fmri.estimation import timeseries_bootstrap
                    boot_series = timeseries_bootstrap(
                        ts_data.func, block_size)[0].astype('float32')
                    return unmask(boot_series, clust_mask_corr_img)

                def run_bs_iteration(i, ts_data, work_dir, local_corr,
                                     clust_type, _local_conn_mat_path,
                                     num_conn_comps, _clust_mask_corr_img,
                                     _standardize, _detrending, k, _local_conn,
                                     conf, _dir_path, _conn_comps):
                    import os
                    import time
                    import gc
                    from pynets.fmri.clustering import parcellate
                    print(f"\nBootstrapped iteration: {i}")
                    out_path = f"{work_dir}/boot_parc_tmp_{str(i)}.nii.gz"

                    boot_img = create_bs_imgs(ts_data, block_size,
                                              _clust_mask_corr_img)
                    try:
                        parcellation = parcellate(
                            boot_img, local_corr, clust_type,
                            _local_conn_mat_path, num_conn_comps,
                            _clust_mask_corr_img, _standardize, _detrending, k,
                            _local_conn, conf, _dir_path, _conn_comps)
                        parcellation.to_filename(out_path)
                        parcellation.uncache()
                        boot_img.uncache()
                        gc.collect()
                    except BaseException:
                        boot_img.uncache()
                        gc.collect()
                        return None
                    _clust_mask_corr_img.uncache()
                    return out_path

                time.sleep(random.randint(1, 5))
                counter = 0
                boot_parcellations = []
                while float(counter) < float(c_boot):
                    with Parallel(n_jobs=nthreads,
                                  max_nbytes='8000M',
                                  backend='loky',
                                  mmap_mode='r+',
                                  temp_folder=cache_dir,
                                  verbose=10) as parallel:
                        iter_bootedparcels = parallel(
                            delayed(run_bs_iteration)
                            (i, ts_data, runtime.cwd, nip.local_corr,
                             nip.clust_type, nip._local_conn_mat_path,
                             nip.num_conn_comps, nip._clust_mask_corr_img,
                             nip._standardize, nip._detrending, nip.k,
                             nip._local_conn, nip.conf, nip._dir_path,
                             nip._conn_comps) for i in range(c_boot))

                        boot_parcellations.extend(
                            [i for i in iter_bootedparcels if i is not None])
                        counter = len(boot_parcellations)
                        del iter_bootedparcels
                        gc.collect()

                print('Bootstrapped samples complete:')
                print(boot_parcellations)
                print("Creating spatially-constrained consensus "
                      "parcellation...")
                consensus_parcellation = clustering.ensemble_parcellate(
                    boot_parcellations, int(self.inputs.k))
                nib.save(consensus_parcellation, nip.parcellation)
                memory.clear(warn=False)
                shutil.rmtree(cache_dir, ignore_errors=True)
                del parallel, memory, cache_dir
                get_reusable_executor().shutdown(wait=True)
                gc.collect()

                for i in boot_parcellations:
                    if i is not None:
                        if os.path.isfile(i):
                            os.system(f"rm -f {i} &")
            else:
                print("Creating spatially-constrained parcellation...")
                out_path = f"{runtime.cwd}/{atlas}_{str(self.inputs.k)}.nii.gz"
                func_img = nib.load(out_name_func_file)
                parcellation = clustering.parcellate(
                    func_img, self.inputs.local_corr, self.inputs.clust_type,
                    nip._local_conn_mat_path, nip.num_conn_comps,
                    nip._clust_mask_corr_img, nip._standardize,
                    nip._detrending, nip.k, nip._local_conn, nip.conf,
                    nip._dir_path, nip._conn_comps)
                parcellation.to_filename(out_path)

        else:
            raise ValueError("Clustering method not recognized. See: "
                             "https://nilearn.github.io/modules/generated/"
                             "nilearn.regions.Parcellations."
                             "html#nilearn.regions.Parcellations")

        # Give it a minute
        ix = 0
        while not os.path.isfile(nip.parcellation) and ix < 60:
            print('Waiting for clustered parcellation...')
            time.sleep(1)
            ix += 1

        if not os.path.isfile(nip.parcellation):
            raise FileNotFoundError(f"Parcellation clustering failed for"
                                    f" {nip.parcellation}")

        self._results["atlas"] = atlas
        self._results["parcellation"] = nip.parcellation
        self._results["clust_mask"] = clust_mask_in_t1w_path
        self._results["k"] = self.inputs.k
        self._results["clust_type"] = self.inputs.clust_type
        self._results["clustering"] = True
        self._results["func_file"] = self.inputs.func_file

        reg_tmp = [
            t1w_brain_tmp_path, mni2t1w_warp_tmp_path, mni2t1_xfm_tmp_path,
            template_tmp_path, out_name_func_file
        ]
        for j in reg_tmp:
            if j is not None:
                if os.path.isfile(j):
                    os.system(f"rm -f {j} &")

        gc.collect()

        return runtime
Esempio n. 24
0
def build_multigraphs(est_path_iterlist):
    """
    Constructs a multimodal multigraph for each available resolution of
    vertices.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy file containing graph.

    Returns
    -------
    multigraph_list_all : list
        List of multiplex graph dictionaries corresponding to
        each unique node resolution.
    graph_path_list_top : list
        List of lists consisting of pairs of most similar
        structural and functional connectomes for each unique node resolution.

    References
    ----------
    .. [1] Bullmore, E., & Sporns, O. (2009). Complex brain networks: Graph
      theoretical analysis of structural and functional systems.
      Nature Reviews Neuroscience. https://doi.org/10.1038/nrn2575
    .. [2] Vaiana, M., & Muldoon, S. F. (2018). Multilayer Brain Networks.
      Journal of Nonlinear Science. https://doi.org/10.1007/s00332-017-9436-8

    """
    import os
    import itertools
    import numpy as np
    from pathlib import Path
    from pynets.statistics.individual.multiplex import matching
    from pynets.core.utils import flatten, load_runconfig

    raw_est_path_iterlist = list(flatten(est_path_iterlist))

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        func_models = hardcoded_params["available_models"]["func_models"]
    except KeyError:
        print("ERROR: available functional models not sucessfully extracted"
              " from advanced.yaml")
    try:
        dwi_models = hardcoded_params["available_models"]["dwi_models"]
    except KeyError:
        print("ERROR: available structural models not sucessfully extracted"
              " from advanced.yaml")

    atlases = list(
        set([x.split("/")[-3].split("/")[0] for x in raw_est_path_iterlist]))
    parcel_dict_func = dict.fromkeys(atlases)
    parcel_dict_dwi = dict.fromkeys(atlases)
    est_path_iterlist_dwi = list(
        set([i for i in raw_est_path_iterlist if "dwi" in i]))
    est_path_iterlist_func = list(
        set([i for i in raw_est_path_iterlist if "func" in i]))

    if "_subnet" in ";".join(est_path_iterlist_func):
        func_subnets = list(
            set([
                i.split("_subnet-")[1].split("_")[0]
                for i in est_path_iterlist_func
            ]))
    else:
        func_subnets = []
    if "_subnet" in ";".join(est_path_iterlist_dwi):
        dwi_subnets = list(
            set([
                i.split("_subnet-")[1].split("_")[0]
                for i in est_path_iterlist_dwi
            ]))
    else:
        dwi_subnets = []

    dir_path = str(
        Path(os.path.dirname(est_path_iterlist_dwi[0])).parent.parent.parent)
    namer_dir = f"{dir_path}/dwi-func"
    if not os.path.isdir(namer_dir):
        os.mkdir(namer_dir)

    multigraph_list_all = []
    graph_path_list_all = []
    for atlas in atlases:
        if len(func_subnets) >= 1:
            parcel_dict_func[atlas] = {}
            for sub_net in func_subnets:
                parcel_dict_func[atlas][sub_net] = []
        else:
            parcel_dict_func[atlas] = []

        if len(dwi_subnets) >= 1:
            parcel_dict_dwi[atlas] = {}
            for sub_net in dwi_subnets:
                parcel_dict_dwi[atlas][sub_net] = []
        else:
            parcel_dict_dwi[atlas] = []

        for graph_path in est_path_iterlist_dwi:
            if atlas in graph_path:
                if len(dwi_subnets) >= 1:
                    for sub_net in dwi_subnets:
                        if sub_net in graph_path:
                            parcel_dict_dwi[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_dwi[atlas].append(graph_path)

        for graph_path in est_path_iterlist_func:
            if atlas in graph_path:
                if len(func_subnets) >= 1:
                    for sub_net in func_subnets:
                        if sub_net in graph_path:
                            parcel_dict_func[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_func[atlas].append(graph_path)

        parcel_dict = {}
        # Create dictionary of all possible pairs of structural-functional
        # graphs for each unique resolution of vertices
        if len(dwi_subnets) >= 1 and len(func_subnets) >= 1:
            parcel_dict[atlas] = {}
            subnets = np.intersect1d(dwi_subnets, func_subnets).tolist()
            for subnet in subnets:
                parcel_dict[atlas][subnet] = list(
                    set(
                        itertools.product(parcel_dict_dwi[atlas][subnet],
                                          parcel_dict_func[atlas][subnet])))
                for paths in list(parcel_dict[atlas][subnet]):
                    [mG_nx, mG, out_dwi_mat, out_func_mat] = matching(
                        paths,
                        atlas,
                        namer_dir,
                    )
        else:
            parcel_dict[atlas] = list(
                set(
                    itertools.product(parcel_dict_dwi[atlas],
                                      parcel_dict_func[atlas])))
            for paths in list(parcel_dict[atlas]):
                [mG_nx, mG, out_dwi_mat, out_func_mat] = matching(
                    paths,
                    atlas,
                    namer_dir,
                )
        multigraph_list_all.append((mG_nx, mG))
        graph_path_list_all.append((out_dwi_mat, out_func_mat))

    return (
        multigraph_list_all,
        graph_path_list_all,
        len(multigraph_list_all) * [namer_dir],
    )