Esempio n. 1
0
def test_particle_filtering_tractography():
    """This tests that the ParticleFilteringTracking produces
    more streamlines connecting the gray matter than LocalTracking.
    """
    sphere = get_sphere('repulsion100')
    step_size = 0.2

    # Simple tissue masks
    simple_wm = np.array([[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0],
                          [0, 1, 1, 1, 0, 0], [0, 1, 1, 1, 0, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_wm = np.dstack([
        np.zeros(simple_wm.shape), simple_wm, simple_wm, simple_wm,
        np.zeros(simple_wm.shape)
    ])
    simple_gm = np.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
                          [0, 1, 0, 0, 1, 0], [0, 0, 0, 0, 1, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_gm = np.dstack([
        np.zeros(simple_gm.shape), simple_gm, simple_gm, simple_gm,
        np.zeros(simple_gm.shape)
    ])
    simple_csf = np.ones(simple_wm.shape) - simple_wm - simple_gm

    sc = ActStoppingCriterion.from_pve(simple_wm, simple_gm, simple_csf)
    seeds = seeds_from_mask(simple_wm, np.eye(4), density=2)

    # Random pmf in every voxel
    shape_img = list(simple_wm.shape)
    shape_img.extend([sphere.vertices.shape[0]])
    np.random.seed(0)  # Random number generator initialization
    pmf = np.random.random(shape_img)

    # Test that PFT recover equal or more streamlines than localTracking
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)
    local_streamlines_generator = LocalTracking(dg,
                                                sc,
                                                seeds,
                                                np.eye(4),
                                                step_size,
                                                max_cross=1,
                                                return_all=False)
    local_streamlines = Streamlines(local_streamlines_generator)

    pft_streamlines_generator = ParticleFilteringTracking(
        dg,
        sc,
        seeds,
        np.eye(4),
        step_size,
        max_cross=1,
        return_all=False,
        pft_back_tracking_dist=1,
        pft_front_tracking_dist=0.5)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_(np.array([len(pft_streamlines) > 0]))
    npt.assert_(np.array([len(pft_streamlines) >= len(local_streamlines)]))

    # Test that all points are equally spaced
    for l in [1, 2, 5, 10, 100]:
        pft_streamlines = ParticleFilteringTracking(dg,
                                                    sc,
                                                    seeds,
                                                    np.eye(4),
                                                    step_size,
                                                    max_cross=1,
                                                    return_all=True,
                                                    maxlen=l)
        for s in pft_streamlines:
            for i in range(len(s) - 1):
                npt.assert_almost_equal(np.linalg.norm(s[i] - s[i + 1]),
                                        step_size)
    # Test that all points are within the image volume
    seeds = seeds_from_mask(np.ones(simple_wm.shape), np.eye(4), density=1)
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          sc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    for s in pft_streamlines:
        npt.assert_(np.all((s + 0.5).astype(int) >= 0))
        npt.assert_(np.all((s + 0.5).astype(int) < simple_wm.shape))

    # Test that the number of streamline return with return_all=True equal the
    # number of seeds places
    npt.assert_(np.array([len(pft_streamlines) == len(seeds)]))

    # Test non WM seed position
    seeds = [[0, 5, 4], [0, 0, 1], [50, 50, 50]]
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          sc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_equal(len(pft_streamlines[0]), 3)  # INVALIDPOINT
    npt.assert_equal(len(pft_streamlines[1]), 3)  # ENDPOINT
    npt.assert_equal(len(pft_streamlines[2]), 1)  # OUTSIDEIMAGE

    # Test with wrong StoppingCriterion type
    sc_bin = BinaryStoppingCriterion(simple_wm)
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(dg, sc_bin, seeds,
                                                      np.eye(4), step_size))
    # Test with invalid back/front tracking distances
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          sc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, pft_back_tracking_dist=-1))
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          sc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=-2))

    # Test with invalid affine shape
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg, sc, seeds, np.eye(3), step_size))

    # Test with invalid maxlen
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, maxlen=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, maxlen=-1))

    # Test with invalid particle count
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, particle_count=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, particle_count=-1))

    # Test reproducibility
    tracking1 = Streamlines(
        ParticleFilteringTracking(dg,
                                  sc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0))._data
    tracking2 = Streamlines(
        ParticleFilteringTracking(dg,
                                  sc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0))._data
    npt.assert_equal(tracking1, tracking2)
