def _filter_from_roi_trackvis_style(tracts_fname, roi_fname, out_tracts_fname, include, mode, tol): tracts, hdr = nb.trackvis.read(tracts_fname, as_generator=True, points_space='voxel') img = nb.load(roi_fname) aff = np.eye(4) filtered_strl = [] mask = img.get_data() for chunk in ichunk(tracts, 100000): logging.debug("Starting new loop") # Need to "unshift" because Dipy shifts them back to fit in voxel grid. # However, to mimic the way TrackVis filters, we need to "unshift" it. # TODO check no negative coordinates strls = [s[0] - 0.5 for s in chunk] # Interpretation from @mdesco code: # target is used for this case because we want to support # include=False. if mode == 'any_part' and tol == 0: target_strl_gen = target(strls, mask, aff, include=include) target_strl = list(target_strl_gen) if len(target_strl) > 0: filtered_strl.extend(target_strl) else: corner_dist = dist_to_corner(aff) # Tolerance is considered to be in mm. Rescale wrt voxel size. tol = tol / min(img.get_header().get_zooms()) # Check this to avoid Dipy's warning and scare users. if tol < corner_dist: logging.debug( "Setting tolerance to minimal distance to corner") tol = corner_dist if mode == 'any_part': strls_ind = near_roi(strls, mask, aff, tol=tol, mode='any') else: strls_ind = near_roi(strls, mask, aff, tol=tol, mode=mode) for i in np.where(strls_ind)[0]: filtered_strl.append(strls[i]) logging.info("Original nb: {0}, kept: {1}".format(hdr['n_count'], len(filtered_strl))) # Remove the unshift final_strls = [(s + 0.5, None, None) for s in filtered_strl] new_hdr = np.copy(hdr) new_hdr['n_count'] = len(filtered_strl) nb.trackvis.write(out_tracts_fname, final_strls, new_hdr, points_space="voxel")
# Compute streamlines and store as a list. streamlines = list(streamlines) """ In order to select only the fibers that enter into the LGN, another ROI is created from a cube of size 5x5x5 voxels. The near_roi command is used to find the fibers that traverse through this ROI. """ # Set a mask for the lateral geniculate nucleus (LGN) mask_lgn = np.zeros(data.shape[:-1], 'bool') rad = 5 mask_lgn[35-rad:35+rad, 42-rad:42+rad, 28-rad:28+rad] = True # Select all the fibers that enter the LGN and discard all others filtered_fibers2 = utils.near_roi(streamlines, mask_lgn, tol=1.8, affine=affine) sfil = [] for i in range(len(streamlines)): if filtered_fibers2[i]: sfil.append(streamlines[i]) streamlines = list(sfil) """ Inspired by [Paulo_Eurographics]_, a lookup-table is created, containing rotated versions of the fiber propagation kernel :math:`P_t` [DuitsAndFranken_IJCV]_ rotated over a discrete set of orientations. See the `Contextual enhancement example <http://nipy.org/dipy/examples_built/contextual_enhancement.html>`_ for more details regarding the kernel. In order to ensure rotationally invariant processing, the discrete orientations are required to be equally distributed over a sphere. By default, a sphere with 100 directions is used obtained from electrostatic repulsion in DIPY.
def test_near_roi(): streamlines = [ np.array([[0., 0., 0.9], [1.9, 0., 0.], [3, 2., 2.]]), np.array([[0.1, 0., 0], [0, 1., 1.], [0, 2., 2.]]), np.array([[2, 2, 2], [3, 3, 3]]) ] mask = np.zeros((4, 4, 4), dtype=bool) mask[0, 0, 0] = True mask[1, 0, 0] = True npt.assert_array_equal(near_roi(streamlines, np.eye(4), mask, tol=1), np.array([True, True, False])) npt.assert_array_equal(near_roi(streamlines, np.eye(4), mask), np.array([False, True, False])) # If there is an affine, we need to use it: affine = np.eye(4) affine[:, 3] = [-1, 100, -20, 1] # Transform the streamlines: x_streamlines = [sl + affine[:3, 3] for sl in streamlines] npt.assert_array_equal(near_roi(x_streamlines, affine, mask, tol=1), np.array([True, True, False])) npt.assert_array_equal(near_roi(x_streamlines, affine, mask, tol=None), np.array([False, True, False])) # Test for use of the 'all' mode: npt.assert_array_equal( near_roi(x_streamlines, affine, mask, tol=None, mode='all'), np.array([False, False, False])) mask[0, 1, 1] = True mask[0, 2, 2] = True # Test for use of the 'all' mode, also testing that setting the tolerance # to a very small number gets overridden: with warnings.catch_warnings(): warnings.simplefilter("ignore") npt.assert_array_equal( near_roi(x_streamlines, affine, mask, tol=0.1, mode='all'), np.array([False, True, False])) mask[2, 2, 2] = True mask[3, 3, 3] = True npt.assert_array_equal( near_roi(x_streamlines, affine, mask, tol=None, mode='all'), np.array([False, True, True])) # Test for use of endpoints as selection criteria: mask = np.zeros((4, 4, 4), dtype=bool) mask[0, 1, 1] = True mask[3, 2, 2] = True npt.assert_array_equal( near_roi(streamlines, np.eye(4), mask, tol=0.87, mode="either_end"), np.array([True, False, False])) npt.assert_array_equal( near_roi(streamlines, np.eye(4), mask, tol=0.87, mode="both_end"), np.array([False, False, False])) mask[0, 0, 0] = True mask[0, 2, 2] = True npt.assert_array_equal( near_roi(streamlines, np.eye(4), mask, mode="both_end"), np.array([False, True, False])) # Test with a generator input: def generate_sl(streamlines): for sl in streamlines: yield sl npt.assert_array_equal( near_roi(generate_sl(streamlines), np.eye(4), mask, mode="both_end"), np.array([False, True, False]))
def test_near_roi(): streamlines = [np.array([[0., 0., 0.9], [1.9, 0., 0.], [3, 2., 2.]]), np.array([[0.1, 0., 0], [0, 1., 1.], [0, 2., 2.]]), np.array([[2, 2, 2], [3, 3, 3]])] affine = np.eye(4) mask = np.zeros((4, 4, 4), dtype=bool) mask[0, 0, 0] = True mask[1, 0, 0] = True assert_array_equal(near_roi(streamlines, mask, tol=1), np.array([True, True, False])) assert_array_equal(near_roi(streamlines, mask), np.array([False, True, False])) # If there is an affine, we need to use it: affine[:, 3] = [-1, 100, -20, 1] # Transform the streamlines: x_streamlines = [sl + affine[:3, 3] for sl in streamlines] assert_array_equal(near_roi(x_streamlines, mask, affine=affine, tol=1), np.array([True, True, False])) assert_array_equal(near_roi(x_streamlines, mask, affine=affine, tol=None), np.array([False, True, False])) # Test for use of the 'all' mode: assert_array_equal(near_roi(x_streamlines, mask, affine=affine, tol=None, mode='all'), np.array([False, False, False])) mask[0, 1, 1] = True mask[0, 2, 2] = True # Test for use of the 'all' mode, also testing that setting the tolerance # to a very small number gets overridden: assert_array_equal(near_roi(x_streamlines, mask, affine=affine, tol=0.1, mode='all'), np.array([False, True, False])) mask[2, 2, 2] = True mask[3, 3, 3] = True assert_array_equal(near_roi(x_streamlines, mask, affine=affine, tol=None, mode='all'), np.array([False, True, True])) # Test for use of endpoints as selection criteria: mask = np.zeros((4, 4, 4), dtype=bool) mask[0, 1, 1] = True mask[3, 2, 2] = True assert_array_equal(near_roi(streamlines, mask, tol=0.87, mode="either_end"), np.array([True, False, False])) assert_array_equal(near_roi(streamlines, mask, tol=0.87, mode="both_end"), np.array([False, False, False])) mask[0, 0, 0] = True mask[0, 2, 2] = True assert_array_equal(near_roi(streamlines, mask, mode="both_end"), np.array([False, True, False])) # Test with a generator input: def generate_sl(streamlines): for sl in streamlines: yield sl assert_array_equal(near_roi(generate_sl(streamlines), mask, mode="both_end"), np.array([False, True, False]))
def track_ensemble(target_samples, atlas_data_wm_gm_int, labels_im_file, recon_path, sphere, traversal, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, BACKEND='threading'): """ Perform native-space ensemble tractography, restricted to a vector of ROI masks. Parameters ---------- target_samples : int Total number of streamline samples specified to generate streams. atlas_data_wm_gm_int : str File path to Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface. parcels : list List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native diffusion space. recon_path : str File path to diffusion reconstruction model. tiss_classifier : str Tissue classification method. sphere : obj DiPy object for modeling diffusion directions on a sphere. traversal : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), and prob (probabilistic). curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). maxcrossing : int Maximum number if diffusion directions that can be assumed per voxel while tracking. roi_neighborhood_tol : float Distance (in the units of the streamlines, usually mm). If any coordinate in the streamline is within this distance from the center of any voxel in the ROI, the filtering criterion is set to True for this streamline, otherwise False. Defaults to the distance between the center of each voxel and the corner of the voxel. min_length : int Minimum fiber length threshold in mm. waymask_data : ndarray Tractography constraint mask array in native diffusion space. B0_mask_data : ndarray B0 brain mask data. n_seeds_per_iter : int Number of seeds from which to initiate tracking for each unique ensemble combination. By default this is set to 250. max_length : int Maximum number of steps to restrict tracking. particle_count pft_back_tracking_dist : float Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm. pft_front_tracking_dist : float Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 1 mm. particle_count : int Number of particles to use in the particle filter. min_separation_angle : float The minimum angle between directions [0, 90]. Returns ------- streamlines : ArraySequence DiPy list/array-like object of streamline points from tractography. References ---------- .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016). Ensemble Tractography. PLoS Computational Biology. https://doi.org/10.1371/journal.pcbi.1004692 """ import os import gc import time import warnings import time import tempfile from joblib import Parallel, delayed, Memory import itertools import pickle5 as pickle from pynets.dmri.track import run_tracking from colorama import Fore, Style from pynets.dmri.utils import generate_sl from nibabel.streamlines.array_sequence import concatenate, ArraySequence from pynets.core.utils import save_3d_to_4d from nilearn.masking import intersect_masks from nilearn.image import math_img from pynets.core.utils import load_runconfig from dipy.tracking import utils warnings.filterwarnings("ignore") pickle.HIGHEST_PROTOCOL = 5 joblib_dir = tempfile.mkdtemp() os.makedirs(joblib_dir, exist_ok=True) hardcoded_params = load_runconfig() nthreads = hardcoded_params["omp_threads"][0] os.environ['MKL_NUM_THREADS'] = str(nthreads) os.environ['OPENBLAS_NUM_THREADS'] = str(nthreads) n_seeds_per_iter = \ hardcoded_params['tracking']["n_seeds_per_iter"][0] max_length = \ hardcoded_params['tracking']["max_length"][0] pft_back_tracking_dist = \ hardcoded_params['tracking']["pft_back_tracking_dist"][0] pft_front_tracking_dist = \ hardcoded_params['tracking']["pft_front_tracking_dist"][0] particle_count = \ hardcoded_params['tracking']["particle_count"][0] min_separation_angle = \ hardcoded_params['tracking']["min_separation_angle"][0] min_streams = \ hardcoded_params['tracking']["min_streams"][0] seeding_mask_thr = hardcoded_params['tracking']["seeding_mask_thr"][0] timeout = hardcoded_params['tracking']["track_timeout"][0] all_combs = list(itertools.product(step_list, curv_thr_list)) # Construct seeding mask seeding_mask = f"{os.path.dirname(labels_im_file)}/seeding_mask.nii.gz" if waymask is not None and os.path.isfile(waymask): waymask_img = math_img(f"img > {seeding_mask_thr}", img=nib.load(waymask)) waymask_img.to_filename(waymask) atlas_data_wm_gm_int_img = intersect_masks( [ waymask_img, math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)), math_img("img > 0.001", img=nib.load(labels_im_file)) ], threshold=1, connected=False, ) nib.save(atlas_data_wm_gm_int_img, seeding_mask) else: atlas_data_wm_gm_int_img = intersect_masks( [ math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)), math_img("img > 0.001", img=nib.load(labels_im_file)) ], threshold=1, connected=False, ) nib.save(atlas_data_wm_gm_int_img, seeding_mask) tissues4d = save_3d_to_4d([ B0_mask, labels_im_file, seeding_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi ]) # Commence Ensemble Tractography start = time.time() stream_counter = 0 all_streams = [] ix = 0 memory = Memory(location=joblib_dir, mmap_mode='r+', verbose=0) os.chdir(f"{memory.location}/joblib") @memory.cache def load_recon_data(recon_path): import h5py with h5py.File(recon_path, 'r') as hf: recon_data = hf['reconstruction'][:].astype('float32') hf.close() return recon_data recon_shelved = load_recon_data.call_and_shelve(recon_path) @memory.cache def load_tissue_data(tissues4d): return nib.load(tissues4d) tissue_shelved = load_tissue_data.call_and_shelve(tissues4d) try: while float(stream_counter) < float(target_samples) and \ float(ix) < 0.50*float(len(all_combs)): with Parallel(n_jobs=nthreads, backend=BACKEND, mmap_mode='r+', verbose=0) as parallel: out_streams = parallel( delayed(run_tracking) (i, recon_shelved, n_seeds_per_iter, traversal, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, min_length, track_type, min_separation_angle, sphere, tiss_class, tissue_shelved) for i in all_combs) out_streams = list(filter(None, out_streams)) if len(out_streams) > 1: out_streams = concatenate(out_streams, axis=0) else: continue if waymask is not None and os.path.isfile(waymask): try: out_streams = out_streams[utils.near_roi( out_streams, np.eye(4), np.asarray( nib.load(waymask).dataobj).astype("bool"), tol=int(round(roi_neighborhood_tol * 0.50, 1)), mode="all")] except BaseException: print(f"\n{Fore.RED}No streamlines generated in " f"waymask vacinity\n") print(Style.RESET_ALL) return None if len(out_streams) < min_streams: ix += 1 print(f"\n{Fore.YELLOW}Fewer than {min_streams} " f"streamlines tracked " f"on last iteration...\n") print(Style.RESET_ALL) if ix > 5: print(f"\n{Fore.RED}No streamlines generated\n") print(Style.RESET_ALL) return None continue else: ix -= 1 stream_counter += len(out_streams) all_streams.extend([generate_sl(i) for i in out_streams]) del out_streams print("%s%s%s%s" % ( "\nCumulative Streamline Count: ", Fore.CYAN, stream_counter, "\n", )) gc.collect() print(Style.RESET_ALL) if time.time() - start > timeout: print(f"\n{Fore.RED}Warning: Tractography timed " f"out: {time.time() - start}") print(Style.RESET_ALL) memory.clear(warn=False) return None except RuntimeError as e: print(f"\n{Fore.RED}Error: Tracking failed due to:\n{e}\n") print(Style.RESET_ALL) memory.clear(warn=False) return None print("Tracking Complete: ", str(time.time() - start)) memory.clear(warn=False) del parallel, all_combs gc.collect() if stream_counter != 0: print('Generating final ...') return ArraySequence([ArraySequence(i) for i in all_streams]) else: print(f"\n{Fore.RED}No streamlines generated!") print(Style.RESET_ALL) return None
def run_tracking(step_curv_combinations, recon_path, n_seeds_per_iter, directget, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, waymask, min_length, track_type, min_separation_angle, sphere, tiss_class, tissues4d, cache_dir, min_seeds=100): import gc import os import h5py from dipy.tracking import utils from dipy.tracking.streamline import select_by_rois from dipy.tracking.local_tracking import LocalTracking, \ ParticleFilteringTracking from dipy.direction import (ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter) from nilearn.image import index_img from pynets.dmri.track import prep_tissues from nibabel.streamlines.array_sequence import ArraySequence from nipype.utils.filemanip import copyfile, fname_presuffix import uuid from time import strftime run_uuid = f"{strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4()}" recon_path_tmp_path = fname_presuffix( recon_path, suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_" f"{run_uuid}", newpath=cache_dir) copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False) tissues4d_tmp_path = fname_presuffix( tissues4d, suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_" f"{run_uuid}", newpath=cache_dir) copyfile(tissues4d, tissues4d_tmp_path, copy=True, use_hardlink=False) if waymask is not None: waymask_tmp_path = fname_presuffix( waymask, suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_" f"{run_uuid}", newpath=cache_dir) copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False) else: waymask_tmp_path = None tissue_img = nib.load(tissues4d_tmp_path) # Order: B0_mask = index_img(tissue_img, 0) atlas_img = index_img(tissue_img, 1) seeding_mask = index_img(tissue_img, 2) t1w2dwi = index_img(tissue_img, 3) gm_in_dwi = index_img(tissue_img, 4) vent_csf_in_dwi = index_img(tissue_img, 5) wm_in_dwi = index_img(tissue_img, 6) tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, B0_mask) B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool") seeding_mask = np.asarray( seeding_mask.dataobj).astype("bool").astype("int16") with h5py.File(recon_path_tmp_path, 'r+') as hf: mod_fit = hf['reconstruction'][:].astype('float32') print("%s%s" % ("Curvature: ", step_curv_combinations[1])) # Instantiate DirectionGetter if directget.lower() in ["probabilistic", "prob"]: dg = ProbabilisticDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget.lower() in ["closestpeaks", "cp"]: dg = ClosestPeakDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget.lower() in ["deterministic", "det"]: maxcrossing = 1 dg = DeterministicMaximumDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) else: raise ValueError("ERROR: No valid direction getter(s) specified.") print("%s%s" % ("Step: ", step_curv_combinations[0])) # Perform wm-gm interface seeding, using n_seeds at a time seeds = utils.random_seeds_from_mask( seeding_mask > 0, seeds_count=n_seeds_per_iter, seed_count_per_voxel=False, affine=np.eye(4), ) if len(seeds) < min_seeds: print( UserWarning( f"<{min_seeds} valid seed points found in wm-gm interface...")) return None # print(seeds) # Perform tracking if track_type == "local": streamline_generator = LocalTracking(dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), maxlen=int(max_length), step_size=float( step_curv_combinations[0]), fixedstep=False, return_all=True, random_seed=42) elif track_type == "particle": streamline_generator = ParticleFilteringTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), step_size=float(step_curv_combinations[0]), maxlen=int(max_length), pft_back_tracking_dist=pft_back_tracking_dist, pft_front_tracking_dist=pft_front_tracking_dist, pft_max_trial=20, particle_count=particle_count, return_all=True, random_seed=42) else: raise ValueError("ERROR: No valid tracking method(s) specified.") # Filter resulting streamlines by those that stay entirely # inside the brain try: roi_proximal_streamlines = utils.target(streamline_generator, np.eye(4), B0_mask_data.astype('bool'), include=True) except BaseException: print('No streamlines found inside the brain! ' 'Check registrations.') return None del mod_fit, seeds, tiss_classifier, streamline_generator, \ B0_mask_data, seeding_mask, dg B0_mask.uncache() atlas_img.uncache() t1w2dwi.uncache() gm_in_dwi.uncache() vent_csf_in_dwi.uncache() wm_in_dwi.uncache() atlas_img.uncache() tissue_img.uncache() gc.collect() # Filter resulting streamlines by roi-intersection # characteristics atlas_data = np.array(atlas_img.dataobj).astype("uint16") # Build mask vector from atlas for later roi filtering parcels = [] i = 0 intensities = [i for i in np.unique(atlas_data) if i != 0] for roi_val in intensities: parcels.append(atlas_data == roi_val) i += 1 parcel_vec = list(np.ones(len(parcels)).astype("bool")) try: roi_proximal_streamlines = \ nib.streamlines.array_sequence.ArraySequence( select_by_rois( roi_proximal_streamlines, affine=np.eye(4), rois=parcels, include=parcel_vec, mode="any", tol=roi_neighborhood_tol, ) ) print("%s%s" % ("Filtering by: \nNode intersection: ", len(roi_proximal_streamlines))) except BaseException: print('No streamlines found to connect any parcels! ' 'Check registrations.') return None try: roi_proximal_streamlines = nib.streamlines. \ array_sequence.ArraySequence( [ s for s in roi_proximal_streamlines if len(s) >= float(min_length) ] ) print(f"Minimum fiber length >{min_length}mm: " f"{len(roi_proximal_streamlines)}") except BaseException: print('No streamlines remaining after minimal length criterion.') return None if waymask is not None and os.path.isfile(waymask_tmp_path): waymask_data = np.asarray( nib.load(waymask_tmp_path).dataobj).astype("bool") try: roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi( roi_proximal_streamlines, np.eye(4), waymask_data, tol=int(round(roi_neighborhood_tol * 0.50, 1)), mode="all")] print("%s%s" % ("Waymask proximity: ", len(roi_proximal_streamlines))) del waymask_data except BaseException: print('No streamlines remaining in waymask\'s vacinity.') return None hf.close() del parcels, atlas_data tmp_files = [tissues4d_tmp_path, waymask_tmp_path, recon_path_tmp_path] for j in tmp_files: if j is not None: if os.path.isfile(j): os.system(f"rm -f {j} &") if len(roi_proximal_streamlines) > 0: return ArraySequence( [s.astype("float32") for s in roi_proximal_streamlines]) else: return None
def track_ensemble(dwi_data, target_samples, atlas_data_wm_gm_int, parcels, mod_fit, tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, max_length, roi_neighborhood_tol, min_length, waymask, n_seeds_per_iter=100, pft_back_tracking_dist=2, pft_front_tracking_dist=1, particle_count=15): """ Perform native-space ensemble tractography, restricted to a vector of ROI masks. dwi_data : array 4D array of dwi data. target_samples : int Total number of streamline samples specified to generate streams. atlas_data_wm_gm_int : array 3D int32 numpy array of atlas parcellation intensities from Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface. parcels : list List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native diffusion space. mod : obj Connectivity reconstruction model. tiss_classifier : str Tissue classification method. sphere : obj DiPy object for modeling diffusion directions on a sphere. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). maxcrossing : int Maximum number if diffusion directions that can be assumed per voxel while tracking. max_length : int Maximum fiber length threshold in mm to restrict tracking. roi_neighborhood_tol : float Distance (in the units of the streamlines, usually mm). If any coordinate in the streamline is within this distance from the center of any voxel in the ROI, the filtering criterion is set to True for this streamline, otherwise False. Defaults to the distance between the center of each voxel and the corner of the voxel. min_length : int Minimum fiber length threshold in mm. waymask : str Path to a Nifti1Image in native diffusion space to constrain tractography. n_seeds_per_iter : int Number of seeds from which to initiate tracking for each unique ensemble combination. By default this is set to 200. particle_count pft_back_tracking_dist : float Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm. pft_front_tracking_dist : float Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 1 mm. particle_count : int Number of particles to use in the particle filter. Returns ------- streamlines : ArraySequence DiPy list/array-like object of streamline points from tractography. """ from colorama import Fore, Style from dipy.tracking import utils from dipy.tracking.streamline import Streamlines, select_by_rois from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking from dipy.direction import ProbabilisticDirectionGetter, BootDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter if waymask: waymask_data = nib.load(waymask).get_fdata().astype('bool') # Commence Ensemble Tractography parcel_vec = list(np.ones(len(parcels)).astype('bool')) streamlines = nib.streamlines.array_sequence.ArraySequence() ix = 0 circuit_ix = 0 stream_counter = 0 while int(stream_counter) < int(target_samples): for curv_thr in curv_thr_list: print("%s%s" % ('Curvature: ', curv_thr)) # Instantiate DirectionGetter if directget == 'prob': dg = ProbabilisticDirectionGetter.from_shcoeff( mod_fit, max_angle=float(curv_thr), sphere=sphere) elif directget == 'boot': dg = BootDirectionGetter.from_data(dwi_data, mod_fit, max_angle=float(curv_thr), sphere=sphere) elif directget == 'clos': dg = ClosestPeakDirectionGetter.from_shcoeff( mod_fit, max_angle=float(curv_thr), sphere=sphere) elif directget == 'det': dg = DeterministicMaximumDirectionGetter.from_shcoeff( mod_fit, max_angle=float(curv_thr), sphere=sphere) else: raise ValueError( 'ERROR: No valid direction getter(s) specified.') for step in step_list: print("%s%s" % ('Step: ', step)) # Perform wm-gm interface seeding, using n_seeds at a time seeds = utils.random_seeds_from_mask( atlas_data_wm_gm_int > 0, seeds_count=n_seeds_per_iter, seed_count_per_voxel=False, affine=np.eye(4)) if len(seeds) == 0: raise RuntimeWarning( 'Warning: No valid seed points found in wm-gm interface...' ) print(seeds) # Perform tracking if track_type == 'local': streamline_generator = LocalTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), maxlen=int(max_length), step_size=float(step), return_all=True) elif track_type == 'particle': streamline_generator = ParticleFilteringTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), step_size=float(step), maxlen=int(max_length), pft_back_tracking_dist=pft_back_tracking_dist, pft_front_tracking_dist=pft_front_tracking_dist, particle_count=particle_count, return_all=True) else: raise ValueError( 'ERROR: No valid tracking method(s) specified.') # Filter resulting streamlines by roi-intersection characteristics roi_proximal_streamlines = Streamlines( select_by_rois(streamline_generator, affine=np.eye(4), rois=parcels, include=parcel_vec, mode='any', tol=roi_neighborhood_tol)) print("%s%s" % ('Qualifying Streamlines by node intersection: ', len(roi_proximal_streamlines))) roi_proximal_streamlines = nib.streamlines.array_sequence.ArraySequence( [ s for s in roi_proximal_streamlines if len(s) > float(min_length) ]) print("%s%s" % ('Qualifying Streamlines by minimum length criterion: ', len(roi_proximal_streamlines))) if waymask: roi_proximal_streamlines = roi_proximal_streamlines[ utils.near_roi(roi_proximal_streamlines, np.eye(4), waymask_data, tol=roi_neighborhood_tol, mode='any')] print("%s%s" % ('Qualifying Streamlines by waymask proximity: ', len(roi_proximal_streamlines))) # Repeat process until target samples condition is met ix = ix + 1 for s in roi_proximal_streamlines: stream_counter = stream_counter + len(s) streamlines.append(s) if int(stream_counter) >= int(target_samples): break else: continue # Cleanup memory del seeds, roi_proximal_streamlines, streamline_generator del dg circuit_ix = circuit_ix + 1 print( "%s%s%s%s%s" % ('Completed hyperparameter circuit: ', circuit_ix, '...\nCumulative Streamline Count: ', Fore.CYAN, stream_counter)) print(Style.RESET_ALL) print('\n') return streamlines
0, len(subjectSpaceTractCoords))] # obtain that data array as bool sphereNifti = WMA_pyFuncs.createSphere(testRadius, testCentroid, testT1) # add that and a True to the list vector for each roisData.append(sphereNifti.get_fdata().astype(bool)) roisNifti.append(sphereNifti) # randomly select include or exclude include.append(bool(random.getrandbits(1))) operations.append('any') # start timing t1_start = time.process_time() # specify segmentation dipySegmented1 = ut.near_roi(testTractogram.streamlines, testT1.affine, roisData[0], mode='any') dipySegmented2 = ut.near_roi(testTractogram.streamlines, testT1.affine, roisData[1], mode='any') #now we have to manually match the implicit logic of the wma Seg function #both true if np.all(include): netDipySegmented = np.logical_and(dipySegmented1, dipySegmented2) # 1 true 2 false elif not include[0]: netDipySegmented = np.logical_and(np.logical_not(dipySegmented1), dipySegmented2) #1 false and 2 True elif not include[1]:
def track_ensemble(target_samples, atlas_data_wm_gm_int, parcels, mod_fit, tiss_classifier, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, max_length=1000, n_seeds_per_iter=500, pft_back_tracking_dist=2, pft_front_tracking_dist=1, particle_count=15, min_separation_angle=20): """ Perform native-space ensemble tractography, restricted to a vector of ROI masks. target_samples : int Total number of streamline samples specified to generate streams. atlas_data_wm_gm_int : array 3D int32 numpy array of atlas parcellation intensities from Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface. parcels : list List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native diffusion space. mod : obj Connectivity reconstruction model. tiss_classifier : str Tissue classification method. sphere : obj DiPy object for modeling diffusion directions on a sphere. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). maxcrossing : int Maximum number if diffusion directions that can be assumed per voxel while tracking. roi_neighborhood_tol : float Distance (in the units of the streamlines, usually mm). If any coordinate in the streamline is within this distance from the center of any voxel in the ROI, the filtering criterion is set to True for this streamline, otherwise False. Defaults to the distance between the center of each voxel and the corner of the voxel. min_length : int Minimum fiber length threshold in mm. waymask : str Path to a Nifti1Image in native diffusion space to constrain tractography. B0_mask : str File path to B0 brain mask. max_length : int Maximum number of steps to restrict tracking. n_seeds_per_iter : int Number of seeds from which to initiate tracking for each unique ensemble combination. By default this is set to 200. particle_count pft_back_tracking_dist : float Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm. pft_front_tracking_dist : float Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 1 mm. particle_count : int Number of particles to use in the particle filter. min_separation_angle : float The minimum angle between directions [0, 90]. Returns ------- streamlines : ArraySequence DiPy list/array-like object of streamline points from tractography. References ---------- .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016). Ensemble Tractography. PLoS Computational Biology. https://doi.org/10.1371/journal.pcbi.1004692 """ import gc import time from colorama import Fore, Style from dipy.tracking import utils from dipy.tracking.streamline import Streamlines, select_by_rois from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking from dipy.direction import (ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter) start = time.time() B0_mask_data = nib.load(B0_mask).get_fdata() if waymask: waymask_data = np.asarray(nib.load(waymask).dataobj).astype('bool') # Commence Ensemble Tractography parcel_vec = list(np.ones(len(parcels)).astype('bool')) streamlines = nib.streamlines.array_sequence.ArraySequence() circuit_ix = 0 stream_counter = 0 while int(stream_counter) < int(target_samples): for curv_thr in curv_thr_list: print("%s%s" % ('Curvature: ', curv_thr)) # Instantiate DirectionGetter if directget == 'prob': dg = ProbabilisticDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere, min_separation_angle=min_separation_angle) elif directget == 'clos': dg = ClosestPeakDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere, min_separation_angle=min_separation_angle) elif directget == 'det': dg = DeterministicMaximumDirectionGetter.from_shcoeff(mod_fit, max_angle=float(curv_thr), sphere=sphere, min_separation_angle=min_separation_angle) else: raise ValueError('ERROR: No valid direction getter(s) specified.') for step in step_list: print("%s%s" % ('Step: ', step)) # Perform wm-gm interface seeding, using n_seeds at a time seeds = utils.random_seeds_from_mask(atlas_data_wm_gm_int > 0, seeds_count=n_seeds_per_iter, seed_count_per_voxel=False, affine=np.eye(4)) if len(seeds) == 0: raise RuntimeWarning('Warning: No valid seed points found in wm-gm interface...') # print(seeds) # Perform tracking if track_type == 'local': streamline_generator = LocalTracking(dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), maxlen=int(max_length), step_size=float(step), fixedstep=False, return_all=True) elif track_type == 'particle': streamline_generator = ParticleFilteringTracking(dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), step_size=float(step), maxlen=int(max_length), pft_back_tracking_dist=pft_back_tracking_dist, pft_front_tracking_dist=pft_front_tracking_dist, particle_count=particle_count, return_all=True) else: raise ValueError('ERROR: No valid tracking method(s) specified.') # Filter resulting streamlines by those that stay entirely inside the brain roi_proximal_streamlines = utils.target(streamline_generator, np.eye(4), B0_mask_data, include=True) # Filter resulting streamlines by roi-intersection characteristics roi_proximal_streamlines = Streamlines(select_by_rois(roi_proximal_streamlines, affine=np.eye(4), rois=parcels, include=parcel_vec, mode='both_end', tol=roi_neighborhood_tol)) print("%s%s" % ('Filtering by: \nnode intersection: ', len(roi_proximal_streamlines))) if str(min_length) != '0': roi_proximal_streamlines = nib.streamlines.array_sequence.ArraySequence([s for s in roi_proximal_streamlines if len(s) >= float(min_length)]) print("%s%s" % ('Minimum length criterion: ', len(roi_proximal_streamlines))) if waymask: roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(roi_proximal_streamlines, np.eye(4), waymask_data, tol=roi_neighborhood_tol, mode='any')] print("%s%s" % ('Waymask proximity: ', len(roi_proximal_streamlines))) out_streams = [s.astype('float32') for s in roi_proximal_streamlines] streamlines.extend(out_streams) stream_counter = stream_counter + len(out_streams) # Cleanup memory del seeds, roi_proximal_streamlines, streamline_generator, out_streams gc.collect() del dg circuit_ix = circuit_ix + 1 print("%s%s%s%s%s%s" % ('Completed Hyperparameter Circuit: ', circuit_ix, '\nCumulative Streamline Count: ', Fore.CYAN, stream_counter, "\n")) print(Style.RESET_ALL) print('Tracking Complete:\n', str(time.time() - start)) return streamlines
def run_tracking(step_curv_combinations, recon_path, n_seeds_per_iter, directget, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, waymask, min_length, track_type, min_separation_angle, sphere, tiss_class, tissues4d, cache_dir): import gc import os import h5py from dipy.tracking import utils from dipy.tracking.streamline import select_by_rois from dipy.tracking.local_tracking import LocalTracking, \ ParticleFilteringTracking from dipy.direction import (ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter) from nilearn.image import index_img from pynets.dmri.track import prep_tissues from nibabel.streamlines.array_sequence import ArraySequence from nipype.utils.filemanip import copyfile, fname_presuffix recon_path_tmp_path = fname_presuffix(recon_path, suffix=f"_{step_curv_combinations}", newpath=cache_dir) copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False) if waymask is not None: waymask_tmp_path = fname_presuffix(waymask, suffix=f"_{step_curv_combinations}", newpath=cache_dir) copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False) else: waymask_tmp_path = None tissue_img = nib.load(tissues4d) # Order: B0_mask = index_img(tissue_img, 0) atlas_img = index_img(tissue_img, 1) atlas_data_wm_gm_int = index_img(tissue_img, 2) t1w2dwi = index_img(tissue_img, 3) gm_in_dwi = index_img(tissue_img, 4) vent_csf_in_dwi = index_img(tissue_img, 5) wm_in_dwi = index_img(tissue_img, 6) tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, B0_mask) B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool") atlas_data = np.array(atlas_img.dataobj).astype("uint16") atlas_data_wm_gm_int_data = np.asarray( atlas_data_wm_gm_int.dataobj).astype("bool").astype("int16") # Build mask vector from atlas for later roi filtering parcels = [] i = 0 intensities = [i for i in np.unique(atlas_data) if i != 0] for roi_val in intensities: parcels.append(atlas_data == roi_val) i += 1 del atlas_data parcel_vec = list(np.ones(len(parcels)).astype("bool")) with h5py.File(recon_path_tmp_path, 'r+') as hf: mod_fit = hf['reconstruction'][:].astype('float32') hf.close() print("%s%s" % ("Curvature: ", step_curv_combinations[1])) # Instantiate DirectionGetter if directget == "prob" or directget == "probabilistic": dg = ProbabilisticDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget == "clos" or directget == "closest": dg = ClosestPeakDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) elif directget == "det" or directget == "deterministic": maxcrossing = 1 dg = DeterministicMaximumDirectionGetter.from_shcoeff( mod_fit, max_angle=float(step_curv_combinations[1]), sphere=sphere, min_separation_angle=min_separation_angle, ) else: raise ValueError("ERROR: No valid direction getter(s) specified.") print("%s%s" % ("Step: ", step_curv_combinations[0])) # Perform wm-gm interface seeding, using n_seeds at a time seeds = utils.random_seeds_from_mask( atlas_data_wm_gm_int_data > 0, seeds_count=n_seeds_per_iter, seed_count_per_voxel=False, affine=np.eye(4), ) if len(seeds) == 0: print( UserWarning("No valid seed points found in wm-gm " "interface...")) return None # print(seeds) # Perform tracking if track_type == "local": streamline_generator = LocalTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), maxlen=int(max_length), step_size=float(step_curv_combinations[0]), fixedstep=False, return_all=True, ) elif track_type == "particle": streamline_generator = ParticleFilteringTracking( dg, tiss_classifier, seeds, np.eye(4), max_cross=int(maxcrossing), step_size=float(step_curv_combinations[0]), maxlen=int(max_length), pft_back_tracking_dist=pft_back_tracking_dist, pft_front_tracking_dist=pft_front_tracking_dist, particle_count=particle_count, return_all=True, ) else: try: raise ValueError("ERROR: No valid tracking method(s) specified.") except ValueError: import sys sys.exit(0) # Filter resulting streamlines by those that stay entirely # inside the brain try: roi_proximal_streamlines = utils.target(streamline_generator, np.eye(4), B0_mask_data, include=True) except BaseException: print('No streamlines found inside the brain! ' 'Check registrations.') return None # Filter resulting streamlines by roi-intersection # characteristics try: roi_proximal_streamlines = \ nib.streamlines.array_sequence.ArraySequence( select_by_rois( roi_proximal_streamlines, affine=np.eye(4), rois=parcels, include=parcel_vec, mode="%s" % ("any" if waymask is not None else "both_end"), tol=roi_neighborhood_tol, ) ) print("%s%s" % ("Filtering by: \nNode intersection: ", len(roi_proximal_streamlines))) except BaseException: print('No streamlines found to connect any parcels! ' 'Check registrations.') return None try: roi_proximal_streamlines = nib.streamlines. \ array_sequence.ArraySequence( [ s for s in roi_proximal_streamlines if len(s) >= float(min_length) ] ) print(f"Minimum fiber length >{min_length}mm: " f"{len(roi_proximal_streamlines)}") except BaseException: print('No streamlines remaining after minimal length criterion.') return None if waymask is not None and os.path.isfile(waymask_tmp_path): from nilearn.image import math_img mask = math_img("img > 0.0075", img=nib.load(waymask_tmp_path)) waymask_data = np.asarray(mask.dataobj).astype("bool") try: roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi( roi_proximal_streamlines, np.eye(4), waymask_data, tol=roi_neighborhood_tol, mode="all")] print("%s%s" % ("Waymask proximity: ", len(roi_proximal_streamlines))) except BaseException: print('No streamlines remaining in waymask\'s vacinity.') return None out_streams = [s.astype("float32") for s in roi_proximal_streamlines] del dg, seeds, roi_proximal_streamlines, streamline_generator, \ atlas_data_wm_gm_int_data, mod_fit, B0_mask_data os.remove(recon_path_tmp_path) gc.collect() try: return ArraySequence(out_streams) except BaseException: return None