Ejemplo n.º 1
0
def test_save_seeds():
    tissue = np.array([[2, 1, 1, 2, 1],
                       [2, 2, 1, 1, 2],
                       [1, 1, 1, 1, 1],
                       [1, 1, 1, 2, 2],
                       [0, 1, 1, 1, 2],
                       [0, 1, 1, 0, 2],
                       [1, 0, 1, 1, 1]])
    tissue = tissue[None]

    sphere = HemiSphere.from_sphere(unit_octahedron)
    pmf_lookup = np.array([[0., 0., 0., ],
                           [0., 0., 1.]])
    pmf = pmf_lookup[(tissue > 0).astype("int")]

    # Create a seeds along
    x = np.array([0., 0, 0, 0, 0, 0, 0])
    y = np.array([0., 1, 2, 3, 4, 5, 6])
    z = np.array([1., 1, 1, 0, 1, 1, 1])
    seeds = np.column_stack([x, y, z])

    # Set up tracking
    endpoint_mask = tissue == TissueTypes.ENDPOINT
    invalidpoint_mask = tissue == TissueTypes.INVALIDPOINT
    tc = ActTissueClassifier(endpoint_mask, invalidpoint_mask)
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)

    # valid streamlines only
    streamlines_generator = LocalTracking(direction_getter=dg,
                                          tissue_classifier=tc,
                                          seeds=seeds,
                                          affine=np.eye(4),
                                          step_size=1.,
                                          return_all=False,
                                          save_seeds=True)

    streamlines_not_all = iter(streamlines_generator)
    # Verifiy that seeds are returned by the LocalTracker
    _, seed = next(streamlines_not_all)
    npt.assert_equal(seed, seeds[0])
    _, seed = next(streamlines_not_all)
    npt.assert_equal(seed, seeds[1])
    # Verifiy that seeds are returned by the PFTTracker also
    pft_streamlines = ParticleFilteringTracking(direction_getter=dg,
                                                tissue_classifier=tc,
                                                seeds=seeds,
                                                affine=np.eye(4),
                                                step_size=1.,
                                                max_cross=1,
                                                return_all=False,
                                                save_seeds=True)
    streamlines = iter(pft_streamlines)
    _, seed = next(streamlines)
    npt.assert_equal(seed, seeds[0])
    _, seed = next(streamlines)
    npt.assert_equal(seed, seeds[1])
Ejemplo n.º 2
0
                                              img_pve_csf.get_data(),
                                              step_size=step_size,
                                              average_voxel_size=voxel_size)

# seeds are place in voxel of the corpus callosum containing only white matter
seed_mask = labels == 2
seed_mask[img_pve_wm.get_data() < 0.5] = 0
seeds = utils.seeds_from_mask(seed_mask, density=2, affine=affine)

# Particle Filtering Tractography
pft_streamline_generator = ParticleFilteringTracking(dg,
                                                     cmc_classifier,
                                                     seeds,
                                                     affine,
                                                     max_cross=1,
                                                     step_size=step_size,
                                                     maxlen=1000,
                                                     pft_back_tracking_dist=2,
                                                     pft_front_tracking_dist=1,
                                                     particle_count=15,
                                                     return_all=False)

# streamlines = list(pft_streamline_generator)
streamlines = Streamlines(pft_streamline_generator)
save_trk("pft_streamline.trk", streamlines, affine, shape)

