def run(context):

    ####################################################
    # Get the path to input files  and other parameter #
    ####################################################
    analysis_data = context.fetch_analysis_data()
    settings = analysis_data['settings']
    postprocessing = settings['postprocessing']

    hcpl_dwi_file_handle = context.get_files('input', modality='HARDI')[0]
    hcpl_dwi_file_path = hcpl_dwi_file_handle.download('/root/')

    hcpl_bvalues_file_handle = context.get_files(
        'input', reg_expression='.*prep.bvalues.hcpl.txt')[0]
    hcpl_bvalues_file_path = hcpl_bvalues_file_handle.download('/root/')
    hcpl_bvecs_file_handle = context.get_files(
        'input', reg_expression='.*prep.gradients.hcpl.txt')[0]
    hcpl_bvecs_file_path = hcpl_bvecs_file_handle.download('/root/')

    dwi_file_handle = context.get_files('input', modality='DSI')[0]
    dwi_file_path = dwi_file_handle.download('/root/')
    bvalues_file_handle = context.get_files(
        'input', reg_expression='.*prep.bvalues.txt')[0]
    bvalues_file_path = bvalues_file_handle.download('/root/')
    bvecs_file_handle = context.get_files(
        'input', reg_expression='.*prep.gradients.txt')[0]
    bvecs_file_path = bvecs_file_handle.download('/root/')

    inject_file_handle = context.get_files(
        'input', reg_expression='.*prep.inject.nii.gz')[0]
    inject_file_path = inject_file_handle.download('/root/')

    VUMC_ROIs_file_handle = context.get_files(
        'input', reg_expression='.*VUMC_ROIs.nii.gz')[0]
    VUMC_ROIs_file_path = VUMC_ROIs_file_handle.download('/root/')

    ###############################
    # _____ _____ _______     __  #
    # |  __ \_   _|  __ \ \   / / #
    # | |  | || | | |__) \ \_/ /  #
    # | |  | || | |  ___/ \   /   #
    # | |__| || |_| |      | |    #
    # |_____/_____|_|      |_|    #
    #                             #
    # dipy.org/documentation      #
    ###############################
    #       IronTract Team        #
    #      TrackyMcTrackface      #
    ###############################

    #################
    # Load the data #
    #################
    dwi_img = nib.load(hcpl_dwi_file_path)
    bvals, bvecs = read_bvals_bvecs(hcpl_bvalues_file_path,
                                    hcpl_bvecs_file_path)
    gtab = gradient_table(bvals, bvecs)

    ############################################
    # Extract the brain mask from the b0 image #
    ############################################
    _, brain_mask = median_otsu(dwi_img.get_data()[:, :, :, 0],
                                median_radius=2,
                                numpass=1)

    ##################################################################
    # Fit the tensor model and compute the fractional anisotropy map #
    ##################################################################
    context.set_progress(message='Processing voxel-wise DTI metrics.')
    tenmodel = TensorModel(gtab)
    tenfit = tenmodel.fit(dwi_img.get_data(), mask=brain_mask)
    FA = fractional_anisotropy(tenfit.evals)
    # fa_file_path = "/root/fa.nii.gz"
    # nib.Nifti1Image(FA,dwi_img.affine).to_filename(fa_file_path)

    ################################################
    # Compute Fiber Orientation Distribution (CSD) #
    ################################################
    context.set_progress(message='Processing voxel-wise FOD estimation.')
    response, _ = auto_response_ssst(gtab,
                                     dwi_img.get_data(),
                                     roi_radii=10,
                                     fa_thr=0.7)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
    csd_fit = csd_model.fit(dwi_img.get_data(), mask=brain_mask)
    # fod_file_path = "/root/fod.nii.gz"
    # nib.Nifti1Image(csd_fit.shm_coeff,dwi_img.affine).to_filename(fod_file_path)

    ###########################################
    # Compute DIPY Probabilistic Tractography #
    ###########################################
    context.set_progress(message='Processing tractography.')
    sphere = get_sphere("repulsion724")
    seed_mask_img = nib.load(inject_file_path)
    affine = seed_mask_img.affine
    seeds = utils.seeds_from_mask(seed_mask_img.get_data(), affine, density=5)

    stopping_criterion = ThresholdStoppingCriterion(FA, 0.2)
    prob_dg = ProbabilisticDirectionGetter.from_shcoeff(csd_fit.shm_coeff,
                                                        max_angle=20.,
                                                        sphere=sphere)
    streamline_generator = LocalTracking(prob_dg,
                                         stopping_criterion,
                                         seeds,
                                         affine,
                                         step_size=.2,
                                         max_cross=1)
    streamlines = Streamlines(streamline_generator)
    # sft = StatefulTractogram(streamlines, seed_mask_img, Space.RASMM)
    # streamlines_file_path = "/root/streamlines.trk"
    # save_trk(sft, streamlines_file_path)

    ###########################################################################
    # Compute 3D volumes for the IronTract Challenge. For 'EPFL', we only     #
    # keep streamlines with length > 1mm. We compute the visitation  count    #
    # image and apply a small gaussian smoothing. The gaussian smoothing      #
    # is especially usefull to increase voxel coverage of deterministic       #
    # algorithms. The log of the smoothed visitation count map is then        #
    # iteratively thresholded producing 200 volumes/operation points.         #
    # For VUMC, additional streamline filtering is done using anatomical      #
    # priors (keeping only streamlines that intersect with at least one ROI). #
    ###########################################################################
    if postprocessing in ["EPFL", "ALL"]:
        context.set_progress(message='Processing density map (EPFL)')
        volume_folder = "/root/vol_epfl"
        output_epfl_zip_file_path = "/root/TrackyMcTrackface_EPFL_example.zip"
        os.mkdir(volume_folder)
        lengths = length(streamlines)
        streamlines = streamlines[lengths > 1]
        density = utils.density_map(streamlines, affine, seed_mask_img.shape)
        density = scipy.ndimage.gaussian_filter(density.astype("float32"), 0.5)

        log_density = np.log10(density + 1)
        max_density = np.max(log_density)
        for i, t in enumerate(np.arange(0, max_density, max_density / 200)):
            nbr = str(i)
            nbr = nbr.zfill(3)
            mask = log_density >= t
            vol_filename = os.path.join(
                volume_folder, "vol" + nbr + "_t" + str(t) + ".nii.gz")
            nib.Nifti1Image(mask.astype("int32"), affine,
                            seed_mask_img.header).to_filename(vol_filename)
        shutil.make_archive(output_epfl_zip_file_path[:-4], 'zip',
                            volume_folder)

    if postprocessing in ["VUMC", "ALL"]:
        context.set_progress(message='Processing density map (VUMC)')
        ROIs_img = nib.load(VUMC_ROIs_file_path)
        volume_folder = "/root/vol_vumc"
        output_vumc_zip_file_path = "/root/TrackyMcTrackface_VUMC_example.zip"
        os.mkdir(volume_folder)
        lengths = length(streamlines)
        streamlines = streamlines[lengths > 1]

        rois = ROIs_img.get_fdata().astype(int)
        _, grouping = utils.connectivity_matrix(streamlines,
                                                affine,
                                                rois,
                                                inclusive=True,
                                                return_mapping=True,
                                                mapping_as_streamlines=False)
        streamlines = streamlines[grouping[(0, 1)]]

        density = utils.density_map(streamlines, affine, seed_mask_img.shape)
        density = scipy.ndimage.gaussian_filter(density.astype("float32"), 0.5)

        log_density = np.log10(density + 1)
        max_density = np.max(log_density)
        for i, t in enumerate(np.arange(0, max_density, max_density / 200)):
            nbr = str(i)
            nbr = nbr.zfill(3)
            mask = log_density >= t
            vol_filename = os.path.join(
                volume_folder, "vol" + nbr + "_t" + str(t) + ".nii.gz")
            nib.Nifti1Image(mask.astype("int32"), affine,
                            seed_mask_img.header).to_filename(vol_filename)
        shutil.make_archive(output_vumc_zip_file_path[:-4], 'zip',
                            volume_folder)

    ###################
    # Upload the data #
    ###################
    context.set_progress(message='Uploading results...')
    # context.upload_file(fa_file_path, 'fa.nii.gz')
    # context.upload_file(fod_file_path, 'fod.nii.gz')
    # context.upload_file(streamlines_file_path, 'streamlines.trk')
    if postprocessing in ["EPFL", "ALL"]:
        context.upload_file(output_epfl_zip_file_path,
                            'TrackyMcTrackface_EPFL_example.zip')
    if postprocessing in ["VUMC", "ALL"]:
        context.upload_file(output_vumc_zip_file_path,
                            'TrackyMcTrackface_VUMC_example.zip')
Ejemplo n.º 2
0
def test_bootstap_peak_tracker():
    """This tests that the Bootstrat Peak Direction Getter plays nice
    LocalTracking and produces reasonable streamlines in a simple example.
    """
    sphere = get_sphere('repulsion100')

    # A simple image with three possible configurations, a vertical tract,
    # a horizontal tract and a crossing
    simple_image = np.array([
        [0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [2, 3, 2, 2, 2, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
    ])
    simple_image = simple_image[..., None]

    bvecs = sphere.vertices
    bvals = np.ones(len(bvecs)) * 1000
    bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0)
    bvals = np.insert(bvals, 0, 0)
    gtab = gradient_table(bvals, bvecs)
    angles = [(90, 90), (90, 0)]
    fracs = [50, 50]
    mevals = np.array([[1.5, 0.4, 0.4], [1.5, 0.4, 0.4]]) * 1e-3
    mevecs = [
        np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
        np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
    ]
    voxel1 = single_tensor(gtab, 1, mevals[0], mevecs[0], snr=None)
    voxel2 = single_tensor(gtab, 1, mevals[0], mevecs[1], snr=None)
    voxel3, _ = multi_tensor(gtab,
                             mevals,
                             fractions=fracs,
                             angles=angles,
                             snr=None)
    data = np.tile(voxel3, [5, 6, 1, 1])
    data[simple_image == 1] = voxel1
    data[simple_image == 2] = voxel2

    response = (np.array(mevals[1]), 1)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)

    seeds = [np.array([0., 1., 0.]), np.array([2., 4., 0.])]

    sc = BinaryStoppingCriterion((simple_image > 0).astype(float))
    sphere = HemiSphere.from_sphere(get_sphere('symmetric724'))
    boot_dg = BootDirectionGetter.from_data(data, csd_model, 60, sphere=sphere)

    streamlines_generator = LocalTracking(boot_dg, sc, seeds, np.eye(4), 1.)
    streamlines = Streamlines(streamlines_generator)
    expected = [
        np.array([[0., 1., 0.], [1., 1., 0.], [2., 1., 0.], [3., 1., 0.],
                  [4., 1., 0.]]),
        np.array([
            [2., 4., 0.],
            [2., 3., 0.],
            [2., 2., 0.],
            [2., 1., 0.],
            [2., 0., 0.],
        ])
    ]

    def allclose(x, y):
        return x.shape == y.shape and np.allclose(x, y, atol=0.5)

    if not allclose(streamlines[0], expected[0]):
        raise AssertionError()
    if not allclose(streamlines[1], expected[1]):
        raise AssertionError()
