def test_track_ensemble(directget, target_samples): """ Test for ensemble tractography functionality """ from pynets.dmri import track from dipy.core.gradients import gradient_table from dipy.data import get_sphere base_dir = str(Path(__file__).parent/"examples") B0_mask = f"{base_dir}/003/anat/mean_B0_bet_mask_tmp.nii.gz" gm_in_dwi = f"{base_dir}/003/anat/t1w_gm_in_dwi.nii.gz" vent_csf_in_dwi = f"{base_dir}/003/anat/t1w_vent_csf_in_dwi.nii.gz" wm_in_dwi = f"{base_dir}/003/anat/t1w_wm_in_dwi.nii.gz" dir_path = f"{base_dir}/003/dmri" bvals = f"{dir_path}/sub-003_dwi.bval" bvecs = f"{base_dir}/003/test_out/003/dwi/bvecs_reor.bvec" gtab = gradient_table(bvals, bvecs) dwi_file = f"{base_dir}/003/test_out/003/dwi/sub-003_dwi_reor-RAS_res-2mm.nii.gz" atlas_data_wm_gm_int = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track_wmgm_int.nii.gz" labels_im_file = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track.nii.gz" conn_model = 'csa' tiss_class = 'bin' min_length = 10 maxcrossing = 2 roi_neighborhood_tol = 6 waymask = None curv_thr_list = [40, 30] step_list = [0.1, 0.2, 0.3, 0.4, 0.5] sphere = get_sphere('repulsion724') track_type = 'local' # Load atlas parcellation (and its wm-gm interface reduced version for seeding) atlas_data = np.array(nib.load(labels_im_file).dataobj).astype('uint16') atlas_data_wm_gm_int = np.asarray(nib.load(atlas_data_wm_gm_int).dataobj).astype('uint16') # Build mask vector from atlas for later roi filtering parcels = [] i = 0 for roi_val in np.unique(atlas_data)[1:]: parcels.append(atlas_data == roi_val) i = i + 1 dwi_img = nib.load(dwi_file) dwi_data = dwi_img.get_fdata() model, _ = track.reconstruction(conn_model, gtab, dwi_data, wm_in_dwi) tiss_classifier = track.prep_tissues(B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, cmc_step_size=0.2) track.track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, model, tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, max_length=1000, n_seeds_per_iter=500, pft_back_tracking_dist=2, pft_front_tracking_dist=1, particle_count=15, min_separation_angle=20)
def test_prep_tissues(tiss_class): """ Test for prep_tissues functionality """ from pynets.dmri import track base_dir = str(Path(__file__).parent/"examples") B0_mask = f"{base_dir}/003/dmri/sub-003_b0_brain_mask.nii.gz" gm_in_dwi = f"{base_dir}/003/dmri/gm_mask_dmri.nii.gz" vent_csf_in_dwi = f"{base_dir}/003/dmri/csf_mask_dmri.nii.gz" wm_in_dwi = f"{base_dir}/003/dmri/wm_mask_dmri.nii.gz" tiss_classifier = track.prep_tissues(B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, cmc_step_size=0.2) assert tiss_classifier is not None
def test_prep_tissues(tiss_class): """ Test for prep_tissues functionality """ from pynets.dmri import track base_dir = os.path.abspath( pkg_resources.resource_filename("pynets", "../data/examples")) t1w_mask = f"{base_dir}/003/dmri/gm_mask_dmri.nii.gz" B0_mask = f"{base_dir}/003/dmri/sub-003_b0_brain_mask.nii.gz" gm_in_dwi = f"{base_dir}/003/dmri/gm_mask_dmri.nii.gz" vent_csf_in_dwi = f"{base_dir}/003/dmri/csf_mask_dmri.nii.gz" wm_in_dwi = f"{base_dir}/003/dmri/wm_mask_dmri.nii.gz" tiss_classifier = track.prep_tissues(nib.load(t1w_mask), nib.load(gm_in_dwi), nib.load(vent_csf_in_dwi), nib.load(wm_in_dwi), tiss_class, nib.load(B0_mask), cmc_step_size=0.2) assert tiss_classifier is not None
def run_tracking(step_curv_combinations, recon_path, n_seeds_per_iter, directget, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, waymask, min_length, track_type, min_separation_angle, sphere, tiss_class, tissues4d, cache_dir, min_seeds=100): import gc import os import h5py from dipy.tracking import utils from dipy.tracking.streamline import select_by_rois from dipy.tracking.local_tracking import LocalTracking, \ ParticleFilteringTracking from dipy.direction import (ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter) from nilearn.image import index_img from pynets.dmri.track import prep_tissues from nibabel.streamlines.array_sequence import ArraySequence from nipype.utils.filemanip import copyfile, fname_presuffix import uuid from time import strftime run_uuid = f"{strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4()}" recon_path_tmp_path = fname_presuffix( recon_path, suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_" f"{run_uuid}", newpath=cache_dir) copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False) tissues4d_tmp_path = fname_presuffix( tissues4d, suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_" f"{run_uuid}", newpath=cache_dir) copyfile(tissues4d, tissues4d_tmp_path, copy=True, use_hardlink=False) if waymask is not None: waymask_tmp_path = fname_presuffix( waymask, suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_" f"{run_uuid}", newpath=cache_dir) copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False) else: waymask_tmp_path = None tissue_img = nib.load(tissues4d_tmp_path) # Order: B0_mask = index_img(tissue_img, 0) atlas_img = index_img(tissue_img, 1) seeding_mask = index_img(tissue_img, 2) t1w2dwi = index_img(tissue_img, 3) gm_in_dwi = index_img(tissue_img, 4) vent_csf_in_dwi = index_img(tissue_img, 5) wm_in_dwi = index_img(tissue_img, 6) tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, B0_mask) B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool") seeding_mask = np.asarray( seeding_mask.dataobj).astype("bool").astype("int16") with h5py.File(recon_path_tmp_path, 'r+') as hf: mod_fit = hf['reconstruction'][:].astype('float32') print("%s%s" % ("Curvature: ", step_curv_combinations[1])) # Instantiate DirectionGetter if directget.lower() in ["probabilistic", "prob"]: dg = ProbabilisticDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget.lower() in ["closestpeaks", "cp"]: dg = ClosestPeakDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget.lower() in ["deterministic", "det"]: maxcrossing = 1 dg = DeterministicMaximumDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) else: raise ValueError("ERROR: No valid direction getter(s) specified.") print("%s%s" % ("Step: ", step_curv_combinations[0])) # Perform wm-gm interface seeding, using n_seeds at a time seeds = utils.random_seeds_from_mask( seeding_mask > 0, seeds_count=n_seeds_per_iter, seed_count_per_voxel=False, affine=np.eye(4), ) if len(seeds) < min_seeds: print( UserWarning( f"<{min_seeds} valid seed points found in wm-gm interface...")) return None # print(seeds) # Perform tracking if track_type == "local": streamline_generator = LocalTracking(dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), maxlen=int(max_length), step_size=float( step_curv_combinations[0]), fixedstep=False, return_all=True, random_seed=42) elif track_type == "particle": streamline_generator = ParticleFilteringTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), step_size=float(step_curv_combinations[0]), maxlen=int(max_length), pft_back_tracking_dist=pft_back_tracking_dist, pft_front_tracking_dist=pft_front_tracking_dist, pft_max_trial=20, particle_count=particle_count, return_all=True, random_seed=42) else: raise ValueError("ERROR: No valid tracking method(s) specified.") # Filter resulting streamlines by those that stay entirely # inside the brain try: roi_proximal_streamlines = utils.target(streamline_generator, np.eye(4), B0_mask_data.astype('bool'), include=True) except BaseException: print('No streamlines found inside the brain! ' 'Check registrations.') return None del mod_fit, seeds, tiss_classifier, streamline_generator, \ B0_mask_data, seeding_mask, dg B0_mask.uncache() atlas_img.uncache() t1w2dwi.uncache() gm_in_dwi.uncache() vent_csf_in_dwi.uncache() wm_in_dwi.uncache() atlas_img.uncache() tissue_img.uncache() gc.collect() # Filter resulting streamlines by roi-intersection # characteristics atlas_data = np.array(atlas_img.dataobj).astype("uint16") # Build mask vector from atlas for later roi filtering parcels = [] i = 0 intensities = [i for i in np.unique(atlas_data) if i != 0] for roi_val in intensities: parcels.append(atlas_data == roi_val) i += 1 parcel_vec = list(np.ones(len(parcels)).astype("bool")) try: roi_proximal_streamlines = \ nib.streamlines.array_sequence.ArraySequence( select_by_rois( roi_proximal_streamlines, affine=np.eye(4), rois=parcels, include=parcel_vec, mode="any", tol=roi_neighborhood_tol, ) ) print("%s%s" % ("Filtering by: \nNode intersection: ", len(roi_proximal_streamlines))) except BaseException: print('No streamlines found to connect any parcels! ' 'Check registrations.') return None try: roi_proximal_streamlines = nib.streamlines. \ array_sequence.ArraySequence( [ s for s in roi_proximal_streamlines if len(s) >= float(min_length) ] ) print(f"Minimum fiber length >{min_length}mm: " f"{len(roi_proximal_streamlines)}") except BaseException: print('No streamlines remaining after minimal length criterion.') return None if waymask is not None and os.path.isfile(waymask_tmp_path): waymask_data = np.asarray( nib.load(waymask_tmp_path).dataobj).astype("bool") try: roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi( roi_proximal_streamlines, np.eye(4), waymask_data, tol=int(round(roi_neighborhood_tol * 0.50, 1)), mode="all")] print("%s%s" % ("Waymask proximity: ", len(roi_proximal_streamlines))) del waymask_data except BaseException: print('No streamlines remaining in waymask\'s vacinity.') return None hf.close() del parcels, atlas_data tmp_files = [tissues4d_tmp_path, waymask_tmp_path, recon_path_tmp_path] for j in tmp_files: if j is not None: if os.path.isfile(j): os.system(f"rm -f {j} &") if len(roi_proximal_streamlines) > 0: return ArraySequence( [s.astype("float32") for s in roi_proximal_streamlines]) else: return None
def run_track(B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, labels_im_file_wm_gm_int, labels_im_file, target_samples, curv_thr_list, step_list, track_type, max_length, maxcrossing, directget, conn_model, gtab_file, dwi_file, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary, atlas_mni, min_length, fa_path, waymask, roi_neighborhood_tol=8, sphere='repulsion724'): """ Run all ensemble tractography and filtering routines. Parameters ---------- B0_mask : str File path to B0 brain mask. gm_in_dwi : str File path to grey-matter tissue segmentation Nifti1Image. vent_csf_in_dwi : str File path to ventricular CSF tissue segmentation Nifti1Image. wm_in_dwi : str File path to white-matter tissue segmentation Nifti1Image. tiss_class : str Tissue classification method. labels_im_file_wm_gm_int : str File path to atlas parcellation Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface. labels_im_file : str File path to atlas parcellation Nifti1Image in T1w-warped native diffusion space. target_samples : int Total number of streamline samples specified to generate streams. curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). max_length : int Maximum fiber length threshold in mm to restrict tracking. maxcrossing : int Maximum number if diffusion directions that can be assumed per voxel while tracking. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). conn_model : str Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd'). gtab_file : str File path to pickled DiPy gradient table object. dwi_file : str File path to diffusion weighted image. network : str Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the study of brain subgraphs. node_size : int Spherical centroid node size in the case that coordinate-based centroids are used as ROI's for tracking. dens_thresh : bool Indicates whether a target graph density is to be used as the basis for thresholding. ID : str A subject id or other unique identifier. roi : str File path to binarized/boolean region-of-interest Nifti1Image file. min_span_tree : bool Indicates whether local thresholding from the Minimum Spanning Tree should be used. disp_filt : bool Indicates whether local thresholding using a disparity filter and 'backbone network' should be used. parc : bool Indicates whether to use parcels instead of coordinates as ROI nodes. prune : bool Indicates whether to prune final graph of disconnected nodes/isolates. atlas : str Name of atlas parcellation used. uatlas : str File path to atlas parcellation Nifti1Image in MNI template space. labels : list List of string labels corresponding to graph nodes. coords : list List of (x, y, z) tuples corresponding to a coordinate atlas used or which represent the center-of-mass of each parcellation node. norm : int Indicates method of normalizing resulting graph. binary : bool Indicates whether to binarize resulting graph edges to form an unweighted graph. atlas_mni : str File path to atlas parcellation Nifti1Image in T1w-warped MNI space. min_length : int Minimum fiber length threshold in mm. fa_path : str File path to FA Nifti1Image. waymask : str Path to a Nifti1Image in native diffusion space to constrain tractography. roi_neighborhood_tol : float Distance (in the units of the streamlines, usually mm). If any coordinate in the streamline is within this distance from the center of any voxel in the ROI, the filtering criterion is set to True for this streamline, otherwise False. Defaults to the distance between the center of each voxel and the corner of the voxel. Default is 10 mm. sphere : str Provide triangulated spheres. Default is repulsion724. Options are: `symmetric362`, `symmetric642`, `symmetric724`, `repulsion724`, `repulsion100`, or `repulsion200` Returns ------- streams : str File path to save streamline array sequence in .trk format. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). target_samples : int Total number of streamline samples specified to generate streams. conn_model : str Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd'). dir_path : str Path to directory containing subject derivative data for a given pynets run. network : str Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the study of brain subgraphs. node_size : int Spherical centroid node size in the case that coordinate-based centroids are used as ROI's for tracking. dens_thresh : bool Indicates whether a target graph density is to be used as the basis for thresholding. ID : str A subject id or other unique identifier. roi : str File path to binarized/boolean region-of-interest Nifti1Image file. min_span_tree : bool Indicates whether local thresholding from the Minimum Spanning Tree should be used. disp_filt : bool Indicates whether local thresholding using a disparity filter and 'backbone network' should be used. parc : bool Indicates whether to use parcels instead of coordinates as ROI nodes. prune : bool Indicates whether to prune final graph of disconnected nodes/isolates. atlas : str Name of atlas parcellation used. uatlas : str File path to atlas parcellation Nifti1Image in MNI template space. labels : list List of string labels corresponding to graph nodes. coords : list List of (x, y, z) tuples corresponding to a coordinate atlas used or which represent the center-of-mass of each parcellation node. norm : int Indicates method of normalizing resulting graph. binary : bool Indicates whether to binarize resulting graph edges to form an unweighted graph. atlas_mni : str File path to atlas parcellation Nifti1Image in T1w-warped MNI space. curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. fa_path : str File path to FA Nifti1Image. dm_path : str File path to fiber density map Nifti1Image. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). max_length : int Maximum fiber length threshold in mm to restrict tracking. """ import gc try: import cPickle as pickle except ImportError: import _pickle as pickle from dipy.io import load_pickle from colorama import Fore, Style from dipy.data import get_sphere from pynets.core import utils from pynets.dmri.track import prep_tissues, reconstruction, create_density_map, track_ensemble # Load diffusion data dwi_img = nib.load(dwi_file) dwi_data = dwi_img.get_fdata() # Fit diffusion model mod_fit = reconstruction(conn_model, load_pickle(gtab_file), dwi_data, B0_mask) # Load atlas parcellation (and its wm-gm interface reduced version for seeding) atlas_data = nib.load(labels_im_file).get_fdata().astype('uint16') atlas_data_wm_gm_int = nib.load( labels_im_file_wm_gm_int).get_fdata().astype('uint16') # Build mask vector from atlas for later roi filtering parcels = [] i = 0 for roi_val in np.unique(atlas_data)[1:]: parcels.append(atlas_data == roi_val) i = i + 1 if np.sum(atlas_data) == 0: raise ValueError( 'ERROR: No non-zero voxels found in atlas. Check any roi masks and/or wm-gm interface images ' 'to verify overlap with dwi-registered atlas.') # Iteratively build a list of streamlines for each ROI while tracking print( "%s%s%s%s" % (Fore.GREEN, 'Target number of samples: ', Fore.BLUE, target_samples)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Using curvature threshold(s): ', Fore.BLUE, curv_thr_list)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Using step size(s): ', Fore.BLUE, step_list)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Tracking type: ', Fore.BLUE, track_type)) print(Style.RESET_ALL) if directget == 'prob': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Probabilistic')) elif directget == 'boot': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Bootstrapped')) elif directget == 'closest': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Closest Peak')) elif directget == 'det': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Deterministic Maximum')) print(Style.RESET_ALL) # Commence Ensemble Tractography streamlines = track_ensemble( dwi_data, target_samples, atlas_data_wm_gm_int, parcels, mod_fit, prep_tissues(B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class), get_sphere(sphere), directget, curv_thr_list, step_list, track_type, maxcrossing, max_length, roi_neighborhood_tol, min_length, waymask) print('Tracking Complete') # Create streamline density map [streams, dir_path, dm_path] = create_density_map(dwi_img, utils.do_dir_path(atlas, dwi_file), streamlines, conn_model, target_samples, node_size, curv_thr_list, step_list, network, roi, directget, max_length) del streamlines, dwi_data, atlas_data_wm_gm_int, atlas_data, mod_fit, parcels dwi_img.uncache() gc.collect() return streams, track_type, target_samples, conn_model, dir_path, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary, atlas_mni, curv_thr_list, step_list, fa_path, dm_path, directget, labels_im_file, roi_neighborhood_tol, max_length
def run_tracking(step_curv_combinations, recon_path, n_seeds_per_iter, directget, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, waymask, min_length, track_type, min_separation_angle, sphere, tiss_class, tissues4d, cache_dir): import gc import os import h5py from dipy.tracking import utils from dipy.tracking.streamline import select_by_rois from dipy.tracking.local_tracking import LocalTracking, \ ParticleFilteringTracking from dipy.direction import (ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter) from nilearn.image import index_img from pynets.dmri.track import prep_tissues from nibabel.streamlines.array_sequence import ArraySequence from nipype.utils.filemanip import copyfile, fname_presuffix recon_path_tmp_path = fname_presuffix(recon_path, suffix=f"_{step_curv_combinations}", newpath=cache_dir) copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False) if waymask is not None: waymask_tmp_path = fname_presuffix(waymask, suffix=f"_{step_curv_combinations}", newpath=cache_dir) copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False) else: waymask_tmp_path = None tissue_img = nib.load(tissues4d) # Order: B0_mask = index_img(tissue_img, 0) atlas_img = index_img(tissue_img, 1) atlas_data_wm_gm_int = index_img(tissue_img, 2) t1w2dwi = index_img(tissue_img, 3) gm_in_dwi = index_img(tissue_img, 4) vent_csf_in_dwi = index_img(tissue_img, 5) wm_in_dwi = index_img(tissue_img, 6) tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, B0_mask) B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool") atlas_data = np.array(atlas_img.dataobj).astype("uint16") atlas_data_wm_gm_int_data = np.asarray( atlas_data_wm_gm_int.dataobj).astype("bool").astype("int16") # Build mask vector from atlas for later roi filtering parcels = [] i = 0 intensities = [i for i in np.unique(atlas_data) if i != 0] for roi_val in intensities: parcels.append(atlas_data == roi_val) i += 1 del atlas_data parcel_vec = list(np.ones(len(parcels)).astype("bool")) with h5py.File(recon_path_tmp_path, 'r+') as hf: mod_fit = hf['reconstruction'][:].astype('float32') hf.close() print("%s%s" % ("Curvature: ", step_curv_combinations[1])) # Instantiate DirectionGetter if directget == "prob" or directget == "probabilistic": dg = ProbabilisticDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget == "clos" or directget == "closest": dg = ClosestPeakDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget == "det" or directget == "deterministic": maxcrossing = 1 dg = DeterministicMaximumDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) else: raise ValueError("ERROR: No valid direction getter(s) specified.") print("%s%s" % ("Step: ", step_curv_combinations[0])) # Perform wm-gm interface seeding, using n_seeds at a time seeds = utils.random_seeds_from_mask( atlas_data_wm_gm_int_data > 0, seeds_count=n_seeds_per_iter, seed_count_per_voxel=False, affine=np.eye(4), ) if len(seeds) == 0: print( UserWarning("No valid seed points found in wm-gm " "interface...")) return None # print(seeds) # Perform tracking if track_type == "local": streamline_generator = LocalTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), maxlen=int(max_length), step_size=float(step_curv_combinations[0]), fixedstep=False, return_all=True, ) elif track_type == "particle": streamline_generator = ParticleFilteringTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), step_size=float(step_curv_combinations[0]), maxlen=int(max_length), pft_back_tracking_dist=pft_back_tracking_dist, pft_front_tracking_dist=pft_front_tracking_dist, particle_count=particle_count, return_all=True, ) else: try: raise ValueError("ERROR: No valid tracking method(s) specified.") except ValueError: import sys sys.exit(0) # Filter resulting streamlines by those that stay entirely # inside the brain try: roi_proximal_streamlines = utils.target(streamline_generator, np.eye(4), B0_mask_data, include=True) except BaseException: print('No streamlines found inside the brain! ' 'Check registrations.') return None # Filter resulting streamlines by roi-intersection # characteristics try: roi_proximal_streamlines = \ nib.streamlines.array_sequence.ArraySequence( select_by_rois( roi_proximal_streamlines, affine=np.eye(4), rois=parcels, include=parcel_vec, mode="%s" % ("any" if waymask is not None else "both_end"), tol=roi_neighborhood_tol, ) ) print("%s%s" % ("Filtering by: \nNode intersection: ", len(roi_proximal_streamlines))) except BaseException: print('No streamlines found to connect any parcels! ' 'Check registrations.') return None try: roi_proximal_streamlines = nib.streamlines. \ array_sequence.ArraySequence( [ s for s in roi_proximal_streamlines if len(s) >= float(min_length) ] ) print(f"Minimum fiber length >{min_length}mm: " f"{len(roi_proximal_streamlines)}") except BaseException: print('No streamlines remaining after minimal length criterion.') return None if waymask is not None and os.path.isfile(waymask_tmp_path): from nilearn.image import math_img mask = math_img("img > 0.0075", img=nib.load(waymask_tmp_path)) waymask_data = np.asarray(mask.dataobj).astype("bool") try: roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi( roi_proximal_streamlines, np.eye(4), waymask_data, tol=roi_neighborhood_tol, mode="all")] print("%s%s" % ("Waymask proximity: ", len(roi_proximal_streamlines))) except BaseException: print('No streamlines remaining in waymask\'s vacinity.') return None out_streams = [s.astype("float32") for s in roi_proximal_streamlines] del dg, seeds, roi_proximal_streamlines, streamline_generator, \ atlas_data_wm_gm_int_data, mod_fit, B0_mask_data os.remove(recon_path_tmp_path) gc.collect() try: return ArraySequence(out_streams) except BaseException: return None
def _run_interface(self, runtime): import gc import numpy as np import nibabel as nib try: import cPickle as pickle except ImportError: import _pickle as pickle from dipy.io import load_pickle from colorama import Fore, Style from dipy.data import get_sphere from pynets.core import utils from pynets.dmri.track import prep_tissues, reconstruction, create_density_map, track_ensemble # Load diffusion data dwi_img = nib.load(self.inputs.dwi_file) # Fit diffusion model mod_fit = reconstruction(self.inputs.conn_model, load_pickle(self.inputs.gtab_file), np.asarray(dwi_img.dataobj), self.inputs.B0_mask) # Load atlas parcellation (and its wm-gm interface reduced version for seeding) atlas_data = np.array(nib.load(self.inputs.labels_im_file).dataobj).astype('uint16') atlas_data_wm_gm_int = np.asarray(nib.load(self.inputs.labels_im_file_wm_gm_int).dataobj).astype('uint16') # Build mask vector from atlas for later roi filtering parcels = [] i = 0 for roi_val in np.unique(atlas_data)[1:]: parcels.append(atlas_data == roi_val) i = i + 1 if np.sum(atlas_data) == 0: raise ValueError( 'ERROR: No non-zero voxels found in atlas. Check any roi masks and/or wm-gm interface images ' 'to verify overlap with dwi-registered atlas.') # Iteratively build a list of streamlines for each ROI while tracking print("%s%s%s%s" % (Fore.GREEN, 'Target number of samples: ', Fore.BLUE, self.inputs.target_samples)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Using curvature threshold(s): ', Fore.BLUE, self.inputs.curv_thr_list)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Using step size(s): ', Fore.BLUE, self.inputs.step_list)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Tracking type: ', Fore.BLUE, self.inputs.track_type)) print(Style.RESET_ALL) if self.inputs.directget == 'prob': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Probabilistic')) elif self.inputs.directget == 'boot': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Bootstrapped')) elif self.inputs.directget == 'closest': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Closest Peak')) elif self.inputs.directget == 'det': print("%s%s%s%s" % (Fore.GREEN, 'Direction-getting type: ', Fore.BLUE, 'Deterministic Maximum')) else: raise ValueError('Direction-getting type not recognized!') print(Style.RESET_ALL) # Commence Ensemble Tractography streamlines = track_ensemble(np.asarray(dwi_img.dataobj), self.inputs.target_samples, atlas_data_wm_gm_int, parcels, mod_fit, prep_tissues(self.inputs.t1w2dwi, self.inputs.gm_in_dwi, self.inputs.vent_csf_in_dwi, self.inputs.wm_in_dwi, self.inputs.tiss_class), get_sphere(self.inputs.sphere), self.inputs.directget, self.inputs.curv_thr_list, self.inputs.step_list, self.inputs.track_type, self.inputs.maxcrossing, int(self.inputs.roi_neighborhood_tol), self.inputs.min_length, self.inputs.waymask) # Create streamline density map [streams, dir_path, dm_path] = create_density_map(dwi_img, utils.do_dir_path(self.inputs.atlas, self.inputs.dwi_file), streamlines, self.inputs.conn_model, self.inputs.target_samples, self.inputs.node_size, self.inputs.curv_thr_list, self.inputs.step_list, self.inputs.network, self.inputs.roi, self.inputs.directget, self.inputs.min_length) self._results['streams'] = streams self._results['track_type'] = self.inputs.track_type self._results['target_samples'] = self.inputs.target_samples self._results['conn_model'] = self.inputs.conn_model self._results['dir_path'] = dir_path self._results['network'] = self.inputs.network self._results['node_size'] = self.inputs.node_size self._results['dens_thresh'] = self.inputs.dens_thresh self._results['ID'] = self.inputs.ID self._results['roi'] = self.inputs.roi self._results['min_span_tree'] = self.inputs.min_span_tree self._results['disp_filt'] = self.inputs.disp_filt self._results['parc'] = self.inputs.parc self._results['prune'] = self.inputs.prune self._results['atlas'] = self.inputs.atlas self._results['uatlas'] = self.inputs.uatlas self._results['labels'] = self.inputs.labels self._results['coords'] = self.inputs.coords self._results['norm'] = self.inputs.norm self._results['binary'] = self.inputs.binary self._results['atlas_mni'] = self.inputs.atlas_mni self._results['curv_thr_list'] = self.inputs.curv_thr_list self._results['step_list'] = self.inputs.step_list self._results['fa_path'] = self.inputs.fa_path self._results['dm_path'] = dm_path self._results['directget'] = self.inputs.directget self._results['labels_im_file'] = self.inputs.labels_im_file self._results['roi_neighborhood_tol'] = self.inputs.roi_neighborhood_tol self._results['min_length'] = self.inputs.min_length del streamlines, atlas_data_wm_gm_int, atlas_data, mod_fit, parcels dwi_img.uncache() gc.collect() return runtime
def test_track_ensemble_particle(): """ Test for ensemble tractography functionality """ from pynets.dmri import track from dipy.core.gradients import gradient_table from dipy.data import get_sphere from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin from dipy.io.streamline import save_tractogram base_dir = str(Path(__file__).parent/"examples") B0_mask = f"{base_dir}/003/anat/mean_B0_bet_mask_tmp.nii.gz" gm_in_dwi = f"{base_dir}/003/anat/t1w_gm_in_dwi.nii.gz" vent_csf_in_dwi = f"{base_dir}/003/anat/t1w_vent_csf_in_dwi.nii.gz" wm_in_dwi = f"{base_dir}/003/anat/t1w_wm_in_dwi.nii.gz" dir_path = f"{base_dir}/003/dmri" bvals = f"{dir_path}/sub-003_dwi.bval" bvecs = f"{base_dir}/003/test_out/003/dwi/bvecs_reor.bvec" gtab = gradient_table(bvals, bvecs) dwi_file = f"{base_dir}/003/test_out/003/dwi/sub-003_dwi_reor-RAS_res-2mm.nii.gz" atlas_data_wm_gm_int = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track_wmgm_int.nii.gz" labels_im_file = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track.nii.gz" conn_model = 'csd' tiss_class = 'cmc' min_length = 10 maxcrossing = 2 roi_neighborhood_tol = 6 waymask = None curv_thr_list = [40, 30] step_list = [0.1, 0.2, 0.3, 0.4, 0.5] sphere = get_sphere('repulsion724') directget = 'prob' track_type = 'particle' target_samples = 1000 # Load atlas parcellation (and its wm-gm interface reduced version for seeding) atlas_data = np.array(nib.load(labels_im_file).dataobj).astype('uint16') atlas_data_wm_gm_int = np.asarray(nib.load(atlas_data_wm_gm_int).dataobj).astype('uint16') # Build mask vector from atlas for later roi filtering parcels = [] i = 0 for roi_val in np.unique(atlas_data)[1:]: parcels.append(atlas_data == roi_val) i = i + 1 dwi_img = nib.load(dwi_file) dwi_data = dwi_img.get_fdata() model, _ = track.reconstruction(conn_model, gtab, dwi_data, wm_in_dwi) tiss_classifier = track.prep_tissues(gm_in_dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, B0_mask, cmc_step_size=0.2) streamlines = track.track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, model, tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, max_length=1000, n_seeds_per_iter=500, pft_back_tracking_dist=2, pft_front_tracking_dist=1, particle_count=15, min_separation_angle=20) streams = f"{base_dir}/miscellaneous/streamlines_model-csd_nodetype-parc_samples-1000streams_tracktype-particle_directget-prob_minlength-10.trk" save_tractogram(StatefulTractogram(streamlines, reference=dwi_img, space=Space.VOXMM, origin=Origin.NIFTI), streams, bbox_valid_check=False)
def run_track(nodif_B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, labels_im_file_wm_gm_int, labels_im_file, target_samples, curv_thr_list, step_list, track_type, max_length, maxcrossing, directget, conn_model, gtab_file, dwi_file, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas_select, uatlas_select, label_names, coords, norm, binary, atlas_mni, life_run, min_length, fa_path): try: import cPickle as pickle except ImportError: import _pickle as pickle from dipy.io import load_pickle from colorama import Fore, Style from dipy.data import get_sphere from pynets import utils from pynets.dmri.track import prep_tissues, reconstruction, filter_streamlines, track_ensemble # Load gradient table gtab = load_pickle(gtab_file) # Fit diffusion model mod_fit = reconstruction(conn_model, gtab, dwi_file, wm_in_dwi) # Load atlas parcellation (and its wm-gm interface reduced version for seeding) atlas_img = nib.load(labels_im_file) atlas_data = atlas_img.get_fdata().astype('int') atlas_img_wm_gm_int = nib.load(labels_im_file_wm_gm_int) atlas_data_wm_gm_int = atlas_img_wm_gm_int.get_fdata().astype('int') # Build mask vector from atlas for later roi filtering parcels = [] i = 0 for roi_val in np.unique(atlas_data)[1:]: parcels.append(atlas_data == roi_val) i = i + 1 parcel_vec = np.ones(len(parcels)) # Get sphere sphere = get_sphere('repulsion724') # Instantiate tissue classifier tiss_classifier = prep_tissues(nodif_B0_mask, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class) if np.sum(atlas_data) == 0: raise ValueError('ERROR: No non-zero voxels found in atlas. Check any roi masks and/or wm-gm interface images ' 'to verify overlap with dwi-registered atlas.') # Iteratively build a list of streamlines for each ROI while tracking print("%s%s%s%s" % (Fore.GREEN, 'Target number of samples: ', Fore.BLUE, target_samples)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Using curvature threshold(s): ', Fore.BLUE, curv_thr_list)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Using step size(s): ', Fore.BLUE, step_list)) print(Style.RESET_ALL) print("%s%s%s%s" % (Fore.GREEN, 'Tracking type: ', Fore.BLUE, track_type)) print(Style.RESET_ALL) if directget == 'prob': print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Probabilistic Direction...')) elif directget == 'boot': print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Bootstrapped Direction...')) elif directget == 'closest': print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Closest Peak Direction...')) elif directget == 'det': print("%s%s%s" % ('Using ', Fore.MAGENTA, 'Deterministic Maximum Direction...')) print(Style.RESET_ALL) # Commence Ensemble Tractography streamlines = track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, parcel_vec, mod_fit, tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, max_length) print('Tracking Complete') # Perform streamline filtering routines dir_path = utils.do_dir_path(atlas_select, dwi_file) [streams, dir_path] = filter_streamlines(dwi_file, dir_path, gtab, streamlines, life_run, min_length, conn_model, target_samples, node_size, curv_thr_list, step_list) return streams, track_type, target_samples, conn_model, dir_path, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas_select, uatlas_select, label_names, coords, norm, binary, atlas_mni, curv_thr_list, step_list, fa_path