Esempio n. 1
0
def test_random_seeds_from_mask():
    mask = np.random.random_integers(0, 1, size=(4, 6, 3))
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=24,
                                   seed_count_per_voxel=True)
    assert_equal(mask.sum() * 24, len(seeds))
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=0,
                                   seed_count_per_voxel=True)
    assert_equal(0, len(seeds))

    mask[:] = False
    mask[2, 2, 2] = True
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=8,
                                   seed_count_per_voxel=True)
    assert_equal(mask.sum() * 8, len(seeds))
    assert_true(np.all((seeds > 1.5) & (seeds < 2.5)))

    seeds = random_seeds_from_mask(mask,
                                   seeds_count=24,
                                   seed_count_per_voxel=False)
    assert_equal(24, len(seeds))
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=0,
                                   seed_count_per_voxel=False)
    assert_equal(0, len(seeds))

    mask[:] = False
    mask[2, 2, 2] = True
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=100,
                                   seed_count_per_voxel=False)
    assert_equal(100, len(seeds))
    assert_true(np.all((seeds > 1.5) & (seeds < 2.5)))

    mask = np.zeros((15, 15, 15))
    mask[2:14, 2:14, 2:14] = 1
    seeds_npv_2 = random_seeds_from_mask(mask,
                                         seeds_count=2,
                                         seed_count_per_voxel=True,
                                         random_seed=0)[:150]
    seeds_npv_3 = random_seeds_from_mask(mask,
                                         seeds_count=3,
                                         seed_count_per_voxel=True,
                                         random_seed=0)[:150]
    assert_true(np.all(seeds_npv_2 == seeds_npv_3))

    seeds_nt_150 = random_seeds_from_mask(mask,
                                          seeds_count=150,
                                          seed_count_per_voxel=False,
                                          random_seed=0)[:150]
    seeds_nt_500 = random_seeds_from_mask(mask,
                                          seeds_count=500,
                                          seed_count_per_voxel=False,
                                          random_seed=0)[:150]
    assert_true(np.all(seeds_nt_150 == seeds_nt_500))
Esempio n. 2
0
def test_random_seeds_from_mask():

    mask = np.random.random_integers(0, 1, size=(4, 6, 3))
    seeds = random_seeds_from_mask(mask, seeds_per_voxel=24)
    assert_equal(mask.sum() * 24, len(seeds))

    mask[:] = False
    mask[2, 2, 2] = True
    seeds = random_seeds_from_mask(mask, seeds_per_voxel=8)
    assert_equal(mask.sum() * 8, len(seeds))
    assert_true(np.all((seeds > 1.5) & (seeds < 2.5)))
Esempio n. 3
0
def test_random_seeds_from_mask():

    mask = np.random.random_integers(0, 1, size=(4, 6, 3))
    seeds = random_seeds_from_mask(mask, seeds_per_voxel=24)
    assert_equal(mask.sum() * 24, len(seeds))

    mask[:] = False
    mask[2, 2, 2] = True
    seeds = random_seeds_from_mask(mask, seeds_per_voxel=8)
    assert_equal(mask.sum() * 8, len(seeds))
    assert_true(np.all((seeds > 1.5) & (seeds < 2.5)))
Esempio n. 4
0
def test_random_seeds_from_mask():
    mask = np.random.randint(0, 1, size=(4, 6, 3))
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=24,
                                   seed_count_per_voxel=True)
    npt.assert_equal(mask.sum() * 24, len(seeds))
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=0,
                                   seed_count_per_voxel=True)
    npt.assert_equal(0, len(seeds))

    mask[:] = False
    mask[2, 2, 2] = True
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=8,
                                   seed_count_per_voxel=True)
    npt.assert_equal(mask.sum() * 8, len(seeds))
    assert_true(np.all((seeds > 1.5) & (seeds < 2.5)))

    seeds = random_seeds_from_mask(mask,
                                   seeds_count=24,
                                   seed_count_per_voxel=False)
    npt.assert_equal(24, len(seeds))
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=0,
                                   seed_count_per_voxel=False)
    npt.assert_equal(0, len(seeds))

    mask[:] = False
    mask[2, 2, 2] = True
    seeds = random_seeds_from_mask(mask,
                                   seeds_count=100,
                                   seed_count_per_voxel=False)
    npt.assert_equal(100, len(seeds))
    assert_true(np.all((seeds > 1.5) & (seeds < 2.5)))

    mask = np.zeros((15, 15, 15))
    mask[2:14, 2:14, 2:14] = 1
    seeds_npv_2 = random_seeds_from_mask(mask, seeds_count=2,
                                         seed_count_per_voxel=True,
                                         random_seed=0)[:150]
    seeds_npv_3 = random_seeds_from_mask(mask, seeds_count=3,
                                         seed_count_per_voxel=True,
                                         random_seed=0)[:150]
    assert_true(np.all(seeds_npv_2 == seeds_npv_3))

    seeds_nt_150 = random_seeds_from_mask(mask, seeds_count=150,
                                          seed_count_per_voxel=False,
                                          random_seed=0)[:150]
    seeds_nt_500 = random_seeds_from_mask(mask, seeds_count=500,
                                          seed_count_per_voxel=False,
                                          random_seed=0)[:150]
    assert_true(np.all(seeds_nt_150 == seeds_nt_500))
Esempio n. 5
0
def build_seed_list(mask_img_file, stream_affine, dens):
    """uses dipy tractography utilities in order to create a seed list for tractography

    Parameters
    ----------
    mask_img_file : str
        path to mask of area to generate seeds for
    stream_affine : ndarray
        4x4 array with 1s diagonally and 0s everywhere else
    dens : int
        seed density

    Returns
    -------
    ndarray
        locations for the seeds
    """

    mask_img = nib.load(mask_img_file)
    mask_img_data = mask_img.get_data().astype("bool")
    seeds = utils.random_seeds_from_mask(
        mask_img_data,
        affine=stream_affine,
        seeds_count=int(dens),
        seed_count_per_voxel=True,
    )
    return seeds
Esempio n. 6
0
 def track(self):
     Tracker.track(self)
     if self.streamlines is None:
         if not self.options.random_seeds:
             seeds = seeds_from_mask(self.data.binarymask, affine=self.data.aff)
         else:
             seeds = random_seeds_from_mask(self.data.binarymask,
                                            seeds_count=self.options.seeds_count,
                                            seed_count_per_voxel=self.options.seeds_per_voxel,
                                            affine=self.data.aff)
         self.seeds = seeds
Esempio n. 7
0
def _get_seeds(data_container,
               random_seeds=False,
               seeds_count=30000,
               seeds_per_voxel=False):
    if not random_seeds:
        return seeds_from_mask(data_container.binary_mask,
                               affine=data_container.aff)
    else:
        return random_seeds_from_mask(data_container.binary_mask,
                                      seeds_count=seeds_count,
                                      seed_count_per_voxel=seeds_per_voxel,
                                      affine=data_container.aff)
Esempio n. 8
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    for param in ['theta', 'curvature']:
        # Default was removed for consistency.
        if param not in args:
            setattr(args, param, None)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file, args.mask_file])
    assert_outputs_exists(parser, args, [args.output_file])

    np.random.seed(args.seed)

    mask_img = nib.load(args.mask_file)
    mask_data = mask_img.get_data()

    seeds = random_seeds_from_mask(
        nib.load(args.seed_file).get_data(),
        seeds_count=args.nts if 'nts' in args else args.npv,
        seed_count_per_voxel='nts' not in args)

    # Tracking is performed in voxel space
    streamlines = LocalTracking(_get_direction_getter(args, mask_data),
                                BinaryTissueClassifier(mask_data),
                                seeds,
                                np.eye(4),
                                step_size=args.step_size,
                                max_cross=1,
                                maxlen=int(args.max_len / args.step_size) + 1,
                                fixedstep=True,
                                return_all=True)

    filtered_streamlines = (s for s in streamlines
                            if args.min_len <= length(s) <= args.max_len)
    if args.compress_streamlines:
        filtered_streamlines = (compress_streamlines(s, args.tolerance_error)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=mask_img.affine)

    # Header with the affine/shape from mask image
    header = {
        Field.VOXEL_TO_RASMM: mask_img.affine.copy(),
        Field.VOXEL_SIZES: mask_img.header.get_zooms(),
        Field.DIMENSIONS: mask_img.shape,
        Field.VOXEL_ORDER: ''.join(aff2axcodes(mask_img.affine))
    }

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Esempio n. 9
0
def build_seed_list(wm_mask, dens):
    """uses dipy tractography utilities in order to create a seed list for tractography
    Parameters
    ----------
    wm_mask : np.array
    dens : int
        seed density
    Returns
    -------
    ndarray
        locations for the seeds
    """
    stream_affine = np.eye(4)

    seeds = utils.random_seeds_from_mask(wm_mask,
                                         affine=stream_affine,
                                         seeds_count=int(dens),
                                         seed_count_per_voxel=True)
    return seeds
Esempio n. 10
0
def execution(self, context):

    mask_vol = aims.read(self.mask.fullPath())
    h = mask_vol.header()
    mask = np.asarray(mask_vol)[..., 0]
    mask = mask.astype(bool)
    voxel_size = np.array(h['voxel_size'])
    if len(voxel_size) == 4:
        voxel_size[-1] = 1
    elif len(voxel_size) == 3:
        voxel_size = np.concatenate((voxel_size, np.ones(1)))
    scaling = np.diag(voxel_size)

    seeds = random_seeds_from_mask(mask,
                                   seeds_count=self.seed_number,
                                   seed_count_per_voxel=self.number_per_voxel,
                                   affine=scaling)
    np.savetxt(self.seeds.fullPath(), seeds)
    transformManager = getTransformationManager()
    transformManager.copyReferential(self.mask, self.seeds)