Ejemplo n.º 3
0
def tracking(folder):
    print('Tracking in ' + folder)
    output_folder = folder + 'dipy_out/'

    # make a folder to save new data into
    try:
        Path(output_folder).mkdir(parents=True, exist_ok=True)
    except OSError:
        print('Could not create output dir. Aborting...')
        return

    # load data
    print('Loading data...')
    img = nib.load(folder + 'data.nii.gz')
    dmri = np.asarray(img.dataobj)
    affine = img.affine
    mask, _ = load_nifti(folder + 'nodif_brain_mask.nii.gz')
    bvals, bvecs = read_bvals_bvecs(folder + 'bvals', folder + 'bvecs')
    gtab = gradient_table(bvals, bvecs)

    # extract peaksoutput_folder + 'peak_vals.nii.gz'
    if Path(output_folder + 'peaks.pam5').exists():
        peaks = load_peaks(output_folder + 'peaks.pam5')
    else:
        print('Extracting peaks...')
        response, ration = auto_response(gtab, dmri, roi_radius=10, fa_thr=.7)
        csd_model = ConstrainedSphericalDeconvModel(gtab, response)

        peaks = peaks_from_model(model=csd_model,
                                 data=dmri,
                                 sphere=default_sphere,
                                 relative_peak_threshold=.5,
                                 min_separation_angle=25,
                                 parallel=True)

        save_peaks(output_folder + 'peaks.pam5', peaks, affine)
        scaled = peaks.peak_dirs * np.repeat(
            np.expand_dims(peaks.peak_values, -1), 3, -1)

        cropped = scaled[:, :, :, :3, :].reshape(dmri.shape[:3] + (9, ))
        save_nifti(output_folder + 'peaks.nii.gz', cropped, affine)
        #save_nifti(output_folder + 'peak_dirs.nii.gz', peaks.peak_dirs, affine)
        #save_nifti(output_folder + 'peak_vals.nii.gz', peaks.peak_values, affine)

    # tracking
    print('Tracking...')
    maskdata, mask = median_otsu(dmri,
                                 vol_idx=range(0, dmri.shape[3]),
                                 median_radius=3,
                                 numpass=1,
                                 autocrop=True,
                                 dilate=2)
    tensor_model = TensorModel(gtab, fit_method='WLS')
    tensor_fit = tensor_model.fit(maskdata)
    fa = fractional_anisotropy(tensor_fit.evals)
    fa[np.isnan(fa)] = 0
    bla = np.average(fa)
    tissue_classifier = ThresholdStoppingCriterion(fa, .1)
    seeds = random_seeds_from_mask(fa > 1e-5, affine, seeds_count=1)

    streamline_generator = LocalTracking(direction_getter=peaks,
                                         stopping_criterion=tissue_classifier,
                                         seeds=seeds,
                                         affine=affine,
                                         step_size=.5)
    streamlines = Streamlines(streamline_generator)
    save_trk(StatefulTractogram(streamlines, img, Space.RASMM),
             output_folder + 'whole_brain.trk')
Ejemplo n.º 4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file, args.mask_file])
    assert_outputs_exist(parser, args, args.output_file)

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    mask_img = nib.load(args.mask_file)
    mask_data = mask_img.get_fdata()

    # Make sure the mask is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines = LocalTracking(_get_direction_getter(args, mask_data),
                                BinaryStoppingCriterion(mask_data),
                                seeds,
                                np.eye(4),
                                step_size=vox_step_size,
                                max_cross=1,
                                maxlen=max_steps,
                                fixedstep=True,
                                return_all=True,
                                random_seed=args.seed,
                                save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Ejemplo n.º 5
0
def run_tracking(step_curv_combinations, recon_path,
                 n_seeds_per_iter, directget, maxcrossing, max_length,
                 pft_back_tracking_dist, pft_front_tracking_dist,
                 particle_count, roi_neighborhood_tol, waymask, min_length,
                 track_type, min_separation_angle, sphere, tiss_class,
                 tissues4d, cache_dir, min_seeds=100):

    import gc
    import os
    import h5py
    from dipy.tracking import utils
    from dipy.tracking.streamline import select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, \
        ParticleFilteringTracking
    from dipy.direction import (
        ProbabilisticDirectionGetter,
        ClosestPeakDirectionGetter,
        DeterministicMaximumDirectionGetter
    )
    from nilearn.image import index_img
    from pynets.dmri.track import prep_tissues
    from nibabel.streamlines.array_sequence import ArraySequence
    from nipype.utils.filemanip import copyfile, fname_presuffix
    import uuid
    from time import strftime

    run_uuid = f"{strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4()}"

    recon_path_tmp_path = fname_presuffix(
        recon_path,
        suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
               f"{run_uuid}",
        newpath=cache_dir
    )
    copyfile(
        recon_path,
        recon_path_tmp_path,
        copy=True,
        use_hardlink=False)

    tissues4d_tmp_path = fname_presuffix(
        tissues4d,
        suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
               f"{run_uuid}",
        newpath=cache_dir
    )
    copyfile(
        tissues4d,
        tissues4d_tmp_path,
        copy=True,
        use_hardlink=False)

    if waymask is not None:
        waymask_tmp_path = fname_presuffix(
            waymask,
            suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
                   f"{run_uuid}",
            newpath=cache_dir
        )
        copyfile(
            waymask,
            waymask_tmp_path,
            copy=True,
            use_hardlink=False)
    else:
        waymask_tmp_path = None

    tissue_img = nib.load(tissues4d_tmp_path)

    # Order:
    B0_mask = index_img(tissue_img, 0)
    atlas_img = index_img(tissue_img, 1)
    seeding_mask = index_img(tissue_img, 2)
    t1w2dwi = index_img(tissue_img, 3)
    gm_in_dwi = index_img(tissue_img, 4)
    vent_csf_in_dwi = index_img(tissue_img, 5)
    wm_in_dwi = index_img(tissue_img, 6)

    tiss_classifier = prep_tissues(
        t1w2dwi,
        gm_in_dwi,
        vent_csf_in_dwi,
        wm_in_dwi,
        tiss_class,
        B0_mask
    )

    B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool")

    seeding_mask = np.asarray(
        seeding_mask.dataobj
    ).astype("bool").astype("int16")

    with h5py.File(recon_path_tmp_path, 'r+') as hf:
        mod_fit = hf['reconstruction'][:].astype('float32')

    print("%s%s" % ("Curvature: ", step_curv_combinations[1]))

    # Instantiate DirectionGetter
    if directget.lower() in ["probabilistic", "prob"]:
        dg = ProbabilisticDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget.lower() in ["closestpeaks", "cp"]:
        dg = ClosestPeakDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget.lower() in ["deterministic", "det"]:
        maxcrossing = 1
        dg = DeterministicMaximumDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    else:
        raise ValueError(
            "ERROR: No valid direction getter(s) specified."
        )

    print("%s%s" % ("Step: ", step_curv_combinations[0]))

    # Perform wm-gm interface seeding, using n_seeds at a time
    seeds = utils.random_seeds_from_mask(
        seeding_mask > 0,
        seeds_count=n_seeds_per_iter,
        seed_count_per_voxel=False,
        affine=np.eye(4),
    )
    if len(seeds) < min_seeds:
        print(UserWarning(
            f"<{min_seeds} valid seed points found in wm-gm interface..."
        ))
        return None

    # print(seeds)

    # Perform tracking
    if track_type == "local":
        streamline_generator = LocalTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            maxlen=int(max_length),
            step_size=float(step_curv_combinations[0]),
            fixedstep=False,
            return_all=True,
            random_seed=42
        )
    elif track_type == "particle":
        streamline_generator = ParticleFilteringTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            step_size=float(step_curv_combinations[0]),
            maxlen=int(max_length),
            pft_back_tracking_dist=pft_back_tracking_dist,
            pft_front_tracking_dist=pft_front_tracking_dist,
            pft_max_trial=20,
            particle_count=particle_count,
            return_all=True,
            random_seed=42
        )
    else:
        raise ValueError(
            "ERROR: No valid tracking method(s) specified.")

    # Filter resulting streamlines by those that stay entirely
    # inside the brain
    try:
        roi_proximal_streamlines = utils.target(
            streamline_generator, np.eye(4),
            B0_mask_data.astype('bool'), include=True
        )
    except BaseException:
        print('No streamlines found inside the brain! '
              'Check registrations.')
        return None

    del mod_fit, seeds, tiss_classifier, streamline_generator, \
        B0_mask_data, seeding_mask, dg

    B0_mask.uncache()
    atlas_img.uncache()
    t1w2dwi.uncache()
    gm_in_dwi.uncache()
    vent_csf_in_dwi.uncache()
    wm_in_dwi.uncache()
    atlas_img.uncache()
    tissue_img.uncache()
    gc.collect()

    # Filter resulting streamlines by roi-intersection
    # characteristics
    atlas_data = np.array(atlas_img.dataobj).astype("uint16")

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    intensities = [i for i in np.unique(atlas_data) if i != 0]
    for roi_val in intensities:
        parcels.append(atlas_data == roi_val)
        i += 1

    parcel_vec = list(np.ones(len(parcels)).astype("bool"))

    try:
        roi_proximal_streamlines = \
            nib.streamlines.array_sequence.ArraySequence(
                select_by_rois(
                    roi_proximal_streamlines,
                    affine=np.eye(4),
                    rois=parcels,
                    include=parcel_vec,
                    mode="any",
                    tol=roi_neighborhood_tol,
                )
            )
        print("%s%s" % ("Filtering by: \nNode intersection: ",
                        len(roi_proximal_streamlines)))
    except BaseException:
        print('No streamlines found to connect any parcels! '
              'Check registrations.')
        return None

    try:
        roi_proximal_streamlines = nib.streamlines. \
            array_sequence.ArraySequence(
            [
                s for s in roi_proximal_streamlines
                if len(s) >= float(min_length)
            ]
        )
        print(f"Minimum fiber length >{min_length}mm: "
              f"{len(roi_proximal_streamlines)}")
    except BaseException:
        print('No streamlines remaining after minimal length criterion.')
        return None

    if waymask is not None and os.path.isfile(waymask_tmp_path):
        waymask_data = np.asarray(nib.load(waymask_tmp_path
                                           ).dataobj).astype("bool")
        try:
            roi_proximal_streamlines = roi_proximal_streamlines[
                utils.near_roi(
                    roi_proximal_streamlines,
                    np.eye(4),
                    waymask_data,
                    tol=int(round(roi_neighborhood_tol*0.50, 1)),
                    mode="all"
                )
            ]
            print("%s%s" % ("Waymask proximity: ",
                            len(roi_proximal_streamlines)))
            del waymask_data
        except BaseException:
            print('No streamlines remaining in waymask\'s vacinity.')
            return None

    hf.close()
    del parcels, atlas_data

    tmp_files = [tissues4d_tmp_path, waymask_tmp_path, recon_path_tmp_path]
    for j in tmp_files:
        if j is not None:
            if os.path.isfile(j):
                os.system(f"rm -f {j} &")

    if len(roi_proximal_streamlines) > 0:
        return ArraySequence([s.astype("float32") for s in
                              roi_proximal_streamlines])
    else:
        return None
