예제 #1
0
def test_act_stopping_criterion():
    """This tests that the act stopping criterion returns expected
    streamline statuses.
    """

    gm = np.random.random((4, 4, 4))
    wm = np.random.random((4, 4, 4))
    csf = np.random.random((4, 4, 4))
    tissue_sum = gm + wm + csf
    gm /= tissue_sum
    wm /= tissue_sum
    csf /= tissue_sum

    act_tc = ActStoppingCriterion(include_map=gm, exclude_map=csf)

    # Test voxel center
    for ind in ndindex(wm.shape):
        pts = np.array(ind, dtype='float64')
        state = act_tc.check_point(pts)
        if csf[ind] > 0.5:
            npt.assert_equal(state, int(StreamlineStatus.INVALIDPOINT))
        elif gm[ind] > 0.5:
            npt.assert_equal(state, int(StreamlineStatus.ENDPOINT))
        else:
            npt.assert_equal(state, int(StreamlineStatus.TRACKPOINT))

    # Test random points in voxel
    inds = [[0, 1.4, 2.2], [0, 2.3, 2.3], [0, 2.2, 1.3], [0, 0.9, 2.2],
            [0, 2.8, 1.1], [0, 1.1, 3.3], [0, 2.1, 1.9], [0, 3.1, 3.1],
            [0, 0.1, 0.1], [0, 0.9, 0.5], [0, 0.9, 0.5], [0, 2.9, 0.1]]
    for pts in inds:
        pts = np.array(pts, dtype='float64')
        state = act_tc.check_point(pts)
        gm_res = scipy.ndimage.map_coordinates(gm,
                                               np.reshape(pts, (3, 1)),
                                               order=1,
                                               mode='nearest')
        csf_res = scipy.ndimage.map_coordinates(csf,
                                                np.reshape(pts, (3, 1)),
                                                order=1,
                                                mode='nearest')
        if csf_res > 0.5:
            npt.assert_equal(state, int(StreamlineStatus.INVALIDPOINT))
        elif gm_res > 0.5:
            npt.assert_equal(state, int(StreamlineStatus.ENDPOINT))
        else:
            npt.assert_equal(state, int(StreamlineStatus.TRACKPOINT))

    # Test outside points
    outside_pts = [[100, 100, 100], [0, -1, 1], [0, 10, 2], [0, 0.5, -0.51],
                   [0, -0.51, 0.1]]
    for pts in outside_pts:
        pts = np.array(pts, dtype='float64')
        state = act_tc.check_point(pts)
        npt.assert_equal(state, int(StreamlineStatus.OUTSIDEIMAGE))
예제 #2
0
파일: track.py 프로젝트: CaseyWeiner/m2g
    def prep_tracking(self):
        """Uses nibabel and dipy functions in order to load the grey matter, white matter, and csf masks
        and use a tissue classifier (act, cmc, or binary) on the include/exclude maps to make a tissueclassifier object

        Returns
        -------
        ActStoppingCriterion, CmcStoppingCriterion, or BinaryStoppingCriterion
            The resulting tissue classifier object, depending on which method you use (currently only does act)
        """

        if self.track_type == "local":
            tiss_class = "bin"
        elif self.track_type == "particle":
            tiss_class = "cmc"

        self.dwi_img = nib.load(self.dwi)
        self.data = self.dwi_img.get_data()
        # Loads mask and ensures it's a true binary mask
        self.mask_img = nib.load(self.nodif_B0_mask)
        self.mask = self.mask_img.get_data() > 0
        # Load tissue maps and prepare tissue classifier
        self.gm_mask = nib.load(self.gm_in_dwi)
        self.gm_mask_data = self.gm_mask.get_data()
        self.wm_mask = nib.load(self.wm_in_dwi)
        self.wm_mask_data = self.wm_mask.get_data()
        self.wm_in_dwi_data = nib.load(
            self.wm_in_dwi).get_data().astype("bool")
        if tiss_class == "act":
            self.vent_csf_in_dwi = nib.load(self.vent_csf_in_dwi)
            self.vent_csf_in_dwi_data = self.vent_csf_in_dwi.get_data()
            self.background = np.ones(self.gm_mask.shape)
            self.background[(self.gm_mask_data + self.wm_mask_data +
                             self.vent_csf_in_dwi_data) > 0] = 0
            self.include_map = self.wm_mask_data
            self.include_map[self.background > 0] = 0
            self.exclude_map = self.vent_csf_in_dwi_data
            self.tiss_classifier = ActStoppingCriterion(
                self.include_map, self.exclude_map)
        elif tiss_class == "bin":
            self.tiss_classifier = BinaryStoppingCriterion(self.wm_in_dwi_data)
            # self.tiss_classifier = BinaryStoppingCriterion(self.mask)
        elif tiss_class == "cmc":
            self.vent_csf_in_dwi = nib.load(self.vent_csf_in_dwi)
            self.vent_csf_in_dwi_data = self.vent_csf_in_dwi.get_data()
            voxel_size = np.average(self.wm_mask.get_header()["pixdim"][1:4])
            step_size = 0.2
            self.tiss_classifier = CmcStoppingCriterion.from_pve(
                self.wm_mask_data,
                self.gm_mask_data,
                self.vent_csf_in_dwi_data,
                step_size=step_size,
                average_voxel_size=voxel_size,
            )
        else:
            pass
        return self.tiss_classifier