Esempio n. 11
0
    print 'We use the roi_radius={},\nand the response is {},\nthe ratio is {},\nusing {} of voxels'.format(radius, response, ratio, nvl)


    print 'fitting CSD model'
    st2 = time.time()
    sphere = get_sphere('symmetric724')
    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=8)
    csd_peaks = peaks_from_model(csd_model, data, sphere = sphere, relative_peak_threshold=0.1, min_separation_angle=25,mask=mask, return_sh=True, parallel=True, normalize_peaks=True)
    et2 = time.time() - st2
    print 'fitting CSD model finished, running time is {}'.format(et2)


    print 'seeding begins, using np.random.seed(123)'
    st3 = time.time()
    np.random.seed(123)
    seeds = utils.random_seeds_from_mask(mask, 4)
    for i in range(len(seeds)):
            if seeds[i][0]>199.:
                seeds[i][0]=398-seeds[i][0]
            if seeds[i][1]>399.:
                seeds[i][1]=798-seeds[i][1]
            if seeds[i][2]>199.:
                seeds[i][2]=398-seeds[i][2]
            for j in range(3):
                if seeds[i][j]<0.:
                    seeds[i][j]=-seeds[i][j]
    et3 = time.time() - st3
    print 'seeding transformation finished, the total seeds are {}, running time is {}'.format(seeds.shape[0], et3)

    print 'generating streamlines begins'
    st4 = time.time()
Esempio n. 12
0
def track(params_file, directions="det", max_angle=30., sphere=None,
          seed_mask=None, n_seeds=1, random_seeds=False, stop_mask=None,
          stop_threshold=0, step_size=0.5, min_length=10, max_length=1000,
          odf_model="DTI"):
    """
    Tractography

    Parameters
    ----------
    params_file : str, nibabel img.
        Full path to a nifti file containing CSD spherical harmonic
        coefficients, or nibabel img with model params.
    directions : str
        How tracking directions are determined.
        One of: {"deterministic" | "probablistic"}
    max_angle : float, optional.
        The maximum turning angle in each step. Default: 30
    sphere : Sphere object, optional.
        The discretization of direction getting. default:
        dipy.data.default_sphere.
    seed_mask : array, optional.
        Binary mask describing the ROI within which we seed for tracking.
        Default to the entire volume.
    n_seeds : int or 2D array, optional.
        The seeding density: if this is an int, it is is how many seeds in each
        voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
        array, these are the coordinates of the seeds. Unless random_seeds is
        set to True, in which case this is the total number of random seeds
        to generate within the mask.
    random_seeds : bool
        Whether to generate a total of n_seeds random seeds in the mask.
        Default: XXX.
    stop_mask : array, optional.
        A floating point value that determines a stopping criterion (e.g. FA).
        Default to no stopping (all ones).
    stop_threshold : float, optional.
        A value of the stop_mask below which tracking is terminated. Default to
        0 (this means that if no stop_mask is passed, we will stop only at
        the edge of the image)
    step_size : float, optional.
        The size (in mm) of a step of tractography. Default: 1.0
    min_length: int, optional
        The miminal length (mm) in a streamline. Default: 10
    max_length: int, optional
        The miminal length (mm) in a streamline. Default: 250
    odf_model : str, optional
        One of {"DTI", "CSD"}. Defaults to use "DTI"
    Returns
    -------
    list of streamlines ()
    """
    logger = logging.getLogger('AFQ.tractography')

    logger.info("Loading Image...")
    if isinstance(params_file, str):
        params_img = nib.load(params_file)
    else:
        params_img = params_file

    model_params = params_img.get_fdata()
    affine = params_img.affine

    logger.info("Generating Seeds...")
    if isinstance(n_seeds, int):
        if seed_mask is None:
            seed_mask = np.ones(params_img.shape[:3])
        if random_seeds:
            seeds = dtu.random_seeds_from_mask(seed_mask, seeds_count=n_seeds,
                                               seed_count_per_voxel=False,
                                               affine=affine)
        else:
            seeds = dtu.seeds_from_mask(seed_mask,
                                        density=n_seeds,
                                        affine=affine)
    else:
        # If user provided an array, we'll use n_seeds as the seeds:
        seeds = n_seeds
    if sphere is None:
        sphere = dpd.default_sphere

    logger.info("Getting Directions...")
    if directions == "det":
        dg = DeterministicMaximumDirectionGetter
    elif directions == "prob":
        dg = ProbabilisticDirectionGetter

    if odf_model == "DTI" or odf_model == "DKI":
        evals = model_params[..., :3]
        evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
        odf = tensor_odf(evals, evecs, sphere)
        dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
    elif odf_model == "CSD":
        dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

    if stop_mask is None:
        stop_mask = np.ones(params_img.shape[:3])

    threshold_classifier = ThresholdStoppingCriterion(stop_mask,
                                                      stop_threshold)
    logger.info("Tracking...")

    return _local_tracking(seeds, dg, threshold_classifier, params_img,
                           step_size=step_size, min_length=min_length,
                           max_length=max_length)
Esempio n. 13
0
def run_tractography(fdwi,
                     fbval,
                     fbvec,
                     fwmparc,
                     mod_func,
                     mod_type,
                     seed_density=20):
    """
    mod_func : 'str'
        'csd' or 'csa'
    mod_type : 'str'
        'det' or 'prob'
    seed_density : int, default=20
        Seeding density for tractography
    """
    # Getting default params
    sphere = get_sphere("repulsion724")
    stream_affine = np.eye(4)

    # Loading data
    print("Loading Data...")
    dwi, gtab, wm_mask = load_data(fdwi, fbval, fbvec, fwmparc)

    # Make tissue classifier
    tiss_classifier = BinaryStoppingCriterion(wm_mask)

    if mod_func == "csd":
        mod = csd_mod_est(gtab, dwi, wm_mask)
    elif mod_func == "csa":
        mod = odf_mod_est(gtab)

    # Build seed list
    seeds = utils.random_seeds_from_mask(
        wm_mask,
        affine=stream_affine,
        seeds_count=int(seed_density),
        seed_count_per_voxel=True,
    )

    # Make streamlines
    if mod_type == "det":
        print("Obtaining peaks from model...")
        direction_getter = peaks_from_model(
            mod,
            dwi,
            sphere,
            relative_peak_threshold=0.5,
            min_separation_angle=25,
            mask=wm_mask,
            npeaks=5,
            normalize_peaks=True,
        )
    elif mod_type == "prob":
        print("Preparing probabilistic tracking...")
        print("Fitting model to data...")
        mod_fit = mod.fit(dwi, wm_mask)
        print("Building direction-getter...")
        try:
            print(
                "Proceeding using spherical harmonic coefficient from model estimation..."
            )
            direction_getter = ProbabilisticDirectionGetter.from_shcoeff(
                mod_fit.shm_coeff, max_angle=60.0, sphere=sphere)
        except:
            print("Proceeding using FOD PMF from model estimation...")
            fod = mod_fit.odf(sphere)
            pmf = fod.clip(min=0)
            direction_getter = ProbabilisticDirectionGetter.from_pmf(
                pmf, max_angle=60.0, sphere=sphere)

    print("Running Local Tracking")
    streamline_generator = LocalTracking(
        direction_getter,
        tiss_classifier,
        seeds,
        stream_affine,
        step_size=0.5,
        return_all=True,
    )

    print("Reconstructing tractogram streamlines...")
    streamlines = Streamlines(streamline_generator)
    tracks = Streamlines([track for track in streamlines if len(track) > 60])
    return tracks
Esempio n. 14
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file, args.mask_file])
    assert_outputs_exist(parser, args, [args.output_file])

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    mask_img = nib.load(args.mask_file)
    mask_data = mask_img.get_data()

    # Make sure the mask is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_data(),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines = LocalTracking(_get_direction_getter(args, mask_data),
                                BinaryTissueClassifier(mask_data),
                                seeds,
                                np.eye(4),
                                step_size=vox_step_size,
                                max_cross=1,
                                maxlen=max_steps,
                                fixedstep=True,
                                return_all=True,
                                random_seed=args.seed)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    filtered_streamlines = (
        s for s in streamlines
        if scaled_min_length <= length(s) <= scaled_max_length)
    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Esempio n. 15