Ejemplo n.º 6
0
stopping_criterion = ThresholdStoppingCriterion(gfa, .25)
"""
Next, we need to set up our two direction getters
"""
"""
Example #1: Bootstrap direction getter with CSD Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""

boot_dg_csd = BootDirectionGetter.from_data(data,
                                            csd_model,
                                            max_angle=30.,
                                            sphere=small_sphere)
boot_streamline_generator = LocalTracking(boot_dg_csd,
                                          stopping_criterion,
                                          seeds,
                                          affine,
                                          step_size=.5)
streamlines = Streamlines(boot_streamline_generator)
sft = StatefulTractogram(streamlines, hardi_img, Space.RASMM)
save_trk(sft, "tractogram_bootstrap_dg.trk")

if has_fury:
    r = window.Renderer()
    r.add(actor.line(streamlines, colormap.line_colors(streamlines)))
    window.record(r, out_path='tractogram_bootstrap_dg.png', size=(800, 800))
    if interactive:
        window.show(r)
"""
.. figure:: tractogram_bootstrap_dg.png
   :align: center
Ejemplo n.º 7
0
    scene.add(actor.line(streamlines, colormap.line_colors(streamlines)))
    window.record(scene, out_path='tractogram_pft.png', size=(800, 800))
    if interactive:
        window.show(scene)
"""
.. figure:: tractogram_pft.png
 :align: center

 **Corpus Callosum using particle filtering tractography**
"""

# Local Probabilistic Tractography
prob_streamline_generator = LocalTracking(dg,
                                          cmc_criterion,
                                          seeds,
                                          affine,
                                          max_cross=1,
                                          step_size=step_size,
                                          maxlen=1000,
                                          return_all=False)
streamlines = Streamlines(prob_streamline_generator)
sft = StatefulTractogram(streamlines, hardi_img, Space.RASMM)
save_trk(sft, "tractogram_probabilistic_cmc.trk")

if has_fury:
    scene = window.Scene()
    scene.add(actor.line(streamlines, colormap.line_colors(streamlines)))
    window.record(scene,
                  out_path='tractogram_probabilistic_cmc.png',
                  size=(800, 800))
    if interactive:
        window.show(scene)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.in_odf, args.in_seed, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    verify_streamline_length_options(parser, args)
    verify_compression_th(args.compress)
    verify_seed_options(parser, args)

    mask_img = nib.load(args.in_mask)
    mask_data = get_data_as_mask(mask_img, dtype=bool)

    # Make sure the data is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    odf_sh_img = nib.load(args.in_odf)
    if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]),
                       odf_sh_img.header.get_zooms()[0], atol=1e-03):
        parser.error(
            'ODF SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.in_seed)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(dtype=np.float32),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines_generator = LocalTracking(
        _get_direction_getter(args),
        BinaryStoppingCriterion(mask_data),
        seeds, np.eye(4),
        step_size=vox_step_size, max_cross=1,
        maxlen=max_steps,
        fixedstep=True, return_all=True,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines_generator
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines_generator
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (
            compress_streamlines(s, args.compress)
            for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
Ejemplo n.º 9
0
def sfm_tracking(name=None,
                 data_path=None,
                 output_path='.',
                 Threshold=.20,
                 data_list=None,
                 return_streamlines=False,
                 save_track=True,
                 seed='.',
                 minus_ROI_mask='.',
                 one_node=False,
                 two_node=False):

    time0 = time.time()
    print("begin loading data, time:", time.time() - time0)

    if data_list == None:
        data, affine, img, labels, gtab, head_mask = get_data(name, data_path)
    else:
        data = data_list['DWI']
        affine = data_list['affine']
        img = data_list['img']
        labels = data_list['labels']
        gtab = data_list['gtab']
        head_mask = data_list['head_mask']

    if type(seed) != str:
        seed_mask = seed
    else:
        seed_mask = (labels == 2) * (head_mask == 1)

    white_matter = (labels == 2) * (head_mask == 1)
    seeds = utils.seeds_from_mask(seed_mask, affine, density=1)

    print('begin reconstruction, time:', time.time() - time0)

    from dipy.reconst.csdeconv import auto_response_ssst
    from dipy.reconst.shm import CsaOdfModel
    from dipy.data import default_sphere
    from dipy.direction import peaks_from_model

    response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)

    sphere = get_sphere()
    sf_model = sfm.SparseFascicleModel(gtab,
                                       sphere=sphere,
                                       l1_ratio=0.5,
                                       alpha=0.001,
                                       response=response[0])

    pnm = peaks_from_model(sf_model,
                           data,
                           sphere,
                           relative_peak_threshold=.5,
                           min_separation_angle=25,
                           mask=white_matter,
                           parallel=True)

    stopping_criterion = ThresholdStoppingCriterion(pnm.gfa, Threshold)
    #seeds = utils.seeds_from_mask(white_matter, affine, density=1)

    print('begin tracking, time:', time.time() - time0)

    streamline_generator = LocalTracking(pnm,
                                         stopping_criterion,
                                         seeds,
                                         affine,
                                         step_size=.5)
    streamlines = Streamlines(streamline_generator)

    print('begin saving, time:', time.time() - time0)

    from dipy.io.stateful_tractogram import Space, StatefulTractogram
    from dipy.io.streamline import save_trk

    if save_track:

        sft = StatefulTractogram(streamlines, img, Space.RASMM)

        if one_node or two_node:
            sft.to_vox()
            streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node,
                                          two_node)

            if type(minus_ROI_mask) != str:
                streamlines = minus_ROI(streamlines=streamlines,
                                        ROI=minus_ROI_mask)

            sft = StatefulTractogram(streamlines, img, Space.VOX)
            sft._vox_to_rasmm()

        output = output_path + '/tractogram_sfm_' + name + '.trk'
        save_trk(sft, output, streamlines)

    if return_streamlines:
        return streamlines
Ejemplo n.º 10
0
def tractography_estimation_data(dmri_estimation_data):
    path_tmp = tempfile.NamedTemporaryFile(mode='w+',
                                           suffix='.trk',
                                           delete=False)
    trk_path_tmp = str(path_tmp.name)
    dir_path = os.path.dirname(trk_path_tmp)

    gtab = dmri_estimation_data['gtab']
    wm_img = nib.load(dmri_estimation_data['f_pve_wm'])
    dwi_img = nib.load(dmri_estimation_data['dwi_file'])
    dwi_data = dwi_img.get_fdata()
    B0_mask_img = nib.load(dmri_estimation_data['B0_mask'])
    mask_img = intersect_masks(
        [
            nib.Nifti1Image(np.asarray(
                wm_img.dataobj).astype('bool').astype('int'),
                            affine=wm_img.affine),
            nib.Nifti1Image(np.asarray(
                B0_mask_img.dataobj).astype('bool').astype('int'),
                            affine=B0_mask_img.affine)
        ],
        threshold=1,
        connected=False,
    )

    mask_data = mask_img.get_fdata()
    mask_file = fname_presuffix(dmri_estimation_data['B0_mask'],
                                suffix="tracking_mask",
                                use_ext=True)
    mask_img.to_filename(mask_file)
    csa_model = CsaOdfModel(gtab, sh_order=6)
    csa_peaks = peaks_from_model(csa_model,
                                 dwi_data,
                                 default_sphere,
                                 relative_peak_threshold=.8,
                                 min_separation_angle=45,
                                 mask=mask_data)

    stopping_criterion = BinaryStoppingCriterion(mask_data)

    seed_mask = (mask_data == 1)
    seeds = seeds_from_mask(seed_mask, dwi_img.affine, density=[1, 1, 1])

    streamlines_generator = LocalTracking(csa_peaks,
                                          stopping_criterion,
                                          seeds,
                                          affine=dwi_img.affine,
                                          step_size=.5)
    streamlines = Streamlines(streamlines_generator)
    sft = StatefulTractogram(streamlines,
                             B0_mask_img,
                             origin=Origin.NIFTI,
                             space=Space.VOXMM)
    sft.remove_invalid_streamlines()
    trk = f"{dir_path}/tractogram.trk"
    os.rename(trk_path_tmp, trk)
    save_tractogram(sft, trk, bbox_valid_check=False)
    del streamlines, sft, streamlines_generator, seeds, seed_mask, csa_peaks, \
        csa_model, dwi_data, mask_data
    dwi_img.uncache()
    mask_img.uncache()
    gc.collect()

    yield {'trk': trk, 'mask': mask_file}