예제 #3
0
def test_save_seeds():
    tissue = np.array([[2, 1, 1, 2, 1],
                       [2, 2, 1, 1, 2],
                       [1, 1, 1, 1, 1],
                       [1, 1, 1, 2, 2],
                       [0, 1, 1, 1, 2],
                       [0, 1, 1, 0, 2],
                       [1, 0, 1, 1, 1]])
    tissue = tissue[None]

    sphere = HemiSphere.from_sphere(unit_octahedron)
    pmf_lookup = np.array([[0., 0., 0., ],
                           [0., 0., 1.]])
    pmf = pmf_lookup[(tissue > 0).astype("int")]

    # Create a seeds along
    x = np.array([0., 0, 0, 0, 0, 0, 0])
    y = np.array([0., 1, 2, 3, 4, 5, 6])
    z = np.array([1., 1, 1, 0, 1, 1, 1])
    seeds = np.column_stack([x, y, z])

    # Set up tracking
    endpoint_mask = tissue == StreamlineStatus.ENDPOINT
    invalidpoint_mask = tissue == StreamlineStatus.INVALIDPOINT
    sc = ActStoppingCriterion(endpoint_mask, invalidpoint_mask)
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)

    # valid streamlines only
    streamlines_generator = LocalTracking(direction_getter=dg,
                                          stopping_criterion=sc,
                                          seeds=seeds,
                                          affine=np.eye(4),
                                          step_size=1.,
                                          return_all=False,
                                          save_seeds=True)

    streamlines_not_all = iter(streamlines_generator)
    # Verifiy that seeds are returned by the LocalTracker
    _, seed = next(streamlines_not_all)
    npt.assert_equal(seed, seeds[0])
    _, seed = next(streamlines_not_all)
    npt.assert_equal(seed, seeds[1])
    # Verifiy that seeds are returned by the PFTTracker also
    pft_streamlines = ParticleFilteringTracking(direction_getter=dg,
                                                stopping_criterion=sc,
                                                seeds=seeds,
                                                affine=np.eye(4),
                                                step_size=1.,
                                                max_cross=1,
                                                return_all=False,
                                                save_seeds=True)
    streamlines = iter(pft_streamlines)
    _, seed = next(streamlines)
    npt.assert_equal(seed, seeds[0])
    _, seed = next(streamlines)
    npt.assert_equal(seed, seeds[1])
예제 #4
0
def create_act_classifier(fa, folder_name, labels):  # Does not working
    from dipy.tracking.stopping_criterion import ActStoppingCriterion
    background = np.ones(labels.shape)
    background[(np.asarray(labels) > 0) > 0] = 0
    include_map = np.zeros(fa.shape)
    lab = f'{folder_name}{os.sep}rMegaAtlas_cortex_Labels.nii'
    lab_file = nib.load(lab)
    lab_labels = lab_file.get_data()
    include_map[background > 0] = 1
    include_map[lab_labels > 0] = 1
    include_map[fa > 0.18] = 1
    include_map = include_map == 1
    exclude_map = labels == 1

    act_classifier = ActStoppingCriterion(include_map, exclude_map)

    return act_classifier