Esempio n. 2
0
def track(params_file,
          directions="det",
          max_angle=30.,
          sphere=None,
          seed_mask=None,
          seed_threshold=0,
          n_seeds=1,
          random_seeds=False,
          rng_seed=None,
          stop_mask=None,
          stop_threshold=0,
          step_size=0.5,
          min_length=10,
          max_length=1000,
          odf_model="DTI",
          tracker="local"):
    """
    Tractography

    Parameters
    ----------
    params_file : str, nibabel img.
        Full path to a nifti file containing CSD spherical harmonic
        coefficients, or nibabel img with model params.
    directions : str
        How tracking directions are determined.
        One of: {"det" | "prob"}
    max_angle : float, optional.
        The maximum turning angle in each step. Default: 30
    sphere : Sphere object, optional.
        The discretization of direction getting. default:
        dipy.data.default_sphere.
    seed_mask : array, optional.
        Float or binary mask describing the ROI within which we seed for
        tracking.
        Default to the entire volume (all ones).
    seed_threshold : float, optional.
        A value of the seed_mask below which tracking is terminated.
        Default to 0.
    n_seeds : int or 2D array, optional.
        The seeding density: if this is an int, it is is how many seeds in each
        voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
        array, these are the coordinates of the seeds. Unless random_seeds is
        set to True, in which case this is the total number of random seeds
        to generate within the mask.
    random_seeds : bool
        Whether to generate a total of n_seeds random seeds in the mask.
        Default: False.
    rng_seed : int
        random seed used to generate random seeds if random_seeds is
        set to True. Default: None
    stop_mask : array or str, optional.
        If array: A float or binary mask that determines a stopping criterion
        (e.g. FA).
        If tuple: it contains a sequence that is interpreted as:
        (pve_wm, pve_gm, pve_csf), each item of which is either a string
        (full path) or a nibabel img to be used in particle filtering
        tractography.
        A tuple is required if tracker is set to "pft".
        Defaults to no stopping (all ones).
    stop_threshold : float or tuple, optional.
        If float, this a value of the stop_mask below which tracking is
        terminated (and stop_mask has to be an array).
        If str, "CMC" for Continuous Map Criterion [Girard2014]_.
                "ACT" for Anatomically-constrained tractography [Smith2012]_.
        A string is required if the tracker is set to "pft".
        Defaults to 0 (this means that if no stop_mask is passed,
        we will stop only at the edge of the image).
    step_size : float, optional.
        The size (in mm) of a step of tractography. Default: 1.0
    min_length: int, optional
        The miminal length (mm) in a streamline. Default: 10
    max_length: int, optional
        The miminal length (mm) in a streamline. Default: 1000
    odf_model : str, optional
        One of {"DTI", "CSD", "DKI", "MSMT"}. Defaults to use "DTI"
    tracker : str, optional
        Which strategy to use in tracking. This can be the standard local
        tracking ("local") or Particle Filtering Tracking ([Girard2014]_).
        One of {"local", "pft"}. Default: "local"

    Returns
    -------
    list of streamlines ()

    References
    ----------
    .. [Girard2014] Girard, G., Whittingstall, K., Deriche, R., &
        Descoteaux, M. Towards quantitative connectivity analysis: reducing
        tractography biases. NeuroImage, 98, 266-278, 2014.
    """
    logger = logging.getLogger('AFQ.tractography')

    logger.info("Loading Image...")
    if isinstance(params_file, str):
        params_img = nib.load(params_file)
    else:
        params_img = params_file

    model_params = params_img.get_fdata()
    affine = params_img.affine
    odf_model = odf_model.upper()
    directions = directions.lower()

    logger.info("Generating Seeds...")
    if isinstance(n_seeds, int):
        if seed_mask is None:
            seed_mask = np.ones(params_img.shape[:3])
        elif seed_mask.dtype != 'bool':
            seed_mask = seed_mask > seed_threshold
        if random_seeds:
            seeds = dtu.random_seeds_from_mask(seed_mask,
                                               seeds_count=n_seeds,
                                               seed_count_per_voxel=False,
                                               affine=affine,
                                               random_seed=rng_seed)
        else:
            seeds = dtu.seeds_from_mask(seed_mask,
                                        density=n_seeds,
                                        affine=affine)
    else:
        # If user provided an array, we'll use n_seeds as the seeds:
        seeds = n_seeds
    if sphere is None:
        sphere = dpd.default_sphere

    logger.info("Getting Directions...")
    if directions == "det":
        dg = DeterministicMaximumDirectionGetter
    elif directions == "prob":
        dg = ProbabilisticDirectionGetter

    if odf_model == "DTI" or odf_model == "DKI" or odf_model == "FWDTI":
        evals = model_params[..., :3]
        evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
        odf = tensor_odf(evals, evecs, sphere)
        dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
    elif odf_model == "CSD" or "MSMT":
        dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

    if tracker == "local":
        if stop_mask is None:
            stop_mask = np.ones(params_img.shape[:3])

        if stop_mask.dtype == 'bool':
            stopping_criterion = ThresholdStoppingCriterion(stop_mask, 0.5)
        else:
            stopping_criterion = ThresholdStoppingCriterion(
                stop_mask, stop_threshold)

        my_tracker = VerboseLocalTracking

    elif tracker == "pft":
        if not isinstance(stop_threshold, str):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a string ",
                "'stop_threshold' input. ",
                "Possible inputs are: 'CMC' or 'ACT'")
        if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a length "
                "3 iterable for `stop_mask`. "
                "Expected a (pve_wm, pve_gm, pve_csf) tuple.")
        pves = []
        pve_imgs = []
        vox_sizes = []
        for ii, pve in enumerate(stop_mask):
            if isinstance(pve, str):
                img = nib.load(pve)
            else:
                img = pve
            pve_imgs.append(img)
            pves.append(pve_imgs[-1].get_fdata())
        average_voxel_size = np.mean(vox_sizes)
        pve_wm_img, pve_gm_img, pve_csf_img = pve_imgs
        pve_wm_data, pve_gm_data, pve_csf_data = pves
        pve_wm_data = resample(pve_wm_data, model_params[...,
                                                         0], pve_wm_img.affine,
                               params_img.affine).get_fdata()
        pve_gm_data = resample(pve_gm_data, model_params[...,
                                                         0], pve_gm_img.affine,
                               params_img.affine).get_fdata()
        pve_csf_data = resample(pve_csf_data, model_params[..., 0],
                                pve_csf_img.affine,
                                params_img.affine).get_fdata()

        vox_sizes.append(np.mean(params_img.header.get_zooms()[:3]))

        my_tracker = VerboseParticleFilteringTracking
        if stop_threshold == "CMC":
            stopping_criterion = CmcStoppingCriterion.from_pve(
                pve_wm_data,
                pve_gm_data,
                pve_csf_data,
                step_size=step_size,
                average_voxel_size=average_voxel_size)
        elif stop_threshold == "ACT":
            stopping_criterion = ActStoppingCriterion.from_pve(
                pve_wm_data, pve_gm_data, pve_csf_data)

    logger.info("Tracking...")

    return _tracking(my_tracker,
                     seeds,
                     dg,
                     stopping_criterion,
                     params_img,
                     step_size=step_size,
                     min_length=min_length,
                     max_length=max_length,
                     random_seed=rng_seed)