Esempio n. 1
0
    def create_clean_mask(self, num_std_dev=1.5):
        """
        Create a subject-refined version of the clustering mask.
        """
        import os
        from pynets.core import utils
        from nilearn.masking import intersect_masks
        from nilearn.image import index_img, math_img, resample_img
        mask_name = os.path.basename(self.clust_mask).split('.nii')[0]
        self.atlas = "%s%s%s%s%s" % (mask_name, '_', self.clust_type, '_k', str(self.k))
        print("%s%s%s%s%s%s%s" % ('\nCreating atlas using ', self.clust_type, ' at cluster level ', str(self.k),
                                  ' for ', str(self.atlas), '...\n'))
        self._dir_path = utils.do_dir_path(self.atlas, self.func_file)
        self.uatlas = "%s%s%s%s%s%s%s%s" % (self._dir_path, '/', mask_name, '_clust-', self.clust_type, '_k',
                                            str(self.k), '.nii.gz')

        # Load clustering mask
        self._func_img.set_data_dtype(np.float32)
        func_vol_img = index_img(self._func_img, 1)
        func_vol_img.set_data_dtype(np.uint16)
        clust_mask_res_img = resample_img(nib.load(self.clust_mask), target_affine=func_vol_img.affine,
                                          target_shape=func_vol_img.shape, interpolation='nearest')
        clust_mask_res_img.set_data_dtype(np.uint16)
        func_data = np.asarray(func_vol_img.dataobj).astype('float32')
        func_int_thr = np.round(np.mean(func_data[func_data > 0]) - np.std(func_data[func_data > 0]) * num_std_dev, 3)
        if self.mask is not None:
            self._mask_img = nib.load(self.mask)
            self._mask_img.set_data_dtype(np.uint16)
            mask_res_img = resample_img(self._mask_img, target_affine=func_vol_img.affine,
                                        target_shape=func_vol_img.shape, interpolation='nearest')
            mask_res_img.set_data_dtype(np.uint16)
            self._clust_mask_corr_img = intersect_masks([math_img('img > ' + str(func_int_thr), img=func_vol_img),
                                                         math_img('img > 0.01', img=clust_mask_res_img),
                                                         math_img('img > 0.01', img=mask_res_img)],
                                                        threshold=1, connected=False)
            self._clust_mask_corr_img.set_data_dtype(np.uint16)
            self._mask_img.uncache()
            mask_res_img.uncache()
        else:
            self._clust_mask_corr_img = intersect_masks([math_img('img > ' + str(func_int_thr), img=func_vol_img),
                                                         math_img('img > 0.01', img=clust_mask_res_img)],
                                                        threshold=1, connected=False)
            self._clust_mask_corr_img.set_data_dtype(np.uint16)
        nib.save(self._clust_mask_corr_img, "%s%s%s%s" % (self._dir_path, '/', mask_name, '.nii.gz'))

        del func_data
        func_vol_img.uncache()
        clust_mask_res_img.uncache()

        return self.atlas
Esempio n. 2
0
def test_do_dir_path():
    """
    Test do_dir_path functionality
    """
    base_dir = str(Path(__file__).parent/"examples")
    func_path = base_dir + '/002/fmri'
    dwi_path = base_dir + '/002/dmri'
    in_func = func_path + '/002.nii.gz'
    in_dwi = dwi_path + '/std_dmri/iso_eddy_corrected_data_denoised_pre_reor.nii.gz'
    in_files = [in_func, in_dwi]

    atlases = ['Power', 'Shirer', 'Shen', 'Smith']
    for inputs in in_files:
        for atlas in atlases:
            dir_path = utils.do_dir_path(atlas, inputs)
            assert dir_path is not None
Esempio n. 3
0
def test_do_dir_path(atlas, input):
    """
    Test do_dir_path functionality
    """

    base_dir = os.path.abspath(
        pkg_resources.resource_filename("pynets", "../data/examples"))

    if input == 'fmri':
        in_file = f"{base_dir}/003/func/sub-003_ses-01_task-rest_bold.nii.gz"
    elif input == 'dmri':
        in_file = f"{base_dir}/003/dmri/sub-003_dwi.nii.gz"

    # Delete existing atlas dirs in in_file parent
    dir_path = utils.do_dir_path(
        atlas, f"{os.path.dirname(os.path.realpath(in_file))}")
    assert dir_path is not None
Esempio n. 4
0
def test_do_dir_path(atlas, input):
    """
    Test do_dir_path functionality
    """
    import tempfile

    dir_path = str(tempfile.TemporaryDirectory().name)
    os.makedirs(dir_path)
    base_dir = str(Path(__file__).parent / "examples")

    if input == 'fmri':
        in_file = f"{base_dir}/BIDS/sub-25659/ses-1/func/sub-25659_ses-1_task-rest_space-T1w_desc-preproc_bold.nii.gz"
    elif input == 'dmri':
        in_file = f"{base_dir}/BIDS/sub-25659/ses-1/dwi/final_preprocessed_dwi.nii.gz"

    # Delete existing atlas dirs in in_file parent
    atlas_dir = os.path.dirname(os.path.realpath(in_file)) + '/' + str(atlas)

    dir_path = utils.do_dir_path(atlas, atlas_dir)
    assert dir_path is not None
Esempio n. 5
0
def test_do_dir_path(atlas, input):
    """
    Test do_dir_path functionality
    """
    base_dir = str(Path(__file__).parent / "examples")
    if input == 'fmri':
        func_path = base_dir + '/002/fmri'
        in_file = func_path + '/002.nii.gz'
    elif input == 'dmri':
        dwi_path = base_dir + '/002/dmri'
        in_file = dwi_path + '/std_dmri/iso_eddy_corrected_data_denoised_pre_reor.nii.gz'

    # Delete existing atlas dirs in in_file parent
    atlas_dir = os.path.dirname(os.path.realpath(in_file)) + '/' + str(atlas)
    if os.path.exists(atlas_dir):
        shutil.move(atlas_dir, atlas_dir + '_tmp')

    dir_path = utils.do_dir_path(atlas, in_file)
    assert dir_path is not None

    # Restore oringal atlas dir
    if os.path.exists(atlas_dir + '_tmp'):
        shutil.move(atlas_dir + '_tmp', atlas_dir)