예제 #5
0
but there is no valid direction to follow.
- 'INVALIDPOINT': ``exclude_map`` > 0.5; the streamline reach a position which
is anatomically not plausible.
"""

img_pve_csf, img_pve_gm, img_pve_wm = read_stanford_pve_maps()

background = np.ones(img_pve_gm.shape)
background[(img_pve_gm.get_data() + img_pve_wm.get_data() +
            img_pve_csf.get_data()) > 0] = 0

include_map = img_pve_gm.get_data()
include_map[background > 0] = 1
exclude_map = img_pve_csf.get_data()

act_criterion = ActStoppingCriterion(include_map, exclude_map)

fig = plt.figure()
plt.subplot(121)
plt.xticks([])
plt.yticks([])
plt.imshow(include_map[:, :, data.shape[2] // 2].T,
           cmap='gray',
           origin='lower',
           interpolation='nearest')

plt.subplot(122)
plt.xticks([])
plt.yticks([])
plt.imshow(exclude_map[:, :, data.shape[2] // 2].T,
           cmap='gray',
예제 #6
0
def test_particle_filtering_tractography():
    """This tests that the ParticleFilteringTracking produces
    more streamlines connecting the gray matter than LocalTracking.
    """
    sphere = get_sphere('repulsion100')
    step_size = 0.2

    # Simple tissue masks
    simple_wm = np.array([[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0],
                          [0, 1, 1, 1, 0, 0], [0, 1, 1, 1, 0, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_wm = np.dstack([
        np.zeros(simple_wm.shape), simple_wm, simple_wm, simple_wm,
        np.zeros(simple_wm.shape)
    ])
    simple_gm = np.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
                          [0, 1, 0, 0, 1, 0], [0, 0, 0, 0, 1, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_gm = np.dstack([
        np.zeros(simple_gm.shape), simple_gm, simple_gm, simple_gm,
        np.zeros(simple_gm.shape)
    ])
    simple_csf = np.ones(simple_wm.shape) - simple_wm - simple_gm

    sc = ActStoppingCriterion.from_pve(simple_wm, simple_gm, simple_csf)
    seeds = seeds_from_mask(simple_wm, np.eye(4), density=2)

    # Random pmf in every voxel
    shape_img = list(simple_wm.shape)
    shape_img.extend([sphere.vertices.shape[0]])
    np.random.seed(0)  # Random number generator initialization
    pmf = np.random.random(shape_img)

    # Test that PFT recover equal or more streamlines than localTracking
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)
    local_streamlines_generator = LocalTracking(dg,
                                                sc,
                                                seeds,
                                                np.eye(4),
                                                step_size,
                                                max_cross=1,
                                                return_all=False)
    local_streamlines = Streamlines(local_streamlines_generator)

    pft_streamlines_generator = ParticleFilteringTracking(
        dg,
        sc,
        seeds,
        np.eye(4),
        step_size,
        max_cross=1,
        return_all=False,
        pft_back_tracking_dist=1,
        pft_front_tracking_dist=0.5)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_(np.array([len(pft_streamlines) > 0]))
    npt.assert_(np.array([len(pft_streamlines) >= len(local_streamlines)]))

    # Test that all points are equally spaced
    for l in [1, 2, 5, 10, 100]:
        pft_streamlines = ParticleFilteringTracking(dg,
                                                    sc,
                                                    seeds,
                                                    np.eye(4),
                                                    step_size,
                                                    max_cross=1,
                                                    return_all=True,
                                                    maxlen=l)
        for s in pft_streamlines:
            for i in range(len(s) - 1):
                npt.assert_almost_equal(np.linalg.norm(s[i] - s[i + 1]),
                                        step_size)
    # Test that all points are within the image volume
    seeds = seeds_from_mask(np.ones(simple_wm.shape), np.eye(4), density=1)
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          sc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    for s in pft_streamlines:
        npt.assert_(np.all((s + 0.5).astype(int) >= 0))
        npt.assert_(np.all((s + 0.5).astype(int) < simple_wm.shape))

    # Test that the number of streamline return with return_all=True equal the
    # number of seeds places
    npt.assert_(np.array([len(pft_streamlines) == len(seeds)]))

    # Test non WM seed position
    seeds = [[0, 5, 4], [0, 0, 1], [50, 50, 50]]
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          sc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_equal(len(pft_streamlines[0]), 3)  # INVALIDPOINT
    npt.assert_equal(len(pft_streamlines[1]), 3)  # ENDPOINT
    npt.assert_equal(len(pft_streamlines[2]), 1)  # OUTSIDEIMAGE

    # Test with wrong StoppingCriterion type
    sc_bin = BinaryStoppingCriterion(simple_wm)
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(dg, sc_bin, seeds,
                                                      np.eye(4), step_size))
    # Test with invalid back/front tracking distances
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          sc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, pft_back_tracking_dist=-1))
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          sc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=-2))

    # Test with invalid affine shape
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg, sc, seeds, np.eye(3), step_size))

    # Test with invalid maxlen
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, maxlen=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, maxlen=-1))

    # Test with invalid particle count
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, particle_count=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, particle_count=-1))

    # Test reproducibility
    tracking1 = Streamlines(
        ParticleFilteringTracking(dg,
                                  sc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0))._data
    tracking2 = Streamlines(
        ParticleFilteringTracking(dg,
                                  sc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0))._data
    npt.assert_equal(tracking1, tracking2)
예제 #7
0
def test_stop_conditions():
    """This tests that the Local Tracker behaves as expected for the
    following tissue types.
    """
    # StreamlineStatus.TRACKPOINT = 1
    # StreamlineStatus.ENDPOINT = 2
    # StreamlineStatus.INVALIDPOINT = 0
    tissue = np.array([[2, 1, 1, 2, 1], [2, 2, 1, 1, 2], [1, 1, 1, 1, 1],
                       [1, 1, 1, 2, 2], [0, 1, 1, 1, 2], [0, 1, 1, 0, 2],
                       [1, 0, 1, 1, 1], [2, 1, 2, 0, 0]])
    tissue = tissue[None]

    sphere = HemiSphere.from_sphere(unit_octahedron)
    pmf_lookup = np.array([[
        0.,
        0.,
        0.,
    ], [0., 0., 1.]])
    pmf = pmf_lookup[(tissue > 0).astype("int")]

    # Create a seeds along
    x = np.array([0., 0, 0, 0, 0, 0, 0, 0])
    y = np.array([0., 1, 2, 3, 4, 5, 6, 7])
    z = np.array([1., 1, 1, 0, 1, 1, 1, 1])
    seeds = np.column_stack([x, y, z])

    # Set up tracking
    endpoint_mask = tissue == StreamlineStatus.ENDPOINT
    invalidpoint_mask = tissue == StreamlineStatus.INVALIDPOINT
    sc = ActStoppingCriterion(endpoint_mask, invalidpoint_mask)
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)

    # valid streamlines only
    streamlines_generator = LocalTracking(direction_getter=dg,
                                          stopping_criterion=sc,
                                          seeds=seeds,
                                          affine=np.eye(4),
                                          step_size=1.,
                                          return_all=False)
    streamlines_not_all = iter(streamlines_generator)

    # all streamlines
    streamlines_all_generator = LocalTracking(direction_getter=dg,
                                              stopping_criterion=sc,
                                              seeds=seeds,
                                              affine=np.eye(4),
                                              step_size=1.,
                                              return_all=True)
    streamlines_all = iter(streamlines_all_generator)

    # Check that the first streamline stops at 1 and 2 (ENDPOINT)
    y = 0
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 2])
    npt.assert_equal(len(sl), 2)

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 2])
    npt.assert_equal(len(sl), 2)

    # Check that the next streamline stops at 1 and 3 (ENDPOINT)
    y = 1
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 3])
    npt.assert_equal(len(sl), 3)

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 3])
    npt.assert_equal(len(sl), 3)

    # This streamline should be the same as above. This row does not have
    # ENDPOINTs, but the streamline should stop at the edge and not include
    # OUTSIDEIMAGE points.
    y = 2
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], [0, y, 0])
    npt.assert_equal(sl[-1], [0, y, 4])
    npt.assert_equal(len(sl), 5)

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 0])
    npt.assert_equal(sl[-1], [0, y, 4])
    npt.assert_equal(len(sl), 5)

    # If we seed on the edge, the first (or last) point in the streamline
    # should be the seed.
    y = 3
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], seeds[y])

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], seeds[y])

    # The last 3 seeds should not produce streamlines,
    # INVALIDPOINT streamlines are rejected (return_all=False).
    npt.assert_equal(len(list(streamlines_not_all)), 0)

    # The last 3 seeds should produce invalid streamlines,
    # INVALIDPOINT streamlines are kept (return_all=True).
    # The streamline stops at 1 (INVALIDPOINT) and 3 (ENDPOINT)
    y = 4
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 3])
    npt.assert_equal(len(sl), 3)

    # The streamline stops at 0 (INVALIDPOINT) and 2 (INVALIDPOINT)
    y = 5
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 2])
    npt.assert_equal(len(sl), 2)

    # The streamline should contain only one point, the seed point,
    # because no valid inital direction was returned.
    y = 6
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], seeds[y])
    npt.assert_equal(sl[-1], seeds[y])
    npt.assert_equal(len(sl), 1)

    # The streamline should contain only one point, the seed point,
    # because no valid neighboring voxel (ENDPOINT)
    y = 7
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], seeds[y])
    npt.assert_equal(sl[-1], seeds[y])
    npt.assert_equal(len(sl), 1)
예제 #8
0
def track(params_file,
          directions="det",
          max_angle=30.,
          sphere=None,
          seed_mask=None,
          seed_threshold=0,
          n_seeds=1,
          random_seeds=False,
          rng_seed=None,
          stop_mask=None,
          stop_threshold=0,
          step_size=0.5,
          min_length=10,
          max_length=1000,
          odf_model="DTI",
          tracker="local"):
    """
    Tractography

    Parameters
    ----------
    params_file : str, nibabel img.
        Full path to a nifti file containing CSD spherical harmonic
        coefficients, or nibabel img with model params.
    directions : str
        How tracking directions are determined.
        One of: {"det" | "prob"}
    max_angle : float, optional.
        The maximum turning angle in each step. Default: 30
    sphere : Sphere object, optional.
        The discretization of direction getting. default:
        dipy.data.default_sphere.
    seed_mask : array, optional.
        Float or binary mask describing the ROI within which we seed for
        tracking.
        Default to the entire volume (all ones).
    seed_threshold : float, optional.
        A value of the seed_mask below which tracking is terminated.
        Default to 0.
    n_seeds : int or 2D array, optional.
        The seeding density: if this is an int, it is is how many seeds in each
        voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
        array, these are the coordinates of the seeds. Unless random_seeds is
        set to True, in which case this is the total number of random seeds
        to generate within the mask.
    random_seeds : bool
        Whether to generate a total of n_seeds random seeds in the mask.
        Default: False.
    rng_seed : int
        random seed used to generate random seeds if random_seeds is
        set to True. Default: None
    stop_mask : array or str, optional.
        If array: A float or binary mask that determines a stopping criterion
        (e.g. FA).
        If tuple: it contains a sequence that is interpreted as:
        (pve_wm, pve_gm, pve_csf), each item of which is either a string
        (full path) or a nibabel img to be used in particle filtering
        tractography.
        A tuple is required if tracker is set to "pft".
        Defaults to no stopping (all ones).
    stop_threshold : float or tuple, optional.
        If float, this a value of the stop_mask below which tracking is
        terminated (and stop_mask has to be an array).
        If str, "CMC" for Continuous Map Criterion [Girard2014]_.
                "ACT" for Anatomically-constrained tractography [Smith2012]_.
        A string is required if the tracker is set to "pft".
        Defaults to 0 (this means that if no stop_mask is passed,
        we will stop only at the edge of the image).
    step_size : float, optional.
        The size (in mm) of a step of tractography. Default: 1.0
    min_length: int, optional
        The miminal length (mm) in a streamline. Default: 10
    max_length: int, optional
        The miminal length (mm) in a streamline. Default: 1000
    odf_model : str, optional
        One of {"DTI", "CSD", "DKI", "MSMT"}. Defaults to use "DTI"
    tracker : str, optional
        Which strategy to use in tracking. This can be the standard local
        tracking ("local") or Particle Filtering Tracking ([Girard2014]_).
        One of {"local", "pft"}. Default: "local"

    Returns
    -------
    list of streamlines ()

    References
    ----------
    .. [Girard2014] Girard, G., Whittingstall, K., Deriche, R., &
        Descoteaux, M. Towards quantitative connectivity analysis: reducing
        tractography biases. NeuroImage, 98, 266-278, 2014.
    """
    logger = logging.getLogger('AFQ.tractography')

    logger.info("Loading Image...")
    if isinstance(params_file, str):
        params_img = nib.load(params_file)
    else:
        params_img = params_file

    model_params = params_img.get_fdata()
    affine = params_img.affine
    odf_model = odf_model.upper()
    directions = directions.lower()

    logger.info("Generating Seeds...")
    if isinstance(n_seeds, int):
        if seed_mask is None:
            seed_mask = np.ones(params_img.shape[:3])
        elif seed_mask.dtype != 'bool':
            seed_mask = seed_mask > seed_threshold
        if random_seeds:
            seeds = dtu.random_seeds_from_mask(seed_mask,
                                               seeds_count=n_seeds,
                                               seed_count_per_voxel=False,
                                               affine=affine,
                                               random_seed=rng_seed)
        else:
            seeds = dtu.seeds_from_mask(seed_mask,
                                        density=n_seeds,
                                        affine=affine)
    else:
        # If user provided an array, we'll use n_seeds as the seeds:
        seeds = n_seeds
    if sphere is None:
        sphere = dpd.default_sphere

    logger.info("Getting Directions...")
    if directions == "det":
        dg = DeterministicMaximumDirectionGetter
    elif directions == "prob":
        dg = ProbabilisticDirectionGetter

    if odf_model == "DTI" or odf_model == "DKI" or odf_model == "FWDTI":
        evals = model_params[..., :3]
        evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
        odf = tensor_odf(evals, evecs, sphere)
        dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
    elif odf_model == "CSD" or "MSMT":
        dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

    if tracker == "local":
        if stop_mask is None:
            stop_mask = np.ones(params_img.shape[:3])

        if stop_mask.dtype == 'bool':
            stopping_criterion = ThresholdStoppingCriterion(stop_mask, 0.5)
        else:
            stopping_criterion = ThresholdStoppingCriterion(
                stop_mask, stop_threshold)

        my_tracker = VerboseLocalTracking

    elif tracker == "pft":
        if not isinstance(stop_threshold, str):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a string ",
                "'stop_threshold' input. ",
                "Possible inputs are: 'CMC' or 'ACT'")
        if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a length "
                "3 iterable for `stop_mask`. "
                "Expected a (pve_wm, pve_gm, pve_csf) tuple.")
        pves = []
        pve_imgs = []
        vox_sizes = []
        for ii, pve in enumerate(stop_mask):
            if isinstance(pve, str):
                img = nib.load(pve)
            else:
                img = pve
            pve_imgs.append(img)
            pves.append(pve_imgs[-1].get_fdata())
        average_voxel_size = np.mean(vox_sizes)
        pve_wm_img, pve_gm_img, pve_csf_img = pve_imgs
        pve_wm_data, pve_gm_data, pve_csf_data = pves
        pve_wm_data = resample(pve_wm_data, model_params[...,
                                                         0], pve_wm_img.affine,
                               params_img.affine).get_fdata()
        pve_gm_data = resample(pve_gm_data, model_params[...,
                                                         0], pve_gm_img.affine,
                               params_img.affine).get_fdata()
        pve_csf_data = resample(pve_csf_data, model_params[..., 0],
                                pve_csf_img.affine,
                                params_img.affine).get_fdata()

        vox_sizes.append(np.mean(params_img.header.get_zooms()[:3]))

        my_tracker = VerboseParticleFilteringTracking
        if stop_threshold == "CMC":
            stopping_criterion = CmcStoppingCriterion.from_pve(
                pve_wm_data,
                pve_gm_data,
                pve_csf_data,
                step_size=step_size,
                average_voxel_size=average_voxel_size)
        elif stop_threshold == "ACT":
            stopping_criterion = ActStoppingCriterion.from_pve(
                pve_wm_data, pve_gm_data, pve_csf_data)

    logger.info("Tracking...")

    return _tracking(my_tracker,
                     seeds,
                     dg,
                     stopping_criterion,
                     params_img,
                     step_size=step_size,
                     min_length=min_length,
                     max_length=max_length,
                     random_seed=rng_seed)
예제 #9
0
파일: track.py 프로젝트: dPys/PyNets
def prep_tissues(t1_mask,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 wm_in_dwi,
                 tiss_class,
                 B0_mask,
                 cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    t1_mask : Nifti1Image
        T1w mask img.
    gm_in_dwi : Nifti1Image
        Grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : Nifti1Image
        Ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : Nifti1Image
        White-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.

    References
    ----------
    .. [1] Zhang, Y., Brady, M. and Smith, S. Segmentation of Brain MR Images
      Through a Hidden Markov Random Field Model and the
      Expectation-Maximization Algorithm IEEE Transactions on Medical Imaging,
      20(1): 45-56, 2001
    .. [2] Avants, B. B., Tustison, N. J., Wu, J., Cook, P. A. and Gee, J. C.
      An open source multivariate framework for n-tissue segmentation with
      evaluation on public data. Neuroinformatics, 9(4): 381-400, 2011.
    """
    import gc
    from dipy.tracking.stopping_criterion import (
        ActStoppingCriterion,
        CmcStoppingCriterion,
        BinaryStoppingCriterion,
    )
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img

    # Load B0 mask
    B0_mask_img = math_img("img > 0.01", img=B0_mask)

    # Load t1 mask
    mask_img = math_img("img > 0.01", img=t1_mask)

    # Load tissue maps and prepare tissue classifier
    wm_mask_img = math_img("img > 0.01", img=wm_in_dwi)
    gm_mask_img = math_img("img > 0.01", img=gm_in_dwi)
    vent_csf_in_dwi_img = math_img("img > 0.01", img=vent_csf_in_dwi)
    gm_data = np.asarray(gm_mask_img.dataobj, dtype=np.float32)
    wm_data = np.asarray(wm_mask_img.dataobj, dtype=np.float32)
    vent_csf_in_dwi_data = np.asarray(vent_csf_in_dwi_img.dataobj,
                                      dtype=np.float32)
    if tiss_class == "act":
        background = np.ones(mask_img.shape)
        background[(gm_data + wm_data + vent_csf_in_dwi_data) > 0] = 0
        gm_data[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(gm_data, vent_csf_in_dwi_data)
        del background
    elif tiss_class == "wm":
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(
                intersect_masks(
                    [
                        mask_img, wm_mask_img, B0_mask_img,
                        nib.Nifti1Image(np.invert(
                            vent_csf_in_dwi_data.astype('bool')).astype('int'),
                                        affine=mask_img.affine)
                    ],
                    threshold=1,
                    connected=False,
                ).dataobj))
    elif tiss_class == "cmc":
        tiss_classifier = CmcStoppingCriterion.from_pve(
            wm_data,
            gm_data,
            vent_csf_in_dwi_data,
            step_size=cmc_step_size,
            average_voxel_size=np.average(mask_img.header["pixdim"][1:4]),
        )
    elif tiss_class == "wb":
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(
                intersect_masks(
                    [
                        mask_img,
                        B0_mask_img,
                        nib.Nifti1Image(np.invert(
                            vent_csf_in_dwi_data.astype('bool')).astype('int'),
                                        affine=mask_img.affine),
                    ],
                    threshold=1,
                    connected=False,
                ).dataobj))
    else:
        raise ValueError("Tissue classifier cannot be none.")

    B0_mask_img.uncache()
    mask_img.uncache()
    wm_mask_img.uncache()
    gm_mask_img.uncache()
    del gm_data, wm_data, vent_csf_in_dwi_data
    gc.collect()

    return tiss_classifier
