コード例 #1
0
def test_cmc_stopping_criterion():
    """This tests that the cmc stopping criterion returns expected
    streamline statuses.
    """

    gm = np.array([[[1, 1], [0, 0], [0, 0]]])
    wm = np.array([[[0, 0], [1, 1], [0, 0]]])
    csf = np.array([[[0, 0], [0, 0], [1, 1]]])
    include_map = gm
    exclude_map = csf

    cmc_tc = CmcStoppingCriterion(include_map=include_map,
                                  exclude_map=exclude_map,
                                  step_size=1,
                                  average_voxel_size=1)
    cmc_tc_from_pve = CmcStoppingCriterion.from_pve(wm_map=wm,
                                                    gm_map=gm,
                                                    csf_map=csf,
                                                    step_size=1,
                                                    average_voxel_size=1)

    # Test constructors
    for idx in np.ndindex(wm.shape):
        idx = np.asarray(idx, dtype="float64")
        npt.assert_almost_equal(cmc_tc.get_include(idx),
                                cmc_tc_from_pve.get_include(idx))
        npt.assert_almost_equal(cmc_tc.get_exclude(idx),
                                cmc_tc_from_pve.get_exclude(idx))

    # Test voxel center
    for ind in ndindex(wm.shape):
        pts = np.array(ind, dtype='float64')
        state = cmc_tc.check_point(pts)
        if csf[ind] == 1:
            npt.assert_equal(state, int(StreamlineStatus.INVALIDPOINT))
        elif gm[ind] == 1:
            npt.assert_equal(state, int(StreamlineStatus.ENDPOINT))
        else:
            npt.assert_equal(state, int(StreamlineStatus.TRACKPOINT))

    # Test outside points
    outside_pts = [[100, 100, 100], [0, -1, 1], [0, 10, 2], [0, 0.5, -0.51],
                   [0, -0.51, 0.1]]
    for pts in outside_pts:
        pts = np.array(pts, dtype='float64')
        npt.assert_equal(cmc_tc.check_point(pts),
                         int(StreamlineStatus.OUTSIDEIMAGE))
        npt.assert_equal(cmc_tc.get_exclude(pts), 0)
        npt.assert_equal(cmc_tc.get_include(pts), 0)
コード例 #2
0
ファイル: track.py プロジェクト: CaseyWeiner/m2g
    def prep_tracking(self):
        """Uses nibabel and dipy functions in order to load the grey matter, white matter, and csf masks
        and use a tissue classifier (act, cmc, or binary) on the include/exclude maps to make a tissueclassifier object

        Returns
        -------
        ActStoppingCriterion, CmcStoppingCriterion, or BinaryStoppingCriterion
            The resulting tissue classifier object, depending on which method you use (currently only does act)
        """

        if self.track_type == "local":
            tiss_class = "bin"
        elif self.track_type == "particle":
            tiss_class = "cmc"

        self.dwi_img = nib.load(self.dwi)
        self.data = self.dwi_img.get_data()
        # Loads mask and ensures it's a true binary mask
        self.mask_img = nib.load(self.nodif_B0_mask)
        self.mask = self.mask_img.get_data() > 0
        # Load tissue maps and prepare tissue classifier
        self.gm_mask = nib.load(self.gm_in_dwi)
        self.gm_mask_data = self.gm_mask.get_data()
        self.wm_mask = nib.load(self.wm_in_dwi)
        self.wm_mask_data = self.wm_mask.get_data()
        self.wm_in_dwi_data = nib.load(
            self.wm_in_dwi).get_data().astype("bool")
        if tiss_class == "act":
            self.vent_csf_in_dwi = nib.load(self.vent_csf_in_dwi)
            self.vent_csf_in_dwi_data = self.vent_csf_in_dwi.get_data()
            self.background = np.ones(self.gm_mask.shape)
            self.background[(self.gm_mask_data + self.wm_mask_data +
                             self.vent_csf_in_dwi_data) > 0] = 0
            self.include_map = self.wm_mask_data
            self.include_map[self.background > 0] = 0
            self.exclude_map = self.vent_csf_in_dwi_data
            self.tiss_classifier = ActStoppingCriterion(
                self.include_map, self.exclude_map)
        elif tiss_class == "bin":
            self.tiss_classifier = BinaryStoppingCriterion(self.wm_in_dwi_data)
            # self.tiss_classifier = BinaryStoppingCriterion(self.mask)
        elif tiss_class == "cmc":
            self.vent_csf_in_dwi = nib.load(self.vent_csf_in_dwi)
            self.vent_csf_in_dwi_data = self.vent_csf_in_dwi.get_data()
            voxel_size = np.average(self.wm_mask.get_header()["pixdim"][1:4])
            step_size = 0.2
            self.tiss_classifier = CmcStoppingCriterion.from_pve(
                self.wm_mask_data,
                self.gm_mask_data,
                self.vent_csf_in_dwi_data,
                step_size=step_size,
                average_voxel_size=voxel_size,
            )
        else:
            pass
        return self.tiss_classifier