Ejemplo n.º 11
0
def determine(name=None,
              data_path=None,
              output_path='.',
              Threshold=.20,
              data_list=None,
              seed='.',
              minus_ROI_mask='.',
              one_node=False,
              two_node=False):

    time0 = time.time()
    print("begin loading data, time:", time.time() - time0)

    if data_list == None:
        data, affine, img, labels, gtab, head_mask = get_data(name, data_path)
    else:
        data = data_list['DWI']
        affine = data_list['affine']
        img = data_list['img']
        labels = data_list['labels']
        gtab = data_list['gtab']
        head_mask = data_list['head_mask']

    print(type(seed))

    if type(seed) != str:
        seed_mask = seed
    else:
        seed_mask = (labels == 2) * (head_mask == 1)

    white_matter = (labels == 2) * (head_mask == 1)
    seeds = utils.seeds_from_mask(seed_mask, affine, density=1)

    print("begin reconstruction, time:", time.time() - time0)
    response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
    csd_fit = csd_model.fit(data, mask=white_matter)

    csa_model = CsaOdfModel(gtab, sh_order=6)
    gfa = csa_model.fit(data, mask=white_matter).gfa
    stopping_criterion = ThresholdStoppingCriterion(gfa, Threshold)

    #from dipy.data import small_sphere

    print("begin tracking, time:", time.time() - time0)
    detmax_dg = DeterministicMaximumDirectionGetter.from_shcoeff(
        csd_fit.shm_coeff, max_angle=30., sphere=default_sphere)
    streamline_generator = LocalTracking(detmax_dg,
                                         stopping_criterion,
                                         seeds,
                                         affine,
                                         step_size=.5)
    streamlines = Streamlines(streamline_generator)
    sft = StatefulTractogram(streamlines, img, Space.RASMM)

    if one_node or two_node:
        sft.to_vox()
        streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node,
                                      two_node)

        if type(minus_ROI_mask) != str:

            streamlines = minus_ROI(streamlines=streamlines,
                                    ROI=minus_ROI_mask)

        sft = StatefulTractogram(streamlines, img, Space.VOX)
        sft._vox_to_rasmm()

    print("begin saving, time:", time.time() - time0)

    output = output_path + '/tractogram_deterministic_' + name + '.trk'
    save_trk(sft, output)

    print("finished, time:", time.time() - time0)
Ejemplo n.º 12
0
                                        max_angle=60.,
                                        sphere=small_sphere)

print('streamline gen')
global_chunk_size = args.chunk_size
nchunks = (seed_mask.shape[0] + global_chunk_size - 1) // global_chunk_size

t1 = time.time()
streamline_time = 0
io_time = 0
for idx in range(int(nchunks)):
    # Main streamline computation
    ts = time.time()
    streamline_generator = LocalTracking(
        boot_dg,
        tissue_classifier,
        seed_mask[idx * global_chunk_size:(idx + 1) * global_chunk_size],
        affine=np.eye(4),
        step_size=.5)
    streamlines = [s for s in streamline_generator]
    te = time.time()
    streamline_time += (te - ts)
    print("Generated {} streamlines from {} seeds, time: {} s".format(
        len(streamlines), seed_mask[idx * global_chunk_size:(idx + 1) *
                                    global_chunk_size].shape[0], te - ts))

    # Save tracklines file
    if args.output_prefix:
        fname = "{}.{}_{}.trk".format(args.output_prefix, idx + 1, nchunks)
        ts = time.time()
        #save_tractogram(fname, streamlines, img.affine, vox_size=roi.header.get_zooms(), shape=roi_data.shape)
        #save_tractogram(fname, streamlines)
Ejemplo n.º 13
0
def run_tractography(fdwi,
                     fbval,
                     fbvec,
                     fwmparc,
                     mod_func,
                     mod_type,
                     seed_density=20):
    """
    mod_func : 'str'
        'csd' or 'csa'
    mod_type : 'str'
        'det' or 'prob'
    seed_density : int, default=20
        Seeding density for tractography
    """
    # Getting default params
    sphere = get_sphere("repulsion724")
    stream_affine = np.eye(4)

    # Loading data
    print("Loading Data...")
    dwi, gtab, wm_mask = load_data(fdwi, fbval, fbvec, fwmparc)

    # Make tissue classifier
    tiss_classifier = BinaryStoppingCriterion(wm_mask)

    if mod_func == "csd":
        mod = csd_mod_est(gtab, dwi, wm_mask)
    elif mod_func == "csa":
        mod = odf_mod_est(gtab)

    # Build seed list
    seeds = utils.random_seeds_from_mask(
        wm_mask,
        affine=stream_affine,
        seeds_count=int(seed_density),
        seed_count_per_voxel=True,
    )

    # Make streamlines
    if mod_type == "det":
        print("Obtaining peaks from model...")
        direction_getter = peaks_from_model(
            mod,
            dwi,
            sphere,
            relative_peak_threshold=0.5,
            min_separation_angle=25,
            mask=wm_mask,
            npeaks=5,
            normalize_peaks=True,
        )
    elif mod_type == "prob":
        print("Preparing probabilistic tracking...")
        print("Fitting model to data...")
        mod_fit = mod.fit(dwi, wm_mask)
        print("Building direction-getter...")
        try:
            print(
                "Proceeding using spherical harmonic coefficient from model estimation..."
            )
            direction_getter = ProbabilisticDirectionGetter.from_shcoeff(
                mod_fit.shm_coeff, max_angle=60.0, sphere=sphere)
        except:
            print("Proceeding using FOD PMF from model estimation...")
            fod = mod_fit.odf(sphere)
            pmf = fod.clip(min=0)
            direction_getter = ProbabilisticDirectionGetter.from_pmf(
                pmf, max_angle=60.0, sphere=sphere)

    print("Running Local Tracking")
    streamline_generator = LocalTracking(
        direction_getter,
        tiss_classifier,
        seeds,
        stream_affine,
        step_size=0.5,
        return_all=True,
    )

    print("Reconstructing tractogram streamlines...")
    streamlines = Streamlines(streamline_generator)
    tracks = Streamlines([track for track in streamlines if len(track) > 60])
    return tracks