renderer.clear()
renderer.add(actor.line(streamlines, cmap.line_colors(streamlines)))
window.record(renderer, out_path='pft_streamlines.png', size=(600, 600))
"""
.. figure:: pft_streamlines.png
Ejemplo n.º 3
0
    def run(self,
            pam_files,
            wm_files,
            gm_files,
            csf_files,
            seeding_files,
            step_size=0.2,
            seed_density=1,
            pmf_threshold=0.1,
            max_angle=20.,
            pft_back=2,
            pft_front=1,
            pft_count=15,
            out_dir='',
            out_tractogram='tractogram.trk',
            save_seeds=False):
        """Workflow for Particle Filtering Tracking.

        This workflow use a saved peaks and metrics (PAM) file as input.

        Parameters
        ----------
        pam_files : string
           Path to the peaks and metrics files. This path may contain
            wildcards to use multiple masks at once.
        wm_files : string
            Path to white matter partial volume estimate for tracking (CMC).
        gm_files : string
            Path to grey matter partial volume estimate for tracking (CMC).
        csf_files : string
            Path to cerebrospinal fluid partial volume estimate for tracking
            (CMC).
        seeding_files : string
            A binary image showing where we need to seed for tracking.
        step_size : float, optional
            Step size used for tracking (default 0.2mm).
        seed_density : int, optional
            Number of seeds per dimension inside voxel (default 1).
             For example, seed_density of 2 means 8 regularly distributed
             points in the voxel. And seed density of 1 means 1 point at the
             center of the voxel.
        pmf_threshold : float, optional
            Threshold for ODF functions (default 0.1).
        max_angle : float, optional
            Maximum angle between streamline segments (range [0, 90],
            default 20).
        pft_back : float, optional
            Distance in mm to back track before starting the particle filtering
            tractography (defaul 2mm). The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_front : float, optional
            Distance in mm to run the particle filtering tractography after the
            the back track distance (default 1mm). The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_count : int, optional
            Number of particles to use in the particle filter (default 15).
        out_dir : string, optional
           Output directory (default input file directory)
        out_tractogram : string, optional
           Name of the tractogram file to be saved (default 'tractogram.trk')
        save_seeds : bool, optional
            If true, save the seeds associated to their streamline
            in the 'data_per_streamline' Tractogram dictionary using
            'seeds' as the key

        References
        ----------
        Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M. Towards
        quantitative connectivity analysis: reducing tractography biases.
        NeuroImage, 98, 266-278, 2014.

        """
        io_it = self.get_io_iterator()

        for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract \
                in io_it:

            logging.info(
                'Particle Filtering tracking on {0}'.format(pams_path))

            pam = load_peaks(pams_path, verbose=False)

            wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True)
            gm, _ = load_nifti(gm_path)
            csf, _ = load_nifti(csf_path)
            avs = sum(voxel_size) / len(voxel_size)  # average_voxel_size
            classifier = CmcTissueClassifier.from_pve(wm,
                                                      gm,
                                                      csf,
                                                      step_size=step_size,
                                                      average_voxel_size=avs)
            logging.info('classifier done')
            seed_mask, _ = load_nifti(seeding_path)
            seeds = utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
            logging.info('seeds done')
            dg = ProbabilisticDirectionGetter

            direction_getter = dg.from_shcoeff(pam.shm_coeff,
                                               max_angle=max_angle,
                                               sphere=pam.sphere,
                                               pmf_threshold=pmf_threshold)

            tracking_result = ParticleFilteringTracking(
                direction_getter,
                classifier,
                seeds,
                affine,
                step_size=step_size,
                pft_back_tracking_dist=pft_back,
                pft_front_tracking_dist=pft_front,
                pft_max_trial=20,
                particle_count=pft_count,
                save_seeds=save_seeds)

            logging.info('ParticleFilteringTracking initiated')

            if save_seeds:
                streamlines, seeds = zip(*tracking_result)
                tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
                tractogram.data_per_streamline['seeds'] = seeds
            else:
                tractogram = Tractogram(tracking_result,
                                        affine_to_rasmm=np.eye(4))

            save(tractogram, out_tract)

            logging.info('Saved {0}'.format(out_tract))
Ejemplo n.º 4
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [
        args.sh_file, args.seed_file, args.map_include_file,
        args.map_exclude_file
    ])
    assert_outputs_exist(parser, args, [args.output_file])

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.sh_file)
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(fodf_sh_img.get_data().astype(np.double),
                              max_angle=theta,
                              sphere=tracking_sphere,
                              basis_type=sh_basis,
                              pmf_threshold=args.sf_threshold,
                              relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.map_include_file)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.get_header()['pixdim'][1:4])

    tissue_classifier = None
    if not args.act:
        tissue_classifier = CmcTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data(),
                                                step_size=args.step_size,
                                                average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data())

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_data(),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size
    filtered_streamlines = (
        s for s in pft_streamlines
        if scaled_min_length <= length(s) <= scaled_max_length)
    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Ejemplo n.º 5
0
def test_particle_filtering_tractography():
    """This tests that the ParticleFilteringTracking produces
    more streamlines connecting the gray matter than LocalTracking.
    """
    sphere = get_sphere('repulsion100')
    step_size = 0.2

    # Simple tissue masks
    simple_wm = np.array([[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0],
                          [0, 1, 1, 1, 0, 0], [0, 1, 1, 1, 0, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_wm = np.dstack([
        np.zeros(simple_wm.shape), simple_wm, simple_wm, simple_wm,
        np.zeros(simple_wm.shape)
    ])
    simple_gm = np.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
                          [0, 1, 0, 0, 1, 0], [0, 0, 0, 0, 1, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_gm = np.dstack([
        np.zeros(simple_gm.shape), simple_gm, simple_gm, simple_gm,
        np.zeros(simple_gm.shape)
    ])
    simple_csf = np.ones(simple_wm.shape) - simple_wm - simple_gm
    tc = ActTissueClassifier.from_pve(simple_wm, simple_gm, simple_csf)
    seeds = seeds_from_mask(simple_wm, density=2)

    # Random pmf in every voxel
    shape_img = list(simple_wm.shape)
    shape_img.extend([sphere.vertices.shape[0]])
    np.random.seed(0)  # Random number generator initialization
    pmf = np.random.random(shape_img)

    # Test that PFT recover equal or more streamlines than localTracking
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)
    local_streamlines_generator = LocalTracking(dg,
                                                tc,
                                                seeds,
                                                np.eye(4),
                                                step_size,
                                                max_cross=1,
                                                return_all=False)
    local_streamlines = Streamlines(local_streamlines_generator)

    pft_streamlines_generator = ParticleFilteringTracking(
        dg,
        tc,
        seeds,
        np.eye(4),
        step_size,
        max_cross=1,
        return_all=False,
        pft_back_tracking_dist=1,
        pft_front_tracking_dist=0.5)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_(np.array([len(pft_streamlines) > 0]))
    npt.assert_(np.array([len(pft_streamlines) >= len(local_streamlines)]))

    # Test that all points are equally spaced
    for l in [1, 2, 5, 10, 100]:
        pft_streamlines = ParticleFilteringTracking(dg,
                                                    tc,
                                                    seeds,
                                                    np.eye(4),
                                                    step_size,
                                                    max_cross=1,
                                                    return_all=True,
                                                    maxlen=l)
        for s in pft_streamlines:
            for i in range(len(s) - 1):
                npt.assert_almost_equal(np.linalg.norm(s[i] - s[i + 1]),
                                        step_size)
    # Test that all points are within the image volume
    seeds = seeds_from_mask(np.ones(simple_wm.shape), density=1)
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          tc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    for s in pft_streamlines:
        npt.assert_(np.all((s + 0.5).astype(int) >= 0))
        npt.assert_(np.all((s + 0.5).astype(int) < simple_wm.shape))

    # Test that the number of streamline return with return_all=True equal the
    # number of seeds places
    npt.assert_(np.array([len(pft_streamlines) == len(seeds)]))

    # Test non WM seed position
    seeds = [[0, 5, 4], [0, 0, 1], [50, 50, 50]]
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          tc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_equal(len(pft_streamlines[0]), 3)  # INVALIDPOINT
    npt.assert_equal(len(pft_streamlines[1]), 3)  # ENDPOINT
    npt.assert_equal(len(pft_streamlines[2]), 1)  # OUTSIDEIMAGE

    # Test with wrong tissueclassifier type
    tc_bin = BinaryTissueClassifier(simple_wm)
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(dg, tc_bin, seeds,
                                                      np.eye(4), step_size))
    # Test with invalid back/front tracking distances
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          tc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, tc, seeds, np.eye(4), step_size, pft_back_tracking_dist=-1))
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          tc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=-2))

    # Test with invalid affine shape
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg, tc, seeds, np.eye(3), step_size))

    # Test with invalid maxlen
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, tc, seeds, np.eye(4), step_size, maxlen=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, tc, seeds, np.eye(4), step_size, maxlen=-1))

    # Test with invalid particle count
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, tc, seeds, np.eye(4), step_size, particle_count=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, tc, seeds, np.eye(4), step_size, particle_count=-1))

    # Test reproducibility
    tracking_1 = Streamlines(
        ParticleFilteringTracking(dg,
                                  tc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0)).data
    tracking_2 = Streamlines(
        ParticleFilteringTracking(dg,
                                  tc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0)).data
    npt.assert_equal(tracking_1, tracking_2)
Ejemplo n.º 6
0
    def run(self,
            pam_files,
            wm_files,
            gm_files,
            csf_files,
            seeding_files,
            step_size=0.2,
            back_tracking_dist=2,
            front_tracking_dist=1,
            max_trial=20,
            particle_count=15,
            seed_density=1,
            pmf_threshold=0.1,
            max_angle=30.,
            out_dir='',
            out_tractogram='tractogram.trk'):
        """Workflow for Particle Filtering Tracking.

        This workflow use a saved peaks and metrics (PAM) file as input.

        Parameters
        ----------
        pam_files : string
           Path to the peaks and metrics files. This path may contain
            wildcards to use multiple masks at once.
        wm_files : string
            Path of White matter for stopping criteria for tracking.
        gm_files : string
            Path of grey matter for stopping criteria for tracking.
        csf_files : string
            Path of cerebrospinal fluid for stopping criteria for tracking.
        seeding_files : string
            A binary image showing where we need to seed for tracking.
        step_size : float, optional
            Step size used for tracking.
        back_tracking_dist : float, optional
            Distance in mm to back track before starting the particle filtering
            tractography. The total particle filtering tractography distance is
            equal to back_tracking_dist + front_tracking_dist.
            By default this is set to 2 mm.
        front_tracking_dist : float, optional
            Distance in mm to run the particle filtering tractography after the
            the back track distance. The total particle filtering tractography
            distance is equal to back_tracking_dist + front_tracking_dist. By
            default this is set to 1 mm.
        max_trial : int, optional
            Maximum number of trial for the particle filtering tractography
            (Prevents infinite loops, default=20).
        particle_count : int, optional
            Number of particles to use in the particle filter. (default 15)
        seed_density : int, optional
            Number of seeds per dimension inside voxel (default 1).
             For example, seed_density of 2 means 8 regularly distributed
             points in the voxel. And seed density of 1 means 1 point at the
             center of the voxel.
        pmf_threshold : float, optional
            Threshold for ODF functions. (default 0.1)
        max_angle : float, optional
            Maximum angle between tract segments. This angle can be more
            generous (larger) than values typically used with probabilistic
            direction getters. The angle range is (0, 90)
        out_dir : string, optional
           Output directory (default input file directory)
        out_tractogram : string, optional
           Name of the tractogram file to be saved (default 'tractogram.trk')

        References
        ----------
        Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M.
               Towards quantitative connectivity analysis: reducing
               tractography biases. NeuroImage, 98, 266-278, 2014..

        """
        io_it = self.get_io_iterator()

        for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract \
                in io_it:

            logging.info(
                'Particle Filtering tracking on {0}'.format(pams_path))

            pam = load_peaks(pams_path, verbose=False)

            wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True)
            gm, _ = load_nifti(gm_path)
            csf, _ = load_nifti(csf_path)
            avs = sum(voxel_size) / len(voxel_size)  # average_voxel_size
            classifier = CmcTissueClassifier.from_pve(wm,
                                                      gm,
                                                      csf,
                                                      step_size=step_size,
                                                      average_voxel_size=avs)
            logging.info('classifier done')
            seed_mask, _ = load_nifti(seeding_path)
            seeds = utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
            logging.info('seeds done')
            dg = ProbabilisticDirectionGetter

            direction_getter = dg.from_shcoeff(pam.shm_coeff,
                                               max_angle=max_angle,
                                               sphere=pam.sphere,
                                               pmf_threshold=pmf_threshold)

            streamlines = ParticleFilteringTracking(
                direction_getter,
                classifier,
                seeds,
                affine,
                step_size=step_size,
                pft_back_tracking_dist=back_tracking_dist,
                pft_front_tracking_dist=front_tracking_dist,
                pft_max_trial=max_trial,
                particle_count=particle_count)

            logging.info('ParticleFilteringTracking initiated')

            tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
            save(tractogram, out_tract)

            logging.info('Saved {0}'.format(out_tract))
Ejemplo n.º 7
0
def track_ensemble(target_samples,
                   atlas_data_wm_gm_int,
                   parcels,
                   mod_fit,
                   tiss_classifier,
                   sphere,
                   directget,
                   curv_thr_list,
                   step_list,
                   track_type,
                   maxcrossing,
                   max_length,
                   n_seeds_per_iter=200,
                   pft_back_tracking_dist=2,
                   pft_front_tracking_dist=1,
                   particle_count=15,
                   roi_neighborhood_tol=8):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI masks.

    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : array
        3D int32 numpy array of atlas parcellation intensities from Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native
        diffusion space.
    mod : obj
        Connectivity reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel while tracking.
    max_length : int
        Maximum fiber length threshold in mm to restrict tracking.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique ensemble combination.
        By default this is set to 200.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.
    """
    from colorama import Fore, Style
    from dipy.tracking import utils
    from dipy.tracking.streamline import Streamlines, select_by_rois
    from dipy.tracking.local import LocalTracking, ParticleFilteringTracking
    from dipy.direction import ProbabilisticDirectionGetter, BootDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter

    # Commence Ensemble Tractography
    parcel_vec = np.ones(len(parcels)).astype('bool')
    streamlines = nib.streamlines.array_sequence.ArraySequence()
    ix = 0
    circuit_ix = 0
    stream_counter = 0
    while int(stream_counter) < int(target_samples):
        for curv_thr in curv_thr_list:
            print("%s%s" % ('Curvature: ', curv_thr))

            # Instantiate DirectionGetter
            if directget == 'prob':
                dg = ProbabilisticDirectionGetter.from_shcoeff(
                    mod_fit, max_angle=float(curv_thr), sphere=sphere)
            elif directget == 'boot':
                dg = BootDirectionGetter.from_shcoeff(
                    mod_fit, max_angle=float(curv_thr), sphere=sphere)
            elif directget == 'clos':
                dg = ClosestPeakDirectionGetter.from_shcoeff(
                    mod_fit, max_angle=float(curv_thr), sphere=sphere)
            elif directget == 'det':
                dg = DeterministicMaximumDirectionGetter.from_shcoeff(
                    mod_fit, max_angle=float(curv_thr), sphere=sphere)
            else:
                raise ValueError(
                    'ERROR: No valid direction getter(s) specified.')

            for step in step_list:
                print("%s%s" % ('Step: ', step))
                # Perform wm-gm interface seeding, using n_seeds at a time
                seeds = utils.random_seeds_from_mask(
                    atlas_data_wm_gm_int > 0,
                    seeds_count=n_seeds_per_iter,
                    seed_count_per_voxel=False,
                    affine=np.eye(4))
                if len(seeds) == 0:
                    raise RuntimeWarning(
                        'Warning: No valid seed points found in wm-gm interface...'
                    )

                print(seeds)
                # Perform tracking
                if track_type == 'local':
                    streamline_generator = LocalTracking(
                        dg,
                        tiss_classifier,
                        seeds,
                        np.eye(4),
                        max_cross=int(maxcrossing),
                        maxlen=int(max_length),
                        step_size=float(step),
                        return_all=True)
                elif track_type == 'particle':
                    streamline_generator = ParticleFilteringTracking(
                        dg,
                        tiss_classifier,
                        seeds,
                        np.eye(4),
                        max_cross=int(maxcrossing),
                        step_size=float(step),
                        maxlen=int(max_length),
                        pft_back_tracking_dist=pft_back_tracking_dist,
                        pft_front_tracking_dist=pft_front_tracking_dist,
                        particle_count=particle_count,
                        return_all=True)
                else:
                    raise ValueError(
                        'ERROR: No valid tracking method(s) specified.')

                # Filter resulting streamlines by roi-intersection characteristics
                roi_proximal_streamlines = Streamlines(
                    select_by_rois(streamline_generator,
                                   parcels,
                                   parcel_vec,
                                   mode='any',
                                   affine=np.eye(4),
                                   tol=roi_neighborhood_tol))

                # Repeat process until target samples condition is met
                ix = ix + 1
                for s in roi_proximal_streamlines:
                    stream_counter = stream_counter + len(s)
                    streamlines.append(s)
                    if int(stream_counter) >= int(target_samples):
                        break
                    else:
                        continue

        circuit_ix = circuit_ix + 1
        print(
            "%s%s%s%s%s" %
            ('Completed hyperparameter circuit: ', circuit_ix,
             '...\nCumulative Streamline Count: ', Fore.CYAN, stream_counter))
        print(Style.RESET_ALL)

    print('\n')
    return streamlines