コード例 #3
0
    def _cmc_sc(self):
        from dipy.tracking.stopping_criterion import CmcStoppingCriterion

        f_pve_csf, f_pve_gm, f_pve_wm = load_pve_files(self.subj_folder,
                                                       self.pve_file_name)[:3]
        pve_csf_data = load_nifti_data(f_pve_csf)
        pve_gm_data = load_nifti_data(f_pve_gm)
        pve_wm_data = load_nifti_data(f_pve_wm)
        cmc_criterion = CmcStoppingCriterion.from_pve(
            pve_wm_data,
            pve_gm_data,
            pve_csf_data,
            step_size=self.parameters_dict['step_size'],
            average_voxel_size=self.parameters_dict['voxel_size'])
        self.classifier = cmc_criterion
コード例 #4
0
ファイル: weighted_tracts.py プロジェクト: HilaGast/FT
def create_cmc_classifier(folder_name):
    from dipy.tracking.stopping_criterion import CmcStoppingCriterion

    f_pve_csf, f_pve_gm, f_pve_wm = load_pve_files(folder_name)
    pve_csf_data = load_nifti_data(f_pve_csf)
    pve_gm_data = load_nifti_data(f_pve_gm)
    pve_wm_data, _, voxel_size = load_nifti(f_pve_wm, return_voxsize=True)
    voxel_size = np.average(voxel_size[1:4])
    step_size = 0.2
    cmc_criterion = CmcStoppingCriterion.from_pve(
        pve_wm_data,
        pve_gm_data,
        pve_csf_data,
        step_size=step_size,
        average_voxel_size=voxel_size)

    return cmc_criterion, step_size
