def bench_load_trk(): rng = np.random.RandomState(42) dtype = 'float32' NB_STREAMLINES = 5000 NB_POINTS = 1000 points = [rng.rand(NB_POINTS, 3).astype(dtype) for i in range(NB_STREAMLINES)] scalars = [rng.rand(NB_POINTS, 10).astype(dtype) for i in range(NB_STREAMLINES)] repeat = 10 with InTemporaryDirectory(): trk_file = "tmp.trk" tractogram = Tractogram(points, affine_to_rasmm=np.eye(4)) TrkFile(tractogram).save(trk_file) streamlines_old = [d[0] - 0.5 for d in tv.read(trk_file, points_space="rasmm")[0]] mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) print("Old: Loaded {:,} streamlines in {:6.2f}".format(NB_STREAMLINES, mtime_old)) trk = nib.streamlines.load(trk_file, lazy_load=False) streamlines_new = trk.streamlines mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat) print("\nNew: Loaded {:,} streamlines in {:6.2}".format(NB_STREAMLINES, mtime_new)) print("Speedup of {:.2f}".format(mtime_old / mtime_new)) for s1, s2 in zip(streamlines_new, streamlines_old): assert_array_equal(s1, s2) # Points and scalars with InTemporaryDirectory(): trk_file = "tmp.trk" tractogram = Tractogram(points, data_per_point={'scalars': scalars}, affine_to_rasmm=np.eye(4)) TrkFile(tractogram).save(trk_file) streamlines_old = [d[0] - 0.5 for d in tv.read(trk_file, points_space="rasmm")[0]] scalars_old = [d[1] for d in tv.read(trk_file, points_space="rasmm")[0]] mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) msg = "Old: Loaded {:,} streamlines with scalars in {:6.2f}" print(msg.format(NB_STREAMLINES, mtime_old)) trk = nib.streamlines.load(trk_file, lazy_load=False) scalars_new = trk.tractogram.data_per_point['scalars'] mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat) msg = "New: Loaded {:,} streamlines with scalars in {:6.2f}" print(msg.format(NB_STREAMLINES, mtime_new)) print("Speedup of {:2f}".format(mtime_old / mtime_new)) for s1, s2 in zip(scalars_new, scalars_old): assert_array_equal(s1, s2)
def _core_run(self, stopping_path, stopping_thr, seeding_path, seed_density, step_size, direction_getter, out_tract, save_seeds): stop, affine = load_nifti(stopping_path) classifier = ThresholdTissueClassifier(stop, stopping_thr) logging.info('classifier done') seed_mask, _ = load_nifti(seeding_path) seeds = \ utils.seeds_from_mask( seed_mask, density=[seed_density, seed_density, seed_density], affine=affine) logging.info('seeds done') tracking_result = LocalTracking(direction_getter, classifier, seeds, affine, step_size=step_size, save_seeds=save_seeds) logging.info('LocalTracking initiated') if save_seeds: streamlines, seeds = zip(*tracking_result) tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4)) tractogram.data_per_streamline['seeds'] = seeds else: tractogram = Tractogram(tracking_result, affine_to_rasmm=np.eye(4)) save(tractogram, out_tract) logging.info('Saved {0}'.format(out_tract))
def harvest( self, states: np.ndarray, compress=False, ) -> Tuple[StatefulTractogram, np.ndarray]: """Internally keep only the streamlines and corresponding env. states that haven't stopped yet, and return the streamlines that triggered a stopping flag. Parameters ---------- states: torch.Tensor Environment states to be "pruned" Returns ------- tractogram : nib.streamlines.Tractogram Tractogram containing the streamlines that stopped tracking, along with the stopping_flags information and seeds in `tractogram.data_per_streamline` states: np.ndarray of size [n_streamlines, input_size] Input size for all continuing last streamline positions and neighbors + input addons stopping_idx: np.ndarray Indexes of stopping trajectories. Returned in case an RL algorithm would need 'em """ tractogram = Tractogram() # Harvest stopped streamlines and associated data stopped_seeds = self.starting_points[self.stopping_idx] stopped_streamlines = self.streamlines[self.stopping_idx, :self.length] # Drop last point if it triggered a flag we don't want flags = is_flag_set(self.stopping_flags, StoppingFlags.STOPPING_CURVATURE) streamlines = [ s[:-1] if f else s for f, s in zip(flags, stopped_streamlines) ] if compress: streamlines = compress_streamlines(streamlines, 0.1) # Harvested tractogram tractogram = Tractogram(streamlines=streamlines, data_per_streamline={ "stopping_flags": self.stopping_flags, "seeds": stopped_seeds }, affine_to_rasmm=self.affine_vox2rasmm) # Keep only streamlines that should continue states = self._keep(self.continue_idx, states) return tractogram, states, self.continue_idx
def main(): parser = build_parser() args = parser.parse_args() print(args) # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error("Cannot find experiment: {0}!".format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading dataset", newline=True): volume_manager = VolumeManager() dataset = datasets.load_tractography_dataset( [args.subject], volume_manager, name="dataset", use_sh_coeffs=hyperparams["use_sh_coeffs"] ) print("Dataset size:", len(dataset)) with Timer("Loading model"): if hyperparams["model"] == "gru_regression": from learn2track.models import GRU_Regression model = GRU_Regression.create(experiment_path, volume_manager=volume_manager) else: raise NameError("Unknown model: {}".format(hyperparams["model"])) with Timer("Building evaluation function"): loss = loss_factory(hyperparams, model, dataset) batch_scheduler = batch_schedulers.TractographyBatchScheduler( dataset, batch_size=1000, noisy_streamlines_sigma=None, use_data_augment=False, # Otherwise it doubles the number of losses :-/ seed=1234, shuffle_streamlines=False, normalize_target=hyperparams["normalize"], ) loss_view = views.LossView(loss=loss, batch_scheduler=batch_scheduler) losses = loss_view.losses.view() with Timer("Saving streamlines"): tractogram = Tractogram(dataset.streamlines, affine_to_rasmm=dataset.subjects[0].signal.affine) tractogram.data_per_streamline["loss"] = losses nib.streamlines.save(tractogram, args.out)
def _core_run(self, stopping_path, stopping_thr, seeding_path, seed_density, use_sh, pam, out_tract): stop, affine = load_nifti(stopping_path) classifier = ThresholdTissueClassifier(stop, stopping_thr) logging.info('classifier done') seed_mask, _ = load_nifti(seeding_path) seeds = \ utils.seeds_from_mask( seed_mask, density=[seed_density, seed_density, seed_density], affine=affine) logging.info('seeds done') direction_getter = pam if use_sh: direction_getter = \ DeterministicMaximumDirectionGetter.from_shcoeff( pam.shm_coeff, max_angle=30., sphere=pam.sphere) streamlines = LocalTracking(direction_getter, classifier, seeds, affine, step_size=.5) logging.info('LocalTracking initiated') tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4)) save(tractogram, out_tract) logging.info('Saved {0}'.format(out_tract))
def _core_run(self, stopping_path, stopping_thr, seeding_path, seed_density, step_size, direction_getter, out_tract): stop, affine = load_nifti(stopping_path) classifier = ThresholdTissueClassifier(stop, stopping_thr) logging.info('classifier done') seed_mask, _ = load_nifti(seeding_path) seeds = \ utils.seeds_from_mask( seed_mask, density=[seed_density, seed_density, seed_density], affine=affine) logging.info('seeds done') streamlines_generator = LocalTracking(direction_getter, classifier, seeds, affine, step_size=step_size) logging.info('LocalTracking initiated') tractogram = Tractogram(streamlines_generator, affine_to_rasmm=np.eye(4)) save(tractogram, out_tract) logging.info('Saved {0}'.format(out_tract))
def get_seeds_from_wm(wm_path, threshold=0): wm_file = nib.load(wm_path) wm_img = wm_file.get_fdata() seeds = np.argwhere(wm_img > threshold) seeds = np.hstack([seeds, np.ones([len(seeds), 1])]) seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3) n_seeds = len(seeds) header = TrkFile.create_empty_header() header["voxel_to_rasmm"] = wm_file.affine header["dimensions"] = wm_file.header["dim"][1:4] header["voxel_sizes"] = wm_file.header["pixdim"][1:4] header["voxel_order"] = get_reference_info(wm_file)[3] tractogram = Tractogram(streamlines=ArraySequence(seeds), affine_to_rasmm=np.eye(4)) save_path = os.path.join(os.path.dirname(wm_path), "seeds_from_wm.trk") print("Saving {}".format(save_path)) TrkFile(tractogram, header).save(save_path)
def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram]) assert_outputs_exists(parser, args, args.out_tractogram) tractogram_file = load(args.in_tractogram) streamlines = list(tractogram_file.streamlines) data_per_point = tractogram_file.tractogram.data_per_point data_per_streamline = tractogram_file.tractogram.data_per_streamline new_streamlines, new_per_point, new_per_streamline = get_subset_streamlines( streamlines, data_per_point, data_per_streamline, args.max_num_streamlines, args.seed) new_tractogram = Tractogram(new_streamlines, data_per_point=new_per_point, data_per_streamline=new_per_streamline, affine_to_rasmm=np.eye(4)) save(new_tractogram, args.out_tractogram, header=tractogram_file.header)
def harvest(self): undone, done, stopping_flags = self.is_stopping( self.sprouts, self.sprouts_stop) # Do not keep last point since it almost surely raised the stopping flag. streamlines = list(self.sprouts[done, :-1]) if self.compress_streamlines: streamlines = compress_streamlines(streamlines) tractogram = Tractogram( streamlines=streamlines, data_per_streamline={"stopping_flags": stopping_flags}) # Keep only undone sprouts self._keep(undone) return tractogram
def save_from_voxel_space(streamlines, anat, ref_tracts, out_name): if isinstance(ref_tracts, six.string_types): nib_object = nib.streamlines.load(ref_tracts, lazy_load=True) else: nib_object = ref_tracts if isinstance(anat, six.string_types): anat = nib.load(anat) affine_to_rasmm = get_affine_trackvis_to_rasmm(nib_object.header) tracto = Tractogram(streamlines=streamlines, affine_to_rasmm=affine_to_rasmm) spacing = anat.header['pixdim'][1:4] tracto.streamlines._data *= spacing nib.streamlines.save(tracto, out_name, header=nib_object.header)
def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram]) assert_outputs_exist(parser, args, args.out_tractogram) tractogram_file = load(args.in_tractogram) streamlines = list(tractogram_file.streamlines) new_streamlines = resample_streamlines(streamlines, args.nb_pts_per_streamline, args.arclength) new_tractogram = Tractogram( new_streamlines, data_per_streamline=tractogram_file.tractogram.data_per_streamline, affine_to_rasmm=np.eye(4)) save(new_tractogram, args.out_tractogram, header=tractogram_file.header)
def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram) tractogram_file = load(args.in_tractogram) streamlines = list(tractogram_file.streamlines) data_per_point = tractogram_file.tractogram.data_per_point data_per_streamline = tractogram_file.tractogram.data_per_streamline new_streamlines, new_per_point, new_per_streamline = filter_streamlines_by_length( streamlines, data_per_point, data_per_streamline, args.minL, args.maxL) new_tractogram = Tractogram(new_streamlines, data_per_streamline=new_per_streamline, data_per_point=new_per_point, affine_to_rasmm=np.eye(4)) save(new_tractogram, args.out_tractogram, header=tractogram_file.header)
def render( self, tractogram: Tractogram = None, filename: str = None ): """ Render the streamlines, either directly or through a file Might render from "outside" the environment, like for comet Parameters: ----------- tractogram: Tractogram, optional Object containing the streamlines and seeds path: str, optional If set, save the image at the specified location instead of displaying directly """ from fury import window, actor # Might be rendering from outside the environment if tractogram is None: tractogram = Tractogram( streamlines=self.streamlines[:, :self.length], data_per_streamline={ 'seeds': self.starting_points }) # Reshape peaks for displaying X, Y, Z, M = self.peaks.data.shape peaks = np.reshape(self.peaks.data, (X, Y, Z, 5, M//5)) # Setup scene and actors scene = window.Scene() stream_actor = actor.streamtube(tractogram.streamlines) peak_actor = actor.peak_slicer(peaks, np.ones((X, Y, Z, M)), colors=(0.2, 0.2, 1.), opacity=0.5) dot_actor = actor.dots(tractogram.data_per_streamline['seeds'], color=(1, 1, 1), opacity=1, dot_size=2.5) scene.add(stream_actor) scene.add(peak_actor) scene.add(dot_actor) scene.reset_camera_tight(0.95) # Save or display scene if filename is not None: directory = os.path.dirname(pjoin(self.experiment_path, 'render')) if not os.path.exists(directory): os.makedirs(directory) dest = pjoin(directory, filename) window.snapshot( scene, fname=dest, offscreen=True, size=(800, 800)) else: showm = window.ShowManager(scene, reset_camera=True) showm.initialize() showm.start()
def run(self, pam_files, wm_files, gm_files, csf_files, seeding_files, step_size=0.2, seed_density=1, pmf_threshold=0.1, max_angle=20., pft_back=2, pft_front=1, pft_count=15, out_dir='', out_tractogram='tractogram.trk', save_seeds=False): """Workflow for Particle Filtering Tracking. This workflow use a saved peaks and metrics (PAM) file as input. Parameters ---------- pam_files : string Path to the peaks and metrics files. This path may contain wildcards to use multiple masks at once. wm_files : string Path to white matter partial volume estimate for tracking (CMC). gm_files : string Path to grey matter partial volume estimate for tracking (CMC). csf_files : string Path to cerebrospinal fluid partial volume estimate for tracking (CMC). seeding_files : string A binary image showing where we need to seed for tracking. step_size : float, optional Step size used for tracking (default 0.2mm). seed_density : int, optional Number of seeds per dimension inside voxel (default 1). For example, seed_density of 2 means 8 regularly distributed points in the voxel. And seed density of 1 means 1 point at the center of the voxel. pmf_threshold : float, optional Threshold for ODF functions (default 0.1). max_angle : float, optional Maximum angle between streamline segments (range [0, 90], default 20). pft_back : float, optional Distance in mm to back track before starting the particle filtering tractography (defaul 2mm). The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. pft_front : float, optional Distance in mm to run the particle filtering tractography after the the back track distance (default 1mm). The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. pft_count : int, optional Number of particles to use in the particle filter (default 15). out_dir : string, optional Output directory (default input file directory) out_tractogram : string, optional Name of the tractogram file to be saved (default 'tractogram.trk') save_seeds : bool, optional If true, save the seeds associated to their streamline in the 'data_per_streamline' Tractogram dictionary using 'seeds' as the key References ---------- Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M. Towards quantitative connectivity analysis: reducing tractography biases. NeuroImage, 98, 266-278, 2014. """ io_it = self.get_io_iterator() for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract \ in io_it: logging.info( 'Particle Filtering tracking on {0}'.format(pams_path)) pam = load_peaks(pams_path, verbose=False) wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True) gm, _ = load_nifti(gm_path) csf, _ = load_nifti(csf_path) avs = sum(voxel_size) / len(voxel_size) # average_voxel_size classifier = CmcTissueClassifier.from_pve(wm, gm, csf, step_size=step_size, average_voxel_size=avs) logging.info('classifier done') seed_mask, _ = load_nifti(seeding_path) seeds = utils.seeds_from_mask( seed_mask, density=[seed_density, seed_density, seed_density], affine=affine) logging.info('seeds done') dg = ProbabilisticDirectionGetter direction_getter = dg.from_shcoeff(pam.shm_coeff, max_angle=max_angle, sphere=pam.sphere, pmf_threshold=pmf_threshold) tracking_result = ParticleFilteringTracking( direction_getter, classifier, seeds, affine, step_size=step_size, pft_back_tracking_dist=pft_back, pft_front_tracking_dist=pft_front, pft_max_trial=20, particle_count=pft_count, save_seeds=save_seeds) logging.info('ParticleFilteringTracking initiated') if save_seeds: streamlines, seeds = zip(*tracking_result) tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4)) tractogram.data_per_streamline['seeds'] = seeds else: tractogram = Tractogram(tracking_result, affine_to_rasmm=np.eye(4)) save(tractogram, out_tract) logging.info('Saved {0}'.format(out_tract))
from nibabel.orientations import aff2axcodes import numpy as np streamlines = np.load( '/media/localadmin/HagmannHDD/Seb/testPFT/diffusion_preproc_resampled_streamlines.npy' ) imref = nb.load('/media/localadmin/HagmannHDD/Seb/testPFT/shore_gfa.nii.gz') affine = imref.affine.copy() print(imref.affine.copy()) print(affine) header = {} header[Field.ORIGIN] = affine[:3, 3] header[Field.VOXEL_TO_RASMM] = affine header[Field.VOXEL_SIZES] = imref.header.get_zooms()[:3] header[Field.DIMENSIONS] = imref.shape[:3] header[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) for i, streamline in enumerate(streamlines): for j, voxel in enumerate(streamline): streamlines[i][j] = streamlines[i][j] - imref.affine.copy()[:3, 3] print(header[Field.VOXEL_ORDER]) tractogram = Tractogram(streamlines=streamlines, affine_to_rasmm=affine) out_fname = '/media/localadmin/HagmannHDD/Seb/testPFT/track_nib1.trk' nb.streamlines.save(tractogram, out_fname, header=header)
def main(): parser = build_args_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) if os.path.isfile(args.output): if args.force: logging.info('Overwriting {0}.'.format(args.output)) else: parser.error('{0} already exist! Use -f to overwrite it.'.format( args.output)) # Load all input streamlines. data = [load_data(f) for f in args.inputs] streamlines, data_per_streamline, data_per_point = zip(*data) nb_streamlines = [len(s) for s in streamlines] # Apply the requested operation to each input file. logging.info('Performing operation \'{}\'.'.format(args.operation)) new_streamlines, indices = perform_streamlines_operation( OPERATIONS[args.operation], streamlines, args.precision) # Get the meta data of the streamlines. new_data_per_streamline = {} new_data_per_point = {} if not args.no_data: for key in data_per_streamline[0].keys(): all_data = np.vstack([s[key] for s in data_per_streamline]) new_data_per_streamline[key] = all_data[indices, :] # Add the indices to the metadata if requested. if args.save_meta_indices: new_data_per_streamline['ids'] = indices for key in data_per_point[0].keys(): all_data = list(chain(*[s[key] for s in data_per_point])) new_data_per_point[key] = [all_data[i] for i in indices] # Save the indices to a file if requested. if args.save_indices is not None: start = 0 indices_dict = {'filenames': args.inputs} for name, nb in zip(args.inputs, nb_streamlines): end = start + nb file_indices = \ [i - start for i in indices if i >= start and i < end] indices_dict[name] = file_indices start = end with open(args.save_indices, 'wt') as f: json.dump(indices_dict, f) # Save the new streamlines. logging.info('Saving streamlines to {0}.'.format(args.output)) reference_file = load(args.inputs[0], True) new_tractogram = Tractogram(new_streamlines, data_per_streamline=new_data_per_streamline, data_per_point=new_data_per_point) # If the reference is a .tck, the affine will be None. affine = reference_file.tractogram.affine_to_rasmm if affine is None: affine = np.eye(4) new_tractogram.affine_to_rasmm = affine new_header = reference_file.header.copy() new_header['nb_streamlines'] = len(new_streamlines) save(new_tractogram, args.output, header=new_header)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) if args.threshold_weights == 'None' or args.threshold_weights == 'none': args.threshold_weights = None if not args.keep_whole_tractogram and ext != '.h5': logging.warning('Not thresholding weigth with trk file without ' 'the --keep_whole_tractogram will not save a ' 'tractogram') else: args.threshold_weights = float(args.threshold_weights) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. hdf5_keys = list(hdf5_file.keys()) streamlines = [] offsets_list = [0] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) offsets_list.append(len(tmp_streamlines)) streamlines.extend(tmp_streamlines) offsets_list = np.cumsum(offsets_list) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, gen_trk=False, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() mit.load_dictionary(tmp_dir.name, use_mask=args.in_tracking_mask is not None) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=tmp_dir.name) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0) mit.save_results() # Simplifying output for streamlines and cleaning output directory commit_results_dir = os.path.join(tmp_dir.name, 'Results_StickZeppelinBall') pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb') commit_output_dict = pickle.load(pk_file) nbr_streamlines = lazy_streamlines_count(args.in_tractogram) commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines]) np.savetxt(os.path.join(commit_results_dir, 'commit_weights.txt'), commit_weights) if ext == '.h5': new_filename = os.path.join(commit_results_dir, 'decompose_commit.h5') with h5py.File(new_filename, 'w') as new_hdf5_file: new_hdf5_file.attrs['affine'] = sft.affine new_hdf5_file.attrs['dimensions'] = sft.dimensions new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes new_hdf5_file.attrs['voxel_order'] = sft.voxel_order # Assign the weights into the hdf5, while respecting the ordering of # connections/streamlines logging.debug('Adding commit weights to {}.'.format(new_filename)) for i, key in enumerate(hdf5_keys): new_group = new_hdf5_file.create_group(key) old_group = hdf5_file[key] tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]] if args.threshold_weights is not None: essential_ind = np.where( tmp_commit_weights > args.threshold_weights)[0] tmp_streamlines = reconstruct_streamlines(old_group['data'], old_group['offsets'], old_group['lengths'], indices=essential_ind) # Replacing the data with the one above the threshold # Safe since this hdf5 was a copy in the first place new_group.create_dataset('data', data=tmp_streamlines.get_data(), dtype=np.float32) new_group.create_dataset('offsets', data=tmp_streamlines._offsets, dtype=np.int64) new_group.create_dataset('lengths', data=tmp_streamlines._lengths, dtype=np.int32) for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: new_group.create_dataset(key, data=hdf5_file[key][dps_key]) new_group.create_dataset('commit_weights', data=tmp_commit_weights) files = os.listdir(commit_results_dir) for f in files: shutil.move(os.path.join(commit_results_dir, f), args.out_dir) # Save split tractogram (essential/nonessential) and/or saving the # tractogram with data_per_streamline updated if args.keep_whole_tractogram or args.threshold_weights is not None: # Reload is needed because of COMMIT handling its file by itself tractogram_file = nib.streamlines.load(args.in_tractogram) tractogram = tractogram_file.tractogram tractogram.data_per_streamline['commit_weights'] = commit_weights if args.threshold_weights is not None: essential_ind = np.where( commit_weights > args.threshold_weights)[0] nonessential_ind = np.where( commit_weights <= args.threshold_weights)[0] logging.debug('{} essential streamlines were kept at ' 'threshold {}'.format(len(essential_ind), args.threshold_weights)) logging.debug('{} nonessential streamlines were kept at ' 'threshold {}'.format(len(nonessential_ind), args.threshold_weights)) # TODO PR when Dipy 1.2 is out with sft slicing essential_streamlines = tractogram.streamlines[essential_ind] essential_dps = tractogram.data_per_streamline[essential_ind] essential_dpp = tractogram.data_per_point[essential_ind] essential_tractogram = Tractogram(essential_streamlines, data_per_point=essential_dpp, data_per_streamline=essential_dps, affine_to_rasmm=np.eye(4)) nonessential_streamlines = tractogram.streamlines[nonessential_ind] nonessential_dps = tractogram.data_per_streamline[nonessential_ind] nonessential_dpp = tractogram.data_per_point[nonessential_ind] nonessential_tractogram = Tractogram(nonessential_streamlines, data_per_point=nonessential_dpp, data_per_streamline=nonessential_dps, affine_to_rasmm=np.eye(4)) nib.streamlines.save(essential_tractogram, os.path.join(args.out_dir, 'essential_tractogram.trk'), header=tractogram_file.header) nib.streamlines.save(nonessential_tractogram, os.path.join(args.out_dir, 'nonessential_tractogram.trk'), header=tractogram_file.header,) if args.keep_whole_tractogram: output_filename = os.path.join(args.out_dir, 'tractogram.trk') logging.debug('Saving tractogram with weights as {}'.format( output_filename)) nib.streamlines.save(tractogram_file, output_filename) tmp_dir.cleanup()
def auto_extract_VCs(streamlines, ref_bundles): # Streamlines = list of all streamlines VC = 0 VC_idx = set() found_vbs_info = {} for bundle in ref_bundles: found_vbs_info[bundle['name']] = { 'nb_streamlines': 0, 'streamlines_indices': set() } # Need to bookkeep because we chunk for big datasets processed_strl_count = 0 chunk_size = 5000 chunk_it = 0 nb_bundles = len(ref_bundles) bundles_found = [False] * nb_bundles logging.debug("Starting scoring VCs") qb = QuickBundles(threshold=20, metric=AveragePointwiseEuclideanMetric()) # Start loop here for big datasets while processed_strl_count < len(streamlines): logging.debug("Starting chunk: {0}".format(chunk_it)) strl_chunk = streamlines[chunk_it * chunk_size:(chunk_it + 1) * chunk_size] processed_strl_count += len(strl_chunk) cur_chunk_VC_idx, cur_chunk_IC_idx, cur_chunk_VCWP_idx = set(), set( ), set() # Already resample and run quickbundles on the submission chunk, # to avoid doing it at every call of auto_extract rstreamlines = set_number_of_points(strl_chunk, NB_POINTS_RESAMPLE) # qb.cluster had problem with f8 rstreamlines = [s.astype('f4') for s in rstreamlines] chunk_cluster_map = qb.cluster(rstreamlines) chunk_cluster_map.refdata = strl_chunk logging.debug("Starting VC identification through auto_extract") for bundle_idx, ref_bundle in enumerate(ref_bundles): # The selected indices are from [0, len(strl_chunk)] selected_streamlines_indices = auto_extract( ref_bundle['cluster_map'], chunk_cluster_map, clean_thr=ref_bundle['threshold']) # Remove duplicates, when streamlines are assigned to multiple VBs. selected_streamlines_indices = set(selected_streamlines_indices) - \ cur_chunk_VC_idx cur_chunk_VC_idx |= selected_streamlines_indices nb_selected_streamlines = len(selected_streamlines_indices) if nb_selected_streamlines: bundles_found[bundle_idx] = True VC += nb_selected_streamlines # Shift indices to match the real number of streamlines global_select_strl_indices = set([ v + chunk_it * chunk_size for v in selected_streamlines_indices ]) vb_info = found_vbs_info.get(ref_bundle['name']) vb_info['nb_streamlines'] += nb_selected_streamlines vb_info['streamlines_indices'] |= global_select_strl_indices VC_idx |= global_select_strl_indices else: global_select_strl_indices = set() chunk_it += 1 # Compute bundle overlap, overreach and f1_scores and update found_vbs_info for bundle_idx, ref_bundle in enumerate(ref_bundles): bundle_name = ref_bundle["name"] bundle_mask = ref_bundle["mask"] vb_info = found_vbs_info[bundle_name] # Streamlines are in voxel space since that's how they were # loaded in the scoring function. tractogram = Tractogram( streamlines=(streamlines[i] for i in vb_info['streamlines_indices']), affine_to_rasmm=bundle_mask.affine) scores = {} if len(tractogram) > 0: scores = compute_bundle_coverage_scores(tractogram, bundle_mask) vb_info['overlap'] = scores.get("OL", 0) vb_info['overreach'] = scores.get("OR", 0) vb_info['overreach_norm'] = scores.get("ORn", 0) vb_info['f1_score'] = scores.get("F1", 0) return VC_idx, found_vbs_info, bundles_found
def dwi_dipy_run(dwi_dir, node_size, dir_path, conn_model, parc, atlas_select, network, wm_mask=None): from dipy.reconst.dti import TensorModel, quantize_evecs from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel, recursive_response from dipy.tracking.local import LocalTracking, ActTissueClassifier from dipy.tracking import utils from dipy.direction import peaks_from_model from dipy.tracking.eudx import EuDX from dipy.data import get_sphere, default_sphere from dipy.core.gradients import gradient_table from dipy.io import read_bvals_bvecs from dipy.tracking.streamline import Streamlines from dipy.direction import ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, BootDirectionGetter from nibabel.streamlines import save as save_trk from nibabel.streamlines import Tractogram ## dwi_dir = '/Users/PSYC-dap3463/Downloads/bedpostx_s002' img_pve_csf = nib.load( '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_vent_csf_diff_dwi.nii.gz' ) img_pve_wm = nib.load( '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_wm_in_dwi_bin.nii.gz' ) img_pve_gm = nib.load( '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_gm_mask_dwi.nii.gz' ) labels_img = nib.load( '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/dwi_aligned_atlas.nii.gz' ) num_total_samples = 10000 tracking_method = 'boot' # Options are 'boot', 'prob', 'peaks', 'closest' procmem = [2, 4] ## if parc is True: node_size = 'parc' dwi_img = "%s%s" % (dwi_dir, '/dwi.nii.gz') nodif_brain_mask_path = "%s%s" % (dwi_dir, '/nodif_brain_mask.nii.gz') bvals = "%s%s" % (dwi_dir, '/bval') bvecs = "%s%s" % (dwi_dir, '/bvec') dwi_img = nib.load(dwi_img) data = dwi_img.get_data() [bvals, bvecs] = read_bvals_bvecs(bvals, bvecs) gtab = gradient_table(bvals, bvecs) gtab.b0_threshold = min(bvals) sphere = get_sphere('symmetric724') # Loads mask and ensures it's a true binary mask mask_img = nib.load(nodif_brain_mask_path) mask = mask_img.get_data() mask = mask > 0 # Fit a basic tensor model first model = TensorModel(gtab) ten = model.fit(data, mask) fa = ten.fa # Tractography if conn_model == 'csd': print('Tracking with csd model...') elif conn_model == 'tensor': print('Tracking with tensor model...') else: raise RuntimeError("%s%s" % (conn_model, ' is not a valid model.')) # Combine seed counts from voxel with seed counts total wm_mask_data = img_pve_wm.get_data() wm_mask_data[0, :, :] = False wm_mask_data[:, 0, :] = False wm_mask_data[:, :, 0] = False seeds = utils.seeds_from_mask(wm_mask_data, density=1, affine=dwi_img.get_affine()) seeds_rnd = utils.random_seeds_from_mask(ten.fa > 0.02, seeds_count=num_total_samples, seed_count_per_voxel=True) seeds_all = np.vstack([seeds, seeds_rnd]) # Load tissue maps and prepare tissue classifier (Anatomically-Constrained Tractography (ACT)) background = np.ones(img_pve_gm.shape) background[(img_pve_gm.get_data() + img_pve_wm.get_data() + img_pve_csf.get_data()) > 0] = 0 include_map = img_pve_gm.get_data() include_map[background > 0] = 1 exclude_map = img_pve_csf.get_data() act_classifier = ActTissueClassifier(include_map, exclude_map) if conn_model == 'tensor': ind = quantize_evecs(ten.evecs, sphere.vertices) streamline_generator = EuDX(a=fa, ind=ind, seeds=seeds_all, odf_vertices=sphere.vertices, a_low=0.05, step_sz=.5) elif conn_model == 'csd': print('Tracking with CSD model...') response = recursive_response( gtab, data, mask=img_pve_wm.get_data().astype('bool'), sh_order=8, peak_thr=0.01, init_fa=0.05, init_trace=0.0021, iter=8, convergence=0.001, parallel=True) csd_model = ConstrainedSphericalDeconvModel(gtab, response) if tracking_method == 'boot': dg = BootDirectionGetter.from_data(data, csd_model, max_angle=30., sphere=default_sphere) elif tracking_method == 'prob': try: print( 'First attempting to build the direction getter directly from the spherical harmonic representation of the FOD...' ) csd_fit = csd_model.fit( data, mask=img_pve_wm.get_data().astype('bool')) dg = ProbabilisticDirectionGetter.from_shcoeff( csd_fit.shm_coeff, max_angle=30., sphere=default_sphere) except: print( 'Sphereical harmonic not available for this model. Using peaks_from_model to represent the ODF of the model on a spherical harmonic basis instead...' ) peaks = peaks_from_model( csd_model, data, default_sphere, .5, 25, mask=img_pve_wm.get_data().astype('bool'), return_sh=True, parallel=True, nbr_processes=procmem[0]) dg = ProbabilisticDirectionGetter.from_shcoeff( peaks.shm_coeff, max_angle=30., sphere=default_sphere) elif tracking_method == 'peaks': dg = peaks_from_model(model=csd_model, data=data, sphere=default_sphere, relative_peak_threshold=.5, min_separation_angle=25, mask=img_pve_wm.get_data().astype('bool'), parallel=True, nbr_processes=procmem[0]) elif tracking_method == 'closest': csd_fit = csd_model.fit(data, mask=img_pve_wm.get_data().astype('bool')) pmf = csd_fit.odf(default_sphere).clip(min=0) dg = ClosestPeakDirectionGetter.from_pmf(pmf, max_angle=30., sphere=default_sphere) streamline_generator = LocalTracking(dg, act_classifier, seeds_all, affine=dwi_img.affine, step_size=0.5) del dg try: del csd_fit except: pass try: del response except: pass try: del csd_model except: pass streamlines = Streamlines(streamline_generator, buffer_size=512) save_trk(Tractogram(streamlines, affine_to_rasmm=dwi_img.affine), 'prob_streamlines.trk') tracks = [sl for sl in streamlines if len(sl) > 1] labels_data = labels_img.get_data().astype('int') labels_affine = labels_img.affine conn_matrix, grouping = utils.connectivity_matrix( tracks, labels_data, affine=labels_affine, return_mapping=True, mapping_as_streamlines=True, symmetric=True) conn_matrix[:3, :] = 0 conn_matrix[:, :3] = 0 return conn_matrix
def get_ismrm_seeds(data_dir, source, keep, weighted, threshold, voxel): trk_dir = os.path.join(data_dir, "bundles") if source in ["wm", "trk"]: anat_path = os.path.join(data_dir, "masks", "wm.nii.gz") resized_path = os.path.join(data_dir, "masks", "wm_{}.nii.gz".format(voxel)) elif source == "brain": anat_path = os.path.join("subjects", "ismrm_gt", "dwi_brain_mask.nii.gz") resized_path = os.path.join("subjects", "ismrm_gt", "dwi_brain_mask_125.nii.gz") sp.call([ "mrresize", "-voxel", "{:1.2f}".format(voxel / 100), anat_path, resized_path ]) if source == "trk": print("Running Tractconverter...") sp.call([ "python", "tractconverter/scripts/WalkingTractConverter.py", "-i", trk_dir, "-a", resized_path, "-vtk2trk" ]) print("Loading seed bundles...") seed_bundles = [] for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir, "*.trk"))): trk_file = nib.streamlines.load(trk_path) endpoints = [] for fiber in trk_file.tractogram.streamlines: endpoints.append(fiber[0]) endpoints.append(fiber[-1]) seed_bundles.append(endpoints) if i == 0: header = trk_file.header n_seeds = sum([len(b) for b in seed_bundles]) n_bundles = len(seed_bundles) print("Loaded {} seeds from {} bundles.".format(n_seeds, n_bundles)) seeds = np.array([[seed] for bundle in seed_bundles for seed in bundle]) if keep < 1: if weighted: p = np.zeros(n_seeds) offset = 0 for b in seed_bundles: l = len(b) p[offset:offset + l] = 1 / (l * n_bundles) offset += l else: p = np.ones(n_seeds) / n_seeds elif source in ["brain", "wm"]: weighted = False wm_file = nib.load(resized_path) wm_img = wm_file.get_fdata() seeds = np.argwhere(wm_img > threshold) seeds = np.hstack([seeds, np.ones([len(seeds), 1])]) seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3) n_seeds = len(seeds) if keep < 1: p = np.ones(n_seeds) / n_seeds header = TrkFile.create_empty_header() header["voxel_to_rasmm"] = wm_file.affine header["dimensions"] = wm_file.header["dim"][1:4] header["voxel_sizes"] = wm_file.header["pixdim"][1:4] header["voxel_order"] = get_reference_info(wm_file)[3] if keep < 1: keep_n = int(keep * n_seeds) print("Subsampling from {} seeds to {} seeds".format(n_seeds, keep_n)) np.random.seed(42) keep_idx = np.random.choice(len(seeds), size=keep_n, replace=False, p=p) seeds = seeds[keep_idx] tractogram = Tractogram(streamlines=ArraySequence(seeds), affine_to_rasmm=np.eye(4)) save_dir = os.path.join(data_dir, "seeds") if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, "seeds_from_{}_{}_vox{:03d}.trk") save_path = save_path.format( source, "W" + str(int(100 * keep)) if weighted else "all", voxel) print("Saving {}".format(save_path)) TrkFile(tractogram, header).save(save_path) os.remove(resized_path) for file in glob.glob(os.path.join(trk_dir, "*.trk")): os.remove(file)
def main(): parser = build_argparser() args = parser.parse_args() # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error('Cannot find experiment: {0}!'.format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading DWIs"): # Load gradients table dwi_name = args.dwi if dwi_name.endswith(".gz"): dwi_name = dwi_name[:-3] if dwi_name.endswith(".nii"): dwi_name = dwi_name[:-4] try: bvals_filename = dwi_name + ".bvals" bvecs_filename = dwi_name + ".bvecs" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) except FileNotFoundError: try: bvals_filename = dwi_name + ".bval" bvecs_filename = dwi_name + ".bvec" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) except FileNotFoundError as e: print("Could not find .bvals/.bvecs or .bval/.bvec files...") raise e dwi = nib.load(args.dwi) if hyperparams["use_sh_coeffs"]: # Use 45 spherical harmonic coefficients to represent the diffusion signal. weights = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) affine_rasmm2dwivox = np.linalg.inv(dwi.affine) with Timer("Loading model"): if hyperparams["model"] == "gru_regression": from learn2track.models import GRU_Regression model_class = GRU_Regression elif hyperparams['model'] == 'gru_gaussian': from learn2track.models import GRU_Gaussian model_class = GRU_Gaussian elif hyperparams['model'] == 'gru_mixture': from learn2track.models import GRU_Mixture model_class = GRU_Mixture elif hyperparams['model'] == 'gru_multistep': from learn2track.models import GRU_Multistep_Gaussian model_class = GRU_Multistep_Gaussian elif hyperparams['model'] == 'ffnn_regression': from learn2track.models import FFNN_Regression model_class = FFNN_Regression else: raise ValueError("Unknown model!") kwargs = {} volume_manager = neurotools.VolumeManager() volume_manager.register(weights) kwargs['volume_manager'] = volume_manager # Load the actual model. model = model_class.create( pjoin(experiment_path), **kwargs) # Create new instance and restore model. model.drop_prob = 0. print(str(model)) mask = None if args.mask is not None: with Timer("Loading mask"): mask_nii = nib.load(args.mask) mask = mask_nii.get_data() # Compute the affine allowing to evaluate the mask at some coordinates correctly. # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox, mask_nii.affine) if args.dilate_mask: import scipy mask = scipy.ndimage.morphology.binary_dilation(mask).astype( mask.dtype) with Timer("Generating seeds"): seeds = [] for filename in args.seeds: if filename.endswith('.trk') or filename.endswith('.tck'): tfile = nib.streamlines.load(filename) # Send the streamlines to voxel since that's where we'll track. tfile.tractogram.apply_affine(affine_rasmm2dwivox) # Use extremities of the streamlines as seeding points. seeds += [s[0] for s in tfile.streamlines] seeds += [s[-1] for s in tfile.streamlines] else: # Assume it is a binary mask. rng = np.random.RandomState(args.seeding_rng_seed) nii_seeds = nib.load(filename) # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox, nii_seeds.affine) nii_seeds_data = nii_seeds.get_data() if args.dilate_seeding_mask: import scipy nii_seeds_data = scipy.ndimage.morphology.binary_dilation( nii_seeds_data).astype(nii_seeds_data.dtype) indices = np.array(np.where(nii_seeds_data)).T for idx in indices: seeds_in_voxel = idx + rng.uniform( -0.5, 0.5, size=(args.nb_seeds_per_voxel, 3)) seeds_in_voxel = nib.affines.apply_affine( affine_seedsvox2dwivox, seeds_in_voxel) seeds.extend(seeds_in_voxel) seeds = np.array(seeds, dtype=theano.config.floatX) with Timer("Tracking in the diffusion voxel space"): voxel_sizes = np.asarray(dwi.header.get_zooms()[:3]) if not np.all(voxel_sizes == dwi.header.get_zooms()[0]): print("* Careful voxel are anisotropic {}!".format( tuple(voxel_sizes))) # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel. if args.step_size is not None: step_size = np.float32(args.step_size / voxel_sizes.max()) # Also convert max length (in mm) to voxel. max_nb_points = int(np.ceil(args.max_length / args.step_size)) else: step_size = None max_nb_points = args.max_length if args.theta is not None: theta = np.deg2rad(args.theta) elif args.curvature is not None and args.curvature > 0: theta = get_max_angle_from_curvature(args.curvature, step_size) else: theta = np.deg2rad(45) print("Angle: {}".format(np.rad2deg(theta))) print("Step size (vox): {}".format(step_size)) print("Max nb. points: {}".format(max_nb_points)) is_outside_mask = make_is_outside_mask(mask, affine_maskvox2dwivox, threshold=args.mask_threshold) is_too_long = make_is_too_long(max_nb_points) is_too_curvy = make_is_too_curvy(np.rad2deg(theta)) is_unlikely = make_is_unlikely(0.5) is_stopping = make_is_stopping({ STOPPING_MASK: is_outside_mask, STOPPING_LENGTH: is_too_long, STOPPING_CURVATURE: is_too_curvy, STOPPING_LIKELIHOOD: is_unlikely }) is_stopping.max_nb_points = max_nb_points # Small hack tractogram = batch_track(model, weights, seeds, step_size=step_size, is_stopping=is_stopping, batch_size=args.batch_size, args=args) # Streamlines have been generated in voxel space. # Transform them them back to RAS+mm space using the dwi's affine. tractogram.affine_to_rasmm = dwi.affine tractogram.to_world() # Performed in-place. nb_streamlines = len(tractogram) if args.save_rejected: rejected_tractogram = Tractogram() rejected_tractogram.affine_to_rasmm = tractogram._affine_to_rasmm print("Generated {:,} (compressed) streamlines".format(nb_streamlines)) with Timer("Cleaning streamlines", newline=True): # Flush streamlines that have no points. if args.save_rejected: rejected_tractogram += tractogram[ np.array(list(map(len, tractogram))) <= 0] tractogram = tractogram[np.array(list(map(len, tractogram))) > 0] print("Removed {:,} empty streamlines".format(nb_streamlines - len(tractogram))) # Remove small streamlines nb_streamlines = len(tractogram) lengths = dipy.tracking.streamline.length(tractogram.streamlines) if args.save_rejected: rejected_tractogram += tractogram[lengths < args.min_length] tractogram = tractogram[lengths >= args.min_length] lengths = lengths[lengths >= args.min_length] if len(lengths) > 0: print("Average length: {:.2f} mm.".format(lengths.mean())) print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format( lengths.min(), lengths.max())) print("Removed {:,} streamlines smaller than {:.2f} mm".format( nb_streamlines - len(tractogram), args.min_length)) if args.discard_stopped_by_curvature: nb_streamlines = len(tractogram) stopping_curvature_flag_is_set = is_flag_set( tractogram.data_per_streamline['stopping_flags'][:, 0], STOPPING_CURVATURE) if args.save_rejected: rejected_tractogram += tractogram[ stopping_curvature_flag_is_set] tractogram = tractogram[np.logical_not( stopping_curvature_flag_is_set)] print( "Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree" .format(nb_streamlines - len(tractogram), np.rad2deg(theta))) if args.filter_threshold is not None: # Remove streamlines that produces a reconstruction error higher than a certain threshold. nb_streamlines = len(tractogram) losses = compute_loss_errors(tractogram.streamlines, model, hyperparams) print("Mean loss: {:.4f} ± {:.4f}".format( np.mean(losses), np.std(losses, ddof=1) / np.sqrt(len(losses)))) if args.save_rejected: rejected_tractogram += tractogram[ losses > args.filter_threshold] tractogram = tractogram[losses <= args.filter_threshold] print( "Removed {:,} streamlines producing a loss lower than {:.2f} mm" .format(nb_streamlines - len(tractogram), args.filter_threshold)) with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))): filename = args.out if args.out is None: prefix = args.prefix if prefix is None: dwi_name = os.path.basename(args.dwi) if dwi_name.endswith(".nii.gz"): dwi_name = dwi_name[:-7] else: # .nii dwi_name = dwi_name[:-4] prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name prefix = prefix.replace(".", "_") seed_mask_type = args.seeds[0].replace(".", "_").replace( "_", "").replace("/", "-") if "int" in args.seeds[0]: seed_mask_type = "int" elif "wm" in args.seeds[0]: seed_mask_type = "wm" elif "rois" in args.seeds[0]: seed_mask_type = "rois" elif "bundles" in args.seeds[0]: seed_mask_type = "bundles" mask_type = "" if "fa" in args.mask: mask_type = "fa" elif "wm" in args.mask: mask_type = "wm" if args.dilate_seeding_mask: seed_mask_type += "D" if args.dilate_mask: mask_type += "D" filename_items = [ "{}", "useMaxComponent-{}", # "seed-{}", # "mask-{}", "step-{:.2f}mm", "nbSeeds-{}", "maxAngleDeg-{:.1f}" # "keepCurv-{}", # "filtered-{}", # "minLen-{}", # "pftRetry-{}", # "pftHist-{}", # "trackLikePeter-{}", ] filename = ('_'.join(filename_items) + ".tck").format( prefix, args.use_max_component, # seed_mask_type, # mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta) # not args.discard_stopped_by_curvature, # args.filter_threshold, # args.min_length, # args.pft_nb_retry, # args.pft_nb_backtrack_steps, # args.track_like_peter ) save_path = pjoin(experiment_path, filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(save_path)) except: pass print("Saving to {}".format(save_path)) nib.streamlines.save(tractogram, save_path) if args.save_rejected: with Timer("Saving {:,} (compressed) rejected streamlines".format( len(rejected_tractogram))): rejected_filename_items = filename_items.copy() rejected_filename_items.insert(1, "rejected") rejected_filename = ( '_'.join(rejected_filename_items) + ".tck" ).format( prefix, args.use_max_component, # seed_mask_type, # mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta) # not args.discard_stopped_by_curvature, # args.filter_threshold, # args.min_length, # args.pft_nb_retry, # args.pft_nb_backtrack_steps, # args.track_like_peter ) rejected_save_path = pjoin(experiment_path, rejected_filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(rejected_save_path)) except: pass print("Saving rejected streamlines to {}".format(rejected_save_path)) nib.streamlines.save(rejected_tractogram, rejected_save_path)
def run(self, pam_files, wm_files, gm_files, csf_files, seeding_files, step_size=0.2, back_tracking_dist=2, front_tracking_dist=1, max_trial=20, particle_count=15, seed_density=1, pmf_threshold=0.1, max_angle=30., out_dir='', out_tractogram='tractogram.trk'): """Workflow for Particle Filtering Tracking. This workflow use a saved peaks and metrics (PAM) file as input. Parameters ---------- pam_files : string Path to the peaks and metrics files. This path may contain wildcards to use multiple masks at once. wm_files : string Path of White matter for stopping criteria for tracking. gm_files : string Path of grey matter for stopping criteria for tracking. csf_files : string Path of cerebrospinal fluid for stopping criteria for tracking. seeding_files : string A binary image showing where we need to seed for tracking. step_size : float, optional Step size used for tracking. back_tracking_dist : float, optional Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 2 mm. front_tracking_dist : float, optional Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. By default this is set to 1 mm. max_trial : int, optional Maximum number of trial for the particle filtering tractography (Prevents infinite loops, default=20). particle_count : int, optional Number of particles to use in the particle filter. (default 15) seed_density : int, optional Number of seeds per dimension inside voxel (default 1). For example, seed_density of 2 means 8 regularly distributed points in the voxel. And seed density of 1 means 1 point at the center of the voxel. pmf_threshold : float, optional Threshold for ODF functions. (default 0.1) max_angle : float, optional Maximum angle between tract segments. This angle can be more generous (larger) than values typically used with probabilistic direction getters. The angle range is (0, 90) out_dir : string, optional Output directory (default input file directory) out_tractogram : string, optional Name of the tractogram file to be saved (default 'tractogram.trk') References ---------- Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M. Towards quantitative connectivity analysis: reducing tractography biases. NeuroImage, 98, 266-278, 2014.. """ io_it = self.get_io_iterator() for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract \ in io_it: logging.info( 'Particle Filtering tracking on {0}'.format(pams_path)) pam = load_peaks(pams_path, verbose=False) wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True) gm, _ = load_nifti(gm_path) csf, _ = load_nifti(csf_path) avs = sum(voxel_size) / len(voxel_size) # average_voxel_size classifier = CmcTissueClassifier.from_pve(wm, gm, csf, step_size=step_size, average_voxel_size=avs) logging.info('classifier done') seed_mask, _ = load_nifti(seeding_path) seeds = utils.seeds_from_mask( seed_mask, density=[seed_density, seed_density, seed_density], affine=affine) logging.info('seeds done') dg = ProbabilisticDirectionGetter direction_getter = dg.from_shcoeff(pam.shm_coeff, max_angle=max_angle, sphere=pam.sphere, pmf_threshold=pmf_threshold) streamlines = ParticleFilteringTracking( direction_getter, classifier, seeds, affine, step_size=step_size, pft_back_tracking_dist=back_tracking_dist, pft_front_tracking_dist=front_tracking_dist, pft_max_trial=max_trial, particle_count=particle_count) logging.info('ParticleFilteringTracking initiated') tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4)) save(tractogram, out_tract) logging.info('Saved {0}'.format(out_tract))
""" .. figure:: det_streamlines.png :align: center **Deterministic streamlines using EuDX (new framework)** To learn more about this process you could start playing with the number of seed points or, even better, specify seeds to be in specific regions of interest in the brain. Save the resulting streamlines in a Trackvis (.trk) format and FA as Nifti1 (.nii.gz). """ save_trk(Tractogram(streamlines, affine_to_rasmm=img.affine), 'det_streamlines.trk') save_nifti('fa_map.nii.gz', fa, img.affine) """ In Windows if you get a runtime error about frozen executable please start your script by adding your code above in a ``main`` function and use: `` if __name__ == '__main__': import multiprocessing multiprocessing.freeze_support() main() ``
def run_test(self): tracto = Tractogram(streamlines=[self.strl_data, self.strl_data]) indices_new = uncompress(tracto.streamlines) self.assertTrue(np.allclose(indices_new[0], self.gt_indices))
def main(): start = time.time() with open('config.json') as config_json: config = json.load(config_json) # Load the data dmri_image = nib.load(config['data_file']) dmri = dmri_image.get_data() affine = dmri_image.affine #aparc_im = nib.load(config['freesurfer']) aparc_im = nib.load('volume.nii.gz') aparc = aparc_im.get_data() end = time.time() print('Loaded Files: ' + str((end - start))) print(dmri.shape) print(aparc.shape) # Create the white matter and callosal masks start = time.time() wm_regions = [ 2, 41, 16, 17, 28, 60, 51, 53, 12, 52, 12, 52, 13, 18, 54, 50, 11, 251, 252, 253, 254, 255, 10, 49, 46, 7 ] wm_mask = np.zeros(aparc.shape) for l in wm_regions: wm_mask[aparc == l] = 1 #np.save('wm_mask',wm_mask) #p = os.getcwd()+'wm.json' #json.dump(wm_mask, codecs.open(p, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=True, indent=4) #with open('wm_mask.txt', 'wb') as wm: #np.savetxt('wm.txt', wm_mask, fmt='%5s') #print(wm_mask) # Create the gradient table from the bvals and bvecs bvals, bvecs = read_bvals_bvecs(config['data_bval'], config['data_bvec']) gtab = gradient_table(bvals, bvecs, b0_threshold=100) end = time.time() print('Created Gradient Table: ' + str((end - start))) ##The probabilistic model## """ # Use the Constant Solid Angle (CSA) to find the Orientation Dist. Function # Helps orient the wm tracts start = time.time() csa_model = CsaOdfModel(gtab, sh_order=6) csa_peaks = peaks_from_model(csa_model, dmri, default_sphere, relative_peak_threshold=.8, min_separation_angle=45, mask=wm_mask) print('Creating CSA Model: ' + str(time.time() - start)) """ # Use the SHORE model to find Orientation Dist. Function start = time.time() shore_model = ShoreModel(gtab) shore_peaks = peaks_from_model(shore_model, dmri, default_sphere, relative_peak_threshold=.8, min_separation_angle=45, mask=wm_mask) print('Creating Shore Model: ' + str(time.time() - start)) # Begins the seed in the wm tracts seeds = utils.seeds_from_mask(wm_mask, density=[1, 1, 1], affine=affine) print('Created White Matter seeds: ' + str(time.time() - start)) # Create a CSD model to measure Fiber Orientation Dist print('Begin the probabilistic model') response, ratio = auto_response(gtab, dmri, roi_radius=10, fa_thr=0.7) csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6) csd_fit = csd_model.fit(data=dmri, mask=wm_mask) print('Created the CSD model: ' + str(time.time() - start)) # Set the Direction Getter to randomly choose directions prob_dg = ProbabilisticDirectionGetter.from_shcoeff(csd_fit.shm_coeff, max_angle=30., sphere=default_sphere) print('Created the Direction Getter: ' + str(time.time() - start)) # Restrict the white matter tracking classifier = ThresholdTissueClassifier(shore_peaks.gfa, .25) print('Created the Tissue Classifier: ' + str(time.time() - start)) # Create the probabilistic model streamlines = LocalTracking(prob_dg, tissue_classifier=classifier, seeds=seeds, step_size=.5, max_cross=1, affine=affine) print('Created the probabilistic model: ' + str(time.time() - start)) # Compute streamlines and store as a list. streamlines = list(streamlines) print('Computed streamlines: ' + str(time.time() - start)) #from dipy.tracking.streamline import transform_streamlines #streamlines = transform_streamlines(streamlines, np.linalg.inv(affine)) # Create a tractogram from the streamlines and save it tractogram = Tractogram(streamlines, affine_to_rasmm=affine) save(tractogram, 'track.tck') end = time.time() print("Created the tck file: " + str((end - start)))
def clean_tractogram(self, tractogram, affine_vox2mask): """ Remove potential "non-connections" by filtering according to curvature, length and mask Parameters: ----------- tractogram: Tractogram Full tractogram Returns: -------- tractogram: Tractogram Filtered tractogram """ print('Cleaning tractogram ... ', end='', flush=True) streamlines = tractogram.streamlines # # Filter by curvature # dirty_mask = is_flag_set( # stopping_flags, StoppingFlags.STOPPING_CURVATURE) dirty_mask = np.zeros(len(streamlines)) # Filter by length unless the streamline ends in GM # Example case: Bundle 3 of fibercup tends to be shorter than 35 lengths = [slength(s) for s in streamlines] short_lengths = np.asarray([lgt <= self.min_length for lgt in lengths]) dirty_mask = np.logical_or(short_lengths, dirty_mask) long_lengths = np.asarray([lgt > 200. for lgt in lengths]) dirty_mask = np.logical_or(long_lengths, dirty_mask) # start_mask = is_inside_mask( # np.asarray([s[0] for s in streamlines])[:, None], # self.target_mask.data, affine_vox2mask, 0.5) # assert(np.any(start_mask)) # end_mask = is_inside_mask( # np.asarray([s[-1] for s in streamlines])[:, None], # self.target_mask.data, affine_vox2mask, 0.5) # assert(np.any(end_mask)) # mask_mask = np.logical_not(np.logical_and(start_mask, end_mask)) # dirty_mask = np.logical_or( # dirty_mask, # mask_mask) # Filter by loops # For example: A streamline ending and starting in the same roi looping_mask = np.array([winding(s) for s in streamlines]) > 330 dirty_mask = np.logical_or(dirty_mask, looping_mask) # Only keep valid streamlines valid_indices = np.nonzero(np.logical_not(dirty_mask)) clean_streamlines = streamlines[valid_indices] clean_dps = tractogram.data_per_streamline[valid_indices] print('Done !') print('Kept {}/{} streamlines'.format(len(valid_indices[0]), len(streamlines))) return Tractogram(clean_streamlines, data_per_streamline=clean_dps)