示例#1
0
def tens_mod_fa_est(gtab_file, dwi_file, B0_mask):
    """
    Estimate a tensor FA image to use for registrations.

    Parameters
    ----------
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted image.
    B0_mask : str
        File path to B0 brain mask.

    Returns
    -------
    fa_path : str
        File path to FA Nifti1Image.
    B0_mask : str
        File path to B0 brain mask Nifti1Image.
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted Nifti1Image.
    fa_md_path : str
        File path to FA/MD mask Nifti1Image.
    """
    import os
    from dipy.io import load_pickle
    from dipy.reconst.dti import TensorModel
    from dipy.reconst.dti import fractional_anisotropy, mean_diffusivity

    gtab = load_pickle(gtab_file)

    data = nib.load(dwi_file).get_fdata()

    print("Generating tensor FA image to use for registrations...")
    nodif_B0_img = nib.load(B0_mask)
    nodif_B0_mask_data = np.nan_to_num(np.asarray(
        nodif_B0_img.dataobj)).astype("bool")
    model = TensorModel(gtab)
    mod = model.fit(data, nodif_B0_mask_data)
    FA = fractional_anisotropy(mod.evals)
    MD = mean_diffusivity(mod.evals)
    FA_MD = np.logical_or(FA >= 0.2,
                          (np.logical_and(FA >= 0.08, MD >= 0.0011)))
    FA[np.isnan(FA)] = 0
    FA_MD[np.isnan(FA_MD)] = 0

    fa_path = f"{os.path.dirname(B0_mask)}{'/tensor_fa.nii.gz'}"
    nib.save(nib.Nifti1Image(FA.astype(np.float32), nodif_B0_img.affine),
             fa_path)

    fa_md_path = f"{os.path.dirname(B0_mask)}{'/tensor_fa_md.nii.gz'}"
    nib.save(nib.Nifti1Image(FA_MD.astype(np.float32), nodif_B0_img.affine),
             fa_md_path)

    nodif_B0_img.uncache()
    del FA, FA_MD

    return fa_path, B0_mask, gtab_file, dwi_file
示例#2
0
def tens_mod_fa_est(gtab_file, dwi_file, B0_mask):
    '''
    Estimate a tensor FA image to use for registrations.

    Parameters
    ----------
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted image.
    B0_mask : str
        File path to B0 brain mask.

    Returns
    -------
    fa_path : str
        File path to FA Nifti1Image.
    B0_mask : str
        File path to B0 brain mask Nifti1Image.
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted Nifti1Image.
    '''
    import os
    from dipy.io import load_pickle
    from dipy.reconst.dti import TensorModel
    from dipy.reconst.dti import fractional_anisotropy

    data = nib.load(dwi_file).get_fdata()
    gtab = load_pickle(gtab_file)

    print('Generating simple tensor FA image to use for registrations...')
    nodif_B0_img = nib.load(B0_mask)
    B0_mask_data = nodif_B0_img.get_fdata().astype('bool')
    nodif_B0_affine = nodif_B0_img.affine
    model = TensorModel(gtab)
    mod = model.fit(data, B0_mask_data)
    FA = fractional_anisotropy(mod.evals)
    FA[np.isnan(FA)] = 0
    fa_img = nib.Nifti1Image(FA.astype(np.float32), nodif_B0_affine)
    fa_path = "%s%s" % (os.path.dirname(B0_mask), '/tensor_fa.nii.gz')
    nib.save(fa_img, fa_path)
    return fa_path, B0_mask, gtab_file, dwi_file
示例#3
0
def tens_mod_fa_est(gtab_file, dwi_file, nodif_B0_mask):
    import os
    from dipy.io import load_pickle
    from dipy.reconst.dti import TensorModel
    from dipy.reconst.dti import fractional_anisotropy

    data = nib.load(dwi_file).get_fdata()
    gtab = load_pickle(gtab_file)

    print('Generating simple tensor FA image to use for registrations...')
    nodif_B0_img = nib.load(nodif_B0_mask)
    nodif_B0_mask_data = nodif_B0_img.get_fdata().astype('bool')
    nodif_B0_affine = nodif_B0_img.affine
    model = TensorModel(gtab)
    mod = model.fit(data, nodif_B0_mask_data)
    FA = fractional_anisotropy(mod.evals)
    FA[np.isnan(FA)] = 0
    fa_img = nib.Nifti1Image(FA.astype(np.float32), nodif_B0_affine)
    fa_path = "%s%s" % (os.path.dirname(nodif_B0_mask), '/tensor_fa.nii.gz')
    nib.save(fa_img, fa_path)
    return fa_path, nodif_B0_mask, gtab_file, dwi_file