コード例 #5
0
anatomical images to determine when the tractography stops.
Both stopping criterion use a trilinear interpolation
at the tracking position. CMC stopping criterion uses a probability derived
from the PVE maps to determine if the streamline reaches a 'valid' or 'invalid'
region. ACT uses a fixed threshold on the PVE maps. Both stopping criterion can
be used in conjunction with PFT. In this example, we used CMC.
"""

from dipy.tracking.stopping_criterion import CmcStoppingCriterion

voxel_size = np.average(voxel_size[1:4])
step_size = 0.2

cmc_criterion = CmcStoppingCriterion.from_pve(pve_wm_data,
                                              pve_gm_data,
                                              pve_csf_data,
                                              step_size=step_size,
                                              average_voxel_size=voxel_size)

# Particle Filtering Tractography
pft_streamline_generator = ParticleFilteringTracking(dg,
                                                     cmc_criterion,
                                                     seeds,
                                                     affine,
                                                     max_cross=1,
                                                     step_size=step_size,
                                                     maxlen=1000,
                                                     pft_back_tracking_dist=2,
                                                     pft_front_tracking_dist=1,
                                                     particle_count=15,
                                                     return_all=False)
コード例 #6
0
    def run(self,
            pam_files,
            wm_files,
            gm_files,
            csf_files,
            seeding_files,
            step_size=0.2,
            seed_density=1,
            pmf_threshold=0.1,
            max_angle=20.,
            pft_back=2,
            pft_front=1,
            pft_count=15,
            out_dir='',
            out_tractogram='tractogram.trk',
            save_seeds=False):
        """Workflow for Particle Filtering Tracking.

        This workflow use a saved peaks and metrics (PAM) file as input.

        Parameters
        ----------
        pam_files : string
           Path to the peaks and metrics files. This path may contain
            wildcards to use multiple masks at once.
        wm_files : string
            Path to white matter partial volume estimate for tracking (CMC).
        gm_files : string
            Path to grey matter partial volume estimate for tracking (CMC).
        csf_files : string
            Path to cerebrospinal fluid partial volume estimate for tracking
            (CMC).
        seeding_files : string
            A binary image showing where we need to seed for tracking.
        step_size : float, optional
            Step size used for tracking (default 0.2mm).
        seed_density : int, optional
            Number of seeds per dimension inside voxel (default 1).
             For example, seed_density of 2 means 8 regularly distributed
             points in the voxel. And seed density of 1 means 1 point at the
             center of the voxel.
        pmf_threshold : float, optional
            Threshold for ODF functions (default 0.1).
        max_angle : float, optional
            Maximum angle between streamline segments (range [0, 90],
            default 20).
        pft_back : float, optional
            Distance in mm to back track before starting the particle filtering
            tractography (default 2mm). The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_front : float, optional
            Distance in mm to run the particle filtering tractography after the
            the back track distance (default 1mm). The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_count : int, optional
            Number of particles to use in the particle filter (default 15).
        out_dir : string, optional
           Output directory (default input file directory)
        out_tractogram : string, optional
           Name of the tractogram file to be saved (default 'tractogram.trk')
        save_seeds : bool, optional
            If true, save the seeds associated to their streamline
            in the 'data_per_streamline' Tractogram dictionary using
            'seeds' as the key

        References
        ----------
        Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M. Towards
        quantitative connectivity analysis: reducing tractography biases.
        NeuroImage, 98, 266-278, 2014.

        """
        io_it = self.get_io_iterator()

        for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract \
                in io_it:

            logging.info(
                'Particle Filtering tracking on {0}'.format(pams_path))

            pam = load_peaks(pams_path, verbose=False)

            wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True)
            gm, _ = load_nifti(gm_path)
            csf, _ = load_nifti(csf_path)
            avs = sum(voxel_size) / len(voxel_size)  # average_voxel_size
            stopping_criterion = CmcStoppingCriterion.from_pve(
                wm, gm, csf, step_size=step_size, average_voxel_size=avs)
            logging.info('stopping criterion done')
            seed_mask, _ = load_nifti(seeding_path)
            seeds = utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
            logging.info('seeds done')
            dg = ProbabilisticDirectionGetter

            direction_getter = dg.from_shcoeff(pam.shm_coeff,
                                               max_angle=max_angle,
                                               sphere=pam.sphere,
                                               pmf_threshold=pmf_threshold)

            tracking_result = ParticleFilteringTracking(
                direction_getter,
                stopping_criterion,
                seeds,
                affine,
                step_size=step_size,
                pft_back_tracking_dist=pft_back,
                pft_front_tracking_dist=pft_front,
                pft_max_trial=20,
                particle_count=pft_count,
                save_seeds=save_seeds)

            logging.info('ParticleFilteringTracking initiated')

            if save_seeds:
                streamlines, seeds = zip(*tracking_result)
                seeds = {'seeds': seeds}
            else:
                streamlines = list(tracking_result)
                seeds = {}

            sft = StatefulTractogram(streamlines,
                                     seeding_path,
                                     Space.RASMM,
                                     data_per_streamline=seeds)
            save_tractogram(sft, out_tract, bbox_valid_check=False)
            logging.info('Saved {0}'.format(out_tract))
コード例 #7
0
ファイル: tractography.py プロジェクト: dweiss044/pyAFQ
def track(params_file,
          directions="det",
          max_angle=30.,
          sphere=None,
          seed_mask=None,
          seed_threshold=0,
          n_seeds=1,
          random_seeds=False,
          rng_seed=None,
          stop_mask=None,
          stop_threshold=0,
          step_size=0.5,
          min_length=10,
          max_length=1000,
          odf_model="DTI",
          tracker="local"):
    """
    Tractography

    Parameters
    ----------
    params_file : str, nibabel img.
        Full path to a nifti file containing CSD spherical harmonic
        coefficients, or nibabel img with model params.
    directions : str
        How tracking directions are determined.
        One of: {"det" | "prob"}
    max_angle : float, optional.
        The maximum turning angle in each step. Default: 30
    sphere : Sphere object, optional.
        The discretization of direction getting. default:
        dipy.data.default_sphere.
    seed_mask : array, optional.
        Float or binary mask describing the ROI within which we seed for
        tracking.
        Default to the entire volume (all ones).
    seed_threshold : float, optional.
        A value of the seed_mask below which tracking is terminated.
        Default to 0.
    n_seeds : int or 2D array, optional.
        The seeding density: if this is an int, it is is how many seeds in each
        voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
        array, these are the coordinates of the seeds. Unless random_seeds is
        set to True, in which case this is the total number of random seeds
        to generate within the mask.
    random_seeds : bool
        Whether to generate a total of n_seeds random seeds in the mask.
        Default: False.
    rng_seed : int
        random seed used to generate random seeds if random_seeds is
        set to True. Default: None
    stop_mask : array or str, optional.
        If array: A float or binary mask that determines a stopping criterion
        (e.g. FA).
        If tuple: it contains a sequence that is interpreted as:
        (pve_wm, pve_gm, pve_csf), each item of which is either a string
        (full path) or a nibabel img to be used in particle filtering
        tractography.
        A tuple is required if tracker is set to "pft".
        Defaults to no stopping (all ones).
    stop_threshold : float or tuple, optional.
        If float, this a value of the stop_mask below which tracking is
        terminated (and stop_mask has to be an array).
        If str, "CMC" for Continuous Map Criterion [Girard2014]_.
                "ACT" for Anatomically-constrained tractography [Smith2012]_.
        A string is required if the tracker is set to "pft".
        Defaults to 0 (this means that if no stop_mask is passed,
        we will stop only at the edge of the image).
    step_size : float, optional.
        The size (in mm) of a step of tractography. Default: 1.0
    min_length: int, optional
        The miminal length (mm) in a streamline. Default: 10
    max_length: int, optional
        The miminal length (mm) in a streamline. Default: 1000
    odf_model : str, optional
        One of {"DTI", "CSD", "DKI", "MSMT"}. Defaults to use "DTI"
    tracker : str, optional
        Which strategy to use in tracking. This can be the standard local
        tracking ("local") or Particle Filtering Tracking ([Girard2014]_).
        One of {"local", "pft"}. Default: "local"

    Returns
    -------
    list of streamlines ()

    References
    ----------
    .. [Girard2014] Girard, G., Whittingstall, K., Deriche, R., &
        Descoteaux, M. Towards quantitative connectivity analysis: reducing
        tractography biases. NeuroImage, 98, 266-278, 2014.
    """
    logger = logging.getLogger('AFQ.tractography')

    logger.info("Loading Image...")
    if isinstance(params_file, str):
        params_img = nib.load(params_file)
    else:
        params_img = params_file

    model_params = params_img.get_fdata()
    affine = params_img.affine
    odf_model = odf_model.upper()
    directions = directions.lower()

    logger.info("Generating Seeds...")
    if isinstance(n_seeds, int):
        if seed_mask is None:
            seed_mask = np.ones(params_img.shape[:3])
        elif seed_mask.dtype != 'bool':
            seed_mask = seed_mask > seed_threshold
        if random_seeds:
            seeds = dtu.random_seeds_from_mask(seed_mask,
                                               seeds_count=n_seeds,
                                               seed_count_per_voxel=False,
                                               affine=affine,
                                               random_seed=rng_seed)
        else:
            seeds = dtu.seeds_from_mask(seed_mask,
                                        density=n_seeds,
                                        affine=affine)
    else:
        # If user provided an array, we'll use n_seeds as the seeds:
        seeds = n_seeds
    if sphere is None:
        sphere = dpd.default_sphere

    logger.info("Getting Directions...")
    if directions == "det":
        dg = DeterministicMaximumDirectionGetter
    elif directions == "prob":
        dg = ProbabilisticDirectionGetter

    if odf_model == "DTI" or odf_model == "DKI" or odf_model == "FWDTI":
        evals = model_params[..., :3]
        evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
        odf = tensor_odf(evals, evecs, sphere)
        dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
    elif odf_model == "CSD" or "MSMT":
        dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

    if tracker == "local":
        if stop_mask is None:
            stop_mask = np.ones(params_img.shape[:3])

        if stop_mask.dtype == 'bool':
            stopping_criterion = ThresholdStoppingCriterion(stop_mask, 0.5)
        else:
            stopping_criterion = ThresholdStoppingCriterion(
                stop_mask, stop_threshold)

        my_tracker = VerboseLocalTracking

    elif tracker == "pft":
        if not isinstance(stop_threshold, str):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a string ",
                "'stop_threshold' input. ",
                "Possible inputs are: 'CMC' or 'ACT'")
        if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a length "
                "3 iterable for `stop_mask`. "
                "Expected a (pve_wm, pve_gm, pve_csf) tuple.")
        pves = []
        pve_imgs = []
        vox_sizes = []
        for ii, pve in enumerate(stop_mask):
            if isinstance(pve, str):
                img = nib.load(pve)
            else:
                img = pve
            pve_imgs.append(img)
            pves.append(pve_imgs[-1].get_fdata())
        average_voxel_size = np.mean(vox_sizes)
        pve_wm_img, pve_gm_img, pve_csf_img = pve_imgs
        pve_wm_data, pve_gm_data, pve_csf_data = pves
        pve_wm_data = resample(pve_wm_data, model_params[...,
                                                         0], pve_wm_img.affine,
                               params_img.affine).get_fdata()
        pve_gm_data = resample(pve_gm_data, model_params[...,
                                                         0], pve_gm_img.affine,
                               params_img.affine).get_fdata()
        pve_csf_data = resample(pve_csf_data, model_params[..., 0],
                                pve_csf_img.affine,
                                params_img.affine).get_fdata()

        vox_sizes.append(np.mean(params_img.header.get_zooms()[:3]))

        my_tracker = VerboseParticleFilteringTracking
        if stop_threshold == "CMC":
            stopping_criterion = CmcStoppingCriterion.from_pve(
                pve_wm_data,
                pve_gm_data,
                pve_csf_data,
                step_size=step_size,
                average_voxel_size=average_voxel_size)
        elif stop_threshold == "ACT":
            stopping_criterion = ActStoppingCriterion.from_pve(
                pve_wm_data, pve_gm_data, pve_csf_data)

    logger.info("Tracking...")

    return _tracking(my_tracker,
                     seeds,
                     dg,
                     stopping_criterion,
                     params_img,
                     step_size=step_size,
                     min_length=min_length,
                     max_length=max_length,
                     random_seed=rng_seed)
コード例 #8
0
ファイル: track.py プロジェクト: dPys/PyNets
def prep_tissues(t1_mask,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 wm_in_dwi,
                 tiss_class,
                 B0_mask,
                 cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    t1_mask : Nifti1Image
        T1w mask img.
    gm_in_dwi : Nifti1Image
        Grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : Nifti1Image
        Ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : Nifti1Image
        White-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.

    References
    ----------
    .. [1] Zhang, Y., Brady, M. and Smith, S. Segmentation of Brain MR Images
      Through a Hidden Markov Random Field Model and the
      Expectation-Maximization Algorithm IEEE Transactions on Medical Imaging,
      20(1): 45-56, 2001
    .. [2] Avants, B. B., Tustison, N. J., Wu, J., Cook, P. A. and Gee, J. C.
      An open source multivariate framework for n-tissue segmentation with
      evaluation on public data. Neuroinformatics, 9(4): 381-400, 2011.
    """
    import gc
    from dipy.tracking.stopping_criterion import (
        ActStoppingCriterion,
        CmcStoppingCriterion,
        BinaryStoppingCriterion,
    )
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img

    # Load B0 mask
    B0_mask_img = math_img("img > 0.01", img=B0_mask)

    # Load t1 mask
    mask_img = math_img("img > 0.01", img=t1_mask)

    # Load tissue maps and prepare tissue classifier
    wm_mask_img = math_img("img > 0.01", img=wm_in_dwi)
    gm_mask_img = math_img("img > 0.01", img=gm_in_dwi)
    vent_csf_in_dwi_img = math_img("img > 0.01", img=vent_csf_in_dwi)
    gm_data = np.asarray(gm_mask_img.dataobj, dtype=np.float32)
    wm_data = np.asarray(wm_mask_img.dataobj, dtype=np.float32)
    vent_csf_in_dwi_data = np.asarray(vent_csf_in_dwi_img.dataobj,
                                      dtype=np.float32)
    if tiss_class == "act":
        background = np.ones(mask_img.shape)
        background[(gm_data + wm_data + vent_csf_in_dwi_data) > 0] = 0
        gm_data[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(gm_data, vent_csf_in_dwi_data)
        del background
    elif tiss_class == "wm":
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(
                intersect_masks(
                    [
                        mask_img, wm_mask_img, B0_mask_img,
                        nib.Nifti1Image(np.invert(
                            vent_csf_in_dwi_data.astype('bool')).astype('int'),
                                        affine=mask_img.affine)
                    ],
                    threshold=1,
                    connected=False,
                ).dataobj))
    elif tiss_class == "cmc":
        tiss_classifier = CmcStoppingCriterion.from_pve(
            wm_data,
            gm_data,
            vent_csf_in_dwi_data,
            step_size=cmc_step_size,
            average_voxel_size=np.average(mask_img.header["pixdim"][1:4]),
        )
    elif tiss_class == "wb":
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(
                intersect_masks(
                    [
                        mask_img,
                        B0_mask_img,
                        nib.Nifti1Image(np.invert(
                            vent_csf_in_dwi_data.astype('bool')).astype('int'),
                                        affine=mask_img.affine),
                    ],
                    threshold=1,
                    connected=False,
                ).dataobj))
    else:
        raise ValueError("Tissue classifier cannot be none.")

    B0_mask_img.uncache()
    mask_img.uncache()
    wm_mask_img.uncache()
    gm_mask_img.uncache()
    del gm_data, wm_data, vent_csf_in_dwi_data
    gc.collect()

    return tiss_classifier