Ejemplo n.º 8
0
def track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, parcel_vec, mod_fit,
                   tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, max_length,
                   n_seeds_per_iter=200):
    from colorama import Fore, Style
    from dipy.tracking import utils
    from dipy.tracking.streamline import Streamlines, select_by_rois
    from dipy.tracking.local import LocalTracking, ParticleFilteringTracking
    from dipy.direction import ProbabilisticDirectionGetter, BootDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter

    # Commence Ensemble Tractography
    streamlines = nib.streamlines.array_sequence.ArraySequence()
    ix = 0
    circuit_ix = 0
    stream_counter = 0
    while int(stream_counter) < int(target_samples):
        for curv_thr in curv_thr_list:
            print("%s%s" % ('Curvature: ', curv_thr))

            # Instantiate DirectionGetter
            if directget == 'prob':
                dg = ProbabilisticDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                               sphere=sphere)
            elif directget == 'boot':
                dg = BootDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                      sphere=sphere)
            elif directget == 'closest':
                dg = ClosestPeakDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                             sphere=sphere)
            elif directget == 'det':
                dg = DeterministicMaximumDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                                      sphere=sphere)
            else:
                raise ValueError('ERROR: No valid direction getter(s) specified.')

            for step in step_list:
                print("%s%s" % ('Step: ', step))
                # Perform wm-gm interface seeding, using n_seeds at a time
                seeds = utils.random_seeds_from_mask(atlas_data_wm_gm_int > 0, seeds_count=n_seeds_per_iter,
                                                     seed_count_per_voxel=False, affine=np.eye(4))
                if len(seeds) == 0:
                    raise RuntimeWarning('Warning: No valid seed points found in wm-gm interface...')

                print(seeds)
                # Perform tracking
                if track_type == 'local':
                    streamline_generator = LocalTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                         max_cross=int(maxcrossing), maxlen=int(max_length),
                                                         step_size=float(step), return_all=True)
                elif track_type == 'particle':
                    streamline_generator = ParticleFilteringTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                                     max_cross=int(maxcrossing),
                                                                     step_size=float(step),
                                                                     maxlen=int(max_length),
                                                                     pft_back_tracking_dist=2,
                                                                     pft_front_tracking_dist=1,
                                                                     particle_count=15, return_all=True)
                else:
                    raise ValueError('ERROR: No valid tracking method(s) specified.')

                # Filter resulting streamlines by roi-intersection characteristics
                streamlines_more = Streamlines(select_by_rois(streamline_generator, parcels, parcel_vec.astype('bool'),
                                                              mode='any', affine=np.eye(4), tol=8))

                # Repeat process until target samples condition is met
                ix = ix + 1
                for s in streamlines_more:
                    stream_counter = stream_counter + len(s)
                    streamlines.append(s)
                    if int(stream_counter) >= int(target_samples):
                        break
                    else:
                        continue

        circuit_ix = circuit_ix + 1
        print("%s%s%s%s%s" % ('Completed hyperparameter circuit: ', circuit_ix, '...\nCumulative Streamline Count: ',
                              Fore.CYAN, stream_counter))
        print(Style.RESET_ALL)

    print('\n')
    return streamlines