0
def main():
    t_init = perf_counter()
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_odf, args.in_mask, args.in_seed])
    assert_outputs_exist(parser, args, args.out_tractogram)
    if args.compress is not None:
        verify_compression_th(args.compress)

    odf_sh_img = nib.load(args.in_odf)
    mask = get_data_as_mask(nib.load(args.in_mask))
    seed_mask = get_data_as_mask(nib.load(args.in_seed))
    odf_sh = odf_sh_img.get_fdata(dtype=np.float32)

    t0 = perf_counter()
    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    # Seeds are returned with origin `center`.
    # However, GPUTracker expects origin to be `corner`.
    # Therefore, we need to shift the seed positions by half voxel.
    seeds = random_seeds_from_mask(seed_mask,
                                   np.eye(4),
                                   seeds_count=nb_seeds,
                                   seed_count_per_voxel=seed_per_vox,
                                   random_seed=args.rng_seed) + 0.5
    logging.info('Generated {0} seed positions in {1:.2f}s.'.format(
        len(seeds),
        perf_counter() - t0))

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    vox_max_length = args.max_length / voxel_size
    vox_min_length = args.min_length / voxel_size
    min_strl_len = int(vox_min_length / vox_step_size) + 1
    max_strl_len = int(vox_max_length / vox_step_size) + 1

    # initialize tracking
    tracker = GPUTacker(odf_sh, mask, seeds, vox_step_size, min_strl_len,
                        max_strl_len, args.theta, args.sh_basis,
                        args.batch_size, args.forward_only, args.rng_seed)

    # wrapper for tracker.track() yielding one TractogramItem per
    # streamline for use with the LazyTractogram.
    def tracks_generator_wrapper():
        for strl, seed in tracker.track():
            # seed must be saved in voxel space, with origin `center`.
            dps = {'seeds': seed - 0.5} if args.save_seeds else {}

            # TODO: Investigate why the streamline must NOT be shifted to
            # origin `corner` for LazyTractogram.
            strl *= voxel_size  # in mm.
            if args.compress:
                strl = compress_streamlines(strl, args.compress)
            yield TractogramItem(strl, dps, {})

    # instantiate tractogram
    tractogram = LazyTractogram.from_data_func(tracks_generator_wrapper)
    tractogram.affine_to_rasmm = odf_sh_img.affine

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(odf_sh_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
    logging.info('Saved tractogram to {0}.'.format(args.out_tractogram))

    # Total runtime
    logging.info('Total runtime of {0:.2f}s.'.format(perf_counter() - t_init))
Esempio n. 16
0
def track(params_file,
          directions="det",
          max_angle=30.,
          sphere=None,
          seed_mask=None,
          seed_threshold=0,
          n_seeds=1,
          random_seeds=False,
          rng_seed=None,
          stop_mask=None,
          stop_threshold=0,
          step_size=0.5,
          min_length=10,
          max_length=1000,
          odf_model="DTI",
          tracker="local"):
    """
    Tractography

    Parameters
    ----------
    params_file : str, nibabel img.
        Full path to a nifti file containing CSD spherical harmonic
        coefficients, or nibabel img with model params.
    directions : str
        How tracking directions are determined.
        One of: {"det" | "prob"}
    max_angle : float, optional.
        The maximum turning angle in each step. Default: 30
    sphere : Sphere object, optional.
        The discretization of direction getting. default:
        dipy.data.default_sphere.
    seed_mask : array, optional.
        Float or binary mask describing the ROI within which we seed for
        tracking.
        Default to the entire volume (all ones).
    seed_threshold : float, optional.
        A value of the seed_mask below which tracking is terminated.
        Default to 0.
    n_seeds : int or 2D array, optional.
        The seeding density: if this is an int, it is is how many seeds in each
        voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
        array, these are the coordinates of the seeds. Unless random_seeds is
        set to True, in which case this is the total number of random seeds
        to generate within the mask.
    random_seeds : bool
        Whether to generate a total of n_seeds random seeds in the mask.
        Default: False.
    rng_seed : int
        random seed used to generate random seeds if random_seeds is
        set to True. Default: None
    stop_mask : array or str, optional.
        If array: A float or binary mask that determines a stopping criterion
        (e.g. FA).
        If tuple: it contains a sequence that is interpreted as:
        (pve_wm, pve_gm, pve_csf), each item of which is either a string
        (full path) or a nibabel img to be used in particle filtering
        tractography.
        A tuple is required if tracker is set to "pft".
        Defaults to no stopping (all ones).
    stop_threshold : float or tuple, optional.
        If float, this a value of the stop_mask below which tracking is
        terminated (and stop_mask has to be an array).
        If str, "CMC" for Continuous Map Criterion [Girard2014]_.
                "ACT" for Anatomically-constrained tractography [Smith2012]_.
        A string is required if the tracker is set to "pft".
        Defaults to 0 (this means that if no stop_mask is passed,
        we will stop only at the edge of the image).
    step_size : float, optional.
        The size (in mm) of a step of tractography. Default: 1.0
    min_length: int, optional
        The miminal length (mm) in a streamline. Default: 10
    max_length: int, optional
        The miminal length (mm) in a streamline. Default: 1000
    odf_model : str, optional
        One of {"DTI", "CSD", "DKI", "MSMT"}. Defaults to use "DTI"
    tracker : str, optional
        Which strategy to use in tracking. This can be the standard local
        tracking ("local") or Particle Filtering Tracking ([Girard2014]_).
        One of {"local", "pft"}. Default: "local"

    Returns
    -------
    list of streamlines ()

    References
    ----------
    .. [Girard2014] Girard, G., Whittingstall, K., Deriche, R., &
        Descoteaux, M. Towards quantitative connectivity analysis: reducing
        tractography biases. NeuroImage, 98, 266-278, 2014.
    """
    logger = logging.getLogger('AFQ.tractography')

    logger.info("Loading Image...")
    if isinstance(params_file, str):
        params_img = nib.load(params_file)
    else:
        params_img = params_file

    model_params = params_img.get_fdata()
    affine = params_img.affine
    odf_model = odf_model.upper()
    directions = directions.lower()

    logger.info("Generating Seeds...")
    if isinstance(n_seeds, int):
        if seed_mask is None:
            seed_mask = np.ones(params_img.shape[:3])
        elif seed_mask.dtype != 'bool':
            seed_mask = seed_mask > seed_threshold
        if random_seeds:
            seeds = dtu.random_seeds_from_mask(seed_mask,
                                               seeds_count=n_seeds,
                                               seed_count_per_voxel=False,
                                               affine=affine,
                                               random_seed=rng_seed)
        else:
            seeds = dtu.seeds_from_mask(seed_mask,
                                        density=n_seeds,
                                        affine=affine)
    else:
        # If user provided an array, we'll use n_seeds as the seeds:
        seeds = n_seeds
    if sphere is None:
        sphere = dpd.default_sphere

    logger.info("Getting Directions...")
    if directions == "det":
        dg = DeterministicMaximumDirectionGetter
    elif directions == "prob":
        dg = ProbabilisticDirectionGetter

    if odf_model == "DTI" or odf_model == "DKI" or odf_model == "FWDTI":
        evals = model_params[..., :3]
        evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
        odf = tensor_odf(evals, evecs, sphere)
        dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
    elif odf_model == "CSD" or "MSMT":
        dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

    if tracker == "local":
        if stop_mask is None:
            stop_mask = np.ones(params_img.shape[:3])

        if stop_mask.dtype == 'bool':
            stopping_criterion = ThresholdStoppingCriterion(stop_mask, 0.5)
        else:
            stopping_criterion = ThresholdStoppingCriterion(
                stop_mask, stop_threshold)

        my_tracker = VerboseLocalTracking

    elif tracker == "pft":
        if not isinstance(stop_threshold, str):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a string ",
                "'stop_threshold' input. ",
                "Possible inputs are: 'CMC' or 'ACT'")
        if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a length "
                "3 iterable for `stop_mask`. "
                "Expected a (pve_wm, pve_gm, pve_csf) tuple.")
        pves = []
        pve_imgs = []
        vox_sizes = []
        for ii, pve in enumerate(stop_mask):
            if isinstance(pve, str):
                img = nib.load(pve)
            else:
                img = pve
            pve_imgs.append(img)
            pves.append(pve_imgs[-1].get_fdata())
        average_voxel_size = np.mean(vox_sizes)
        pve_wm_img, pve_gm_img, pve_csf_img = pve_imgs
        pve_wm_data, pve_gm_data, pve_csf_data = pves
        pve_wm_data = resample(pve_wm_data, model_params[...,
                                                         0], pve_wm_img.affine,
                               params_img.affine).get_fdata()
        pve_gm_data = resample(pve_gm_data, model_params[...,
                                                         0], pve_gm_img.affine,
                               params_img.affine).get_fdata()
        pve_csf_data = resample(pve_csf_data, model_params[..., 0],
                                pve_csf_img.affine,
                                params_img.affine).get_fdata()

        vox_sizes.append(np.mean(params_img.header.get_zooms()[:3]))

        my_tracker = VerboseParticleFilteringTracking
        if stop_threshold == "CMC":
            stopping_criterion = CmcStoppingCriterion.from_pve(
                pve_wm_data,
                pve_gm_data,
                pve_csf_data,
                step_size=step_size,
                average_voxel_size=average_voxel_size)
        elif stop_threshold == "ACT":
            stopping_criterion = ActStoppingCriterion.from_pve(
                pve_wm_data, pve_gm_data, pve_csf_data)

    logger.info("Tracking...")

    return _tracking(my_tracker,
                     seeds,
                     dg,
                     stopping_criterion,
                     params_img,
                     step_size=step_size,
                     min_length=min_length,
                     max_length=max_length,
                     random_seed=rng_seed)
Esempio n. 17
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [
        args.sh_file, args.seed_file, args.map_include_file,
        args.map_exclude_file
    ])
    assert_outputs_exist(parser, args, [args.output_file])

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.sh_file)
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(fodf_sh_img.get_data().astype(np.double),
                              max_angle=theta,
                              sphere=tracking_sphere,
                              basis_type=sh_basis,
                              pmf_threshold=args.sf_threshold,
                              relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.map_include_file)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.get_header()['pixdim'][1:4])

    tissue_classifier = None
    if not args.act:
        tissue_classifier = CmcTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data(),
                                                step_size=args.step_size,
                                                average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data())

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_data(),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size
    filtered_streamlines = (
        s for s in pft_streamlines
        if scaled_min_length <= length(s) <= scaled_max_length)
    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Esempio n. 18
0
from dipy.io.streamline import save_trk
from dipy.tracking import utils

fimg = "DTICAP_bet.nii.gz"
img = nib.load(fimg)
data = img.get_data()
fbval = "DTICAP.bval"
fbvec = "DTICAP.bvec"
affine = img.affine

bvals, bvecs = read_bvals_bvecs(fbval, fbvec)
gtab = gradient_table(bvals, bvecs)

mask, S0_mask = median_otsu(data[:, :, :, 0])
# create seeds
seeds = random_seeds_from_mask(mask, affine, seeds_count=1)
 

"""
 fit the data to a Constant Solid Angle ODF Model. This model will estimate the
Orientation Distribution Function (ODF) at each voxel. The ODF is the
distribution of water diffusion as a function of direction. The peaks of an ODF
are good estimates for the orientation of tract segments at a point in the
image. Here, we use ``peaks_from_model`` to fit the data and calculated the
fiber directions in all voxels of the white matter.
"""
 