예제 #10
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file,
                                 args.map_include_file,
                                 args.map_exclude_file])
    assert_outputs_exist(parser, args, args.output_file)

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'
                     .format(args.min_length))
    if args.max_length < args.min_length:
        parser.error('maxL must be > than minL, (minL={}mm, maxL={}mm).'
                     .format(args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'
                .format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0], atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(
        fodf_sh_img.get_fdata(dtype=np.double),
        max_angle=theta,
        sphere=tracking_sphere,
        basis_type=sh_basis,
        pmf_threshold=args.sf_threshold,
        relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.map_include_file)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.get_header()['pixdim'][1:4])

    if not args.act:
        tissue_classifier = CmcStoppingCriterion(map_include_img.get_fdata(),
                                                 map_exclude_img.get_fdata(),
                                                 step_size=args.step_size,
                                                 average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActStoppingCriterion(map_include_img.get_fdata(),
                                                 map_exclude_img.get_fdata())

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in pft_streamlines
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in pft_streamlines
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (
            compress_streamlines(s, args.compress)
            for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
예제 #11
0
파일: track.py 프로젝트: devhliu/PyNets
def prep_tissues(B0_mask,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 wm_in_dwi,
                 tiss_class,
                 cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    B0_mask : str
        File path to B0 brain mask.
    gm_in_dwi : str
        File path to grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : str
        File path to ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : str
        File path to white-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.
    """
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.tracking.stopping_criterion import ActStoppingCriterion, CmcStoppingCriterion, BinaryStoppingCriterion
    # Loads mask and ensures it's a true binary mask
    mask_img = nib.load(B0_mask)
    # Load tissue maps and prepare tissue classifier
    gm_mask_data = nib.load(gm_in_dwi).get_fdata()
    wm_mask_data = nib.load(wm_in_dwi).get_fdata()
    vent_csf_in_dwi_data = nib.load(vent_csf_in_dwi).get_fdata()
    if tiss_class == 'act':
        background = np.ones(mask_img.shape)
        background[(gm_mask_data + wm_mask_data +
                    vent_csf_in_dwi_data) > 0] = 0
        include_map = gm_mask_data
        include_map[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(include_map,
                                               vent_csf_in_dwi_data)
        del background
        del include_map
    elif tiss_class == 'bin':
        tiss_classifier = BinaryStoppingCriterion(wm_mask_data.astype('bool'))
    elif tiss_class == 'cmc':
        voxel_size = np.average(mask_img.header['pixdim'][1:4])
        tiss_classifier = CmcStoppingCriterion.from_pve(
            wm_mask_data,
            gm_mask_data,
            vent_csf_in_dwi_data,
            step_size=cmc_step_size,
            average_voxel_size=voxel_size)
    elif tiss_class == 'wb':
        tiss_classifier = BinaryStoppingCriterion(
            mask_img.get_fdata().astype('bool'))
    else:
        raise ValueError('Tissue Classifier cannot be none.')

    del gm_mask_data, wm_mask_data, vent_csf_in_dwi_data
    mask_img.uncache()

    return tiss_classifier
예제 #12
0
def prep_tissues(t1_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    t1_mask : str
        File path to a T1w mask.
    gm_in_dwi : str
        File path to grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : str
        File path to ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : str
        File path to white-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.

    References
    ----------
    .. [1] Zhang, Y., Brady, M. and Smith, S. Segmentation of Brain MR Images
      Through a Hidden Markov Random Field Model and the Expectation-Maximization
      Algorithm IEEE Transactions on Medical Imaging, 20(1): 45-56, 2001
    .. [2] Avants, B. B., Tustison, N. J., Wu, J., Cook, P. A. and Gee, J. C.
      An open source multivariate framework for n-tissue segmentation with
      evaluation on public data. Neuroinformatics, 9(4): 381-400, 2011.

    """
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.tracking.stopping_criterion import ActStoppingCriterion, CmcStoppingCriterion, BinaryStoppingCriterion
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img

    # Loads mask
    mask_img = nib.load(t1_mask)
    # Load tissue maps and prepare tissue classifier
    wm_img = nib.load(wm_in_dwi)
    gm_img = nib.load(gm_in_dwi)
    gm_mask_data = np.asarray(gm_img.dataobj)
    wm_mask_data = np.asarray(wm_img.dataobj)
    vent_csf_in_dwi_data = np.asarray(nib.load(vent_csf_in_dwi).dataobj)
    if tiss_class == 'act':
        background = np.ones(mask_img.shape)
        background[(gm_mask_data + wm_mask_data + vent_csf_in_dwi_data) > 0] = 0
        gm_mask_data[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(gm_mask_data, vent_csf_in_dwi_data)
        del background
    elif tiss_class == 'bin':
        tiss_classifier = BinaryStoppingCriterion(np.asarray(intersect_masks([math_img('img > 0.0', img=mask_img),
                                                                              math_img('img > 0.0', img=wm_img)],
                                                                             threshold=1, connected=False).dataobj))
    elif tiss_class == 'cmc':
        voxel_size = np.average(mask_img.header['pixdim'][1:4])
        tiss_classifier = CmcStoppingCriterion.from_pve(wm_mask_data, gm_mask_data, vent_csf_in_dwi_data,
                                                        step_size=cmc_step_size, average_voxel_size=voxel_size)
    elif tiss_class == 'wb':
        tiss_classifier = BinaryStoppingCriterion(np.asarray(mask_img.dataobj).astype('bool'))
    else:
        raise ValueError('Tissue classifier cannot be none.')

    del gm_mask_data, wm_mask_data, vent_csf_in_dwi_data
    mask_img.uncache()
    gm_img.uncache()
    wm_img.uncache()

    return tiss_classifier
예제 #13
0
 def _act_sc(self):
     from dipy.tracking.stopping_criterion import ActStoppingCriterion
     include_map = self.tissue_labels < 3
     exclude_map = (self.tissue_labels == 3) & (self.tissue_labels == 0)
     act_classifier = ActStoppingCriterion(include_map, exclude_map)
     self.classifier = act_classifier
예제 #14
0
def prep_tissues(t1_mask,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 wm_in_dwi,
                 tiss_class,
                 cmc_step_size=0.2):
    """
    Estimate a tissue classifier for tractography.

    Parameters
    ----------
    t1_mask : str
        File path to a T1w mask.
    gm_in_dwi : str
        File path to grey-matter tissue segmentation Nifti1Image.
    vent_csf_in_dwi : str
        File path to ventricular CSF tissue segmentation Nifti1Image.
    wm_in_dwi : str
        File path to white-matter tissue segmentation Nifti1Image.
    tiss_class : str
        Tissue classification method.
    cmc_step_size : float
        Step size from CMC tissue classification method.

    Returns
    -------
    tiss_classifier : obj
        Tissue classifier object.
    """
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle
    from dipy.tracking.stopping_criterion import ActStoppingCriterion, CmcStoppingCriterion, BinaryStoppingCriterion
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img

    # Loads mask
    mask_img = nib.load(t1_mask)
    # Load tissue maps and prepare tissue classifier
    wm_img = nib.load(wm_in_dwi)
    gm_img = nib.load(gm_in_dwi)
    gm_mask_data = np.asarray(gm_img.dataobj)
    wm_mask_data = np.asarray(wm_img.dataobj)
    vent_csf_in_dwi_data = np.asarray(nib.load(vent_csf_in_dwi).dataobj)
    if tiss_class == 'act':
        background = np.ones(mask_img.shape)
        background[(gm_mask_data + wm_mask_data +
                    vent_csf_in_dwi_data) > 0] = 0
        gm_mask_data[background > 0] = 1
        tiss_classifier = ActStoppingCriterion(gm_mask_data,
                                               vent_csf_in_dwi_data)
        del background
    elif tiss_class == 'bin':
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(
                intersect_masks([
                    math_img('img > 0.0', img=mask_img),
                    math_img('img > 0.0', img=wm_img)
                ],
                                threshold=1,
                                connected=False).dataobj))
    elif tiss_class == 'cmc':
        voxel_size = np.average(mask_img.header['pixdim'][1:4])
        tiss_classifier = CmcStoppingCriterion.from_pve(
            wm_mask_data,
            gm_mask_data,
            vent_csf_in_dwi_data,
            step_size=cmc_step_size,
            average_voxel_size=voxel_size)
    elif tiss_class == 'wb':
        tiss_classifier = BinaryStoppingCriterion(
            np.asarray(mask_img.dataobj).astype('bool'))
    else:
        raise ValueError('Tissue Classifier cannot be none.')

    del gm_mask_data, wm_mask_data, vent_csf_in_dwi_data
    mask_img.uncache()
    gm_img.uncache()
    wm_img.uncache()

    return tiss_classifier