Ejemplo n.º 9
0
def execution(self, context):

    sh_coeff_vol = aims.read(self.sh_coefficients.fullPath())
    header = sh_coeff_vol.header()

    #transformation from Aims LPI mm space to RAS mm (reference space)

    aims_mm_to_ras_mm = np.array(header['transformations'][0]).reshape((4, 4))
    voxel_size = np.array(header['voxel_size'])
    if len(voxel_size) == 4:
        voxel_size = voxel_size[:-1]
    scaling = np.concatenate((voxel_size, np.ones(1)))
    context.write(voxel_size.shape)
    scaling_mat = np.diag(scaling)
    context.write(scaling_mat.shape, aims_mm_to_ras_mm.shape)
    aims_voxel_to_ras_mm = np.dot(scaling_mat, aims_mm_to_ras_mm)

    affine_tracking = np.eye(4)

    sh = np.array(sh_coeff_vol, copy=True)
    sh = sh.astype(np.float64)
    vol_shape = sh.shape[:-1]
    if self.sphere is not None:
        sphere = read_sphere(self.sphere.fullPath())
    else:
        context.write(
            'No Projection Sphere provided. Default dipy sphere symmetric 362 is used'
        )
        sphere = get_sphere()

    dg = DirectionGetter[self.type].from_shcoeff(
        sh,
        self.max_angle,
        sphere,
        basis_type=None,
        relative_peak_threshold=self.relative_peak_threshold,
        min_separation_angle=self.min_separation_angle)

    #Handling seeds in both deterministic and probabilistic framework
    s = np.loadtxt(self.seeds.fullPath())
    s = s.astype(np.float32)
    i = np.arange(self.nb_samples)
    if self.nb_samples <= 1:
        seeds = s
    else:
        seeds = np.zeros((self.nb_samples, ) + s.shape)
        seeds[i] = s
        seeds = seeds.reshape((-1, 3))
    seeds = nib.affines.apply_affine(np.linalg.inv(scaling_mat), seeds)
    #building classifier

    csf_vol = aims.read(self.csf_pve.fullPath())
    grey_vol = aims.read(self.gm_pve.fullPath())
    white_vol = aims.read(self.wm_pve.fullPath())

    csf = np.array(csf_vol)
    csf = csf[..., 0]
    gm = np.array(grey_vol)
    gm = gm[..., 0]
    wm = np.array(white_vol)
    wm = wm[..., 0]

    #rethreshold volumes due to interpolation (eg values >1)
    total = (csf + gm + wm).copy()
    csf[total <= 0] = 0
    gm[total <= 0] = 0
    wm[total <= 0] = 0
    csf[total != 0] = (csf[total != 0]) / (total[total != 0])
    wm[total != 0] = (wm[total != 0]) / (total[total != 0])
    gm[total != 0] = gm[total != 0] / (total[total != 0])

    classif = Classifiers[self.constraint]
    classifier = classif.from_pve(wm_map=wm, gm_map=gm, csf_map=csf)

    #Tracking is made in the LPI voxel space in order no to imposes affine to data. The seeds are supposed to also be in LPI voxel space
    streamlines_generator = ParticleFilteringTracking(
        dg,
        classifier,
        seeds,
        affine_tracking,
        step_size=self.step_size,
        max_cross=self.crossing_max,
        maxlen=self.nb_iter_max,
        pft_back_tracking_dist=self.back_tracking_dist,
        pft_front_tracking_dist=self.front_tracking_dist,
        pft_max_trial=self.max_trial,
        particle_count=self.nb_particles,
        return_all=self.return_all)
    #Store Fibers directly in  LPI orientation with appropriate transformation
    save_trk(self.streamlines.fullPath(),
             streamlines_generator,
             affine=aims_voxel_to_ras_mm,
             vox_size=voxel_size,
             shape=vol_shape)

    transformManager = getTransformationManager()
    transformManager.copyReferential(self.sh_coefficients, self.streamlines)