Esempio n. 6
0
    def create_clean_mask(self, num_std_dev=1.5):
        """
        Create a subject-refined version of the clustering mask.
        """
        import os
        from pynets.core import utils
        from nilearn.masking import intersect_masks
        from nilearn.image import index_img, math_img, resample_img

        mask_name = os.path.basename(self.clust_mask).split(".nii")[0]
        self.atlas = f"{mask_name}{'_'}{self.clust_type}{'_k'}{str(self.k)}"
        print(f"\nCreating atlas using {self.clust_type} at cluster level"
              f" {str(self.k)} for {str(self.atlas)}...\n")
        self._dir_path = utils.do_dir_path(self.atlas, self.outdir)
        self.parcellation = f"{self._dir_path}/{mask_name}_" \
                            f"clust-{self.clust_type}" \
                            f"_k{str(self.k)}.nii.gz"

        # Load clustering mask
        self._func_img.set_data_dtype(np.float32)
        func_vol_img = index_img(self._func_img, 1)
        func_vol_img.set_data_dtype(np.uint16)
        clust_mask_res_img = resample_img(
            nib.load(self.clust_mask),
            target_affine=func_vol_img.affine,
            target_shape=func_vol_img.shape,
            interpolation="nearest",
        )
        clust_mask_res_img.set_data_dtype(np.uint16)
        func_data = np.asarray(func_vol_img.dataobj, dtype=np.float32)
        func_int_thr = np.round(
            np.mean(func_data[func_data > 0]) -
            np.std(func_data[func_data > 0]) * num_std_dev,
            3,
        )
        if self.mask is not None:
            self._mask_img = nib.load(self.mask)
            self._mask_img.set_data_dtype(np.uint16)
            mask_res_img = resample_img(
                self._mask_img,
                target_affine=func_vol_img.affine,
                target_shape=func_vol_img.shape,
                interpolation="nearest",
            )
            mask_res_img.set_data_dtype(np.uint16)
            self._clust_mask_corr_img = intersect_masks(
                [
                    math_img(f"img > {func_int_thr}", img=func_vol_img),
                    math_img("img > 0.01", img=clust_mask_res_img),
                    math_img("img > 0.01", img=mask_res_img),
                ],
                threshold=1,
                connected=False,
            )
            self._clust_mask_corr_img.set_data_dtype(np.uint16)
            self._mask_img.uncache()
            mask_res_img.uncache()
        else:
            self._clust_mask_corr_img = intersect_masks(
                [
                    math_img("img > " + str(func_int_thr), img=func_vol_img),
                    math_img("img > 0.01", img=clust_mask_res_img),
                ],
                threshold=1,
                connected=False,
            )
            self._clust_mask_corr_img.set_data_dtype(np.uint16)
        nib.save(self._clust_mask_corr_img,
                 f"{self._dir_path}{'/'}{mask_name}{'.nii.gz'}")

        del func_data
        func_vol_img.uncache()
        clust_mask_res_img.uncache()

        return self.atlas