示例#4
0
def test_evaluate_streamline_plausibility():
    """
    Test evaluate_streamline_plausibility functionality
    """
    import nibabel as nib
    from pynets.dmri.dmri_utils import evaluate_streamline_plausibility
    from dipy.io.stateful_tractogram import Space, Origin
    from dipy.io.streamline import load_tractogram
    from dipy.io import load_pickle

    base_dir = str(Path(__file__).parent / "examples")
    gtab_file = f"{base_dir}/miscellaneous/tractography/gtab.pkl"
    dwi_path = f"{base_dir}/miscellaneous/tractography/sub-OAS31172_" \
               f"ses-d0407_dwi_reor-RAS_res-2mm.nii.gz"
    B0_mask = f"{base_dir}/miscellaneous/tractography/mean_B0_bet_mask.nii.gz"
    streams = f"{base_dir}/miscellaneous/tractography/streamlines_csa_" \
              f"20000_parc_curv-[40_30]_step-[0.1_0.2_0.3_0.4_0.5]_" \
              f"directget-prob_minlength-20.trk"

    gtab = load_pickle(gtab_file)
    dwi_img = nib.load(dwi_path)
    dwi_data = dwi_img.get_fdata()
    B0_mask_img = nib.load(B0_mask)
    B0_mask_data = B0_mask_img.get_fdata()
    tractogram = load_tractogram(
        streams,
        B0_mask_img,
        to_origin=Origin.NIFTI,
        to_space=Space.VOXMM,
        bbox_valid_check=False,
    )
    streamlines = tractogram.streamlines
    cleaned = evaluate_streamline_plausibility(dwi_data, gtab, B0_mask_data,
                                               streamlines)

    assert len(cleaned) > 0
    assert len(cleaned) <= len(streamlines)
示例#5
0
def create_anisopowermap(gtab_file, dwi_file, B0_mask):
    """
    Estimate an anisotropic power map image to use for registrations.

    Parameters
    ----------
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted image.
    B0_mask : str
        File path to B0 brain mask.

    Returns
    -------
    anisopwr_path : str
        File path to the anisotropic power Nifti1Image.
    B0_mask : str
        File path to B0 brain mask Nifti1Image.
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted Nifti1Image.

    References
    ----------
    .. [1] Chen, D. Q., Dell’Acqua, F., Rokem, A., Garyfallidis, E., Hayes, D.,
      Zhong, J., & Hodaie, M. (2018). Diffusion Weighted Image Co-registration:
      Investigation of Best Practices. PLoS ONE.

    """
    import os
    from dipy.io import load_pickle
    from dipy.reconst.shm import anisotropic_power
    from dipy.core.sphere import HemiSphere, Sphere
    from dipy.reconst.shm import sf_to_sh

    gtab = load_pickle(gtab_file)

    dwi_vertices = gtab.bvecs[np.where(gtab.b0s_mask == False)]

    gtab_hemisphere = HemiSphere(xyz=gtab.bvecs[np.where(
        gtab.b0s_mask == False)])

    try:
        assert len(gtab_hemisphere.vertices) == len(dwi_vertices)
    except BaseException:
        gtab_hemisphere = Sphere(xyz=gtab.bvecs[np.where(
            gtab.b0s_mask == False)])

    img = nib.load(dwi_file)
    aff = img.affine

    anisopwr_path = f"{os.path.dirname(B0_mask)}{'/aniso_power.nii.gz'}"

    if os.path.isfile(anisopwr_path):
        pass
    else:
        print("Generating anisotropic power map to use for registrations...")
        nodif_B0_img = nib.load(B0_mask)

        dwi_data = np.asarray(img.dataobj, dtype=np.float32)
        for b0 in sorted(list(np.where(gtab.b0s_mask)[0]), reverse=True):
            dwi_data = np.delete(dwi_data, b0, 3)

        anisomap = anisotropic_power(
            sf_to_sh(dwi_data, gtab_hemisphere, sh_order=2))
        anisomap[np.isnan(anisomap)] = 0
        masked_data = anisomap * \
            np.asarray(nodif_B0_img.dataobj).astype("bool")
        img = nib.Nifti1Image(masked_data.astype(np.float32), aff)
        img.to_filename(anisopwr_path)
        nodif_B0_img.uncache()
        del anisomap

    return anisopwr_path, B0_mask, gtab_file, dwi_file
