def track_ensemble(target_samples, atlas_data_wm_gm_int, labels_im_file, recon_path, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, cache_dir): """ Perform native-space ensemble tractography, restricted to a vector of ROI masks. target_samples : int Total number of streamline samples specified to generate streams. atlas_data_wm_gm_int : str File path to Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface. parcels : list List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native diffusion space. recon_path : str File path to diffusion reconstruction model. tiss_classifier : str Tissue classification method. sphere : obj DiPy object for modeling diffusion directions on a sphere. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), and prob (probabilistic). curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). maxcrossing : int Maximum number if diffusion directions that can be assumed per voxel while tracking. roi_neighborhood_tol : float Distance (in the units of the streamlines, usually mm). If any coordinate in the streamline is within this distance from the center of any voxel in the ROI, the filtering criterion is set to True for this streamline, otherwise False. Defaults to the distance between the center of each voxel and the corner of the voxel. min_length : int Minimum fiber length threshold in mm. waymask_data : ndarray Tractography constraint mask array in native diffusion space. B0_mask_data : ndarray B0 brain mask data. n_seeds_per_iter : int Number of seeds from which to initiate tracking for each unique ensemble combination. By default this is set to 250. max_length : int Maximum number of steps to restrict tracking. particle_count pft_back_tracking_dist : float Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm. pft_front_tracking_dist : float Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 1 mm. particle_count : int Number of particles to use in the particle filter. min_separation_angle : float The minimum angle between directions [0, 90]. Returns ------- streamlines : ArraySequence DiPy list/array-like object of streamline points from tractography. References ---------- .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016). Ensemble Tractography. PLoS Computational Biology. https://doi.org/10.1371/journal.pcbi.1004692 """ import os import gc import time import warnings from joblib import Parallel, delayed import itertools from pynets.dmri.track import run_tracking from colorama import Fore, Style from pynets.dmri.utils import generate_sl from nibabel.streamlines.array_sequence import concatenate, ArraySequence from pynets.core.utils import save_3d_to_4d from nilearn.masking import intersect_masks from nilearn.image import math_img from pynets.core.utils import load_runconfig warnings.filterwarnings("ignore") tmp_files_dir = f"{cache_dir}/tmp_files" joblib_dir = f"{cache_dir}/joblib_tracking" os.makedirs(tmp_files_dir, exist_ok=True) os.makedirs(joblib_dir, exist_ok=True) hardcoded_params = load_runconfig() nthreads = hardcoded_params["nthreads"][0] n_seeds_per_iter = \ hardcoded_params['tracking']["n_seeds_per_iter"][0] max_length = \ hardcoded_params['tracking']["max_length"][0] pft_back_tracking_dist = \ hardcoded_params['tracking']["pft_back_tracking_dist"][0] pft_front_tracking_dist = \ hardcoded_params['tracking']["pft_front_tracking_dist"][0] particle_count = \ hardcoded_params['tracking']["particle_count"][0] min_separation_angle = \ hardcoded_params['tracking']["min_separation_angle"][0] min_streams = \ hardcoded_params['tracking']["min_streams"][0] timeout = hardcoded_params['tracking']["track_timeout"][0] all_combs = list(itertools.product(step_list, curv_thr_list)) # Construct seeding mask seeding_mask = f"{tmp_files_dir}/seeding_mask.nii.gz" if waymask is not None and os.path.isfile(waymask): waymask_img = math_img("img > 0.0075", img=nib.load(waymask)) waymask_img.to_filename(waymask) atlas_data_wm_gm_int_img = intersect_masks( [ waymask_img, math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)), math_img("img > 0.001", img=nib.load(labels_im_file)) ], threshold=1, connected=False, ) nib.save(atlas_data_wm_gm_int_img, seeding_mask) else: atlas_data_wm_gm_int_img = intersect_masks( [ math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)), math_img("img > 0.001", img=nib.load(labels_im_file)) ], threshold=1, connected=False, ) nib.save(atlas_data_wm_gm_int_img, seeding_mask) tissues4d = save_3d_to_4d([ B0_mask, labels_im_file, seeding_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi ]) # Commence Ensemble Tractography start = time.time() stream_counter = 0 all_streams = [] ix = 0 try: while float(stream_counter) < float(target_samples) and \ float(ix) < 0.50*float(len(all_combs)): with Parallel(n_jobs=nthreads, backend='loky', mmap_mode='r+', temp_folder=joblib_dir, verbose=0, timeout=timeout) as parallel: out_streams = parallel( delayed(run_tracking) (i, recon_path, n_seeds_per_iter, directget, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, waymask, min_length, track_type, min_separation_angle, sphere, tiss_class, tissues4d, tmp_files_dir) for i in all_combs) out_streams = [ i for i in out_streams if i is not None and i is not ArraySequence() and len(i) > 0 ] if len(out_streams) > 1: out_streams = concatenate(out_streams, axis=0) if len(out_streams) < min_streams: ix += 2 print(f"Fewer than {min_streams} streamlines tracked " f"on last iteration with cache directory: " f"{cache_dir}. Loosening tolerance and " f"anatomical constraints. Check {tissues4d} or " f"{recon_path} for errors...") # if track_type != 'particle': # tiss_class = 'wb' roi_neighborhood_tol = float(roi_neighborhood_tol) * 1.25 # min_length = float(min_length) * 0.9875 continue else: ix -= 1 # Append streamline generators to prevent exponential growth # in memory consumption all_streams.extend([generate_sl(i) for i in out_streams]) stream_counter += len(out_streams) del out_streams print("%s%s%s%s" % ( "\nCumulative Streamline Count: ", Fore.CYAN, stream_counter, "\n", )) gc.collect() print(Style.RESET_ALL) os.system(f"rm -rf {joblib_dir}/*") except BaseException: os.system(f"rm -rf {tmp_files_dir} &") return None if ix >= 0.75*len(all_combs) and \ float(stream_counter) < float(target_samples): print(f"Tractography failed. >{len(all_combs)} consecutive sampling " f"iterations with few streamlines.") os.system(f"rm -rf {tmp_files_dir} &") return None else: os.system(f"rm -rf {tmp_files_dir} &") print("Tracking Complete: ", str(time.time() - start)) del parallel, all_combs gc.collect() if stream_counter != 0: print('Generating final ArraySequence...') return ArraySequence([ArraySequence(i) for i in all_streams]) else: print('No streamlines generated!') return None
def track_ensemble(target_samples, atlas_data_wm_gm_int, labels_im_file, recon_path, sphere, traversal, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, BACKEND='threading'): """ Perform native-space ensemble tractography, restricted to a vector of ROI masks. Parameters ---------- target_samples : int Total number of streamline samples specified to generate streams. atlas_data_wm_gm_int : str File path to Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface. parcels : list List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native diffusion space. recon_path : str File path to diffusion reconstruction model. tiss_classifier : str Tissue classification method. sphere : obj DiPy object for modeling diffusion directions on a sphere. traversal : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), and prob (probabilistic). curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). maxcrossing : int Maximum number if diffusion directions that can be assumed per voxel while tracking. roi_neighborhood_tol : float Distance (in the units of the streamlines, usually mm). If any coordinate in the streamline is within this distance from the center of any voxel in the ROI, the filtering criterion is set to True for this streamline, otherwise False. Defaults to the distance between the center of each voxel and the corner of the voxel. min_length : int Minimum fiber length threshold in mm. waymask_data : ndarray Tractography constraint mask array in native diffusion space. B0_mask_data : ndarray B0 brain mask data. n_seeds_per_iter : int Number of seeds from which to initiate tracking for each unique ensemble combination. By default this is set to 250. max_length : int Maximum number of steps to restrict tracking. particle_count pft_back_tracking_dist : float Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm. pft_front_tracking_dist : float Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 1 mm. particle_count : int Number of particles to use in the particle filter. min_separation_angle : float The minimum angle between directions [0, 90]. Returns ------- streamlines : ArraySequence DiPy list/array-like object of streamline points from tractography. References ---------- .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016). Ensemble Tractography. PLoS Computational Biology. https://doi.org/10.1371/journal.pcbi.1004692 """ import os import gc import time import warnings import time import tempfile from joblib import Parallel, delayed, Memory import itertools import pickle5 as pickle from pynets.dmri.track import run_tracking from colorama import Fore, Style from pynets.dmri.utils import generate_sl from nibabel.streamlines.array_sequence import concatenate, ArraySequence from pynets.core.utils import save_3d_to_4d from nilearn.masking import intersect_masks from nilearn.image import math_img from pynets.core.utils import load_runconfig from dipy.tracking import utils warnings.filterwarnings("ignore") pickle.HIGHEST_PROTOCOL = 5 joblib_dir = tempfile.mkdtemp() os.makedirs(joblib_dir, exist_ok=True) hardcoded_params = load_runconfig() nthreads = hardcoded_params["omp_threads"][0] os.environ['MKL_NUM_THREADS'] = str(nthreads) os.environ['OPENBLAS_NUM_THREADS'] = str(nthreads) n_seeds_per_iter = \ hardcoded_params['tracking']["n_seeds_per_iter"][0] max_length = \ hardcoded_params['tracking']["max_length"][0] pft_back_tracking_dist = \ hardcoded_params['tracking']["pft_back_tracking_dist"][0] pft_front_tracking_dist = \ hardcoded_params['tracking']["pft_front_tracking_dist"][0] particle_count = \ hardcoded_params['tracking']["particle_count"][0] min_separation_angle = \ hardcoded_params['tracking']["min_separation_angle"][0] min_streams = \ hardcoded_params['tracking']["min_streams"][0] seeding_mask_thr = hardcoded_params['tracking']["seeding_mask_thr"][0] timeout = hardcoded_params['tracking']["track_timeout"][0] all_combs = list(itertools.product(step_list, curv_thr_list)) # Construct seeding mask seeding_mask = f"{os.path.dirname(labels_im_file)}/seeding_mask.nii.gz" if waymask is not None and os.path.isfile(waymask): waymask_img = math_img(f"img > {seeding_mask_thr}", img=nib.load(waymask)) waymask_img.to_filename(waymask) atlas_data_wm_gm_int_img = intersect_masks( [ waymask_img, math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)), math_img("img > 0.001", img=nib.load(labels_im_file)) ], threshold=1, connected=False, ) nib.save(atlas_data_wm_gm_int_img, seeding_mask) else: atlas_data_wm_gm_int_img = intersect_masks( [ math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)), math_img("img > 0.001", img=nib.load(labels_im_file)) ], threshold=1, connected=False, ) nib.save(atlas_data_wm_gm_int_img, seeding_mask) tissues4d = save_3d_to_4d([ B0_mask, labels_im_file, seeding_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi ]) # Commence Ensemble Tractography start = time.time() stream_counter = 0 all_streams = [] ix = 0 memory = Memory(location=joblib_dir, mmap_mode='r+', verbose=0) os.chdir(f"{memory.location}/joblib") @memory.cache def load_recon_data(recon_path): import h5py with h5py.File(recon_path, 'r') as hf: recon_data = hf['reconstruction'][:].astype('float32') hf.close() return recon_data recon_shelved = load_recon_data.call_and_shelve(recon_path) @memory.cache def load_tissue_data(tissues4d): return nib.load(tissues4d) tissue_shelved = load_tissue_data.call_and_shelve(tissues4d) try: while float(stream_counter) < float(target_samples) and \ float(ix) < 0.50*float(len(all_combs)): with Parallel(n_jobs=nthreads, backend=BACKEND, mmap_mode='r+', verbose=0) as parallel: out_streams = parallel( delayed(run_tracking) (i, recon_shelved, n_seeds_per_iter, traversal, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, min_length, track_type, min_separation_angle, sphere, tiss_class, tissue_shelved) for i in all_combs) out_streams = list(filter(None, out_streams)) if len(out_streams) > 1: out_streams = concatenate(out_streams, axis=0) else: continue if waymask is not None and os.path.isfile(waymask): try: out_streams = out_streams[utils.near_roi( out_streams, np.eye(4), np.asarray( nib.load(waymask).dataobj).astype("bool"), tol=int(round(roi_neighborhood_tol * 0.50, 1)), mode="all")] except BaseException: print(f"\n{Fore.RED}No streamlines generated in " f"waymask vacinity\n") print(Style.RESET_ALL) return None if len(out_streams) < min_streams: ix += 1 print(f"\n{Fore.YELLOW}Fewer than {min_streams} " f"streamlines tracked " f"on last iteration...\n") print(Style.RESET_ALL) if ix > 5: print(f"\n{Fore.RED}No streamlines generated\n") print(Style.RESET_ALL) return None continue else: ix -= 1 stream_counter += len(out_streams) all_streams.extend([generate_sl(i) for i in out_streams]) del out_streams print("%s%s%s%s" % ( "\nCumulative Streamline Count: ", Fore.CYAN, stream_counter, "\n", )) gc.collect() print(Style.RESET_ALL) if time.time() - start > timeout: print(f"\n{Fore.RED}Warning: Tractography timed " f"out: {time.time() - start}") print(Style.RESET_ALL) memory.clear(warn=False) return None except RuntimeError as e: print(f"\n{Fore.RED}Error: Tracking failed due to:\n{e}\n") print(Style.RESET_ALL) memory.clear(warn=False) return None print("Tracking Complete: ", str(time.time() - start)) memory.clear(warn=False) del parallel, all_combs gc.collect() if stream_counter != 0: print('Generating final ...') return ArraySequence([ArraySequence(i) for i in all_streams]) else: print(f"\n{Fore.RED}No streamlines generated!") print(Style.RESET_ALL) return None
def track_ensemble(target_samples, atlas_data_wm_gm_int, labels_im_file, recon_path, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, cache_dir): """ Perform native-space ensemble tractography, restricted to a vector of ROI masks. target_samples : int Total number of streamline samples specified to generate streams. atlas_data_wm_gm_int : array 3D int32 numpy array of atlas parcellation intensities from Nifti1Image in T1w-warped native diffusion space, restricted to wm-gm interface. parcels : list List of 3D boolean numpy arrays of atlas parcellation ROI masks from a Nifti1Image in T1w-warped native diffusion space. recon_path : str File path to diffusion reconstruction model. tiss_classifier : str Tissue classification method. sphere : obj DiPy object for modeling diffusion directions on a sphere. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), and prob (probabilistic). curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). maxcrossing : int Maximum number if diffusion directions that can be assumed per voxel while tracking. roi_neighborhood_tol : float Distance (in the units of the streamlines, usually mm). If any coordinate in the streamline is within this distance from the center of any voxel in the ROI, the filtering criterion is set to True for this streamline, otherwise False. Defaults to the distance between the center of each voxel and the corner of the voxel. min_length : int Minimum fiber length threshold in mm. waymask_data : ndarray Tractography constraint mask array in native diffusion space. B0_mask_data : ndarray B0 brain mask data. n_seeds_per_iter : int Number of seeds from which to initiate tracking for each unique ensemble combination. By default this is set to 250. max_length : int Maximum number of steps to restrict tracking. particle_count pft_back_tracking_dist : float Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm. pft_front_tracking_dist : float Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 1 mm. particle_count : int Number of particles to use in the particle filter. min_separation_angle : float The minimum angle between directions [0, 90]. Returns ------- streamlines : ArraySequence DiPy list/array-like object of streamline points from tractography. References ---------- .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016). Ensemble Tractography. PLoS Computational Biology. https://doi.org/10.1371/journal.pcbi.1004692 """ import os import gc import time import pkg_resources import yaml import shutil from joblib import Parallel, delayed import itertools from pynets.dmri.track import run_tracking from colorama import Fore, Style from pynets.dmri.dmri_utils import generate_sl from nibabel.streamlines.array_sequence import concatenate, ArraySequence from pynets.core.utils import save_3d_to_4d cache_dir = f"{cache_dir}/joblib_tracking" os.makedirs(cache_dir, exist_ok=True) with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"), "r") as stream: hardcoded_params = yaml.load(stream) nthreads = hardcoded_params["nthreads"][0] n_seeds_per_iter = \ hardcoded_params['tracking']["n_seeds_per_iter"][0] max_length = \ hardcoded_params['tracking']["max_length"][0] pft_back_tracking_dist = \ hardcoded_params['tracking']["pft_back_tracking_dist"][0] pft_front_tracking_dist = \ hardcoded_params['tracking']["pft_front_tracking_dist"][0] particle_count = \ hardcoded_params['tracking']["particle_count"][0] min_separation_angle = \ hardcoded_params['tracking']["min_separation_angle"][0] stream.close() all_combs = list(itertools.product(step_list, curv_thr_list)) tissues4d = save_3d_to_4d([ B0_mask, labels_im_file, atlas_data_wm_gm_int, t1w2dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi ]) # Commence Ensemble Tractography start = time.time() stream_counter = 0 all_streams = [] ix = 0 while float(stream_counter) < float(target_samples) and \ float(ix) < 0.75*float(len(all_combs)): with Parallel(n_jobs=nthreads, backend='loky', mmap_mode='r+', temp_folder=cache_dir, verbose=10) as parallel: out_streams = parallel( delayed(run_tracking) (i, recon_path, n_seeds_per_iter, directget, maxcrossing, max_length, pft_back_tracking_dist, pft_front_tracking_dist, particle_count, roi_neighborhood_tol, waymask, min_length, track_type, min_separation_angle, sphere, tiss_class, tissues4d, cache_dir) for i in all_combs) out_streams = [ i for i in out_streams if i is not None and i is not ArraySequence() and len(i) > 0 ] if len(out_streams) > 1: out_streams = concatenate(out_streams, axis=0) if len(out_streams) < 50: ix += 1 print("Fewer than 100 streamlines tracked on last iteration." " loosening tolerance and anatomical constraints...") if track_type != 'particle': tiss_class = 'wb' roi_neighborhood_tol = float(roi_neighborhood_tol) * 1.05 min_length = float(min_length) * 0.95 continue else: ix -= 1 # Append streamline generators to prevent exponential growth # in memory consumption all_streams.extend([generate_sl(i) for i in out_streams]) stream_counter += len(out_streams) del out_streams print("%s%s%s%s" % ( "\nCumulative Streamline Count: ", Fore.CYAN, stream_counter, "\n", )) gc.collect() print(Style.RESET_ALL) if ix >= 0.75*len(all_combs) and \ float(stream_counter) < float(target_samples): print(f"Tractography failed. >{len(all_combs)} consecutive sampling " f"iterations with <50 streamlines. Are you using a waymask? " f"If so, it may be too restrictive.") return ArraySequence() else: print("Tracking Complete: ", str(time.time() - start)) del parallel, all_combs shutil.rmtree(cache_dir, ignore_errors=True) if stream_counter != 0: print('Generating final ArraySequence...') return ArraySequence([ArraySequence(i) for i in all_streams]) else: print('No streamlines generated!') return ArraySequence()