Ejemplo n.º 14
0
def test_affine_transformations():
    """This tests that the input affine is properly handled by
    LocalTracking and produces reasonable streamlines in a simple example.
    """
    sphere = HemiSphere.from_sphere(unit_octahedron)

    # A simple image with three possible configurations, a vertical tract,
    # a horizontal tract and a crossing
    pmf_lookup = np.array([[0., 0., 1.], [1., 0., 0.], [0., 1., 0.],
                           [.4, .6, 0.]])
    simple_image = np.array([
        [0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 3, 2, 2, 2, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
    ])

    simple_image = simple_image[..., None]
    pmf = pmf_lookup[simple_image]

    seeds = [np.array([1., 1., 0.]), np.array([2., 4., 0.])]

    expected = [
        np.array([[1., 1., 0.], [2., 1., 0.], [3., 1., 0.]]),
        np.array([[2., 1., 0.], [2., 2., 0.], [2., 3., 0.], [2., 4., 0.]])
    ]

    mask = (simple_image > 0).astype(float)
    sc = BinaryStoppingCriterion(mask)

    dg = DeterministicMaximumDirectionGetter.from_pmf(pmf,
                                                      60,
                                                      sphere,
                                                      pmf_threshold=0.1)

    # TST- bad affine wrong shape
    bad_affine = np.eye(3)
    npt.assert_raises(ValueError, LocalTracking, dg, sc, seeds, bad_affine, 1.)

    # TST - bad affine with shearing
    bad_affine = np.eye(4)
    bad_affine[0, 1] = 1.
    npt.assert_raises(ValueError, LocalTracking, dg, sc, seeds, bad_affine, 1.)

    # TST - identity
    a0 = np.eye(4)
    # TST - affines with positive/negative offsets
    a1 = np.eye(4)
    a1[:3, 3] = [1, 2, 3]
    a2 = np.eye(4)
    a2[:3, 3] = [-2, 0, -1]
    # TST - affine with scaling
    a3 = np.eye(4)
    a3[0, 0] = a3[1, 1] = a3[2, 2] = 8
    # TST - affine with axes inverting (negative value)
    a4 = np.eye(4)
    a4[1, 1] = a4[2, 2] = -1
    # TST - combined affines
    a5 = a1 + a2 + a3
    a5[3, 3] = 1
    # TST - in vivo affine example
    # Sometimes data have affines with tiny shear components.
    # For example, the small_101D data-set has some of that:
    fdata, _, _ = get_fnames('small_101D')
    a6 = nib.load(fdata).affine

    for affine in [a0, a1, a2, a3, a4, a5, a6]:
        lin = affine[:3, :3]
        offset = affine[:3, 3]
        seeds_trans = [np.dot(lin, s) + offset for s in seeds]

        # We compute the voxel size to adjust the step size to one voxel
        voxel_size = np.mean(np.sqrt(np.dot(lin, lin).diagonal()))

        streamlines = LocalTracking(direction_getter=dg,
                                    stopping_criterion=sc,
                                    seeds=seeds_trans,
                                    affine=affine,
                                    step_size=voxel_size,
                                    return_all=True)

        # We apply the inverse affine transformation to the generated
        # streamlines. It should be equals to the expected streamlines
        # (generated with the identity affine matrix).
        affine_inv = np.linalg.inv(affine)
        lin = affine_inv[:3, :3]
        offset = affine_inv[:3, 3]
        streamlines_inv = []
        for line in streamlines:
            streamlines_inv.append([np.dot(pts, lin) + offset for pts in line])

        npt.assert_equal(len(streamlines_inv[0]), len(expected[0]))
        npt.assert_(np.allclose(streamlines_inv[0], expected[0], atol=0.3))
        npt.assert_equal(len(streamlines_inv[1]), len(expected[1]))
        npt.assert_(np.allclose(streamlines_inv[1], expected[1], atol=0.3))
Ejemplo n.º 15
0
t1_data = t1.get_data()

white_matter = binary_dilation((labels == 1) | (labels == 2))
csamodel = shm.CsaOdfModel(gtab, 6)
csapeaks = peaks.peaks_from_model(model=csamodel,
                                  data=data,
                                  sphere=peaks.default_sphere,
                                  relative_peak_threshold=.8,
                                  min_separation_angle=45,
                                  mask=white_matter)

affine = np.eye(4)
seeds = utils.seeds_from_mask(white_matter, affine, density=1)
stopping_criterion = BinaryStoppingCriterion(white_matter)

streamline_generator = LocalTracking(csapeaks, stopping_criterion, seeds,
                                     affine=affine, step_size=0.5)
streamlines = Streamlines(streamline_generator)

cc_slice = labels == 2
cc_streamlines = utils.target(streamlines, affine, cc_slice)
cc_streamlines = Streamlines(cc_streamlines)

other_streamlines = utils.target(streamlines, affine, cc_slice,
                                 include=False)
other_streamlines = Streamlines(other_streamlines)
assert len(other_streamlines) + len(cc_streamlines) == len(streamlines)

from dipy.viz import window, actor, colormap as cmap

# Enables/disables interactive visualization
interactive = False
Ejemplo n.º 16
0
def get_csd_streamlines(data_container,
                        random_seeds=False,
                        seeds_count=30000,
                        seeds_per_voxel=False,
                        step_width=1.0,
                        roi_r=10,
                        auto_response_fa_threshold=0.7,
                        fa_threshold=0.15,
                        relative_peak_threshold=0.5,
                        min_separation_angle=25):
    """
    Tracks and returns CSD Streamlines for the given DataContainer.

    Parameters
    ----------
    data_container
        The DataContainer we would like to track streamlines on
    random_seeds
        A boolean indicating whether we would like to use random seeds
    seeds_count
        If we use random seeds, this specifies the seed count
    seeds_per_voxel
        If True, the seed count is specified per voxel
    step_width
        The step width used while tracking
    roi_r
        The radii of the cuboid roi for the automatic estimation of single-shell single-tissue response function using FA.
    auto_response_fa_threshold
        The FA threshold for the automatic estimation of single-shell single-tissue response function using FA.
    fa_threshold
        The FA threshold to use to stop tracking
    relative_peak_threshold
        The relative peak threshold to use to get peaks from the CSDModel
    min_separation_angle
        The minimal separation angle of peaks
    Returns
    -------
    Streamlines
        A list of Streamlines
    """
    seeds = _get_seeds(data_container, random_seeds, seeds_count,
                       seeds_per_voxel)

    response, _ = auto_response_ssst(data_container.gtab,
                                     data_container.dwi,
                                     roi_radii=roi_r,
                                     fa_thr=auto_response_fa_threshold)
    csd_model = ConstrainedSphericalDeconvModel(data_container.gtab, response)

    direction_getter = peaks_from_model(
        model=csd_model,
        data=data_container.dwi,
        sphere=get_sphere('symmetric724'),
        mask=data_container.binary_mask,
        relative_peak_threshold=relative_peak_threshold,
        min_separation_angle=min_separation_angle,
        parallel=False)

    dti_fit = dti.TensorModel(data_container.gtab, fit_method='LS').fit(
        data_container.dwi, mask=data_container.binary_mask)
    classifier = ThresholdStoppingCriterion(dti_fit.fa, fa_threshold)

    streamlines_generator = LocalTracking(direction_getter,
                                          classifier,
                                          seeds,
                                          data_container.aff,
                                          step_size=step_width)
    streamlines = Streamlines(streamlines_generator)

    return streamlines
Ejemplo n.º 17
0
plt.yticks([])
plt.imshow(mask_fa[:, :, data.shape[2] // 2].T, cmap='gray', origin='lower',
           interpolation='nearest')
fig.tight_layout()
fig.savefig('threshold_fa.png')

"""
.. figure:: threshold_fa.png
 :align: center

 **Thresholded fractional anisotropy map.**
"""

streamline_generator = LocalTracking(dg,
                                     threshold_criterion,
                                     seeds,
                                     affine,
                                     step_size=.5,
                                     return_all=True)
streamlines = Streamlines(streamline_generator)
sft = StatefulTractogram(streamlines, hardi_img, Space.RASMM)
save_trk(sft, "tractogram_probabilistic_thresh_all.trk")

if has_fury:
    scene = window.Scene()
    scene.add(actor.line(streamlines, colormap.line_colors(streamlines)))
    window.record(scene, out_path='tractogram_deterministic_thresh_all.png',
                  size=(800, 800))
    if interactive:
        window.show(scene)

"""
Ejemplo n.º 18
0
def tracking(image,
             bvecs,
             bvals,
             wm,
             seeds,
             fibers,
             prune_length=3,
             rseed=42,
             plot=False,
             proba=False,
             verbose=False):
    # Pipelines transcribed from:
    #   https://dipy.org/documentation/1.1.1./examples_built/tracking_introduction_eudx/#example-tracking-introduction-eudx
    #   https://dipy.org/documentation/1.1.1./examples_built/tracking_probabilistic/

    # Load Images
    dwi_loaded = nib.load(image)
    dwi_data = dwi_loaded.get_fdata()

    wm_loaded = nib.load(wm)
    wm_data = wm_loaded.get_fdata()

    seeds_loaded = nib.load(seeds)
    seeds_data = seeds_loaded.get_fdata()
    seeds = utils.seeds_from_mask(seeds_data, dwi_loaded.affine, density=2)

    # Load B-values & B-vectors
    # NB. Use aligned b-vecs if providing eddy-aligned data
    bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
    gtab = gradient_table(bvals, bvecs)
    csa_model = CsaOdfModel(gtab, sh_order=6)

    # Set stopping criterion
    gfa = csa_model.fit(dwi_data, mask=wm_data).gfa
    stop_criterion = ThresholdStoppingCriterion(gfa, .25)

    if proba:
        # Establish ODF model
        response, ratio = auto_response(gtab,
                                        dwi_data,
                                        roi_radius=10,
                                        fa_thr=0.7)
        csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
        csd_fit = csd_model.fit(dwi_data, mask=wm_data)

        # Create Probabilisitic direction getter
        fod = csd_fit.odf(default_sphere)
        pmf = fod.clip(min=0)
        prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf,
                                                        max_angle=30.,
                                                        sphere=default_sphere)
        # Use the probabilisitic direction getter as the dg
        dg = prob_dg

    else:
        # Establish ODF model
        csa_peaks = peaks_from_model(csa_model,
                                     dwi_data,
                                     default_sphere,
                                     relative_peak_threshold=0.8,
                                     min_separation_angle=45,
                                     mask=wm_data)

        # Use the CSA peaks as the dg
        dg = csa_peaks

    # Create generator and perform tracing
    s_generator = LocalTracking(dg,
                                stop_criterion,
                                seeds,
                                dwi_loaded.affine,
                                0.5,
                                random_seed=rseed)
    streamlines = Streamlines(s_generator)

    # Prune streamlines
    streamlines = ArraySequence(
        [strline for strline in streamlines if len(strline) > prune_length])
    sft = StatefulTractogram(streamlines, dwi_loaded, Space.RASMM)

    # Save streamlines
    save_trk(sft, fibers + ".trk")

    # Visualize fibers
    if plot and has_fury:
        from dipy.viz import window, actor, colormap as cmap

        # Create the 3D display.
        r = window.Renderer()
        r.add(actor.line(streamlines, cmap.line_colors(streamlines)))
        window.record(r, out_path=fibers + '.png', size=(800, 800))
Ejemplo n.º 19
0
def test_probabilistic_odf_weighted_tracker():
    """This tests that the Probabalistic Direction Getter plays nice
    LocalTracking and produces reasonable streamlines in a simple example.
    """
    sphere = HemiSphere.from_sphere(unit_octahedron)

    # A simple image with three possible configurations, a vertical tract,
    # a horizontal tract and a crossing
    pmf_lookup = np.array([[0., 0., 1.], [1., 0., 0.], [0., 1., 0.],
                           [.6, .4, 0.]])
    simple_image = np.array([
        [0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 3, 2, 2, 2, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
    ])

    simple_image = simple_image[..., None]
    pmf = pmf_lookup[simple_image]

    seeds = [np.array([1., 1., 0.])] * 30

    mask = (simple_image > 0).astype(float)
    sc = ThresholdStoppingCriterion(mask, .5)

    dg = ProbabilisticDirectionGetter.from_pmf(pmf,
                                               90,
                                               sphere,
                                               pmf_threshold=0.1)
    streamlines = LocalTracking(dg, sc, seeds, np.eye(4), 1.)

    expected = [
        np.array([[0., 1., 0.], [1., 1., 0.], [2., 1., 0.], [2., 2., 0.],
                  [2., 3., 0.], [2., 4., 0.]]),
        np.array([[0., 1., 0.], [1., 1., 0.], [2., 1., 0.], [3., 1., 0.],
                  [4., 1., 0.]])
    ]

    def allclose(x, y):
        return x.shape == y.shape and np.allclose(x, y)

    path = [False, False]
    for sl in streamlines:
        if allclose(sl, expected[0]):
            path[0] = True
        elif allclose(sl, expected[1]):
            path[1] = True
        else:
            raise AssertionError()
    npt.assert_(all(path))

    # The first path is not possible if 90 degree turns are excluded
    dg = ProbabilisticDirectionGetter.from_pmf(pmf,
                                               80,
                                               sphere,
                                               pmf_threshold=0.1)
    streamlines = LocalTracking(dg, sc, seeds, np.eye(4), 1.)

    for sl in streamlines:
        npt.assert_(np.allclose(sl, expected[1]))

    # The first path is not possible if pmf_threshold > 0.67
    # 0.4/0.6 < 2/3, multiplying the pmf should not change the ratio
    dg = ProbabilisticDirectionGetter.from_pmf(10 * pmf,
                                               90,
                                               sphere,
                                               pmf_threshold=0.67)
    streamlines = LocalTracking(dg, sc, seeds, np.eye(4), 1.)

    for sl in streamlines:
        npt.assert_(np.allclose(sl, expected[1]))

    # Test non WM seed position
    seeds = [[0, 0, 0], [5, 5, 5]]
    streamlines = LocalTracking(dg,
                                sc,
                                seeds,
                                np.eye(4),
                                0.2,
                                max_cross=1,
                                return_all=True)
    streamlines = Streamlines(streamlines)
    npt.assert_(len(streamlines[0]) == 1)  # INVALIDPOINT
    npt.assert_(len(streamlines[1]) == 1)  # OUTSIDEIMAGE

    # Test that all points are within the image volume
    seeds = seeds_from_mask(np.ones(mask.shape), np.eye(4), density=2)
    streamline_generator = LocalTracking(dg,
                                         sc,
                                         seeds,
                                         np.eye(4),
                                         0.5,
                                         return_all=True)
    streamlines = Streamlines(streamline_generator)
    for s in streamlines:
        npt.assert_(np.all((s + 0.5).astype(int) >= 0))
        npt.assert_(np.all((s + 0.5).astype(int) < mask.shape))
    # Test that the number of streamline return with return_all=True equal the
    # number of seeds places

    npt.assert_(np.array([len(streamlines) == len(seeds)]))

    # Test reproducibility
    tracking_1 = Streamlines(
        LocalTracking(dg, sc, seeds, np.eye(4), 0.5, random_seed=0)).data
    tracking_2 = Streamlines(
        LocalTracking(dg, sc, seeds, np.eye(4), 0.5, random_seed=0)).data
    npt.assert_equal(tracking_1, tracking_2)
Ejemplo n.º 20
0
def run_tracking(step_curv_combinations, recon_path, n_seeds_per_iter,
                 directget, maxcrossing, max_length, pft_back_tracking_dist,
                 pft_front_tracking_dist, particle_count, roi_neighborhood_tol,
                 waymask, min_length, track_type, min_separation_angle, sphere,
                 tiss_class, tissues4d, cache_dir):

    import gc
    import os
    import h5py
    from dipy.tracking import utils
    from dipy.tracking.streamline import select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, \
        ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter,
                                ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)
    from nilearn.image import index_img
    from pynets.dmri.track import prep_tissues
    from nibabel.streamlines.array_sequence import ArraySequence
    from nipype.utils.filemanip import copyfile, fname_presuffix

    recon_path_tmp_path = fname_presuffix(recon_path,
                                          suffix=f"_{step_curv_combinations}",
                                          newpath=cache_dir)
    copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False)

    if waymask is not None:
        waymask_tmp_path = fname_presuffix(waymask,
                                           suffix=f"_{step_curv_combinations}",
                                           newpath=cache_dir)
        copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False)
    else:
        waymask_tmp_path = None

    tissue_img = nib.load(tissues4d)

    # Order:
    B0_mask = index_img(tissue_img, 0)
    atlas_img = index_img(tissue_img, 1)
    atlas_data_wm_gm_int = index_img(tissue_img, 2)
    t1w2dwi = index_img(tissue_img, 3)
    gm_in_dwi = index_img(tissue_img, 4)
    vent_csf_in_dwi = index_img(tissue_img, 5)
    wm_in_dwi = index_img(tissue_img, 6)

    tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                                   wm_in_dwi, tiss_class, B0_mask)

    B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool")
    atlas_data = np.array(atlas_img.dataobj).astype("uint16")
    atlas_data_wm_gm_int_data = np.asarray(
        atlas_data_wm_gm_int.dataobj).astype("bool").astype("int16")

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    intensities = [i for i in np.unique(atlas_data) if i != 0]
    for roi_val in intensities:
        parcels.append(atlas_data == roi_val)
        i += 1

    del atlas_data

    parcel_vec = list(np.ones(len(parcels)).astype("bool"))

    with h5py.File(recon_path_tmp_path, 'r+') as hf:
        mod_fit = hf['reconstruction'][:].astype('float32')
    hf.close()

    print("%s%s" % ("Curvature: ", step_curv_combinations[1]))

    # Instantiate DirectionGetter
    if directget == "prob" or directget == "probabilistic":
        dg = ProbabilisticDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget == "clos" or directget == "closest":
        dg = ClosestPeakDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget == "det" or directget == "deterministic":
        maxcrossing = 1
        dg = DeterministicMaximumDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    else:
        raise ValueError("ERROR: No valid direction getter(s) specified.")

    print("%s%s" % ("Step: ", step_curv_combinations[0]))

    # Perform wm-gm interface seeding, using n_seeds at a time
    seeds = utils.random_seeds_from_mask(
        atlas_data_wm_gm_int_data > 0,
        seeds_count=n_seeds_per_iter,
        seed_count_per_voxel=False,
        affine=np.eye(4),
    )
    if len(seeds) == 0:
        print(
            UserWarning("No valid seed points found in wm-gm "
                        "interface..."))
        return None

    # print(seeds)

    # Perform tracking
    if track_type == "local":
        streamline_generator = LocalTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            maxlen=int(max_length),
            step_size=float(step_curv_combinations[0]),
            fixedstep=False,
            return_all=True,
        )
    elif track_type == "particle":
        streamline_generator = ParticleFilteringTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            step_size=float(step_curv_combinations[0]),
            maxlen=int(max_length),
            pft_back_tracking_dist=pft_back_tracking_dist,
            pft_front_tracking_dist=pft_front_tracking_dist,
            particle_count=particle_count,
            return_all=True,
        )
    else:
        try:
            raise ValueError("ERROR: No valid tracking method(s) specified.")
        except ValueError:
            import sys
            sys.exit(0)

    # Filter resulting streamlines by those that stay entirely
    # inside the brain
    try:
        roi_proximal_streamlines = utils.target(streamline_generator,
                                                np.eye(4),
                                                B0_mask_data,
                                                include=True)
    except BaseException:
        print('No streamlines found inside the brain! ' 'Check registrations.')
        return None

    # Filter resulting streamlines by roi-intersection
    # characteristics

    try:
        roi_proximal_streamlines = \
            nib.streamlines.array_sequence.ArraySequence(
                select_by_rois(
                    roi_proximal_streamlines,
                    affine=np.eye(4),
                    rois=parcels,
                    include=parcel_vec,
                    mode="%s" % ("any" if waymask is not None else
                                 "both_end"),
                    tol=roi_neighborhood_tol,
                )
            )
        print("%s%s" % ("Filtering by: \nNode intersection: ",
                        len(roi_proximal_streamlines)))
    except BaseException:
        print('No streamlines found to connect any parcels! '
              'Check registrations.')
        return None

    try:
        roi_proximal_streamlines = nib.streamlines. \
            array_sequence.ArraySequence(
            [
                s for s in roi_proximal_streamlines
                if len(s) >= float(min_length)
            ]
        )
        print(f"Minimum fiber length >{min_length}mm: "
              f"{len(roi_proximal_streamlines)}")
    except BaseException:
        print('No streamlines remaining after minimal length criterion.')
        return None

    if waymask is not None and os.path.isfile(waymask_tmp_path):
        from nilearn.image import math_img
        mask = math_img("img > 0.0075", img=nib.load(waymask_tmp_path))
        waymask_data = np.asarray(mask.dataobj).astype("bool")
        try:
            roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(
                roi_proximal_streamlines,
                np.eye(4),
                waymask_data,
                tol=roi_neighborhood_tol,
                mode="all")]
            print("%s%s" %
                  ("Waymask proximity: ", len(roi_proximal_streamlines)))
        except BaseException:
            print('No streamlines remaining in waymask\'s vacinity.')
            return None

    out_streams = [s.astype("float32") for s in roi_proximal_streamlines]

    del dg, seeds, roi_proximal_streamlines, streamline_generator, \
        atlas_data_wm_gm_int_data, mod_fit, B0_mask_data

    os.remove(recon_path_tmp_path)
    gc.collect()

    try:
        return ArraySequence(out_streams)
    except BaseException:
        return None
Ejemplo n.º 21
0
def test_stop_conditions():
    """This tests that the Local Tracker behaves as expected for the
    following tissue types.
    """
    # StreamlineStatus.TRACKPOINT = 1
    # StreamlineStatus.ENDPOINT = 2
    # StreamlineStatus.INVALIDPOINT = 0
    tissue = np.array([[2, 1, 1, 2, 1], [2, 2, 1, 1, 2], [1, 1, 1, 1, 1],
                       [1, 1, 1, 2, 2], [0, 1, 1, 1, 2], [0, 1, 1, 0, 2],
                       [1, 0, 1, 1, 1], [2, 1, 2, 0, 0]])
    tissue = tissue[None]

    sphere = HemiSphere.from_sphere(unit_octahedron)
    pmf_lookup = np.array([[
        0.,
        0.,
        0.,
    ], [0., 0., 1.]])
    pmf = pmf_lookup[(tissue > 0).astype("int")]

    # Create a seeds along
    x = np.array([0., 0, 0, 0, 0, 0, 0, 0])
    y = np.array([0., 1, 2, 3, 4, 5, 6, 7])
    z = np.array([1., 1, 1, 0, 1, 1, 1, 1])
    seeds = np.column_stack([x, y, z])

    # Set up tracking
    endpoint_mask = tissue == StreamlineStatus.ENDPOINT
    invalidpoint_mask = tissue == StreamlineStatus.INVALIDPOINT
    sc = ActStoppingCriterion(endpoint_mask, invalidpoint_mask)
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)

    # valid streamlines only
    streamlines_generator = LocalTracking(direction_getter=dg,
                                          stopping_criterion=sc,
                                          seeds=seeds,
                                          affine=np.eye(4),
                                          step_size=1.,
                                          return_all=False)
    streamlines_not_all = iter(streamlines_generator)

    # all streamlines
    streamlines_all_generator = LocalTracking(direction_getter=dg,
                                              stopping_criterion=sc,
                                              seeds=seeds,
                                              affine=np.eye(4),
                                              step_size=1.,
                                              return_all=True)
    streamlines_all = iter(streamlines_all_generator)

    # Check that the first streamline stops at 1 and 2 (ENDPOINT)
    y = 0
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 2])
    npt.assert_equal(len(sl), 2)

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 2])
    npt.assert_equal(len(sl), 2)

    # Check that the next streamline stops at 1 and 3 (ENDPOINT)
    y = 1
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 3])
    npt.assert_equal(len(sl), 3)

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 3])
    npt.assert_equal(len(sl), 3)

    # This streamline should be the same as above. This row does not have
    # ENDPOINTs, but the streamline should stop at the edge and not include
    # OUTSIDEIMAGE points.
    y = 2
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], [0, y, 0])
    npt.assert_equal(sl[-1], [0, y, 4])
    npt.assert_equal(len(sl), 5)

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 0])
    npt.assert_equal(sl[-1], [0, y, 4])
    npt.assert_equal(len(sl), 5)

    # If we seed on the edge, the first (or last) point in the streamline
    # should be the seed.
    y = 3
    sl = next(streamlines_not_all)
    npt.assert_equal(sl[0], seeds[y])

    sl = next(streamlines_all)
    npt.assert_equal(sl[0], seeds[y])

    # The last 3 seeds should not produce streamlines,
    # INVALIDPOINT streamlines are rejected (return_all=False).
    npt.assert_equal(len(list(streamlines_not_all)), 0)

    # The last 3 seeds should produce invalid streamlines,
    # INVALIDPOINT streamlines are kept (return_all=True).
    # The streamline stops at 1 (INVALIDPOINT) and 3 (ENDPOINT)
    y = 4
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 3])
    npt.assert_equal(len(sl), 3)

    # The streamline stops at 0 (INVALIDPOINT) and 2 (INVALIDPOINT)
    y = 5
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], [0, y, 1])
    npt.assert_equal(sl[-1], [0, y, 2])
    npt.assert_equal(len(sl), 2)

    # The streamline should contain only one point, the seed point,
    # because no valid inital direction was returned.
    y = 6
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], seeds[y])
    npt.assert_equal(sl[-1], seeds[y])
    npt.assert_equal(len(sl), 1)

    # The streamline should contain only one point, the seed point,
    # because no valid neighboring voxel (ENDPOINT)
    y = 7
    sl = next(streamlines_all)
    npt.assert_equal(sl[0], seeds[y])
    npt.assert_equal(sl[-1], seeds[y])
    npt.assert_equal(len(sl), 1)
Ejemplo n.º 22
0
csa_peaks = peaks_from_model(csa_model,
                             data,
                             default_sphere,
                             relative_peak_threshold=.8,
                             min_separation_angle=45,
                             mask=white_matter)

stopping_criterion = ThresholdStoppingCriterion(csa_peaks.gfa, .25)

seed_mask = labels == 2
seeds = utils.seeds_from_mask(seed_mask, affine, density=[1, 1, 1])

# Initialization of LocalTracking. The computation happens in the next step.
streamlines = LocalTracking(csa_peaks,
                            stopping_criterion,
                            seeds,
                            affine,
                            step_size=2)

# Compute streamlines and store as a list.
streamlines = Streamlines(streamlines)
"""
We will create a streamline actor from the streamlines.
"""

streamlines_actor = actor.line(streamlines, cmap.line_colors(streamlines))
"""
Next, we create a surface actor from the corpus callosum seed ROI. We
provide the ROI data, the affine, the color in [R,G,B], and the opacity as
a decimal between zero and one. Here, we set the color as blue/green with
50% opacity.
Ejemplo n.º 23
0
def test_particle_filtering_tractography():
    """This tests that the ParticleFilteringTracking produces
    more streamlines connecting the gray matter than LocalTracking.
    """
    sphere = get_sphere('repulsion100')
    step_size = 0.2

    # Simple tissue masks
    simple_wm = np.array([[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0],
                          [0, 1, 1, 1, 0, 0], [0, 1, 1, 1, 0, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_wm = np.dstack([
        np.zeros(simple_wm.shape), simple_wm, simple_wm, simple_wm,
        np.zeros(simple_wm.shape)
    ])
    simple_gm = np.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
                          [0, 1, 0, 0, 1, 0], [0, 0, 0, 0, 1, 0],
                          [0, 0, 0, 0, 0, 0]])
    simple_gm = np.dstack([
        np.zeros(simple_gm.shape), simple_gm, simple_gm, simple_gm,
        np.zeros(simple_gm.shape)
    ])
    simple_csf = np.ones(simple_wm.shape) - simple_wm - simple_gm

    sc = ActStoppingCriterion.from_pve(simple_wm, simple_gm, simple_csf)
    seeds = seeds_from_mask(simple_wm, np.eye(4), density=2)

    # Random pmf in every voxel
    shape_img = list(simple_wm.shape)
    shape_img.extend([sphere.vertices.shape[0]])
    np.random.seed(0)  # Random number generator initialization
    pmf = np.random.random(shape_img)

    # Test that PFT recover equal or more streamlines than localTracking
    dg = ProbabilisticDirectionGetter.from_pmf(pmf, 60, sphere)
    local_streamlines_generator = LocalTracking(dg,
                                                sc,
                                                seeds,
                                                np.eye(4),
                                                step_size,
                                                max_cross=1,
                                                return_all=False)
    local_streamlines = Streamlines(local_streamlines_generator)

    pft_streamlines_generator = ParticleFilteringTracking(
        dg,
        sc,
        seeds,
        np.eye(4),
        step_size,
        max_cross=1,
        return_all=False,
        pft_back_tracking_dist=1,
        pft_front_tracking_dist=0.5)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_(np.array([len(pft_streamlines) > 0]))
    npt.assert_(np.array([len(pft_streamlines) >= len(local_streamlines)]))

    # Test that all points are equally spaced
    for l in [1, 2, 5, 10, 100]:
        pft_streamlines = ParticleFilteringTracking(dg,
                                                    sc,
                                                    seeds,
                                                    np.eye(4),
                                                    step_size,
                                                    max_cross=1,
                                                    return_all=True,
                                                    maxlen=l)
        for s in pft_streamlines:
            for i in range(len(s) - 1):
                npt.assert_almost_equal(np.linalg.norm(s[i] - s[i + 1]),
                                        step_size)
    # Test that all points are within the image volume
    seeds = seeds_from_mask(np.ones(simple_wm.shape), np.eye(4), density=1)
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          sc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    for s in pft_streamlines:
        npt.assert_(np.all((s + 0.5).astype(int) >= 0))
        npt.assert_(np.all((s + 0.5).astype(int) < simple_wm.shape))

    # Test that the number of streamline return with return_all=True equal the
    # number of seeds places
    npt.assert_(np.array([len(pft_streamlines) == len(seeds)]))

    # Test non WM seed position
    seeds = [[0, 5, 4], [0, 0, 1], [50, 50, 50]]
    pft_streamlines_generator = ParticleFilteringTracking(dg,
                                                          sc,
                                                          seeds,
                                                          np.eye(4),
                                                          step_size,
                                                          max_cross=1,
                                                          return_all=True)
    pft_streamlines = Streamlines(pft_streamlines_generator)

    npt.assert_equal(len(pft_streamlines[0]), 3)  # INVALIDPOINT
    npt.assert_equal(len(pft_streamlines[1]), 3)  # ENDPOINT
    npt.assert_equal(len(pft_streamlines[2]), 1)  # OUTSIDEIMAGE

    # Test with wrong StoppingCriterion type
    sc_bin = BinaryStoppingCriterion(simple_wm)
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(dg, sc_bin, seeds,
                                                      np.eye(4), step_size))
    # Test with invalid back/front tracking distances
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          sc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, pft_back_tracking_dist=-1))
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg,
                                          sc,
                                          seeds,
                                          np.eye(4),
                                          step_size,
                                          pft_back_tracking_dist=0,
                                          pft_front_tracking_dist=-2))

    # Test with invalid affine shape
    npt.assert_raises(
        ValueError,
        lambda: ParticleFilteringTracking(dg, sc, seeds, np.eye(3), step_size))

    # Test with invalid maxlen
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, maxlen=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, maxlen=-1))

    # Test with invalid particle count
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, particle_count=0))
    npt.assert_raises(
        ValueError, lambda: ParticleFilteringTracking(
            dg, sc, seeds, np.eye(4), step_size, particle_count=-1))

    # Test reproducibility
    tracking1 = Streamlines(
        ParticleFilteringTracking(dg,
                                  sc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0)).data
    tracking2 = Streamlines(
        ParticleFilteringTracking(dg,
                                  sc,
                                  seeds,
                                  np.eye(4),
                                  step_size,
                                  random_seed=0)).data
    npt.assert_equal(tracking1, tracking2)