示例#6
0
文件: track.py 项目: devhliu/PyNets
def run_track(B0_mask,
              gm_in_dwi,
              vent_csf_in_dwi,
              wm_in_dwi,
              tiss_class,
              labels_im_file_wm_gm_int,
              labels_im_file,
              target_samples,
              curv_thr_list,
              step_list,
              track_type,
              max_length,
              maxcrossing,
              directget,
              conn_model,
              gtab_file,
              dwi_file,
              network,
              node_size,
              dens_thresh,
              ID,
              roi,
              min_span_tree,
              disp_filt,
              parc,
              prune,
              atlas,
              uatlas,
              labels,
              coords,
              norm,
              binary,
              atlas_mni,
              min_length,
              fa_path,
              waymask,
              roi_neighborhood_tol=8,
              sphere='repulsion724'):
    """
    Run all ensemble tractography and filtering routines.

    Parameters
    ----------
    B0_mask : str
        File path to B0 brain mask.
    gm_in_dwi : str
        File path to grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : str
        File path to ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : str
        File path to white-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    labels_im_file_wm_gm_int : str
        File path to atlas parcellation Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface.
    labels_im_file : str
        File path to atlas parcellation Nifti1Image in T1w-warped native diffusion space.
    target_samples : int
        Total number of streamline samples specified to generate streams.
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    max_length : int
        Maximum fiber length threshold in mm to restrict tracking.
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel while tracking.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted image.
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based centroids
        are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    min_length : int
        Minimum fiber length threshold in mm.
    fa_path : str
        File path to FA Nifti1Image.
    waymask : str
        Path to a Nifti1Image in native diffusion space to constrain tractography.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel. Default is 10 mm.
    sphere : str
        Provide triangulated spheres. Default is repulsion724. Options are:
        `symmetric362`, `symmetric642`, `symmetric724`, `repulsion724`, `repulsion100`, or `repulsion200`

    Returns
    -------
    streams : str
        File path to save streamline array sequence in .trk format.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    dir_path : str
        Path to directory containing subject derivative data for a given pynets run.
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based centroids
        are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    fa_path : str
        File path to FA Nifti1Image.
    dm_path : str
        File path to fiber density map Nifti1Image.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    max_length : int
        Maximum fiber length threshold in mm to restrict tracking.
    """
    import gc
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.io import load_pickle
    from colorama import Fore, Style
    from dipy.data import get_sphere
    from pynets.core import utils
    from pynets.dmri.track import prep_tissues, reconstruction, create_density_map, track_ensemble

    # Load diffusion data
    dwi_img = nib.load(dwi_file)
    dwi_data = dwi_img.get_fdata()

    # Fit diffusion model
    mod_fit = reconstruction(conn_model, load_pickle(gtab_file), dwi_data,
                             B0_mask)

    # Load atlas parcellation (and its wm-gm interface reduced version for seeding)
    atlas_data = nib.load(labels_im_file).get_fdata().astype('uint16')
    atlas_data_wm_gm_int = nib.load(
        labels_im_file_wm_gm_int).get_fdata().astype('uint16')

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    for roi_val in np.unique(atlas_data)[1:]:
        parcels.append(atlas_data == roi_val)
        i = i + 1

    if np.sum(atlas_data) == 0:
        raise ValueError(
            'ERROR: No non-zero voxels found in atlas. Check any roi masks and/or wm-gm interface images '
            'to verify overlap with dwi-registered atlas.')

    # Iteratively build a list of streamlines for each ROI while tracking
    print(
        "%s%s%s%s" %
        (Fore.GREEN, 'Target number of samples: ', Fore.BLUE, target_samples))
    print(Style.RESET_ALL)
    print("%s%s%s%s" % (Fore.GREEN, 'Using curvature threshold(s): ',
                        Fore.BLUE, curv_thr_list))
    print(Style.RESET_ALL)
    print("%s%s%s%s" %
          (Fore.GREEN, 'Using step size(s): ', Fore.BLUE, step_list))
    print(Style.RESET_ALL)
    print("%s%s%s%s" % (Fore.GREEN, 'Tracking type: ', Fore.BLUE, track_type))
    print(Style.RESET_ALL)
    if directget == 'prob':
        print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE,
                            'Probabilistic'))
    elif directget == 'boot':
        print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE,
                            'Bootstrapped'))
    elif directget == 'closest':
        print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE,
                            'Closest Peak'))
    elif directget == 'det':
        print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE,
                            'Deterministic Maximum'))
    print(Style.RESET_ALL)

    # Commence Ensemble Tractography
    streamlines = track_ensemble(
        dwi_data, target_samples, atlas_data_wm_gm_int, parcels, mod_fit,
        prep_tissues(B0_mask,
                     gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class),
        get_sphere(sphere), directget, curv_thr_list, step_list, track_type,
        maxcrossing, max_length, roi_neighborhood_tol, min_length, waymask)
    print('Tracking Complete')

    # Create streamline density map
    [streams, dir_path,
     dm_path] = create_density_map(dwi_img, utils.do_dir_path(atlas, dwi_file),
                                   streamlines, conn_model, target_samples,
                                   node_size, curv_thr_list, step_list,
                                   network, roi, directget, max_length)

    del streamlines, dwi_data, atlas_data_wm_gm_int, atlas_data, mod_fit, parcels
    dwi_img.uncache()

    gc.collect()

    return streams, track_type, target_samples, conn_model, dir_path, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary, atlas_mni, curv_thr_list, step_list, fa_path, dm_path, directget, labels_im_file, roi_neighborhood_tol, max_length
示例#7
0
 def transform_affine(fname, fix):
     out_ = dio.load_pickle(fname)
     return out_.transform(fix)