コード例 #9
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file,
                                 args.map_include_file,
                                 args.map_exclude_file])
    assert_outputs_exist(parser, args, args.output_file)

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'
                     .format(args.min_length))
    if args.max_length < args.min_length:
        parser.error('maxL must be > than minL, (minL={}mm, maxL={}mm).'
                     .format(args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'
                .format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0], atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(
        fodf_sh_img.get_fdata(dtype=np.double),
        max_angle=theta,
        sphere=tracking_sphere,
        basis_type=sh_basis,
        pmf_threshold=args.sf_threshold,
        relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.map_include_file)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.get_header()['pixdim'][1:4])

    if not args.act:
        tissue_classifier = CmcStoppingCriterion(map_include_img.get_fdata(),
                                                 map_exclude_img.get_fdata(),
                                                 step_size=args.step_size,
                                                 average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActStoppingCriterion(map_include_img.get_fdata(),
                                                 map_exclude_img.get_fdata())

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in pft_streamlines
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in pft_streamlines
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (
            compress_streamlines(s, args.compress)
            for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
コード例 #10
0
ファイル: track.py プロジェクト: devhliu/PyNets
def prep_tissues(B0_mask,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 wm_in_dwi,
                 tiss_class,
                 cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    B0_mask : str
        File path to B0 brain mask.
    gm_in_dwi : str
        File path to grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : str
        File path to ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : str
        File path to white-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.
    """
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.tracking.stopping_criterion import ActStoppingCriterion, CmcStoppingCriterion, BinaryStoppingCriterion
    # Loads mask and ensures it's a true binary mask
    mask_img = nib.load(B0_mask)
    # Load tissue maps and prepare tissue classifier
    gm_mask_data = nib.load(gm_in_dwi).get_fdata()
    wm_mask_data = nib.load(wm_in_dwi).get_fdata()
    vent_csf_in_dwi_data = nib.load(vent_csf_in_dwi).get_fdata()
    if tiss_class == 'act':
        background = np.ones(mask_img.shape)
        background[(gm_mask_data + wm_mask_data +
                    vent_csf_in_dwi_data) > 0] = 0
        include_map = gm_mask_data
        include_map[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(include_map,
                                               vent_csf_in_dwi_data)
        del background
        del include_map
    elif tiss_class == 'bin':
        tiss_classifier = BinaryStoppingCriterion(wm_mask_data.astype('bool'))
    elif tiss_class == 'cmc':
        voxel_size = np.average(mask_img.header['pixdim'][1:4])
        tiss_classifier = CmcStoppingCriterion.from_pve(
            wm_mask_data,
            gm_mask_data,
            vent_csf_in_dwi_data,
            step_size=cmc_step_size,
            average_voxel_size=voxel_size)
    elif tiss_class == 'wb':
        tiss_classifier = BinaryStoppingCriterion(
            mask_img.get_fdata().astype('bool'))
    else:
        raise ValueError('Tissue Classifier cannot be none.')

    del gm_mask_data, wm_mask_data, vent_csf_in_dwi_data
    mask_img.uncache()

    return tiss_classifier
コード例 #11
0
anatomical images to determine when the tractography stops.
Both stopping criterion use a trilinear interpolation
at the tracking position. CMC stopping criterion uses a probability derived
from the PVE maps to determine if the streamline reaches a 'valid' or 'invalid'
region. ACT uses a fixed threshold on the PVE maps. Both stopping criterion can
be used in conjunction with PFT. In this example, we used CMC.
"""

from dipy.tracking.stopping_criterion import CmcStoppingCriterion

voxel_size = np.average(img_pve_wm.header['pixdim'][1:4])
step_size = 0.2

cmc_criterion = CmcStoppingCriterion.from_pve(img_pve_wm.get_data(),
                                              img_pve_gm.get_data(),
                                              img_pve_csf.get_data(),
                                              step_size=step_size,
                                              average_voxel_size=voxel_size)

# Particle Filtering Tractography
pft_streamline_generator = ParticleFilteringTracking(dg,
                                                     cmc_criterion,
                                                     seeds,
                                                     affine,
                                                     max_cross=1,
                                                     step_size=step_size,
                                                     maxlen=1000,
                                                     pft_back_tracking_dist=2,
                                                     pft_front_tracking_dist=1,
                                                     particle_count=15,
                                                     return_all=False)
コード例 #12
0
def prep_tissues(t1_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    t1_mask : str
        File path to a T1w mask.
    gm_in_dwi : str
        File path to grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : str
        File path to ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : str
        File path to white-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.

    References
    ----------
    .. [1] Zhang, Y., Brady, M. and Smith, S. Segmentation of Brain MR Images
      Through a Hidden Markov Random Field Model and the Expectation-Maximization
      Algorithm IEEE Transactions on Medical Imaging, 20(1): 45-56, 2001
    .. [2] Avants, B. B., Tustison, N. J., Wu, J., Cook, P. A. and Gee, J. C.
      An open source multivariate framework for n-tissue segmentation with
      evaluation on public data. Neuroinformatics, 9(4): 381-400, 2011.

    """
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.tracking.stopping_criterion import ActStoppingCriterion, CmcStoppingCriterion, BinaryStoppingCriterion
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img

    # Loads mask
    mask_img = nib.load(t1_mask)
    # Load tissue maps and prepare tissue classifier
    wm_img = nib.load(wm_in_dwi)
    gm_img = nib.load(gm_in_dwi)
    gm_mask_data = np.asarray(gm_img.dataobj)
    wm_mask_data = np.asarray(wm_img.dataobj)
    vent_csf_in_dwi_data = np.asarray(nib.load(vent_csf_in_dwi).dataobj)
    if tiss_class == 'act':
        background = np.ones(mask_img.shape)
        background[(gm_mask_data + wm_mask_data + vent_csf_in_dwi_data) > 0] = 0
        gm_mask_data[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(gm_mask_data, vent_csf_in_dwi_data)
        del background
    elif tiss_class == 'bin':
        tiss_classifier = BinaryStoppingCriterion(np.asarray(intersect_masks([math_img('img > 0.0', img=mask_img),
                                                                              math_img('img > 0.0', img=wm_img)],
                                                                             threshold=1, connected=False).dataobj))
    elif tiss_class == 'cmc':
        voxel_size = np.average(mask_img.header['pixdim'][1:4])
        tiss_classifier = CmcStoppingCriterion.from_pve(wm_mask_data, gm_mask_data, vent_csf_in_dwi_data,
                                                        step_size=cmc_step_size, average_voxel_size=voxel_size)
    elif tiss_class == 'wb':
        tiss_classifier = BinaryStoppingCriterion(np.asarray(mask_img.dataobj).astype('bool'))
    else:
        raise ValueError('Tissue classifier cannot be none.')

    del gm_mask_data, wm_mask_data, vent_csf_in_dwi_data
    mask_img.uncache()
    gm_img.uncache()
    wm_img.uncache()

    return tiss_classifier
コード例 #13
0
def prep_tissues(t1_mask,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 wm_in_dwi,
                 tiss_class,
                 cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    t1_mask : str
        File path to a T1w mask.
    gm_in_dwi : str
        File path to grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : str
        File path to ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : str
        File path to white-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.
    """
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.tracking.stopping_criterion import ActStoppingCriterion, CmcStoppingCriterion, BinaryStoppingCriterion
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img

    # Loads mask
    mask_img = nib.load(t1_mask)
    # Load tissue maps and prepare tissue classifier
    wm_img = nib.load(wm_in_dwi)
    gm_img = nib.load(gm_in_dwi)
    gm_mask_data = np.asarray(gm_img.dataobj)
    wm_mask_data = np.asarray(wm_img.dataobj)
    vent_csf_in_dwi_data = np.asarray(nib.load(vent_csf_in_dwi).dataobj)
    if tiss_class == 'act':
        background = np.ones(mask_img.shape)
        background[(gm_mask_data + wm_mask_data +
                    vent_csf_in_dwi_data) > 0] = 0
        gm_mask_data[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(gm_mask_data,
                                               vent_csf_in_dwi_data)
        del background
    elif tiss_class == 'bin':
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(
                intersect_masks([
                    math_img('img > 0.0', img=mask_img),
                    math_img('img > 0.0', img=wm_img)
                ],
                                threshold=1,
                                connected=False).dataobj))
    elif tiss_class == 'cmc':
        voxel_size = np.average(mask_img.header['pixdim'][1:4])
        tiss_classifier = CmcStoppingCriterion.from_pve(
            wm_mask_data,
            gm_mask_data,
            vent_csf_in_dwi_data,
            step_size=cmc_step_size,
            average_voxel_size=voxel_size)
    elif tiss_class == 'wb':
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(mask_img.dataobj).astype('bool'))
    else:
        raise ValueError('Tissue Classifier cannot be none.')

    del gm_mask_data, wm_mask_data, vent_csf_in_dwi_data
    mask_img.uncache()
    gm_img.uncache()
    wm_img.uncache()

    return tiss_classifier