response, ratio = auto_response(gtab, data, roi_radius=10, fa_thr=0.7)
csa_model = CsaOdfModel(gtab, sh_order=6)
csa_peaks = peaks_from_model(csa_model, data, default_sphere,
                             relative_peak_threshold=.5,
Esempio n. 19
0
def run_tracking(step_curv_combinations, recon_path, n_seeds_per_iter,
                 directget, maxcrossing, max_length, pft_back_tracking_dist,
                 pft_front_tracking_dist, particle_count, roi_neighborhood_tol,
                 waymask, min_length, track_type, min_separation_angle, sphere,
                 tiss_class, tissues4d, cache_dir):

    import gc
    import os
    import h5py
    from dipy.tracking import utils
    from dipy.tracking.streamline import select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, \
        ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter,
                                ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)
    from nilearn.image import index_img
    from pynets.dmri.track import prep_tissues
    from nibabel.streamlines.array_sequence import ArraySequence
    from nipype.utils.filemanip import copyfile, fname_presuffix

    recon_path_tmp_path = fname_presuffix(recon_path,
                                          suffix=f"_{step_curv_combinations}",
                                          newpath=cache_dir)
    copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False)

    if waymask is not None:
        waymask_tmp_path = fname_presuffix(waymask,
                                           suffix=f"_{step_curv_combinations}",
                                           newpath=cache_dir)
        copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False)
    else:
        waymask_tmp_path = None

    tissue_img = nib.load(tissues4d)

    # Order:
    B0_mask = index_img(tissue_img, 0)
    atlas_img = index_img(tissue_img, 1)
    atlas_data_wm_gm_int = index_img(tissue_img, 2)
    t1w2dwi = index_img(tissue_img, 3)
    gm_in_dwi = index_img(tissue_img, 4)
    vent_csf_in_dwi = index_img(tissue_img, 5)
    wm_in_dwi = index_img(tissue_img, 6)

    tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                                   wm_in_dwi, tiss_class, B0_mask)

    B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool")
    atlas_data = np.array(atlas_img.dataobj).astype("uint16")
    atlas_data_wm_gm_int_data = np.asarray(
        atlas_data_wm_gm_int.dataobj).astype("bool").astype("int16")

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    intensities = [i for i in np.unique(atlas_data) if i != 0]
    for roi_val in intensities:
        parcels.append(atlas_data == roi_val)
        i += 1

    del atlas_data

    parcel_vec = list(np.ones(len(parcels)).astype("bool"))

    with h5py.File(recon_path_tmp_path, 'r+') as hf:
        mod_fit = hf['reconstruction'][:].astype('float32')
    hf.close()

    print("%s%s" % ("Curvature: ", step_curv_combinations[1]))

    # Instantiate DirectionGetter
    if directget == "prob" or directget == "probabilistic":
        dg = ProbabilisticDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget == "clos" or directget == "closest":
        dg = ClosestPeakDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget == "det" or directget == "deterministic":
        maxcrossing = 1
        dg = DeterministicMaximumDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    else:
        raise ValueError("ERROR: No valid direction getter(s) specified.")

    print("%s%s" % ("Step: ", step_curv_combinations[0]))

    # Perform wm-gm interface seeding, using n_seeds at a time
    seeds = utils.random_seeds_from_mask(
        atlas_data_wm_gm_int_data > 0,
        seeds_count=n_seeds_per_iter,
        seed_count_per_voxel=False,
        affine=np.eye(4),
    )
    if len(seeds) == 0:
        print(
            UserWarning("No valid seed points found in wm-gm "
                        "interface..."))
        return None

    # print(seeds)

    # Perform tracking
    if track_type == "local":
        streamline_generator = LocalTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            maxlen=int(max_length),
            step_size=float(step_curv_combinations[0]),
            fixedstep=False,
            return_all=True,
        )
    elif track_type == "particle":
        streamline_generator = ParticleFilteringTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            step_size=float(step_curv_combinations[0]),
            maxlen=int(max_length),
            pft_back_tracking_dist=pft_back_tracking_dist,
            pft_front_tracking_dist=pft_front_tracking_dist,
            particle_count=particle_count,
            return_all=True,
        )
    else:
        try:
            raise ValueError("ERROR: No valid tracking method(s) specified.")
        except ValueError:
            import sys
            sys.exit(0)

    # Filter resulting streamlines by those that stay entirely
    # inside the brain
    try:
        roi_proximal_streamlines = utils.target(streamline_generator,
                                                np.eye(4),
                                                B0_mask_data,
                                                include=True)
    except BaseException:
        print('No streamlines found inside the brain! ' 'Check registrations.')
        return None

    # Filter resulting streamlines by roi-intersection
    # characteristics

    try:
        roi_proximal_streamlines = \
            nib.streamlines.array_sequence.ArraySequence(
                select_by_rois(
                    roi_proximal_streamlines,
                    affine=np.eye(4),
                    rois=parcels,
                    include=parcel_vec,
                    mode="%s" % ("any" if waymask is not None else
                                 "both_end"),
                    tol=roi_neighborhood_tol,
                )
            )
        print("%s%s" % ("Filtering by: \nNode intersection: ",
                        len(roi_proximal_streamlines)))
    except BaseException:
        print('No streamlines found to connect any parcels! '
              'Check registrations.')
        return None

    try:
        roi_proximal_streamlines = nib.streamlines. \
            array_sequence.ArraySequence(
            [
                s for s in roi_proximal_streamlines
                if len(s) >= float(min_length)
            ]
        )
        print(f"Minimum fiber length >{min_length}mm: "
              f"{len(roi_proximal_streamlines)}")
    except BaseException:
        print('No streamlines remaining after minimal length criterion.')
        return None

    if waymask is not None and os.path.isfile(waymask_tmp_path):
        from nilearn.image import math_img
        mask = math_img("img > 0.0075", img=nib.load(waymask_tmp_path))
        waymask_data = np.asarray(mask.dataobj).astype("bool")
        try:
            roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(
                roi_proximal_streamlines,
                np.eye(4),
                waymask_data,
                tol=roi_neighborhood_tol,
                mode="all")]
            print("%s%s" %
                  ("Waymask proximity: ", len(roi_proximal_streamlines)))
        except BaseException:
            print('No streamlines remaining in waymask\'s vacinity.')
            return None

    out_streams = [s.astype("float32") for s in roi_proximal_streamlines]

    del dg, seeds, roi_proximal_streamlines, streamline_generator, \
        atlas_data_wm_gm_int_data, mod_fit, B0_mask_data

    os.remove(recon_path_tmp_path)
    gc.collect()

    try:
        return ArraySequence(out_streams)
    except BaseException:
        return None
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.in_odf, args.in_seed, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    verify_streamline_length_options(parser, args)
    verify_compression_th(args.compress)
    verify_seed_options(parser, args)

    mask_img = nib.load(args.in_mask)
    mask_data = get_data_as_mask(mask_img, dtype=bool)

    # Make sure the data is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    odf_sh_img = nib.load(args.in_odf)
    if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]),
                       odf_sh_img.header.get_zooms()[0], atol=1e-03):
        parser.error(
            'ODF SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.in_seed)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(dtype=np.float32),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines_generator = LocalTracking(
        _get_direction_getter(args),
        BinaryStoppingCriterion(mask_data),
        seeds, np.eye(4),
        step_size=vox_step_size, max_cross=1,
        maxlen=max_steps,
        fixedstep=True, return_all=True,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines_generator
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines_generator
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (
            compress_streamlines(s, args.compress)
            for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
Esempio n. 21
0
fa = tensor_fit.fa

"""
In this simple example we can use FA to stop tracking. Here we stop tracking
when FA < 0.1.
"""

tissue_classifier = ThresholdTissueClassifier(fa, 0.1)

"""
Now, we need to set starting points for propagating each track. We call those
seeds. Using ``random_seeds_from_mask`` we can select a specific number of
seeds (``seeds_count``) in each voxel where the mask ``fa > 0.3`` is true.
"""

seeds = random_seeds_from_mask(fa > 0.3, seeds_count=1)

"""
For quality assurance we can also visualize a slice from the direction field
which we will use as the basis to perform the tracking.
"""

ren = window.Renderer()
ren.add(actor.peak_slicer(csd_peaks.peak_dirs,
                          csd_peaks.peak_values,
                          colors=None))

if interactive:
    window.show(ren, size=(900, 900))
else:
    window.record(ren, out_path='csd_direction_field.png', size=(900, 900))
Esempio n. 22
0
def track_ensemble(dwi_data,
                   target_samples,
                   atlas_data_wm_gm_int,
                   parcels,
                   mod_fit,
                   tiss_classifier,
                   sphere,
                   directget,
                   curv_thr_list,
                   step_list,
                   track_type,
                   maxcrossing,
                   max_length,
                   roi_neighborhood_tol,
                   min_length,
                   waymask,
                   n_seeds_per_iter=100,
                   pft_back_tracking_dist=2,
                   pft_front_tracking_dist=1,
                   particle_count=15):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI masks.

    dwi_data : array
        4D array of dwi data.
    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : array
        3D int32 numpy array of atlas parcellation intensities from Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native
        diffusion space.
    mod : obj
        Connectivity reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel while tracking.
    max_length : int
        Maximum fiber length threshold in mm to restrict tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask : str
        Path to a Nifti1Image in native diffusion space to constrain tractography.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique ensemble combination.
        By default this is set to 200.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.
    """
    from colorama import Fore, Style
    from dipy.tracking import utils
    from dipy.tracking.streamline import Streamlines, select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking
    from dipy.direction import ProbabilisticDirectionGetter, BootDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter

    if waymask:
        waymask_data = nib.load(waymask).get_fdata().astype('bool')

    # Commence Ensemble Tractography
    parcel_vec = list(np.ones(len(parcels)).astype('bool'))
    streamlines = nib.streamlines.array_sequence.ArraySequence()
    ix = 0
    circuit_ix = 0
    stream_counter = 0
    while int(stream_counter) < int(target_samples):
        for curv_thr in curv_thr_list:
            print("%s%s" % ('Curvature: ', curv_thr))

            # Instantiate DirectionGetter
            if directget == 'prob':
                dg = ProbabilisticDirectionGetter.from_shcoeff(
                    mod_fit, max_angle=float(curv_thr), sphere=sphere)
            elif directget == 'boot':
                dg = BootDirectionGetter.from_data(dwi_data,
                                                   mod_fit,
                                                   max_angle=float(curv_thr),
                                                   sphere=sphere)
            elif directget == 'clos':
                dg = ClosestPeakDirectionGetter.from_shcoeff(
                    mod_fit, max_angle=float(curv_thr), sphere=sphere)
            elif directget == 'det':
                dg = DeterministicMaximumDirectionGetter.from_shcoeff(
                    mod_fit, max_angle=float(curv_thr), sphere=sphere)
            else:
                raise ValueError(
                    'ERROR: No valid direction getter(s) specified.')

            for step in step_list:
                print("%s%s" % ('Step: ', step))

                # Perform wm-gm interface seeding, using n_seeds at a time
                seeds = utils.random_seeds_from_mask(
                    atlas_data_wm_gm_int > 0,
                    seeds_count=n_seeds_per_iter,
                    seed_count_per_voxel=False,
                    affine=np.eye(4))
                if len(seeds) == 0:
                    raise RuntimeWarning(
                        'Warning: No valid seed points found in wm-gm interface...'
                    )

                print(seeds)

                # Perform tracking
                if track_type == 'local':
                    streamline_generator = LocalTracking(
                        dg,
                        tiss_classifier,
                        seeds,
                        np.eye(4),
                        max_cross=int(maxcrossing),
                        maxlen=int(max_length),
                        step_size=float(step),
                        return_all=True)
                elif track_type == 'particle':
                    streamline_generator = ParticleFilteringTracking(
                        dg,
                        tiss_classifier,
                        seeds,
                        np.eye(4),
                        max_cross=int(maxcrossing),
                        step_size=float(step),
                        maxlen=int(max_length),
                        pft_back_tracking_dist=pft_back_tracking_dist,
                        pft_front_tracking_dist=pft_front_tracking_dist,
                        particle_count=particle_count,
                        return_all=True)
                else:
                    raise ValueError(
                        'ERROR: No valid tracking method(s) specified.')

                # Filter resulting streamlines by roi-intersection characteristics
                roi_proximal_streamlines = Streamlines(
                    select_by_rois(streamline_generator,
                                   affine=np.eye(4),
                                   rois=parcels,
                                   include=parcel_vec,
                                   mode='any',
                                   tol=roi_neighborhood_tol))

                print("%s%s" %
                      ('Qualifying Streamlines by node intersection: ',
                       len(roi_proximal_streamlines)))

                roi_proximal_streamlines = nib.streamlines.array_sequence.ArraySequence(
                    [
                        s for s in roi_proximal_streamlines
                        if len(s) > float(min_length)
                    ])

                print("%s%s" %
                      ('Qualifying Streamlines by minimum length criterion: ',
                       len(roi_proximal_streamlines)))

                if waymask:
                    roi_proximal_streamlines = roi_proximal_streamlines[
                        utils.near_roi(roi_proximal_streamlines,
                                       np.eye(4),
                                       waymask_data,
                                       tol=roi_neighborhood_tol,
                                       mode='any')]
                    print("%s%s" %
                          ('Qualifying Streamlines by waymask proximity: ',
                           len(roi_proximal_streamlines)))

                # Repeat process until target samples condition is met
                ix = ix + 1
                for s in roi_proximal_streamlines:
                    stream_counter = stream_counter + len(s)
                    streamlines.append(s)
                    if int(stream_counter) >= int(target_samples):
                        break
                    else:
                        continue

                # Cleanup memory
                del seeds, roi_proximal_streamlines, streamline_generator

            del dg

        circuit_ix = circuit_ix + 1
        print(
            "%s%s%s%s%s" %
            ('Completed hyperparameter circuit: ', circuit_ix,
             '...\nCumulative Streamline Count: ', Fore.CYAN, stream_counter))
        print(Style.RESET_ALL)

    print('\n')

    return streamlines
    #                           colors=None)
    # slice_actor.RotateX(90)

    # ren.add(slice_actor)
    # if interactive:
    #     window.show(ren, size=(900, 900))
    # else:
    #     ren.set_camera(position=[0,-1,0], focal_point=[0,0,0], view_up=[0,0,1])
    #     window.record(ren, out_path='csd_direction_bm.png', size=(900, 900))

    print 'seeding begins, using np.random.seed(123)'
    st3 = time.time()
    np.random.seed(123)
    #seeds = utils.random_seeds_from_mask(mask, 1) #does crash because of memory limitations
    seeds = utils.random_seeds_from_mask(mask,
                                         20000,
                                         seed_count_per_voxel=False)
    for i in range(len(seeds)):
        if seeds[i][0] > 199.:
            seeds[i][0] = 398 - seeds[i][0]
        if seeds[i][1] > 399.:
            seeds[i][1] = 798 - seeds[i][1]
        if seeds[i][2] > 199.:
            seeds[i][2] = 398 - seeds[i][2]
        for j in range(3):
            if seeds[i][j] < 0.:
                seeds[i][j] = -seeds[i][j]
    et3 = time.time() - st3
    print 'seeding transformation finished, the total seeds are {}, running time is {}'.format(
        seeds.shape[0], et3)
Esempio n. 24
0
def run_tracking(step_curv_combinations,
                 recon_path,
                 n_seeds_per_iter,
                 directget,
                 maxcrossing,
                 max_length,
                 pft_back_tracking_dist,
                 pft_front_tracking_dist,
                 particle_count,
                 roi_neighborhood_tol,
                 waymask,
                 min_length,
                 track_type,
                 min_separation_angle,
                 sphere,
                 tiss_class,
                 tissues4d,
                 cache_dir,
                 min_seeds=100):

    import gc
    import os
    import h5py
    from dipy.tracking import utils
    from dipy.tracking.streamline import select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, \
        ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter,
                                ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)
    from nilearn.image import index_img
    from pynets.dmri.track import prep_tissues
    from nibabel.streamlines.array_sequence import ArraySequence
    from nipype.utils.filemanip import copyfile, fname_presuffix
    import uuid
    from time import strftime

    run_uuid = f"{strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4()}"

    recon_path_tmp_path = fname_presuffix(
        recon_path,
        suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
        f"{run_uuid}",
        newpath=cache_dir)
    copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False)

    tissues4d_tmp_path = fname_presuffix(
        tissues4d,
        suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
        f"{run_uuid}",
        newpath=cache_dir)
    copyfile(tissues4d, tissues4d_tmp_path, copy=True, use_hardlink=False)

    if waymask is not None:
        waymask_tmp_path = fname_presuffix(
            waymask,
            suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
            f"{run_uuid}",
            newpath=cache_dir)
        copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False)
    else:
        waymask_tmp_path = None

    tissue_img = nib.load(tissues4d_tmp_path)

    # Order:
    B0_mask = index_img(tissue_img, 0)
    atlas_img = index_img(tissue_img, 1)
    seeding_mask = index_img(tissue_img, 2)
    t1w2dwi = index_img(tissue_img, 3)
    gm_in_dwi = index_img(tissue_img, 4)
    vent_csf_in_dwi = index_img(tissue_img, 5)
    wm_in_dwi = index_img(tissue_img, 6)

    tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                                   wm_in_dwi, tiss_class, B0_mask)

    B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool")

    seeding_mask = np.asarray(
        seeding_mask.dataobj).astype("bool").astype("int16")

    with h5py.File(recon_path_tmp_path, 'r+') as hf:
        mod_fit = hf['reconstruction'][:].astype('float32')

    print("%s%s" % ("Curvature: ", step_curv_combinations[1]))

    # Instantiate DirectionGetter
    if directget.lower() in ["probabilistic", "prob"]:
        dg = ProbabilisticDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget.lower() in ["closestpeaks", "cp"]:
        dg = ClosestPeakDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget.lower() in ["deterministic", "det"]:
        maxcrossing = 1
        dg = DeterministicMaximumDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    else:
        raise ValueError("ERROR: No valid direction getter(s) specified.")

    print("%s%s" % ("Step: ", step_curv_combinations[0]))

    # Perform wm-gm interface seeding, using n_seeds at a time
    seeds = utils.random_seeds_from_mask(
        seeding_mask > 0,
        seeds_count=n_seeds_per_iter,
        seed_count_per_voxel=False,
        affine=np.eye(4),
    )
    if len(seeds) < min_seeds:
        print(
            UserWarning(
                f"<{min_seeds} valid seed points found in wm-gm interface..."))
        return None

    # print(seeds)

    # Perform tracking
    if track_type == "local":
        streamline_generator = LocalTracking(dg,
                                             tiss_classifier,
                                             seeds,
                                             np.eye(4),
                                             max_cross=int(maxcrossing),
                                             maxlen=int(max_length),
                                             step_size=float(
                                                 step_curv_combinations[0]),
                                             fixedstep=False,
                                             return_all=True,
                                             random_seed=42)
    elif track_type == "particle":
        streamline_generator = ParticleFilteringTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            step_size=float(step_curv_combinations[0]),
            maxlen=int(max_length),
            pft_back_tracking_dist=pft_back_tracking_dist,
            pft_front_tracking_dist=pft_front_tracking_dist,
            pft_max_trial=20,
            particle_count=particle_count,
            return_all=True,
            random_seed=42)
    else:
        raise ValueError("ERROR: No valid tracking method(s) specified.")

    # Filter resulting streamlines by those that stay entirely
    # inside the brain
    try:
        roi_proximal_streamlines = utils.target(streamline_generator,
                                                np.eye(4),
                                                B0_mask_data.astype('bool'),
                                                include=True)
    except BaseException:
        print('No streamlines found inside the brain! ' 'Check registrations.')
        return None

    del mod_fit, seeds, tiss_classifier, streamline_generator, \
        B0_mask_data, seeding_mask, dg

    B0_mask.uncache()
    atlas_img.uncache()
    t1w2dwi.uncache()
    gm_in_dwi.uncache()
    vent_csf_in_dwi.uncache()
    wm_in_dwi.uncache()
    atlas_img.uncache()
    tissue_img.uncache()
    gc.collect()

    # Filter resulting streamlines by roi-intersection
    # characteristics
    atlas_data = np.array(atlas_img.dataobj).astype("uint16")

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    intensities = [i for i in np.unique(atlas_data) if i != 0]
    for roi_val in intensities:
        parcels.append(atlas_data == roi_val)
        i += 1

    parcel_vec = list(np.ones(len(parcels)).astype("bool"))

    try:
        roi_proximal_streamlines = \
            nib.streamlines.array_sequence.ArraySequence(
                select_by_rois(
                    roi_proximal_streamlines,
                    affine=np.eye(4),
                    rois=parcels,
                    include=parcel_vec,
                    mode="any",
                    tol=roi_neighborhood_tol,
                )
            )
        print("%s%s" % ("Filtering by: \nNode intersection: ",
                        len(roi_proximal_streamlines)))
    except BaseException:
        print('No streamlines found to connect any parcels! '
              'Check registrations.')
        return None

    try:
        roi_proximal_streamlines = nib.streamlines. \
            array_sequence.ArraySequence(
                [
                    s for s in roi_proximal_streamlines
                    if len(s) >= float(min_length)
                ]
            )
        print(f"Minimum fiber length >{min_length}mm: "
              f"{len(roi_proximal_streamlines)}")
    except BaseException:
        print('No streamlines remaining after minimal length criterion.')
        return None

    if waymask is not None and os.path.isfile(waymask_tmp_path):
        waymask_data = np.asarray(
            nib.load(waymask_tmp_path).dataobj).astype("bool")
        try:
            roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(
                roi_proximal_streamlines,
                np.eye(4),
                waymask_data,
                tol=int(round(roi_neighborhood_tol * 0.50, 1)),
                mode="all")]
            print("%s%s" %
                  ("Waymask proximity: ", len(roi_proximal_streamlines)))
            del waymask_data
        except BaseException:
            print('No streamlines remaining in waymask\'s vacinity.')
            return None

    hf.close()
    del parcels, atlas_data

    tmp_files = [tissues4d_tmp_path, waymask_tmp_path, recon_path_tmp_path]
    for j in tmp_files:
        if j is not None:
            if os.path.isfile(j):
                os.system(f"rm -f {j} &")

    if len(roi_proximal_streamlines) > 0:
        return ArraySequence(
            [s.astype("float32") for s in roi_proximal_streamlines])
    else:
        return None
Esempio n. 25
0
def dwi_dipy_run(dwi_dir,
                 node_size,
                 dir_path,
                 conn_model,
                 parc,
                 atlas_select,
                 network,
                 wm_mask=None):
    from dipy.reconst.dti import TensorModel, quantize_evecs
    from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel, recursive_response
    from dipy.tracking.local import LocalTracking, ActTissueClassifier
    from dipy.tracking import utils
    from dipy.direction import peaks_from_model
    from dipy.tracking.eudx import EuDX
    from dipy.data import get_sphere, default_sphere
    from dipy.core.gradients import gradient_table
    from dipy.io import read_bvals_bvecs
    from dipy.tracking.streamline import Streamlines
    from dipy.direction import ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, BootDirectionGetter
    from nibabel.streamlines import save as save_trk
    from nibabel.streamlines import Tractogram

    ##
    dwi_dir = '/Users/PSYC-dap3463/Downloads/bedpostx_s002'
    img_pve_csf = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_vent_csf_diff_dwi.nii.gz'
    )
    img_pve_wm = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_wm_in_dwi_bin.nii.gz'
    )
    img_pve_gm = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_gm_mask_dwi.nii.gz'
    )
    labels_img = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/dwi_aligned_atlas.nii.gz'
    )
    num_total_samples = 10000
    tracking_method = 'boot'  # Options are 'boot', 'prob', 'peaks', 'closest'
    procmem = [2, 4]
    ##

    if parc is True:
        node_size = 'parc'

    dwi_img = "%s%s" % (dwi_dir, '/dwi.nii.gz')
    nodif_brain_mask_path = "%s%s" % (dwi_dir, '/nodif_brain_mask.nii.gz')
    bvals = "%s%s" % (dwi_dir, '/bval')
    bvecs = "%s%s" % (dwi_dir, '/bvec')

    dwi_img = nib.load(dwi_img)
    data = dwi_img.get_data()
    [bvals, bvecs] = read_bvals_bvecs(bvals, bvecs)
    gtab = gradient_table(bvals, bvecs)
    gtab.b0_threshold = min(bvals)
    sphere = get_sphere('symmetric724')

    # Loads mask and ensures it's a true binary mask
    mask_img = nib.load(nodif_brain_mask_path)
    mask = mask_img.get_data()
    mask = mask > 0

    # Fit a basic tensor model first
    model = TensorModel(gtab)
    ten = model.fit(data, mask)
    fa = ten.fa

    # Tractography
    if conn_model == 'csd':
        print('Tracking with csd model...')
    elif conn_model == 'tensor':
        print('Tracking with tensor model...')
    else:
        raise RuntimeError("%s%s" % (conn_model, ' is not a valid model.'))

    # Combine seed counts from voxel with seed counts total
    wm_mask_data = img_pve_wm.get_data()
    wm_mask_data[0, :, :] = False
    wm_mask_data[:, 0, :] = False
    wm_mask_data[:, :, 0] = False
    seeds = utils.seeds_from_mask(wm_mask_data,
                                  density=1,
                                  affine=dwi_img.get_affine())
    seeds_rnd = utils.random_seeds_from_mask(ten.fa > 0.02,
                                             seeds_count=num_total_samples,
                                             seed_count_per_voxel=True)
    seeds_all = np.vstack([seeds, seeds_rnd])

    # Load tissue maps and prepare tissue classifier (Anatomically-Constrained Tractography (ACT))
    background = np.ones(img_pve_gm.shape)
    background[(img_pve_gm.get_data() + img_pve_wm.get_data() +
                img_pve_csf.get_data()) > 0] = 0
    include_map = img_pve_gm.get_data()
    include_map[background > 0] = 1
    exclude_map = img_pve_csf.get_data()
    act_classifier = ActTissueClassifier(include_map, exclude_map)

    if conn_model == 'tensor':
        ind = quantize_evecs(ten.evecs, sphere.vertices)
        streamline_generator = EuDX(a=fa,
                                    ind=ind,
                                    seeds=seeds_all,
                                    odf_vertices=sphere.vertices,
                                    a_low=0.05,
                                    step_sz=.5)
    elif conn_model == 'csd':
        print('Tracking with CSD model...')
        response = recursive_response(
            gtab,
            data,
            mask=img_pve_wm.get_data().astype('bool'),
            sh_order=8,
            peak_thr=0.01,
            init_fa=0.05,
            init_trace=0.0021,
            iter=8,
            convergence=0.001,
            parallel=True)
        csd_model = ConstrainedSphericalDeconvModel(gtab, response)
        if tracking_method == 'boot':
            dg = BootDirectionGetter.from_data(data,
                                               csd_model,
                                               max_angle=30.,
                                               sphere=default_sphere)
        elif tracking_method == 'prob':
            try:
                print(
                    'First attempting to build the direction getter directly from the spherical harmonic representation of the FOD...'
                )
                csd_fit = csd_model.fit(
                    data, mask=img_pve_wm.get_data().astype('bool'))
                dg = ProbabilisticDirectionGetter.from_shcoeff(
                    csd_fit.shm_coeff, max_angle=30., sphere=default_sphere)
            except:
                print(
                    'Sphereical harmonic not available for this model. Using peaks_from_model to represent the ODF of the model on a spherical harmonic basis instead...'
                )
                peaks = peaks_from_model(
                    csd_model,
                    data,
                    default_sphere,
                    .5,
                    25,
                    mask=img_pve_wm.get_data().astype('bool'),
                    return_sh=True,
                    parallel=True,
                    nbr_processes=procmem[0])
                dg = ProbabilisticDirectionGetter.from_shcoeff(
                    peaks.shm_coeff, max_angle=30., sphere=default_sphere)
        elif tracking_method == 'peaks':
            dg = peaks_from_model(model=csd_model,
                                  data=data,
                                  sphere=default_sphere,
                                  relative_peak_threshold=.5,
                                  min_separation_angle=25,
                                  mask=img_pve_wm.get_data().astype('bool'),
                                  parallel=True,
                                  nbr_processes=procmem[0])
        elif tracking_method == 'closest':
            csd_fit = csd_model.fit(data,
                                    mask=img_pve_wm.get_data().astype('bool'))
            pmf = csd_fit.odf(default_sphere).clip(min=0)
            dg = ClosestPeakDirectionGetter.from_pmf(pmf,
                                                     max_angle=30.,
                                                     sphere=default_sphere)
        streamline_generator = LocalTracking(dg,
                                             act_classifier,
                                             seeds_all,
                                             affine=dwi_img.affine,
                                             step_size=0.5)
        del dg
        try:
            del csd_fit
        except:
            pass
        try:
            del response
        except:
            pass
        try:
            del csd_model
        except:
            pass
        streamlines = Streamlines(streamline_generator, buffer_size=512)

    save_trk(Tractogram(streamlines, affine_to_rasmm=dwi_img.affine),
             'prob_streamlines.trk')
    tracks = [sl for sl in streamlines if len(sl) > 1]
    labels_data = labels_img.get_data().astype('int')
    labels_affine = labels_img.affine
    conn_matrix, grouping = utils.connectivity_matrix(
        tracks,
        labels_data,
        affine=labels_affine,
        return_mapping=True,
        mapping_as_streamlines=True,
        symmetric=True)
    conn_matrix[:3, :] = 0
    conn_matrix[:, :3] = 0

    return conn_matrix
Esempio n. 26
0
def run(context):

    ####################################################
    # Get the path to input files  and other parameter #
    ####################################################
    analysis_data = context.fetch_analysis_data()
    settings = analysis_data['settings']
    postprocessing = settings['postprocessing']
    dataset = settings['dataset']

    if dataset == "HCPL":
        dwi_file_handle = context.get_files('input', modality='HARDI')[0]
        dwi_file_path = dwi_file_handle.download('/root/')

        bvalues_file_handle = context.get_files(
            'input', reg_expression='.*prep.bvalues.hcpl.txt')[0]
        bvalues_file_path = bvalues_file_handle.download('/root/')
        bvecs_file_handle = context.get_files(
            'input', reg_expression='.*prep.gradients.hcpl.txt')[0]
        bvecs_file_path = bvecs_file_handle.download('/root/')
    elif dataset == "DSI":
        dwi_file_handle = context.get_files('input', modality='DSI')[0]
        dwi_file_path = dwi_file_handle.download('/root/')
        bvalues_file_handle = context.get_files(
            'input', reg_expression='.*prep.bvalues.txt')[0]
        bvalues_file_path = bvalues_file_handle.download('/root/')
        bvecs_file_handle = context.get_files(
            'input', reg_expression='.*prep.gradients.txt')[0]
        bvecs_file_path = bvecs_file_handle.download('/root/')
    else:
            context.set_progress(message='Wrong dataset parameter')

    inject_file_handle = context.get_files(
        'input', reg_expression='.*prep.inject.nii.gz')[0]
    inject_file_path = inject_file_handle.download('/root/')

    VUMC_ROIs_file_handle = context.get_files(
        'input', reg_expression='.*VUMC_ROIs.nii.gz')[0]
    VUMC_ROIs_file_path = VUMC_ROIs_file_handle.download('/root/')

    ###############################
    # _____ _____ _______     __  #
    # |  __ \_   _|  __ \ \   / / #
    # | |  | || | | |__) \ \_/ /  #
    # | |  | || | |  ___/ \   /   #
    # | |__| || |_| |      | |    #
    # |_____/_____|_|      |_|    #
    #                             #
    ###############################

    ########################################################################################
    #  _______             _          __  __   _______             _     __                #
    # |__   __|           | |        |  \/  | |__   __|           | |   / _|               #
    #    | |_ __ __ _  ___| | ___   _| \  / | ___| |_ __ __ _  ___| | _| |_ __ _  ___ ___  #
    #    | | '__/ _` |/ __| |/ / | | | |\/| |/ __| | '__/ _` |/ __| |/ /  _/ _` |/ __/ _ \ #
    #    | | | | (_| | (__|   <| |_| | |  | | (__| | | | (_| | (__|   <| || (_| | (_|  __/ #
    #    |_|_|  \__,_|\___|_|\_\\__, |_|  |_|\___|_|_|  \__,_|\___|_|\_\_| \__,_|\___\___| #
    #                            __/ |                                                     #
    #                           |___/                                                      #
    #                                                                                      #
    #                                                                                      #
    #                               IronTract Team                                         #
    ########################################################################################

    #################
    # Load the data #
    #################
    dwi_img = nib.load(dwi_file_path)
    bvals, bvecs = read_bvals_bvecs(bvalues_file_path,
                                    bvecs_file_path)
    gtab = gradient_table(bvals, bvecs)

    ############################################
    # Extract the brain mask from the b0 image #
    ############################################
    _, brain_mask = median_otsu(dwi_img.get_data()[:, :, :, 0],
                                median_radius=2, numpass=1)

    ##################################################################
    # Fit the tensor model and compute the fractional anisotropy map #
    ##################################################################
    context.set_progress(message='Processing voxel-wise DTI metrics.')
    tenmodel = TensorModel(gtab)
    tenfit = tenmodel.fit(dwi_img.get_data(), mask=brain_mask)
    FA = fractional_anisotropy(tenfit.evals)
    stopping_criterion = ThresholdStoppingCriterion(FA, 0.2)

    sphere = get_sphere("repulsion724")
    seed_mask_img = nib.load(inject_file_path)
    affine = seed_mask_img.affine
    seeds = utils.random_seeds_from_mask(seed_mask_img.get_data(),
                                         affine,
                                         seed_count_per_voxel=True,
                                         seeds_count=5000)

    if dataset == "HCPL":
        ################################################
        # Compute Fiber Orientation Distribution (CSD) #
        ################################################
        context.set_progress(message='Processing voxel-wise FOD estimation.')

        response, _ = auto_response_ssst(gtab, dwi_img.get_data(),
                                         roi_radii=10, fa_thr=0.7)
        csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=8)
        csd_fit = csd_model.fit(dwi_img.get_data(), mask=brain_mask)
        shm = csd_fit.shm_coeff

        prob_dg = ProbabilisticDirectionGetter.from_shcoeff(shm,
                                                            max_angle=20.,
                                                            sphere=sphere,
                                                            pmf_threshold=0.1)
    elif dataset == "DSI":
        context.set_progress(message='Processing voxel-wise DSI estimation.')
        dsmodel = DiffusionSpectrumModel(gtab)
        dsfit = dsmodel.fit(dwi_img.get_data())
        ODFs = dsfit.odf(sphere)
        prob_dg = ProbabilisticDirectionGetter.from_pmf(ODFs,
                                                        max_angle=20.,
                                                        sphere=sphere,
                                                        pmf_threshold=0.01)

    ###########################################
    # Compute DIPY Probabilistic Tractography #
    ###########################################
    context.set_progress(message='Processing tractography.')
    streamline_generator = LocalTracking(prob_dg, stopping_criterion, seeds,
                                         affine, step_size=.2, max_cross=1)
    streamlines = Streamlines(streamline_generator)
    # sft = StatefulTractogram(streamlines, seed_mask_img, Space.RASMM)
    # streamlines_file_path = "/root/streamlines.trk"
    # save_trk(sft, streamlines_file_path)

    ###########################################################################
    # Compute 3D volumes for the IronTract Challenge. For 'EPFL', we only     #
    # keep streamlines with length > 1mm. We compute the visitation  count    #
    # image and apply a small gaussian smoothing. The gaussian smoothing      #
    # is especially usefull to increase voxel coverage of deterministic       #
    # algorithms. The log of the smoothed visitation count map is then        #
    # iteratively thresholded producing 200 volumes/operation points.         #
    # For VUMC, additional streamline filtering is done using anatomical      #
    # priors (keeping only streamlines that intersect with at least one ROI). #
    ###########################################################################
    if postprocessing in ["EPFL", "ALL"]:
        context.set_progress(message='Processing density map (EPFL)')
        volume_folder = "/root/vol_epfl"
        output_epfl_zip_file_path = "/root/TrackyMcTrackface_EPFL_example.zip"
        os.mkdir(volume_folder)
        lengths = length(streamlines)
        streamlines = streamlines[lengths > 1]
        density = utils.density_map(streamlines, affine, seed_mask_img.shape)
        density = scipy.ndimage.gaussian_filter(density.astype("float32"), 0.5)

        log_density = np.log10(density + 1)
        max_density = np.max(log_density)
        for i, t in enumerate(np.arange(0, max_density, max_density / 200)):
            nbr = str(i)
            nbr = nbr.zfill(3)
            mask = log_density >= t
            vol_filename = os.path.join(volume_folder,
                                        "vol" + nbr + "_t" + str(t) + ".nii.gz")
            nib.Nifti1Image(mask.astype("int32"), affine,
                            seed_mask_img.header).to_filename(vol_filename)
        shutil.make_archive(output_epfl_zip_file_path[:-4], 'zip', volume_folder)

    if postprocessing in ["VUMC", "ALL"]:
        context.set_progress(message='Processing density map (VUMC)')
        ROIs_img = nib.load(VUMC_ROIs_file_path)
        volume_folder = "/root/vol_vumc"
        output_vumc_zip_file_path = "/root/TrackyMcTrackface_VUMC_example.zip"
        os.mkdir(volume_folder)
        lengths = length(streamlines)
        streamlines = streamlines[lengths > 1]

        rois = ROIs_img.get_fdata().astype(int)
        _, grouping = utils.connectivity_matrix(streamlines, affine, rois,
                                                inclusive=True,
                                                return_mapping=True,
                                                mapping_as_streamlines=False)
        streamlines = streamlines[grouping[(0, 1)]]

        density = utils.density_map(streamlines, affine, seed_mask_img.shape)
        density = scipy.ndimage.gaussian_filter(density.astype("float32"), 0.5)

        log_density = np.log10(density + 1)
        max_density = np.max(log_density)
        for i, t in enumerate(np.arange(0, max_density, max_density / 200)):
            nbr = str(i)
            nbr = nbr.zfill(3)
            mask = log_density >= t
            vol_filename = os.path.join(volume_folder,
                                        "vol" + nbr + "_t" + str(t) + ".nii.gz")
            nib.Nifti1Image(mask.astype("int32"), affine,
                            seed_mask_img.header).to_filename(vol_filename)
        shutil.make_archive(output_vumc_zip_file_path[:-4], 'zip', volume_folder)

    ###################
    # Upload the data #
    ###################
    context.set_progress(message='Uploading results...')
    #context.upload_file(fa_file_path, 'fa.nii.gz')
    # context.upload_file(fod_file_path, 'fod.nii.gz')
    # context.upload_file(streamlines_file_path, 'streamlines.trk')
    if postprocessing in ["EPFL", "ALL"]:
        context.upload_file(output_epfl_zip_file_path,
                            'TrackyMcTrackface_' + dataset +'_EPFL.zip')
    if postprocessing in ["VUMC", "ALL"]:
        context.upload_file(output_vumc_zip_file_path,
                            'TrackyMcTrackface_' + dataset +'_VUMC.zip')
Esempio n. 27
0
def track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, parcel_vec, mod_fit,
                   tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, max_length,
                   n_seeds_per_iter=200):
    from colorama import Fore, Style
    from dipy.tracking import utils
    from dipy.tracking.streamline import Streamlines, select_by_rois
    from dipy.tracking.local import LocalTracking, ParticleFilteringTracking
    from dipy.direction import ProbabilisticDirectionGetter, BootDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter

    # Commence Ensemble Tractography
    streamlines = nib.streamlines.array_sequence.ArraySequence()
    ix = 0
    circuit_ix = 0
    stream_counter = 0
    while int(stream_counter) < int(target_samples):
        for curv_thr in curv_thr_list:
            print("%s%s" % ('Curvature: ', curv_thr))

            # Instantiate DirectionGetter
            if directget == 'prob':
                dg = ProbabilisticDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                               sphere=sphere)
            elif directget == 'boot':
                dg = BootDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                      sphere=sphere)
            elif directget == 'closest':
                dg = ClosestPeakDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                             sphere=sphere)
            elif directget == 'det':
                dg = DeterministicMaximumDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr),
                                                                      sphere=sphere)
            else:
                raise ValueError('ERROR: No valid direction getter(s) specified.')

            for step in step_list:
                print("%s%s" % ('Step: ', step))
                # Perform wm-gm interface seeding, using n_seeds at a time
                seeds = utils.random_seeds_from_mask(atlas_data_wm_gm_int > 0, seeds_count=n_seeds_per_iter,
                                                     seed_count_per_voxel=False, affine=np.eye(4))
                if len(seeds) == 0:
                    raise RuntimeWarning('Warning: No valid seed points found in wm-gm interface...')

                print(seeds)
                # Perform tracking
                if track_type == 'local':
                    streamline_generator = LocalTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                         max_cross=int(maxcrossing), maxlen=int(max_length),
                                                         step_size=float(step), return_all=True)
                elif track_type == 'particle':
                    streamline_generator = ParticleFilteringTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                                     max_cross=int(maxcrossing),
                                                                     step_size=float(step),
                                                                     maxlen=int(max_length),
                                                                     pft_back_tracking_dist=2,
                                                                     pft_front_tracking_dist=1,
                                                                     particle_count=15, return_all=True)
                else:
                    raise ValueError('ERROR: No valid tracking method(s) specified.')

                # Filter resulting streamlines by roi-intersection characteristics
                streamlines_more = Streamlines(select_by_rois(streamline_generator, parcels, parcel_vec.astype('bool'),
                                                              mode='any', affine=np.eye(4), tol=8))

                # Repeat process until target samples condition is met
                ix = ix + 1
                for s in streamlines_more:
                    stream_counter = stream_counter + len(s)
                    streamlines.append(s)
                    if int(stream_counter) >= int(target_samples):
                        break
                    else:
                        continue

        circuit_ix = circuit_ix + 1
        print("%s%s%s%s%s" % ('Completed hyperparameter circuit: ', circuit_ix, '...\nCumulative Streamline Count: ',
                              Fore.CYAN, stream_counter))
        print(Style.RESET_ALL)

    print('\n')
    return streamlines
Esempio n. 28
0
def tracking(folder):
    print('Tracking in ' + folder)
    output_folder = folder + 'dipy_out/'

    # make a folder to save new data into
    try:
        Path(output_folder).mkdir(parents=True, exist_ok=True)
    except OSError:
        print('Could not create output dir. Aborting...')
        return

    # load data
    print('Loading data...')
    img = nib.load(folder + 'data.nii.gz')
    dmri = np.asarray(img.dataobj)
    affine = img.affine
    mask, _ = load_nifti(folder + 'nodif_brain_mask.nii.gz')
    bvals, bvecs = read_bvals_bvecs(folder + 'bvals', folder + 'bvecs')
    gtab = gradient_table(bvals, bvecs)

    # extract peaksoutput_folder + 'peak_vals.nii.gz'
    if Path(output_folder + 'peaks.pam5').exists():
        peaks = load_peaks(output_folder + 'peaks.pam5')
    else:
        print('Extracting peaks...')
        response, ration = auto_response(gtab, dmri, roi_radius=10, fa_thr=.7)
        csd_model = ConstrainedSphericalDeconvModel(gtab, response)

        peaks = peaks_from_model(model=csd_model,
                                 data=dmri,
                                 sphere=default_sphere,
                                 relative_peak_threshold=.5,
                                 min_separation_angle=25,
                                 parallel=True)

        save_peaks(output_folder + 'peaks.pam5', peaks, affine)
        scaled = peaks.peak_dirs * np.repeat(
            np.expand_dims(peaks.peak_values, -1), 3, -1)

        cropped = scaled[:, :, :, :3, :].reshape(dmri.shape[:3] + (9, ))
        save_nifti(output_folder + 'peaks.nii.gz', cropped, affine)
        #save_nifti(output_folder + 'peak_dirs.nii.gz', peaks.peak_dirs, affine)
        #save_nifti(output_folder + 'peak_vals.nii.gz', peaks.peak_values, affine)

    # tracking
    print('Tracking...')
    maskdata, mask = median_otsu(dmri,
                                 vol_idx=range(0, dmri.shape[3]),
                                 median_radius=3,
                                 numpass=1,
                                 autocrop=True,
                                 dilate=2)
    tensor_model = TensorModel(gtab, fit_method='WLS')
    tensor_fit = tensor_model.fit(maskdata)
    fa = fractional_anisotropy(tensor_fit.evals)
    fa[np.isnan(fa)] = 0
    bla = np.average(fa)
    tissue_classifier = ThresholdStoppingCriterion(fa, .1)
    seeds = random_seeds_from_mask(fa > 1e-5, affine, seeds_count=1)

    streamline_generator = LocalTracking(direction_getter=peaks,
                                         stopping_criterion=tissue_classifier,
                                         seeds=seeds,
                                         affine=affine,
                                         step_size=.5)
    streamlines = Streamlines(streamline_generator)
    save_trk(StatefulTractogram(streamlines, img, Space.RASMM),
             output_folder + 'whole_brain.trk')
Esempio n. 29
0
                             mask=mask,
                             relative_peak_threshold=.5,
                             min_separation_angle=25,
                             parallel=True)

# using the peak_slicer
peak_actor = actor.peak_slicer(csd_peaks.peak_dirs,
                               csd_peaks.peak_values,
                               colors=None)
slider(peak_actor, None)

#generating streamlines

tissue_classifier = ThresholdTissueClassifier(tenfit.fa, 0.1)

seeds = random_seeds_from_mask(tenfit.fa > 0.3, seeds_count=5)

streamline_generator = LocalTracking(csd_peaks,
                                     tissue_classifier,
                                     seeds,
                                     affine=np.eye(4),
                                     step_size=0.5,
                                     return_all=True)

streamlines = Streamlines(streamline_generator)

show_streamlines(streamlines)

save_trk_n('streamlines.trk',
           streamlines,
           affine=affine,
Esempio n. 30
0
fa = tensor_fit.fa

"""
In this simple example we can use FA to stop tracking. Here we stop tracking
when FA < 0.1.
"""

tissue_classifier = ThresholdTissueClassifier(fa, 0.1)

"""
Now, we need to set starting points for propagating each track. We call those
seeds. Using ``random_seeds_from_mask`` we can select a specific number of
seeds (``seeds_count``) in each voxel where the mask ``fa > 0.3`` is true.
"""

seeds = random_seeds_from_mask(fa > 0.3, seeds_count=1)

"""
For quality assurance we can also visualize a slice from the direction field
which we will use as the basis to perform the tracking.
"""

ren = window.Renderer()
ren.add(actor.peak_slicer(csd_peaks.peak_dirs,
                          csd_peaks.peak_values,
                          colors=None))

if interactive:
    window.show(ren, size=(900, 900))
else:
    window.record(ren, out_path='csd_direction_field.png', size=(900, 900))
Esempio n. 31
0
def track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, mod_fit, tiss_classifier, sphere, directget,
                   curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask,
                   B0_mask, max_length=1000, n_seeds_per_iter=500, pft_back_tracking_dist=2, pft_front_tracking_dist=1,
                   particle_count=15, min_separation_angle=20):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI masks.

    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : array
        3D int32 numpy array of atlas parcellation intensities from Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native
        diffusion space.
    mod : obj
        Connectivity reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel while tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask : str
        Path to a Nifti1Image in native diffusion space to constrain tractography.
    B0_mask : str
        File path to B0 brain mask.
    max_length : int
        Maximum number of steps to restrict tracking.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique ensemble combination.
        By default this is set to 200.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    min_separation_angle : float
        The minimum angle between directions [0, 90].

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.

    References
    ----------
    .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016).
      Ensemble Tractography. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.1004692

    """
    import gc
    import time
    from colorama import Fore, Style
    from dipy.tracking import utils
    from dipy.tracking.streamline import Streamlines, select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter, ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)

    start = time.time()

    B0_mask_data = nib.load(B0_mask).get_fdata()

    if waymask:
        waymask_data = np.asarray(nib.load(waymask).dataobj).astype('bool')

    # Commence Ensemble Tractography
    parcel_vec = list(np.ones(len(parcels)).astype('bool'))
    streamlines = nib.streamlines.array_sequence.ArraySequence()

    circuit_ix = 0
    stream_counter = 0
    while int(stream_counter) < int(target_samples):
        for curv_thr in curv_thr_list:
            print("%s%s" % ('Curvature: ', curv_thr))

            # Instantiate DirectionGetter
            if directget == 'prob':
                dg = ProbabilisticDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere,
                                                               min_separation_angle=min_separation_angle)
            elif directget == 'clos':
                dg = ClosestPeakDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere,
                                                             min_separation_angle=min_separation_angle)
            elif directget == 'det':
                dg = DeterministicMaximumDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere,
                                                                      min_separation_angle=min_separation_angle)
            else:
                raise ValueError('ERROR: No valid direction getter(s) specified.')

            for step in step_list:
                print("%s%s" % ('Step: ', step))

                # Perform wm-gm interface seeding, using n_seeds at a time
                seeds = utils.random_seeds_from_mask(atlas_data_wm_gm_int > 0, seeds_count=n_seeds_per_iter,
                                                     seed_count_per_voxel=False, affine=np.eye(4))
                if len(seeds) == 0:
                    raise RuntimeWarning('Warning: No valid seed points found in wm-gm interface...')

                # print(seeds)

                # Perform tracking
                if track_type == 'local':
                    streamline_generator = LocalTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                         max_cross=int(maxcrossing), maxlen=int(max_length),
                                                         step_size=float(step), fixedstep=False, return_all=True)
                elif track_type == 'particle':
                    streamline_generator = ParticleFilteringTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                                     max_cross=int(maxcrossing),
                                                                     step_size=float(step),
                                                                     maxlen=int(max_length),
                                                                     pft_back_tracking_dist=pft_back_tracking_dist,
                                                                     pft_front_tracking_dist=pft_front_tracking_dist,
                                                                     particle_count=particle_count,
                                                                     return_all=True)
                else:
                    raise ValueError('ERROR: No valid tracking method(s) specified.')

                # Filter resulting streamlines by those that stay entirely inside the brain
                roi_proximal_streamlines = utils.target(streamline_generator, np.eye(4), B0_mask_data,
                                                        include=True)

                # Filter resulting streamlines by roi-intersection characteristics
                roi_proximal_streamlines = Streamlines(select_by_rois(roi_proximal_streamlines, affine=np.eye(4),
                                                                      rois=parcels, include=parcel_vec,
                                                                      mode='both_end',
                                                                      tol=roi_neighborhood_tol))

                print("%s%s" % ('Filtering by: \nnode intersection: ', len(roi_proximal_streamlines)))

                if str(min_length) != '0':
                    roi_proximal_streamlines = nib.streamlines.array_sequence.ArraySequence([s for s in
                                                                                             roi_proximal_streamlines
                                                                                             if len(s) >=
                                                                                             float(min_length)])

                    print("%s%s" % ('Minimum length criterion: ', len(roi_proximal_streamlines)))

                if waymask:
                    roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(roi_proximal_streamlines,
                                                                                       np.eye(4),
                                                                                       waymask_data,
                                                                                       tol=roi_neighborhood_tol,
                                                                                       mode='any')]
                    print("%s%s" % ('Waymask proximity: ', len(roi_proximal_streamlines)))

                out_streams = [s.astype('float32') for s in roi_proximal_streamlines]
                streamlines.extend(out_streams)
                stream_counter = stream_counter + len(out_streams)

                # Cleanup memory
                del seeds, roi_proximal_streamlines, streamline_generator, out_streams
                gc.collect()
            del dg

        circuit_ix = circuit_ix + 1
        print("%s%s%s%s%s%s" % ('Completed Hyperparameter Circuit: ', circuit_ix,
                                '\nCumulative Streamline Count: ', Fore.CYAN, stream_counter, "\n"))
        print(Style.RESET_ALL)

    print('Tracking Complete:\n', str(time.time() - start))

    return streamlines