Ejemplo n.º 24
0
These discrete fODFs can be used as a PMF in the `ProbabilisticDirectionGetter`
for sampling tracking directions. The PMF must be strictly non-negative;
RUMBA-SD already adheres to this constraint so no further manipulation of the
fODFs is necessary.
"""

from dipy.direction import ProbabilisticDirectionGetter
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_trk

prob_dg = ProbabilisticDirectionGetter.from_pmf(odf,
                                                max_angle=30.,
                                                sphere=sphere)
streamline_generator = LocalTracking(prob_dg,
                                     stopping_criterion,
                                     seeds,
                                     affine,
                                     step_size=.5)
streamlines = Streamlines(streamline_generator)

color = colormap.line_colors(streamlines)
streamlines_actor = actor.streamtube(list(
    transform_streamlines(streamlines, inv(t1_aff))),
                                     color,
                                     linewidth=0.1)

vol_actor = actor.slicer(t1_data)
vol_actor.display(x=40)
vol_actor2 = vol_actor.copy()
vol_actor2.display(z=35)
Ejemplo n.º 25
0
def test_maximum_deterministic_tracker():
    """This tests that the Maximum Deterministic Direction Getter plays nice
    LocalTracking and produces reasonable streamlines in a simple example.
    """
    sphere = HemiSphere.from_sphere(unit_octahedron)

    # A simple image with three possible configurations, a vertical tract,
    # a horizontal tract and a crossing
    pmf_lookup = np.array([[0., 0., 1.], [1., 0., 0.], [0., 1., 0.],
                           [.4, .6, 0.]])
    simple_image = np.array([
        [0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 3, 2, 2, 2, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
    ])

    simple_image = simple_image[..., None]
    pmf = pmf_lookup[simple_image]

    seeds = [np.array([1., 1., 0.])] * 30

    mask = (simple_image > 0).astype(float)
    sc = ThresholdStoppingCriterion(mask, .5)

    dg = DeterministicMaximumDirectionGetter.from_pmf(pmf,
                                                      90,
                                                      sphere,
                                                      pmf_threshold=0.1)
    streamlines = LocalTracking(dg, sc, seeds, np.eye(4), 1.)

    expected = [
        np.array([[0., 1., 0.], [1., 1., 0.], [2., 1., 0.], [2., 2., 0.],
                  [2., 3., 0.], [2., 4., 0.]]),
        np.array([[0., 1., 0.], [1., 1., 0.], [2., 1., 0.], [3., 1., 0.],
                  [4., 1., 0.]]),
        np.array([[0., 1., 0.], [1., 1., 0.], [2., 1., 0.]])
    ]

    def allclose(x, y):
        return x.shape == y.shape and np.allclose(x, y)

    for sl in streamlines:
        if not allclose(sl, expected[0]):
            raise AssertionError()

    # The first path is not possible if 90 degree turns are excluded
    dg = DeterministicMaximumDirectionGetter.from_pmf(pmf,
                                                      80,
                                                      sphere,
                                                      pmf_threshold=0.1)
    streamlines = LocalTracking(dg, sc, seeds, np.eye(4), 1.)

    for sl in streamlines:
        npt.assert_(np.allclose(sl, expected[1]))

    # Both path are not possible if 90 degree turns are exclude and
    # if pmf_threshold is larger than 0.67. Streamlines should stop at
    # the crossing.
    # 0.4/0.6 < 2/3, multiplying the pmf should not change the ratio
    dg = DeterministicMaximumDirectionGetter.from_pmf(10 * pmf,
                                                      80,
                                                      sphere,
                                                      pmf_threshold=0.67)
    streamlines = LocalTracking(dg, sc, seeds, np.eye(4), 1.)

    for sl in streamlines:
        npt.assert_(np.allclose(sl, expected[2]))
Ejemplo n.º 26
0
def track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, mod_fit, tiss_classifier, sphere, directget,
                   curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask,
                   B0_mask, max_length=1000, n_seeds_per_iter=500, pft_back_tracking_dist=2, pft_front_tracking_dist=1,
                   particle_count=15, min_separation_angle=20):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI masks.

    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : array
        3D int32 numpy array of atlas parcellation intensities from Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native
        diffusion space.
    mod : obj
        Connectivity reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel while tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask : str
        Path to a Nifti1Image in native diffusion space to constrain tractography.
    B0_mask : str
        File path to B0 brain mask.
    max_length : int
        Maximum number of steps to restrict tracking.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique ensemble combination.
        By default this is set to 200.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    min_separation_angle : float
        The minimum angle between directions [0, 90].

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.

    References
    ----------
    .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016).
      Ensemble Tractography. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.1004692

    """
    import gc
    import time
    from colorama import Fore, Style
    from dipy.tracking import utils
    from dipy.tracking.streamline import Streamlines, select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter, ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)

    start = time.time()

    B0_mask_data = nib.load(B0_mask).get_fdata()

    if waymask:
        waymask_data = np.asarray(nib.load(waymask).dataobj).astype('bool')

    # Commence Ensemble Tractography
    parcel_vec = list(np.ones(len(parcels)).astype('bool'))
    streamlines = nib.streamlines.array_sequence.ArraySequence()

    circuit_ix = 0
    stream_counter = 0
    while int(stream_counter) < int(target_samples):
        for curv_thr in curv_thr_list:
            print("%s%s" % ('Curvature: ', curv_thr))

            # Instantiate DirectionGetter
            if directget == 'prob':
                dg = ProbabilisticDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere,
                                                               min_separation_angle=min_separation_angle)
            elif directget == 'clos':
                dg = ClosestPeakDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere,
                                                             min_separation_angle=min_separation_angle)
            elif directget == 'det':
                dg = DeterministicMaximumDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere,
                                                                      min_separation_angle=min_separation_angle)
            else:
                raise ValueError('ERROR: No valid direction getter(s) specified.')

            for step in step_list:
                print("%s%s" % ('Step: ', step))

                # Perform wm-gm interface seeding, using n_seeds at a time
                seeds = utils.random_seeds_from_mask(atlas_data_wm_gm_int > 0, seeds_count=n_seeds_per_iter,
                                                     seed_count_per_voxel=False, affine=np.eye(4))
                if len(seeds) == 0:
                    raise RuntimeWarning('Warning: No valid seed points found in wm-gm interface...')

                # print(seeds)

                # Perform tracking
                if track_type == 'local':
                    streamline_generator = LocalTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                         max_cross=int(maxcrossing), maxlen=int(max_length),
                                                         step_size=float(step), fixedstep=False, return_all=True)
                elif track_type == 'particle':
                    streamline_generator = ParticleFilteringTracking(dg, tiss_classifier, seeds, np.eye(4),
                                                                     max_cross=int(maxcrossing),
                                                                     step_size=float(step),
                                                                     maxlen=int(max_length),
                                                                     pft_back_tracking_dist=pft_back_tracking_dist,
                                                                     pft_front_tracking_dist=pft_front_tracking_dist,
                                                                     particle_count=particle_count,
                                                                     return_all=True)
                else:
                    raise ValueError('ERROR: No valid tracking method(s) specified.')

                # Filter resulting streamlines by those that stay entirely inside the brain
                roi_proximal_streamlines = utils.target(streamline_generator, np.eye(4), B0_mask_data,
                                                        include=True)

                # Filter resulting streamlines by roi-intersection characteristics
                roi_proximal_streamlines = Streamlines(select_by_rois(roi_proximal_streamlines, affine=np.eye(4),
                                                                      rois=parcels, include=parcel_vec,
                                                                      mode='both_end',
                                                                      tol=roi_neighborhood_tol))

                print("%s%s" % ('Filtering by: \nnode intersection: ', len(roi_proximal_streamlines)))

                if str(min_length) != '0':
                    roi_proximal_streamlines = nib.streamlines.array_sequence.ArraySequence([s for s in
                                                                                             roi_proximal_streamlines
                                                                                             if len(s) >=
                                                                                             float(min_length)])

                    print("%s%s" % ('Minimum length criterion: ', len(roi_proximal_streamlines)))

                if waymask:
                    roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(roi_proximal_streamlines,
                                                                                       np.eye(4),
                                                                                       waymask_data,
                                                                                       tol=roi_neighborhood_tol,
                                                                                       mode='any')]
                    print("%s%s" % ('Waymask proximity: ', len(roi_proximal_streamlines)))

                out_streams = [s.astype('float32') for s in roi_proximal_streamlines]
                streamlines.extend(out_streams)
                stream_counter = stream_counter + len(out_streams)

                # Cleanup memory
                del seeds, roi_proximal_streamlines, streamline_generator, out_streams
                gc.collect()
            del dg

        circuit_ix = circuit_ix + 1
        print("%s%s%s%s%s%s" % ('Completed Hyperparameter Circuit: ', circuit_ix,
                                '\nCumulative Streamline Count: ', Fore.CYAN, stream_counter, "\n"))
        print(Style.RESET_ALL)

    print('Tracking Complete:\n', str(time.time() - start))

    return streamlines