示例#8
0
    def _run_interface(self, runtime):
        import gc
        import os
        import time
        import os.path as op
        from dipy.io import load_pickle
        from colorama import Fore, Style
        from dipy.data import get_sphere
        from pynets.core import utils
        from pynets.core.utils import load_runconfig
        from pynets.dmri.estimation import reconstruction
        from pynets.dmri.track import (
            create_density_map,
            track_ensemble,
        )
        from dipy.io.stateful_tractogram import Space, StatefulTractogram, \
            Origin
        from dipy.io.streamline import save_tractogram
        from nipype.utils.filemanip import copyfile, fname_presuffix

        hardcoded_params = load_runconfig()
        use_life = hardcoded_params['tracking']["use_life"][0]
        roi_neighborhood_tol = hardcoded_params['tracking'][
            "roi_neighborhood_tol"][0]
        sphere = hardcoded_params['tracking']["sphere"][0]
        target_samples = hardcoded_params['tracking']["tracking_samples"][0]

        dir_path = utils.do_dir_path(self.inputs.atlas,
                                     os.path.dirname(self.inputs.dwi_file))

        namer_dir = "{}/tractography".format(dir_path)
        if not os.path.isdir(namer_dir):
            os.makedirs(namer_dir, exist_ok=True)

        # Load diffusion data
        dwi_file_tmp_path = fname_presuffix(self.inputs.dwi_file,
                                            suffix="_tmp",
                                            newpath=runtime.cwd)
        copyfile(self.inputs.dwi_file,
                 dwi_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        dwi_img = nib.load(dwi_file_tmp_path, mmap=True)
        dwi_data = dwi_img.get_fdata(dtype=np.float32)

        # Load FA data
        fa_file_tmp_path = fname_presuffix(self.inputs.fa_path,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.fa_path,
                 fa_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        fa_img = nib.load(fa_file_tmp_path, mmap=True)

        labels_im_file_tmp_path = fname_presuffix(self.inputs.labels_im_file,
                                                  suffix="_tmp",
                                                  newpath=runtime.cwd)
        copyfile(self.inputs.labels_im_file,
                 labels_im_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        # Load B0 mask
        B0_mask_tmp_path = fname_presuffix(self.inputs.B0_mask,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.B0_mask,
                 B0_mask_tmp_path,
                 copy=True,
                 use_hardlink=False)

        streams = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            runtime.cwd,
            "/streamlines_",
            "%s" % (self.inputs.subnet +
                    "_" if self.inputs.subnet is not None else ""),
            "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                    "_" if self.inputs.roi is not None else ""),
            self.inputs.conn_model,
            "_",
            target_samples,
            "_",
            "%s" % ("%s%s" % (self.inputs.node_radius, "mm_") if
                    ((self.inputs.node_radius != "parc") and
                     (self.inputs.node_radius is not None)) else "parc_"),
            "curv-",
            str(self.inputs.curv_thr_list).replace(", ", "_"),
            "_step-",
            str(self.inputs.step_list).replace(", ", "_"),
            "_traversal-",
            self.inputs.traversal,
            "_minlength-",
            self.inputs.min_length,
            ".trk",
        )

        if os.path.isfile(f"{namer_dir}/{op.basename(streams)}"):
            from dipy.io.streamline import load_tractogram
            copyfile(
                f"{namer_dir}/{op.basename(streams)}",
                streams,
                copy=True,
                use_hardlink=False,
            )
            tractogram = load_tractogram(
                streams,
                fa_img,
                bbox_valid_check=False,
            )

            streamlines = tractogram.streamlines

            # Create streamline density map
            try:
                [dir_path, dm_path] = create_density_map(
                    fa_img,
                    dir_path,
                    streamlines,
                    self.inputs.conn_model,
                    self.inputs.node_radius,
                    self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.subnet,
                    self.inputs.roi,
                    self.inputs.traversal,
                    self.inputs.min_length,
                    namer_dir,
                )
            except BaseException:
                print('Density map failed. Check tractography output.')
                dm_path = None

            del streamlines, tractogram
            fa_img.uncache()
            dwi_img.uncache()
            gc.collect()
            self._results["dm_path"] = dm_path
            self._results["streams"] = streams
            recon_path = None
        else:
            # Fit diffusion model
            # Save reconstruction to .npy
            recon_path = "%s%s%s%s%s%s%s%s" % (
                runtime.cwd,
                "/reconstruction_",
                "%s" % (self.inputs.subnet +
                        "_" if self.inputs.subnet is not None else ""),
                "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                        "_" if self.inputs.roi is not None else ""),
                self.inputs.conn_model,
                "_",
                "%s" % ("%s%s" % (self.inputs.node_radius, "mm") if
                        ((self.inputs.node_radius != "parc") and
                         (self.inputs.node_radius is not None)) else "parc"),
                ".hdf5",
            )

            gtab_file_tmp_path = fname_presuffix(self.inputs.gtab_file,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gtab_file,
                     gtab_file_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gtab = load_pickle(gtab_file_tmp_path)

            # Only re-run the reconstruction if we have to
            if not os.path.isfile(f"{namer_dir}/{op.basename(recon_path)}"):
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(2)
                del model
            elif os.path.getsize(f"{namer_dir}/{op.basename(recon_path)}") > 0:
                print(f"Found existing reconstruction with "
                      f"{self.inputs.conn_model}. Loading...")
                copyfile(
                    f"{namer_dir}/{op.basename(recon_path)}",
                    recon_path,
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
            else:
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
                del model
            dwi_img.uncache()
            del dwi_data

            # Load atlas wm-gm interface reduced version for seeding
            labels_im_file_tmp_path_wm_gm_int = fname_presuffix(
                self.inputs.labels_im_file_wm_gm_int,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.labels_im_file_wm_gm_int,
                     labels_im_file_tmp_path_wm_gm_int,
                     copy=True,
                     use_hardlink=False)

            t1w2dwi_tmp_path = fname_presuffix(self.inputs.t1w2dwi,
                                               suffix="_tmp",
                                               newpath=runtime.cwd)
            copyfile(self.inputs.t1w2dwi,
                     t1w2dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gm_in_dwi_tmp_path = fname_presuffix(self.inputs.gm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gm_in_dwi,
                     gm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            vent_csf_in_dwi_tmp_path = fname_presuffix(
                self.inputs.vent_csf_in_dwi,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.vent_csf_in_dwi,
                     vent_csf_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            wm_in_dwi_tmp_path = fname_presuffix(self.inputs.wm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.wm_in_dwi,
                     wm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            if self.inputs.waymask:
                waymask_tmp_path = fname_presuffix(self.inputs.waymask,
                                                   suffix="_tmp",
                                                   newpath=runtime.cwd)
                copyfile(self.inputs.waymask,
                         waymask_tmp_path,
                         copy=True,
                         use_hardlink=False)
            else:
                waymask_tmp_path = None

            # Iteratively build a list of streamlines for each ROI while
            # tracking
            print(f"{Fore.GREEN}Target streamlines per iteration: "
                  f"{Fore.BLUE} "
                  f"{target_samples}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Curvature threshold(s): {Fore.BLUE} "
                  f"{self.inputs.curv_thr_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Step size(s): {Fore.BLUE} "
                  f"{self.inputs.step_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Tracking type: {Fore.BLUE} "
                  f"{self.inputs.track_type}")
            print(Style.RESET_ALL)
            if self.inputs.traversal == "prob":
                print(f"{Fore.GREEN}Direction-getting type: {Fore.BLUE}"
                      f"Probabilistic")
            elif self.inputs.traversal == "clos":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Closest Peak")
            elif self.inputs.traversal == "det":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Deterministic Maximum")
            else:
                raise ValueError("Direction-getting type not recognized!")

            print(Style.RESET_ALL)

            # Commence Ensemble Tractography
            try:
                streamlines = track_ensemble(
                    target_samples, labels_im_file_tmp_path_wm_gm_int,
                    labels_im_file_tmp_path, recon_path, get_sphere(sphere),
                    self.inputs.traversal, self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.track_type, self.inputs.maxcrossing,
                    int(roi_neighborhood_tol), self.inputs.min_length,
                    waymask_tmp_path, B0_mask_tmp_path, t1w2dwi_tmp_path,
                    gm_in_dwi_tmp_path, vent_csf_in_dwi_tmp_path,
                    wm_in_dwi_tmp_path, self.inputs.tiss_class)
                gc.collect()
            except BaseException as w:
                print(f"\n{Fore.RED}Tractography failed: {w}")
                print(Style.RESET_ALL)
                streamlines = None

            if streamlines is not None:
                # import multiprocessing
                # from pynets.core.utils import kill_process_family
                # return kill_process_family(int(
                # multiprocessing.current_process().pid))

                # Linear Fascicle Evaluation (LiFE)
                if use_life is True:
                    print('Using LiFE to evaluate streamline plausibility...')
                    from pynets.dmri.utils import \
                        evaluate_streamline_plausibility
                    dwi_img = nib.load(dwi_file_tmp_path)
                    dwi_data = dwi_img.get_fdata(dtype=np.float32)
                    orig_count = len(streamlines)

                    if self.inputs.waymask:
                        mask_data = nib.load(waymask_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    else:
                        mask_data = nib.load(wm_in_dwi_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    try:
                        streamlines = evaluate_streamline_plausibility(
                            dwi_data,
                            gtab,
                            mask_data,
                            streamlines,
                            sphere=sphere)
                    except BaseException:
                        print(f"Linear Fascicle Evaluation failed. "
                              f"Visually checking streamlines output "
                              f"{namer_dir}/{op.basename(streams)} is "
                              f"recommended.")
                    if len(streamlines) < 0.5 * orig_count:
                        raise ValueError('LiFE revealed no plausible '
                                         'streamlines in the tractogram!')
                    del dwi_data, mask_data

                # Save streamlines to trk
                stf = StatefulTractogram(streamlines,
                                         fa_img,
                                         origin=Origin.NIFTI,
                                         space=Space.VOXMM)
                stf.remove_invalid_streamlines()

                save_tractogram(
                    stf,
                    streams,
                )

                del stf

                copyfile(
                    streams,
                    f"{namer_dir}/{op.basename(streams)}",
                    copy=True,
                    use_hardlink=False,
                )

                # Create streamline density map
                try:
                    [dir_path, dm_path] = create_density_map(
                        dwi_img,
                        dir_path,
                        streamlines,
                        self.inputs.conn_model,
                        self.inputs.node_radius,
                        self.inputs.curv_thr_list,
                        self.inputs.step_list,
                        self.inputs.subnet,
                        self.inputs.roi,
                        self.inputs.traversal,
                        self.inputs.min_length,
                        namer_dir,
                    )
                except BaseException:
                    print('Density map failed. Check tractography output.')
                    dm_path = None

                del streamlines
                dwi_img.uncache()
                gc.collect()
                self._results["dm_path"] = dm_path
                self._results["streams"] = streams
            else:
                self._results["streams"] = None
                self._results["dm_path"] = None
            tmp_files = [
                gtab_file_tmp_path, wm_in_dwi_tmp_path, gm_in_dwi_tmp_path,
                vent_csf_in_dwi_tmp_path, t1w2dwi_tmp_path
            ]

            for j in tmp_files:
                if j is not None:
                    if os.path.isfile(j):
                        os.system(f"rm -f {j} &")

        self._results["track_type"] = self.inputs.track_type
        self._results["conn_model"] = self.inputs.conn_model
        self._results["dir_path"] = dir_path
        self._results["subnet"] = self.inputs.subnet
        self._results["node_radius"] = self.inputs.node_radius
        self._results["dens_thresh"] = self.inputs.dens_thresh
        self._results["ID"] = self.inputs.ID
        self._results["roi"] = self.inputs.roi
        self._results["min_span_tree"] = self.inputs.min_span_tree
        self._results["disp_filt"] = self.inputs.disp_filt
        self._results["parc"] = self.inputs.parc
        self._results["prune"] = self.inputs.prune
        self._results["atlas"] = self.inputs.atlas
        self._results["parcellation"] = self.inputs.parcellation
        self._results["labels"] = self.inputs.labels
        self._results["coords"] = self.inputs.coords
        self._results["norm"] = self.inputs.norm
        self._results["binary"] = self.inputs.binary
        self._results["atlas_t1w"] = self.inputs.atlas_t1w
        self._results["curv_thr_list"] = self.inputs.curv_thr_list
        self._results["step_list"] = self.inputs.step_list
        self._results["fa_path"] = fa_file_tmp_path
        self._results["traversal"] = self.inputs.traversal
        self._results["labels_im_file"] = labels_im_file_tmp_path
        self._results["min_length"] = self.inputs.min_length

        tmp_files = [B0_mask_tmp_path, dwi_file_tmp_path]

        for j in tmp_files:
            if j is not None:
                if os.path.isfile(j):
                    os.system(f"rm -f {j} &")

        # Exercise caution when deleting copied recon_path
        # if recon_path is not None:
        #     if os.path.isfile(recon_path):
        #         os.remove(recon_path)

        return runtime
示例#9
0
def create_anisopowermap(gtab_file, dwi_file, B0_mask):
    '''
    Estimate an anisotropic power map image to use for registrations.

    Parameters
    ----------
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted image.
    B0_mask : str
        File path to B0 brain mask.

    Returns
    -------
    anisopwr_path : str
        File path to the anisotropic power Nifti1Image.
    B0_mask : str
        File path to B0 brain mask Nifti1Image.
    gtab_file : str
        File path to pickled DiPy gradient table object.
    dwi_file : str
        File path to diffusion weighted Nifti1Image.
    '''
    import os
    from dipy.io import load_pickle
    from dipy.reconst.shm import anisotropic_power
    from dipy.core.sphere import HemiSphere
    from dipy.reconst.shm import sf_to_sh

    gtab = load_pickle(gtab_file)
    gtab_hemisphere = HemiSphere(xyz=gtab.bvecs[np.where(
        gtab.b0s_mask == False)])

    img = nib.load(dwi_file)
    aff = img.affine

    anisopwr_path = "%s%s" % (os.path.dirname(B0_mask), '/aniso_power.nii.gz')

    if os.path.isfile(anisopwr_path):
        pass
    else:
        print('Generating anisotropic power map to use for registrations...')
        nodif_B0_img = nib.load(B0_mask)

        dwi_data = np.asarray(img.dataobj)
        for b0 in sorted(list(np.where(gtab.b0s_mask == True)[0]),
                         reverse=True):
            dwi_data = np.delete(dwi_data, b0, 3)

        anisomap = anisotropic_power(
            sf_to_sh(dwi_data, gtab_hemisphere, sh_order=2))
        anisomap[np.isnan(anisomap)] = 0
        masked_data = anisomap * np.asarray(
            nodif_B0_img.dataobj).astype('bool')
        img = nib.Nifti1Image(masked_data.astype(np.float32), aff)
        img.to_filename(anisopwr_path)
        nodif_B0_img.uncache()
        del anisomap

    return anisopwr_path, B0_mask, gtab_file, dwi_file
示例#10
0
    def _run_interface(self, runtime):
        import gc
        import numpy as np
        import nibabel as nib
        try:
            import cPickle as pickle
        except ImportError:
            import _pickle as pickle
        from dipy.io import load_pickle
        from colorama import Fore, Style
        from dipy.data import get_sphere
        from pynets.core import utils
        from pynets.dmri.track import prep_tissues, reconstruction, create_density_map, track_ensemble

        # Load diffusion data
        dwi_img = nib.load(self.inputs.dwi_file)

        # Fit diffusion model
        mod_fit = reconstruction(self.inputs.conn_model, load_pickle(self.inputs.gtab_file),
                                 np.asarray(dwi_img.dataobj), self.inputs.B0_mask)

        # Load atlas parcellation (and its wm-gm interface reduced version for seeding)
        atlas_data = np.array(nib.load(self.inputs.labels_im_file).dataobj).astype('uint16')
        atlas_data_wm_gm_int = np.asarray(nib.load(self.inputs.labels_im_file_wm_gm_int).dataobj).astype('uint16')

        # Build mask vector from atlas for later roi filtering
        parcels = []
        i = 0
        for roi_val in np.unique(atlas_data)[1:]:
            parcels.append(atlas_data == roi_val)
            i = i + 1

        if np.sum(atlas_data) == 0:
            raise ValueError(
                'ERROR: No non-zero voxels found in atlas. Check any roi masks and/or wm-gm interface images '
                'to verify overlap with dwi-registered atlas.')

        # Iteratively build a list of streamlines for each ROI while tracking
        print("%s%s%s%s" % (Fore.GREEN, 'Target number of samples: ', Fore.BLUE, self.inputs.target_samples))
        print(Style.RESET_ALL)
        print("%s%s%s%s" % (Fore.GREEN, 'Using curvature threshold(s): ', Fore.BLUE, self.inputs.curv_thr_list))
        print(Style.RESET_ALL)
        print("%s%s%s%s" % (Fore.GREEN, 'Using step size(s): ', Fore.BLUE, self.inputs.step_list))
        print(Style.RESET_ALL)
        print("%s%s%s%s" % (Fore.GREEN, 'Tracking type: ', Fore.BLUE, self.inputs.track_type))
        print(Style.RESET_ALL)
        if self.inputs.directget == 'prob':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Probabilistic'))
        elif self.inputs.directget == 'boot':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Bootstrapped'))
        elif self.inputs.directget == 'closest':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Closest Peak'))
        elif self.inputs.directget == 'det':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Deterministic Maximum'))
        else:
            raise ValueError('Direction-getting type not recognized!')
        print(Style.RESET_ALL)

        # Commence Ensemble Tractography
        streamlines = track_ensemble(np.asarray(dwi_img.dataobj), self.inputs.target_samples, atlas_data_wm_gm_int,
                                     parcels, mod_fit,
                                     prep_tissues(self.inputs.t1w2dwi, self.inputs.gm_in_dwi,
                                                  self.inputs.vent_csf_in_dwi, self.inputs.wm_in_dwi,
                                                  self.inputs.tiss_class),
                                     get_sphere(self.inputs.sphere), self.inputs.directget, self.inputs.curv_thr_list,
                                     self.inputs.step_list, self.inputs.track_type, self.inputs.maxcrossing,
                                     int(self.inputs.roi_neighborhood_tol), self.inputs.min_length, self.inputs.waymask)

        # Create streamline density map
        [streams, dir_path, dm_path] = create_density_map(dwi_img, utils.do_dir_path(self.inputs.atlas,
                                                                                     self.inputs.dwi_file), streamlines,
                                                          self.inputs.conn_model, self.inputs.target_samples,
                                                          self.inputs.node_size, self.inputs.curv_thr_list,
                                                          self.inputs.step_list, self.inputs.network, self.inputs.roi,
                                                          self.inputs.directget, self.inputs.min_length)

        self._results['streams'] = streams
        self._results['track_type'] = self.inputs.track_type
        self._results['target_samples'] = self.inputs.target_samples
        self._results['conn_model'] = self.inputs.conn_model
        self._results['dir_path'] = dir_path
        self._results['network'] = self.inputs.network
        self._results['node_size'] = self.inputs.node_size
        self._results['dens_thresh'] = self.inputs.dens_thresh
        self._results['ID'] = self.inputs.ID
        self._results['roi'] = self.inputs.roi
        self._results['min_span_tree'] = self.inputs.min_span_tree
        self._results['disp_filt'] = self.inputs.disp_filt
        self._results['parc'] = self.inputs.parc
        self._results['prune'] = self.inputs.prune
        self._results['atlas'] = self.inputs.atlas
        self._results['uatlas'] = self.inputs.uatlas
        self._results['labels'] = self.inputs.labels
        self._results['coords'] = self.inputs.coords
        self._results['norm'] = self.inputs.norm
        self._results['binary'] = self.inputs.binary
        self._results['atlas_mni'] = self.inputs.atlas_mni
        self._results['curv_thr_list'] = self.inputs.curv_thr_list
        self._results['step_list'] = self.inputs.step_list
        self._results['fa_path'] = self.inputs.fa_path
        self._results['dm_path'] = dm_path
        self._results['directget'] = self.inputs.directget
        self._results['labels_im_file'] = self.inputs.labels_im_file
        self._results['roi_neighborhood_tol'] = self.inputs.roi_neighborhood_tol
        self._results['min_length'] = self.inputs.min_length

        del streamlines, atlas_data_wm_gm_int, atlas_data, mod_fit, parcels
        dwi_img.uncache()
        gc.collect()

        return runtime
示例#11
0
def run_track(nodif_B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, labels_im_file_wm_gm_int,
              labels_im_file, target_samples, curv_thr_list, step_list, track_type, max_length, maxcrossing, directget,
              conn_model, gtab_file, dwi_file, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc,
              prune, atlas_select, uatlas_select, label_names, coords, norm, binary, atlas_mni, life_run, min_length,
              fa_path):
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.io import load_pickle
    from colorama import Fore, Style
    from dipy.data import get_sphere
    from pynets import utils
    from pynets.dmri.track import prep_tissues, reconstruction, filter_streamlines, track_ensemble

    # Load gradient table
    gtab = load_pickle(gtab_file)

    # Fit diffusion model
    mod_fit = reconstruction(conn_model, gtab, dwi_file, wm_in_dwi)

    # Load atlas parcellation (and its wm-gm interface reduced version for seeding)
    atlas_img = nib.load(labels_im_file)
    atlas_data = atlas_img.get_fdata().astype('int')
    atlas_img_wm_gm_int = nib.load(labels_im_file_wm_gm_int)
    atlas_data_wm_gm_int = atlas_img_wm_gm_int.get_fdata().astype('int')

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    for roi_val in np.unique(atlas_data)[1:]:
        parcels.append(atlas_data == roi_val)
        i = i + 1
    parcel_vec = np.ones(len(parcels))

    # Get sphere
    sphere = get_sphere('repulsion724')

    # Instantiate tissue classifier
    tiss_classifier = prep_tissues(nodif_B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class)

    if np.sum(atlas_data) == 0:
        raise ValueError('ERROR: No non-zero voxels found in atlas. Check any roi masks and/or wm-gm interface images '
                         'to verify overlap with dwi-registered atlas.')

    # Iteratively build a list of streamlines for each ROI while tracking
    print("%s%s%s%s" % (Fore.GREEN, 'Target number of samples: ', Fore.BLUE, target_samples))
    print(Style.RESET_ALL)
    print("%s%s%s%s" % (Fore.GREEN, 'Using curvature threshold(s): ', Fore.BLUE, curv_thr_list))
    print(Style.RESET_ALL)
    print("%s%s%s%s" % (Fore.GREEN, 'Using step size(s): ', Fore.BLUE, step_list))
    print(Style.RESET_ALL)
    print("%s%s%s%s" % (Fore.GREEN, 'Tracking type: ', Fore.BLUE, track_type))
    print(Style.RESET_ALL)
    if directget == 'prob':
        print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Probabilistic Direction...'))
    elif directget == 'boot':
        print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Bootstrapped Direction...'))
    elif directget == 'closest':
        print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Closest Peak Direction...'))
    elif directget == 'det':
        print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Deterministic Maximum Direction...'))
    print(Style.RESET_ALL)

    # Commence Ensemble Tractography
    streamlines = track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, parcel_vec,
                                 mod_fit, tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type,
                                 maxcrossing, max_length)
    print('Tracking Complete')

    # Perform streamline filtering routines
    dir_path = utils.do_dir_path(atlas_select, dwi_file)
    [streams, dir_path] = filter_streamlines(dwi_file, dir_path, gtab, streamlines, life_run, min_length, conn_model,
                                             target_samples, node_size, curv_thr_list, step_list)

    return streams, track_type, target_samples, conn_model, dir_path, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas_select, uatlas_select, label_names, coords, norm, binary, atlas_mni, curv_thr_list, step_list, fa_path