Esempio n. 7
0
    def _run_interface(self, runtime):
        import gc
        import os
        import time
        import os.path as op
        from dipy.io import load_pickle
        from colorama import Fore, Style
        from dipy.data import get_sphere
        from pynets.core import utils
        from pynets.core.utils import load_runconfig
        from pynets.dmri.estimation import reconstruction
        from pynets.dmri.track import (
            create_density_map,
            track_ensemble,
        )
        from dipy.io.stateful_tractogram import Space, StatefulTractogram, \
            Origin
        from dipy.io.streamline import save_tractogram
        from nipype.utils.filemanip import copyfile, fname_presuffix

        hardcoded_params = load_runconfig()
        use_life = hardcoded_params['tracking']["use_life"][0]
        roi_neighborhood_tol = hardcoded_params['tracking'][
            "roi_neighborhood_tol"][0]
        sphere = hardcoded_params['tracking']["sphere"][0]
        target_samples = hardcoded_params['tracking']["tracking_samples"][0]

        dir_path = utils.do_dir_path(self.inputs.atlas,
                                     os.path.dirname(self.inputs.dwi_file))

        namer_dir = "{}/tractography".format(dir_path)
        if not os.path.isdir(namer_dir):
            os.makedirs(namer_dir, exist_ok=True)

        # Load diffusion data
        dwi_file_tmp_path = fname_presuffix(self.inputs.dwi_file,
                                            suffix="_tmp",
                                            newpath=runtime.cwd)
        copyfile(self.inputs.dwi_file,
                 dwi_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        dwi_img = nib.load(dwi_file_tmp_path, mmap=True)
        dwi_data = dwi_img.get_fdata(dtype=np.float32)

        # Load FA data
        fa_file_tmp_path = fname_presuffix(self.inputs.fa_path,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.fa_path,
                 fa_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        fa_img = nib.load(fa_file_tmp_path, mmap=True)

        labels_im_file_tmp_path = fname_presuffix(self.inputs.labels_im_file,
                                                  suffix="_tmp",
                                                  newpath=runtime.cwd)
        copyfile(self.inputs.labels_im_file,
                 labels_im_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        # Load B0 mask
        B0_mask_tmp_path = fname_presuffix(self.inputs.B0_mask,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.B0_mask,
                 B0_mask_tmp_path,
                 copy=True,
                 use_hardlink=False)

        streams = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            runtime.cwd,
            "/streamlines_",
            "%s" % (self.inputs.subnet +
                    "_" if self.inputs.subnet is not None else ""),
            "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                    "_" if self.inputs.roi is not None else ""),
            self.inputs.conn_model,
            "_",
            target_samples,
            "_",
            "%s" % ("%s%s" % (self.inputs.node_radius, "mm_") if
                    ((self.inputs.node_radius != "parc") and
                     (self.inputs.node_radius is not None)) else "parc_"),
            "curv-",
            str(self.inputs.curv_thr_list).replace(", ", "_"),
            "_step-",
            str(self.inputs.step_list).replace(", ", "_"),
            "_traversal-",
            self.inputs.traversal,
            "_minlength-",
            self.inputs.min_length,
            ".trk",
        )

        if os.path.isfile(f"{namer_dir}/{op.basename(streams)}"):
            from dipy.io.streamline import load_tractogram
            copyfile(
                f"{namer_dir}/{op.basename(streams)}",
                streams,
                copy=True,
                use_hardlink=False,
            )
            tractogram = load_tractogram(
                streams,
                fa_img,
                bbox_valid_check=False,
            )

            streamlines = tractogram.streamlines

            # Create streamline density map
            try:
                [dir_path, dm_path] = create_density_map(
                    fa_img,
                    dir_path,
                    streamlines,
                    self.inputs.conn_model,
                    self.inputs.node_radius,
                    self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.subnet,
                    self.inputs.roi,
                    self.inputs.traversal,
                    self.inputs.min_length,
                    namer_dir,
                )
            except BaseException:
                print('Density map failed. Check tractography output.')
                dm_path = None

            del streamlines, tractogram
            fa_img.uncache()
            dwi_img.uncache()
            gc.collect()
            self._results["dm_path"] = dm_path
            self._results["streams"] = streams
            recon_path = None
        else:
            # Fit diffusion model
            # Save reconstruction to .npy
            recon_path = "%s%s%s%s%s%s%s%s" % (
                runtime.cwd,
                "/reconstruction_",
                "%s" % (self.inputs.subnet +
                        "_" if self.inputs.subnet is not None else ""),
                "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                        "_" if self.inputs.roi is not None else ""),
                self.inputs.conn_model,
                "_",
                "%s" % ("%s%s" % (self.inputs.node_radius, "mm") if
                        ((self.inputs.node_radius != "parc") and
                         (self.inputs.node_radius is not None)) else "parc"),
                ".hdf5",
            )

            gtab_file_tmp_path = fname_presuffix(self.inputs.gtab_file,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gtab_file,
                     gtab_file_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gtab = load_pickle(gtab_file_tmp_path)

            # Only re-run the reconstruction if we have to
            if not os.path.isfile(f"{namer_dir}/{op.basename(recon_path)}"):
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(2)
                del model
            elif os.path.getsize(f"{namer_dir}/{op.basename(recon_path)}") > 0:
                print(f"Found existing reconstruction with "
                      f"{self.inputs.conn_model}. Loading...")
                copyfile(
                    f"{namer_dir}/{op.basename(recon_path)}",
                    recon_path,
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
            else:
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
                del model
            dwi_img.uncache()
            del dwi_data

            # Load atlas wm-gm interface reduced version for seeding
            labels_im_file_tmp_path_wm_gm_int = fname_presuffix(
                self.inputs.labels_im_file_wm_gm_int,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.labels_im_file_wm_gm_int,
                     labels_im_file_tmp_path_wm_gm_int,
                     copy=True,
                     use_hardlink=False)

            t1w2dwi_tmp_path = fname_presuffix(self.inputs.t1w2dwi,
                                               suffix="_tmp",
                                               newpath=runtime.cwd)
            copyfile(self.inputs.t1w2dwi,
                     t1w2dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gm_in_dwi_tmp_path = fname_presuffix(self.inputs.gm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gm_in_dwi,
                     gm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            vent_csf_in_dwi_tmp_path = fname_presuffix(
                self.inputs.vent_csf_in_dwi,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.vent_csf_in_dwi,
                     vent_csf_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            wm_in_dwi_tmp_path = fname_presuffix(self.inputs.wm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.wm_in_dwi,
                     wm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            if self.inputs.waymask:
                waymask_tmp_path = fname_presuffix(self.inputs.waymask,
                                                   suffix="_tmp",
                                                   newpath=runtime.cwd)
                copyfile(self.inputs.waymask,
                         waymask_tmp_path,
                         copy=True,
                         use_hardlink=False)
            else:
                waymask_tmp_path = None

            # Iteratively build a list of streamlines for each ROI while
            # tracking
            print(f"{Fore.GREEN}Target streamlines per iteration: "
                  f"{Fore.BLUE} "
                  f"{target_samples}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Curvature threshold(s): {Fore.BLUE} "
                  f"{self.inputs.curv_thr_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Step size(s): {Fore.BLUE} "
                  f"{self.inputs.step_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Tracking type: {Fore.BLUE} "
                  f"{self.inputs.track_type}")
            print(Style.RESET_ALL)
            if self.inputs.traversal == "prob":
                print(f"{Fore.GREEN}Direction-getting type: {Fore.BLUE}"
                      f"Probabilistic")
            elif self.inputs.traversal == "clos":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Closest Peak")
            elif self.inputs.traversal == "det":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Deterministic Maximum")
            else:
                raise ValueError("Direction-getting type not recognized!")

            print(Style.RESET_ALL)

            # Commence Ensemble Tractography
            try:
                streamlines = track_ensemble(
                    target_samples, labels_im_file_tmp_path_wm_gm_int,
                    labels_im_file_tmp_path, recon_path, get_sphere(sphere),
                    self.inputs.traversal, self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.track_type, self.inputs.maxcrossing,
                    int(roi_neighborhood_tol), self.inputs.min_length,
                    waymask_tmp_path, B0_mask_tmp_path, t1w2dwi_tmp_path,
                    gm_in_dwi_tmp_path, vent_csf_in_dwi_tmp_path,
                    wm_in_dwi_tmp_path, self.inputs.tiss_class)
                gc.collect()
            except BaseException as w:
                print(f"\n{Fore.RED}Tractography failed: {w}")
                print(Style.RESET_ALL)
                streamlines = None

            if streamlines is not None:
                # import multiprocessing
                # from pynets.core.utils import kill_process_family
                # return kill_process_family(int(
                # multiprocessing.current_process().pid))

                # Linear Fascicle Evaluation (LiFE)
                if use_life is True:
                    print('Using LiFE to evaluate streamline plausibility...')
                    from pynets.dmri.utils import \
                        evaluate_streamline_plausibility
                    dwi_img = nib.load(dwi_file_tmp_path)
                    dwi_data = dwi_img.get_fdata(dtype=np.float32)
                    orig_count = len(streamlines)

                    if self.inputs.waymask:
                        mask_data = nib.load(waymask_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    else:
                        mask_data = nib.load(wm_in_dwi_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    try:
                        streamlines = evaluate_streamline_plausibility(
                            dwi_data,
                            gtab,
                            mask_data,
                            streamlines,
                            sphere=sphere)
                    except BaseException:
                        print(f"Linear Fascicle Evaluation failed. "
                              f"Visually checking streamlines output "
                              f"{namer_dir}/{op.basename(streams)} is "
                              f"recommended.")
                    if len(streamlines) < 0.5 * orig_count:
                        raise ValueError('LiFE revealed no plausible '
                                         'streamlines in the tractogram!')
                    del dwi_data, mask_data

                # Save streamlines to trk
                stf = StatefulTractogram(streamlines,
                                         fa_img,
                                         origin=Origin.NIFTI,
                                         space=Space.VOXMM)
                stf.remove_invalid_streamlines()

                save_tractogram(
                    stf,
                    streams,
                )

                del stf

                copyfile(
                    streams,
                    f"{namer_dir}/{op.basename(streams)}",
                    copy=True,
                    use_hardlink=False,
                )

                # Create streamline density map
                try:
                    [dir_path, dm_path] = create_density_map(
                        dwi_img,
                        dir_path,
                        streamlines,
                        self.inputs.conn_model,
                        self.inputs.node_radius,
                        self.inputs.curv_thr_list,
                        self.inputs.step_list,
                        self.inputs.subnet,
                        self.inputs.roi,
                        self.inputs.traversal,
                        self.inputs.min_length,
                        namer_dir,
                    )
                except BaseException:
                    print('Density map failed. Check tractography output.')
                    dm_path = None

                del streamlines
                dwi_img.uncache()
                gc.collect()
                self._results["dm_path"] = dm_path
                self._results["streams"] = streams
            else:
                self._results["streams"] = None
                self._results["dm_path"] = None
            tmp_files = [
                gtab_file_tmp_path, wm_in_dwi_tmp_path, gm_in_dwi_tmp_path,
                vent_csf_in_dwi_tmp_path, t1w2dwi_tmp_path
            ]

            for j in tmp_files:
                if j is not None:
                    if os.path.isfile(j):
                        os.system(f"rm -f {j} &")

        self._results["track_type"] = self.inputs.track_type
        self._results["conn_model"] = self.inputs.conn_model
        self._results["dir_path"] = dir_path
        self._results["subnet"] = self.inputs.subnet
        self._results["node_radius"] = self.inputs.node_radius
        self._results["dens_thresh"] = self.inputs.dens_thresh
        self._results["ID"] = self.inputs.ID
        self._results["roi"] = self.inputs.roi
        self._results["min_span_tree"] = self.inputs.min_span_tree
        self._results["disp_filt"] = self.inputs.disp_filt
        self._results["parc"] = self.inputs.parc
        self._results["prune"] = self.inputs.prune
        self._results["atlas"] = self.inputs.atlas
        self._results["parcellation"] = self.inputs.parcellation
        self._results["labels"] = self.inputs.labels
        self._results["coords"] = self.inputs.coords
        self._results["norm"] = self.inputs.norm
        self._results["binary"] = self.inputs.binary
        self._results["atlas_t1w"] = self.inputs.atlas_t1w
        self._results["curv_thr_list"] = self.inputs.curv_thr_list
        self._results["step_list"] = self.inputs.step_list
        self._results["fa_path"] = fa_file_tmp_path
        self._results["traversal"] = self.inputs.traversal
        self._results["labels_im_file"] = labels_im_file_tmp_path
        self._results["min_length"] = self.inputs.min_length

        tmp_files = [B0_mask_tmp_path, dwi_file_tmp_path]

        for j in tmp_files:
            if j is not None:
                if os.path.isfile(j):
                    os.system(f"rm -f {j} &")

        # Exercise caution when deleting copied recon_path
        # if recon_path is not None:
        #     if os.path.isfile(recon_path):
        #         os.remove(recon_path)

        return runtime
Esempio n. 8
0
    def _run_interface(self, runtime):
        import gc
        import numpy as np
        import nibabel as nib
        try:
            import cPickle as pickle
        except ImportError:
            import _pickle as pickle
        from dipy.io import load_pickle
        from colorama import Fore, Style
        from dipy.data import get_sphere
        from pynets.core import utils
        from pynets.dmri.track import prep_tissues, reconstruction, create_density_map, track_ensemble

        # Load diffusion data
        dwi_img = nib.load(self.inputs.dwi_file)

        # Fit diffusion model
        mod_fit = reconstruction(self.inputs.conn_model, load_pickle(self.inputs.gtab_file),
                                 np.asarray(dwi_img.dataobj), self.inputs.B0_mask)

        # Load atlas parcellation (and its wm-gm interface reduced version for seeding)
        atlas_data = np.array(nib.load(self.inputs.labels_im_file).dataobj).astype('uint16')
        atlas_data_wm_gm_int = np.asarray(nib.load(self.inputs.labels_im_file_wm_gm_int).dataobj).astype('uint16')

        # Build mask vector from atlas for later roi filtering
        parcels = []
        i = 0
        for roi_val in np.unique(atlas_data)[1:]:
            parcels.append(atlas_data == roi_val)
            i = i + 1

        if np.sum(atlas_data) == 0:
            raise ValueError(
                'ERROR: No non-zero voxels found in atlas. Check any roi masks and/or wm-gm interface images '
                'to verify overlap with dwi-registered atlas.')

        # Iteratively build a list of streamlines for each ROI while tracking
        print("%s%s%s%s" % (Fore.GREEN, 'Target number of samples: ', Fore.BLUE, self.inputs.target_samples))
        print(Style.RESET_ALL)
        print("%s%s%s%s" % (Fore.GREEN, 'Using curvature threshold(s): ', Fore.BLUE, self.inputs.curv_thr_list))
        print(Style.RESET_ALL)
        print("%s%s%s%s" % (Fore.GREEN, 'Using step size(s): ', Fore.BLUE, self.inputs.step_list))
        print(Style.RESET_ALL)
        print("%s%s%s%s" % (Fore.GREEN, 'Tracking type: ', Fore.BLUE, self.inputs.track_type))
        print(Style.RESET_ALL)
        if self.inputs.directget == 'prob':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Probabilistic'))
        elif self.inputs.directget == 'boot':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Bootstrapped'))
        elif self.inputs.directget == 'closest':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Closest Peak'))
        elif self.inputs.directget == 'det':
            print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Deterministic Maximum'))
        else:
            raise ValueError('Direction-getting type not recognized!')
        print(Style.RESET_ALL)

        # Commence Ensemble Tractography
        streamlines = track_ensemble(np.asarray(dwi_img.dataobj), self.inputs.target_samples, atlas_data_wm_gm_int,
                                     parcels, mod_fit,
                                     prep_tissues(self.inputs.t1w2dwi, self.inputs.gm_in_dwi,
                                                  self.inputs.vent_csf_in_dwi, self.inputs.wm_in_dwi,
                                                  self.inputs.tiss_class),
                                     get_sphere(self.inputs.sphere), self.inputs.directget, self.inputs.curv_thr_list,
                                     self.inputs.step_list, self.inputs.track_type, self.inputs.maxcrossing,
                                     int(self.inputs.roi_neighborhood_tol), self.inputs.min_length, self.inputs.waymask)

        # Create streamline density map
        [streams, dir_path, dm_path] = create_density_map(dwi_img, utils.do_dir_path(self.inputs.atlas,
                                                                                     self.inputs.dwi_file), streamlines,
                                                          self.inputs.conn_model, self.inputs.target_samples,
                                                          self.inputs.node_size, self.inputs.curv_thr_list,
                                                          self.inputs.step_list, self.inputs.network, self.inputs.roi,
                                                          self.inputs.directget, self.inputs.min_length)

        self._results['streams'] = streams
        self._results['track_type'] = self.inputs.track_type
        self._results['target_samples'] = self.inputs.target_samples
        self._results['conn_model'] = self.inputs.conn_model
        self._results['dir_path'] = dir_path
        self._results['network'] = self.inputs.network
        self._results['node_size'] = self.inputs.node_size
        self._results['dens_thresh'] = self.inputs.dens_thresh
        self._results['ID'] = self.inputs.ID
        self._results['roi'] = self.inputs.roi
        self._results['min_span_tree'] = self.inputs.min_span_tree
        self._results['disp_filt'] = self.inputs.disp_filt
        self._results['parc'] = self.inputs.parc
        self._results['prune'] = self.inputs.prune
        self._results['atlas'] = self.inputs.atlas
        self._results['uatlas'] = self.inputs.uatlas
        self._results['labels'] = self.inputs.labels
        self._results['coords'] = self.inputs.coords
        self._results['norm'] = self.inputs.norm
        self._results['binary'] = self.inputs.binary
        self._results['atlas_mni'] = self.inputs.atlas_mni
        self._results['curv_thr_list'] = self.inputs.curv_thr_list
        self._results['step_list'] = self.inputs.step_list
        self._results['fa_path'] = self.inputs.fa_path
        self._results['dm_path'] = dm_path
        self._results['directget'] = self.inputs.directget
        self._results['labels_im_file'] = self.inputs.labels_im_file
        self._results['roi_neighborhood_tol'] = self.inputs.roi_neighborhood_tol
        self._results['min_length'] = self.inputs.min_length

        del streamlines, atlas_data_wm_gm_int, atlas_data, mod_fit, parcels
        dwi_img.uncache()
        gc.collect()

        return runtime
Esempio n. 9
0
    def _run_interface(self, runtime):
        from pynets.core import utils, nodemaker
        from nipype.utils.filemanip import fname_presuffix, copyfile
        from nilearn.image import concat_imgs
        import pandas as pd
        import time
        import textwrap
        from pathlib import Path
        import os.path as op
        import glob

        base_path = utils.get_file()
        # Test if atlas is a nilearn atlas. If so, fetch coords, labels, and/or
        # networks.
        nilearn_parc_atlases = [
            "atlas_harvard_oxford",
            "atlas_aal",
            "atlas_destrieux_2009",
            "atlas_talairach_gyrus",
            "atlas_talairach_ba",
            "atlas_talairach_lobe",
        ]
        nilearn_coords_atlases = ["coords_power_2011", "coords_dosenbach_2010"]
        nilearn_prob_atlases = ["atlas_msdl", "atlas_pauli_2017"]
        local_atlases = [
            op.basename(i).split(".nii")[0]
            for i in glob.glob(f"{str(Path(base_path).parent.parent)}"
                               f"/templates/atlases/*.nii.gz")
            if "_4d" not in i
        ]

        if self.inputs.parcellation is None and self.inputs.atlas in \
                nilearn_parc_atlases:
            [labels, networks_list, parcellation
             ] = nodemaker.nilearn_atlas_helper(self.inputs.atlas,
                                                self.inputs.parc)
            if parcellation:
                if not isinstance(parcellation, str):
                    nib.save(
                        parcellation, f"{runtime.cwd}"
                        f"{self.inputs.atlas}{'.nii.gz'}")
                    parcellation = f"{runtime.cwd}" \
                                   f"{self.inputs.atlas}{'.nii.gz'}"
                if self.inputs.clustering is False:
                    [parcellation,
                     labels] = \
                        nodemaker.enforce_hem_distinct_consecutive_labels(
                        parcellation, label_names=labels)
                [coords, atlas, par_max, label_intensities] = \
                    nodemaker.get_names_and_coords_of_parcels(parcellation)
                if self.inputs.parc is True:
                    parcels_4d_img = nodemaker.three_to_four_parcellation(
                        parcellation)
                else:
                    parcels_4d_img = None
            else:
                raise FileNotFoundError(
                    f"\nAtlas file for {self.inputs.atlas} not found!")

            atlas = self.inputs.atlas
        elif (self.inputs.parcellation is None and self.inputs.parc is False
              and self.inputs.atlas in nilearn_coords_atlases):
            print("Fetching coords and labels from nilearn coordinate-based"
                  " atlas library...")
            # Fetch nilearn atlas coords
            [coords, _, networks_list,
             labels] = nodemaker.fetch_nilearn_atlas_coords(self.inputs.atlas)
            parcels_4d = None
            par_max = None
            atlas = self.inputs.atlas
            parcellation = None
            label_intensities = None
        elif (self.inputs.parcellation is None and self.inputs.parc is False
              and self.inputs.atlas in nilearn_prob_atlases):
            import matplotlib
            matplotlib.use("agg")
            from nilearn.plotting import find_probabilistic_atlas_cut_coords

            print("Fetching coords and labels from nilearn probabilistic atlas"
                  " library...")
            # Fetch nilearn atlas coords
            [labels, networks_list, parcellation
             ] = nodemaker.nilearn_atlas_helper(self.inputs.atlas,
                                                self.inputs.parc)
            coords = find_probabilistic_atlas_cut_coords(maps_img=parcellation)
            if parcellation:
                if not isinstance(parcellation, str):
                    nib.save(
                        parcellation, f"{runtime.cwd}"
                        f"{self.inputs.atlas}{'.nii.gz'}")
                    parcellation = f"{runtime.cwd}" \
                                   f"{self.inputs.atlas}{'.nii.gz'}"
                if self.inputs.clustering is False:
                    [parcellation,
                     labels] = \
                        nodemaker.enforce_hem_distinct_consecutive_labels(
                        parcellation, label_names=labels)
                if self.inputs.parc is True:
                    parcels_4d_img = nodemaker.three_to_four_parcellation(
                        parcellation)
                else:
                    parcels_4d_img = None
            else:
                raise FileNotFoundError(
                    f"\nAtlas file for {self.inputs.atlas} not found!")

            par_max = None
            atlas = self.inputs.atlas
            label_intensities = None
        elif self.inputs.parcellation is None and self.inputs.atlas in \
            local_atlases:
            parcellation_pre = (
                f"{str(Path(base_path).parent.parent)}/templates/atlases/"
                f"{self.inputs.atlas}.nii.gz")
            parcellation = fname_presuffix(parcellation_pre,
                                           newpath=runtime.cwd)
            copyfile(parcellation_pre,
                     parcellation,
                     copy=True,
                     use_hardlink=False)
            try:
                par_img = nib.load(parcellation)
            except indexed_gzip.ZranError as e:
                print(
                    e, "\nCannot load subnetwork reference image. "
                    "Do you have git-lfs installed?")
            try:
                if self.inputs.clustering is False:
                    [parcellation, _] = \
                        nodemaker.enforce_hem_distinct_consecutive_labels(
                            parcellation)

                # Fetch user-specified atlas coords
                [coords, _, par_max, label_intensities] = \
                    nodemaker.get_names_and_coords_of_parcels(parcellation)
                if self.inputs.parc is True:
                    parcels_4d_img = nodemaker.three_to_four_parcellation(
                        parcellation)
                else:
                    parcels_4d_img = None
                # Describe user atlas coords
                print(f"\n{self.inputs.atlas} comes with {par_max} parcels\n")
            except ValueError as e:
                print(
                    e, "Either you have specified the name of an atlas that "
                    "does not exist in the nilearn or local repository or "
                    "you have not supplied a 3d atlas parcellation image!")
            labels = None
            networks_list = None
            atlas = self.inputs.atlas
        elif self.inputs.parcellation:
            if self.inputs.clustering is True:
                while True:
                    if op.isfile(self.inputs.parcellation):
                        break
                    else:
                        print("Waiting for atlas file...")
                        time.sleep(5)

            try:
                parcellation_tmp_path = fname_presuffix(
                    self.inputs.parcellation, newpath=runtime.cwd)
                copyfile(self.inputs.parcellation,
                         parcellation_tmp_path,
                         copy=True,
                         use_hardlink=False)
                # Fetch user-specified atlas coords
                if self.inputs.clustering is False:
                    [parcellation,
                     _] = nodemaker.enforce_hem_distinct_consecutive_labels(
                         parcellation_tmp_path)
                else:
                    parcellation = parcellation_tmp_path
                [coords, atlas, par_max, label_intensities] = \
                    nodemaker.get_names_and_coords_of_parcels(parcellation)
                if self.inputs.parc is True:
                    parcels_4d_img = nodemaker.three_to_four_parcellation(
                        parcellation)
                else:
                    parcels_4d_img = None

                atlas = utils.prune_suffices(atlas)

                # Describe user atlas coords
                print(f"\n{atlas} comes with {par_max} parcels\n")
            except ValueError as e:
                print(
                    e, "Either you have specified the name of an atlas that "
                    "does not exist in the nilearn or local repository or "
                    "you have not supplied a 3d atlas parcellation image!")
            labels = None
            networks_list = None
        else:
            raise ValueError(
                "Either you have specified the name of an atlas that does"
                " not exist in the nilearn or local repository or you have"
                " not supplied a 3d atlas parcellation image!")

        # Labels prep
        if atlas and not labels:
            if (self.inputs.ref_txt is not None) and (op.exists(
                    self.inputs.ref_txt)):
                labels = pd.read_csv(self.inputs.ref_txt,
                                     sep=" ",
                                     header=None,
                                     names=["Index",
                                            "Region"])["Region"].tolist()
            else:
                if atlas in local_atlases:
                    ref_txt = (
                        f"{str(Path(base_path).parent.parent)}/templates/"
                        f"labels/"
                        f"{atlas}.txt")
                else:
                    ref_txt = self.inputs.ref_txt
                if ref_txt is not None:
                    try:
                        labels = pd.read_csv(ref_txt,
                                             sep=" ",
                                             header=None,
                                             names=["Index", "Region"
                                                    ])["Region"].tolist()
                    except BaseException:
                        if self.inputs.use_parcel_naming is True:
                            try:
                                labels = nodemaker.parcel_naming(
                                    coords, self.inputs.vox_size)
                            except BaseException:
                                print("AAL reference labeling failed!")
                                labels = np.arange(len(coords) + 1)[
                                    np.arange(len(coords) + 1) != 0].tolist()
                        else:
                            print("Using generic index labels...")
                            labels = np.arange(len(coords) +
                                               1)[np.arange(len(coords) +
                                                            1) != 0].tolist()
                else:
                    if self.inputs.use_parcel_naming is True:
                        try:
                            labels = nodemaker.parcel_naming(
                                coords, self.inputs.vox_size)
                        except BaseException:
                            print("AAL reference labeling failed!")
                            labels = np.arange(len(coords) +
                                               1)[np.arange(len(coords) +
                                                            1) != 0].tolist()
                    else:
                        print("Using generic index labels...")
                        labels = np.arange(len(coords) +
                                           1)[np.arange(len(coords) +
                                                        1) != 0].tolist()

        dir_path = utils.do_dir_path(atlas, self.inputs.outdir)

        if len(coords) != len(labels):
            labels = [
                i for i in labels if (i != 'Unknown' and i != 'Background')
            ]
            if len(coords) != len(labels):
                print("Length of coordinates is not equal to length of "
                      "label names...")
                if self.inputs.use_parcel_naming is True:
                    try:
                        print("Attempting consensus parcel naming instead...")
                        labels = nodemaker.parcel_naming(
                            coords, self.inputs.vox_size)
                    except BaseException:
                        print("Reverting to integer labels instead...")
                        labels = np.arange(len(coords) +
                                           1)[np.arange(len(coords) +
                                                        1) != 0].tolist()
                else:
                    print("Reverting to integer labels instead...")
                    labels = np.arange(len(coords) +
                                       1)[np.arange(len(coords) +
                                                    1) != 0].tolist()

        print(f"Coordinates:\n{coords}")
        print(f"Labels:\n"
              f"{textwrap.shorten(str(labels), width=1000, placeholder='...')}"
              f"")

        assert len(coords) == len(labels)

        if label_intensities is not None:
            self._results["labels"] = list(zip(labels, label_intensities))
        else:
            self._results["labels"] = labels
        self._results["coords"] = coords
        self._results["atlas"] = atlas
        self._results["networks_list"] = networks_list
        # TODO: Optimize this with 4d array concatenation and .npyz

        out_path = f"{runtime.cwd}/parcels_4d.nii.gz"
        nib.save(parcels_4d_img, out_path)
        self._results["parcels_4d"] = out_path
        self._results["par_max"] = par_max
        self._results["parcellation"] = parcellation
        self._results["dir_path"] = dir_path

        return runtime
Esempio n. 10
0
def fetch_nodes_and_labels(atlas, uatlas, ref_txt, parc, in_file, use_AAL_naming, clustering=False):
    """
    General API for fetching, identifying, and defining atlas nodes based on coordinates and/or labels.

    Parameters
    ----------
    atlas : str
        Name of a Nilearn-hosted coordinate or parcellation/label-based atlas supported for fetching.
        See Nilearn's datasets.atlas module for more detailed reference.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    ref_txt : str
        Path to an atlas reference .txt file that maps labels to intensities corresponding to uatlas.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    in_file : str
        File path to Nifti1Image object whose affine will provide sampling reference for fetching.
    use_AAL_naming : bool
        Indicates whether to perform Automated-Anatomical Labeling of each coordinate from a list of a voxel
        coordinates.
    clustering : bool
        Indicates whether clustering was performed. Default is False.

    Returns
    -------
    labels : list
        List of string labels corresponding to ROI nodes.
    coords : list
        List of (x, y, z) tuples in mm-space corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    atlas : str
        Name of atlas parcellation (can differ slightly from fetch API string).
    networks_list : list
        List of RSN's and their associated cooordinates, if predefined uniquely for a given atlas.
    parcel_list : list
        List of 3D boolean numpy arrays or binarized Nifti1Images corresponding to ROI masks.
    par_max : int
        The maximum label intensity in the parcellation image.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    dir_path : str
        Path to directory containing subject derivative data for given run.
    """
    from pynets.core import utils, nodemaker
    import pandas as pd
    import time
    from pathlib import Path
    import os.path as op

    base_path = utils.get_file()
    # Test if atlas is a nilearn atlas. If so, fetch coords, labels, and/or networks.
    nilearn_parc_atlases = ['atlas_harvard_oxford', 'atlas_aal', 'atlas_destrieux_2009', 'atlas_talairach_gyrus',
                            'atlas_talairach_ba', 'atlas_talairach_lobe']
    nilearn_coords_atlases = ['coords_power_2011', 'coords_dosenbach_2010']
    nilearn_prob_atlases = ['atlas_msdl', 'atlas_pauli_2017']
    if uatlas is None and atlas in nilearn_parc_atlases:
        [labels, networks_list, uatlas] = nodemaker.nilearn_atlas_helper(atlas, parc)
        if uatlas:
            if not isinstance(uatlas, str):
                nib.save(uatlas, "%s%s%s" % ('/tmp/', atlas, '.nii.gz'))
                uatlas = "%s%s%s" % ('/tmp/', atlas, '.nii.gz')
            [coords, _, par_max] = nodemaker.get_names_and_coords_of_parcels(uatlas)
            if parc is True:
                parcel_list = nodemaker.gen_img_list(uatlas)
            else:
                parcel_list = None
        else:
            raise ValueError("%s%s%s" % ('\nERROR: Atlas file for ', atlas, ' not found!'))
    elif uatlas is None and parc is False and atlas in nilearn_coords_atlases:
        print('Fetching coords and labels from nilearn coordinate-based atlas library...')
        # Fetch nilearn atlas coords
        [coords, _, networks_list, labels] = nodemaker.fetch_nilearn_atlas_coords(atlas)
        parcel_list = None
        par_max = None
    elif uatlas is None and parc is False and atlas in nilearn_prob_atlases:
        from nilearn.plotting import find_probabilistic_atlas_cut_coords
        print('Fetching coords and labels from nilearn probabilistic atlas library...')
        # Fetch nilearn atlas coords
        [labels, networks_list, uatlas] = nodemaker.nilearn_atlas_helper(atlas, parc)
        coords = find_probabilistic_atlas_cut_coords(maps_img=uatlas)
        if uatlas:
            if not isinstance(uatlas, str):
                nib.save(uatlas, "%s%s%s" % ('/tmp/', atlas, '.nii.gz'))
                uatlas = "%s%s%s" % ('/tmp/', atlas, '.nii.gz')
            if parc is True:
                parcel_list = nodemaker.gen_img_list(uatlas)
            else:
                parcel_list = None
        else:
            raise ValueError("%s%s%s" % ('\nERROR: Atlas file for ', atlas, ' not found!'))
        par_max = None
    elif uatlas:
        if clustering is True:
            while True:
                if op.isfile(uatlas):
                    break
                else:
                    print('Waiting for atlas file...')
                    time.sleep(15)
        atlas = uatlas.split('/')[-1].split('.')[0]
        try:
            # Fetch user-specified atlas coords
            [coords, atlas, par_max] = nodemaker.get_names_and_coords_of_parcels(uatlas)
            if parc is True:
                parcel_list = nodemaker.gen_img_list(uatlas)
            else:
                parcel_list = None
            # Describe user atlas coords
            print("%s%s%s%s" % ('\n', atlas, ' comes with {0} '.format(par_max), 'parcels\n'))
        except ValueError:
            print('\n\nError: Either you have specified the name of a nilearn atlas that does not exist or '
                  'you have not supplied a 3d atlas parcellation image!\n\n')
            parcel_list = None
            par_max = None
            coords = None
        labels = None
        networks_list = None
    else:
        networks_list = None
        labels = None
        parcel_list = None
        par_max = None
        coords = None

    # Labels prep
    if atlas:
        if labels:
            pass
        else:
            if (ref_txt is not None) and (op.exists(ref_txt)) and (use_AAL_naming is False):
                labels = pd.read_csv(ref_txt, sep=" ", header=None, names=["Index", "Region"])['Region'].tolist()
            else:
                try:
                    ref_txt = "%s%s%s%s" % (str(Path(base_path).parent), '/labelcharts/', atlas, '.txt')
                    if op.exists(ref_txt) and (use_AAL_naming is False):
                        try:
                            labels = pd.read_csv(ref_txt,
                                                 sep="\t", header=None, names=["Index", "Region"])['Region'].tolist()
                        except:
                            labels = np.arange(len(coords) + 1)[np.arange(len(coords) + 1) != 0].tolist()
                    else:
                        if use_AAL_naming is True:
                            try:
                                labels = nodemaker.AAL_naming(coords)
                            except:
                                print('AAL reference labeling failed!')
                                labels = np.arange(len(coords) + 1)[np.arange(len(coords) + 1) != 0].tolist()
                        else:
                            print('Using generic index labels...')
                            labels = np.arange(len(coords) + 1)[np.arange(len(coords) + 1) != 0].tolist()
                except:
                    print("Label reference file not found. Attempting AAL naming...")
                    if use_AAL_naming is True:
                        try:
                            labels = nodemaker.AAL_naming(coords)
                        except:
                            print('AAL reference labeling failed!')
                            labels = np.arange(len(coords) + 1)[np.arange(len(coords) + 1) != 0].tolist()
                    else:
                        print('Using generic index labels...')
                        labels = np.arange(len(coords) + 1)[np.arange(len(coords) + 1) != 0].tolist()
    else:
        print('WARNING: No labels available since atlas name is not specified!')

    print("%s%s" % ('Labels:\n', labels))
    dir_path = utils.do_dir_path(atlas, in_file)

    if len(coords) != len(labels):
        labels = len(coords) * [np.nan]
        if len(coords) != len(labels):
            raise ValueError('ERROR: length of coordinates is not equal to length of label names')

    return labels, coords, atlas, networks_list, parcel_list, par_max, uatlas, dir_path