Ejemplo n.º 27
0
    def localTracking(self):
        if self.graddev is None:

            #multiply by the jacobian (zero out z-direction)
            graddev = np.zeros([3, 3])
            graddev[0, 0] = 1
            graddev[1, 1] = 1
            graddev[2, 2] = 1

            new_peak_dirsp = np.einsum('ab,ijkvb->aijkv', graddev,
                                       self.peaks.peak_dirs)
            shape = new_peak_dirsp.shape
            new_peak_dirsp = new_peak_dirsp.reshape(3, -1)
            new_peak_dirs = copy.deepcopy(new_peak_dirsp)
            for i in range(0, new_peak_dirs.shape[-1]):
                norm = np.linalg.norm(new_peak_dirsp[:, i])
                if norm != 0:
                    new_peak_dirs[:, i] = new_peak_dirsp[:, i] / norm
            new_peak_dirs = new_peak_dirs.reshape(shape)
            new_peak_dirs = np.moveaxis(new_peak_dirs, 0, -1)
            new_peak_dirs = new_peak_dirs.reshape(
                [-1, self.peaks.peak_indices.shape[-1], 3])
            #update self.peaks.peak_indices
            peak_indices = np.zeros(self.peaks.peak_indices.shape)
            peak_indices = peak_indices.reshape(
                [-1, self.peaks.peak_indices.shape[-1]])

            for i in range(0, peak_indices.shape[0]):
                for k in range(0, self.peaks.peak_indices.shape[-1]):
                    peak_indices[i, k] = self.sphere.find_closest(
                        new_peak_dirs[i, k, :])

            self.peaks.peak_indices = peak_indices.reshape(
                self.peaks.peak_indices.shape)

            streamlines_generator = LocalTracking(self.peaks,
                                                  self.stopping_criterion,
                                                  self.seeds,
                                                  self.affine,
                                                  step_size=abs(
                                                      self.affine[0, 0] / 6))
            self.streamlines = Streamlines(streamlines_generator)

        else:
            shape = self.graddev.shape
            self.graddev = self.graddev.reshape(shape[0:3] + (3, 3), order='F')
            #self.graddev[:, :, :, :, 2] = 0
            #self.graddev[:, :, :, 2, :] = 0
            #self.graddev[:, :, :, 2, 2] = -1

            self.graddev = (self.graddev.reshape([-1, 3, 3]) + np.eye(3))
            self.graddev = self.graddev.reshape(shape[0:3] + (3, 3))

            #multiply by the jacobian
            new_peak_dirsp = np.einsum('ijkab,ijkvb->aijkv', self.graddev,
                                       self.peaks.peak_dirs)
            shape = new_peak_dirsp.shape
            new_peak_dirsp = new_peak_dirsp.reshape(3, -1)
            new_peak_dirs = copy.deepcopy(new_peak_dirsp)
            for i in range(0, new_peak_dirs.shape[-1]):
                norm = np.linalg.norm(new_peak_dirsp[:, i])
                if norm != 0:
                    new_peak_dirs[:, i] = new_peak_dirsp[:, i] / norm
            new_peak_dirs = new_peak_dirs.reshape(shape)
            new_peak_dirs = np.moveaxis(new_peak_dirs, 0, -1)
            new_peak_dirs = new_peak_dirs.reshape(
                [-1, self.peaks.peak_indices.shape[-1], 3])
            #update self.peaks.peak_indices
            peak_indices = np.zeros(self.peaks.peak_indices.shape)
            peak_indices = peak_indices.reshape(
                [-1, self.peaks.peak_indices.shape[-1]])

            for i in range(0, peak_indices.shape[0]):
                for k in range(0, self.peaks.peak_indices.shape[-1]):
                    peak_indices[i, k] = self.sphere.find_closest(
                        new_peak_dirs[i, k, :])

            self.peaks.peak_indices = peak_indices.reshape(
                self.peaks.peak_indices.shape)

            streamlines_generator = LocalTracking(self.peaks,
                                                  self.stopping_criterion,
                                                  self.seeds,
                                                  self.affine,
                                                  step_size=self.affine[0, 0] /
                                                  6)

            self.streamlines = Streamlines(streamlines_generator)
            self.NpointsPerLine = pointsPerLine(self.streamlines)