def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_labels], args.reference) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) if (args.save_raw_connections or args.save_intermediate or args.save_discarded) and not args.out_dir: parser.error('To save outputs in the streamlines form, provide the ' 'output directory using --out_dir.') if args.out_dir: if os.path.abspath(args.out_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.in_labels) data_labels = get_data_as_label(img_labels) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03): parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) if not is_header_compatible(sft, img_labels): raise IOError('{} and {}do not have a compatible header'.format( args.in_tractogram, args.in_labels)) sft.to_vox() sft.to_corner() # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 with h5py.File(args.out_hdf5, 'w') as hdf5_file: affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft) hdf5_file.attrs['affine'] = affine hdf5_file.attrs['dimensions'] = dimensions hdf5_file.attrs['voxel_sizes'] = voxel_sizes hdf5_file.attrs['voxel_order'] = voxel_order # Each connections is processed independently. Multiprocessing would be # a burden on the I/O of most SSD/HD for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format( iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] connecting_ids = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) connecting_ids.append(strl_idx) # Each step is processed from the previous 'success' # 1. raw -> length pass/fail # 2. length pass -> loops pass/fail # 3. loops pass -> outlier detection pass/fail # 4. outlier detection pass -> qb curvature pass/fail # 5. qb curvature pass == final connections connecting_streamlines = ArraySequence(connecting_streamlines) raw_dps = sft.data_per_streamline[connecting_ids] raw_sft = StatefulTractogram.from_sft(connecting_streamlines, sft, data_per_streamline=raw_dps, data_per_point={}) _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length_ids, invalid_length_ids = _prune_segments( raw_sft.streamlines, args.min_length, args.max_length, vox_sizes[0]) invalid_length_sft = raw_sft[invalid_length_ids] valid_length = connecting_streamlines[valid_length_ids] _save_if_needed(invalid_length_sft, hdf5_file, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines valid_length_ids = range(len(connecting_streamlines)) if not len(valid_length): continue valid_length_sft = raw_sft[valid_length_ids] _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns( valid_length, args.loop_max_angle) loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops_sft = valid_length_sft[loop_ids] no_loops = valid_length[no_loop_ids] _save_if_needed(loops_sft, hdf5_file, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length no_loop_ids = range(len(valid_length)) if not len(no_loops): continue no_loops_sft = valid_length_sft[no_loop_ids] _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: outliers_ids, inliers_ids = remove_outliers( no_loops, args.outlier_threshold, nb_samplings=10, fast_approx=True) outliers_sft = no_loops_sft[outliers_ids] inliers = no_loops[inliers_ids] _save_if_needed(outliers_sft, hdf5_file, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops inliers_ids = range(len(no_loops)) if not len(inliers): continue inliers_sft = no_loops_sft[inliers_ids] _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) qb_curv_ids = np.setdiff1d(np.arange(len(inliers)), no_qb_curv_ids) qb_curv_sft = inliers_sft[qb_curv_ids] _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv_ids = range(len(inliers)) no_qb_curv_sft = inliers_sft[no_qb_curv_ids] _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))
print("Segmenting fiber groups...") segmentation = seg.Segmentation(return_idx=True, filter_by_endpoints=False) segmentation.segment(bundles, tractogram, fdata=hardi_fdata, fbval=hardi_fbval, fbvec=hardi_fbvec, mapping=mapping, reg_template=MNI_T2_img) fiber_groups = segmentation.fiber_groups for bundle in bundles: tractogram = StatefulTractogram(fiber_groups[bundle]['sl'].streamlines, img, Space.VOX) tractogram.to_rasmm() save_tractogram(tractogram, op.join(working_dir, f'afq_{bundle}_seg.trk'), bbox_valid_check=False) tractogram_img = density_map(tractogram, n_sls=1000, to_vox=True) nib.save(tractogram_img, op.join(working_dir, f'afq_{bundle}_seg_density_map.nii.gz')) show_anatomical_slices(tractogram_img.get_fdata(), f'Segmented {bundle} Density Map') ########################################################################## # Cleaning: # --------- # Each fiber group is cleaned to exclude streamlines that are outliers in terms
def load_tractogram(filename, reference, to_space=Space.RASMM, shifted_origin=False, bbox_valid_check=True, trk_header_check=True): """ Load the stateful tractogram from any format (trk, tck, fib, dpy) Parameters ---------- filename : string Filename with valid extension reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict), or 'same' if the input is a trk file. Reference that provides the spatial attribute. Typically a nifti-related object from the native diffusion used for streamlines generation space : string Space in which the streamlines will be transformed after loading (vox, voxmm or rasmm) shifted_origin : bool Information on the position of the origin, False is Trackvis standard, default (center of the voxel) True is NIFTI standard (corner of the voxel) Returns ------- output : StatefulTractogram The tractogram to load (must have been saved properly) """ _, extension = os.path.splitext(filename) if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']: logging.error('Output filename is not one of the supported format') return False if to_space not in Space: logging.error('Space MUST be one of the 3 choices (Enum)') return False if reference == 'same': if extension == '.trk': reference = filename else: logging.error('Reference must be provided, "same" is only ' + 'available for Trk file.') return False if trk_header_check and extension == '.trk': if not is_header_compatible(filename, reference): logging.error('Trk file header does not match the provided ' + 'reference') return False timer = time.time() data_per_point = None data_per_streamline = None if extension in ['.trk', '.tck']: tractogram_obj = nib.streamlines.load(filename).tractogram streamlines = tractogram_obj.streamlines if extension == '.trk': data_per_point = tractogram_obj.data_per_point data_per_streamline = tractogram_obj.data_per_streamline elif extension in ['.vtk', '.fib']: streamlines = load_vtk_streamlines(filename) elif extension in ['.dpy']: dpy_obj = Dpy(filename, mode='r') streamlines = list(dpy_obj.read_tracks()) dpy_obj.close() logging.debug('Load %s with %s streamlines in %s seconds', filename, len(streamlines), round(time.time() - timer, 3)) sft = StatefulTractogram(streamlines, reference, Space.RASMM, shifted_origin=shifted_origin, data_per_point=data_per_point, data_per_streamline=data_per_streamline) if to_space == Space.VOX: sft.to_vox() elif to_space == Space.VOXMM: sft.to_voxmm() if bbox_valid_check and not sft.is_bbox_in_vox_valid(): raise ValueError('Bounding box is not valid in voxel space, cannot ' + 'load a valid file if some coordinates are invalid') return sft
It is recommended to re-create a new StatefulTractogram object and explicitly specify in which space the streamlines are. Be careful to follow the order of operations. If the tractogram was from a Trk file with metadata, this will be lost. If you wish to keep metadata while manipulating the number or the order look at the function StatefulTractogram.remove_invalid_streamlines() for more details It is important to mention that once the object is created in a consistent state the ``save_tractogram`` function will save a valid file. And then the function ``load_tractogram`` will load them in a valid state. """ cc_sft = StatefulTractogram(cc_streamlines_vox, reference_anatomy, Space.VOX) laf_sft = StatefulTractogram(laf_streamlines_vox, reference_anatomy, Space.VOX) raf_sft = StatefulTractogram(raf_streamlines_vox, reference_anatomy, Space.VOX) lpt_sft = StatefulTractogram(lpt_streamlines_vox, reference_anatomy, Space.VOX) rpt_sft = StatefulTractogram(rpt_streamlines_vox, reference_anatomy, Space.VOX) print(len(cc_sft), len(laf_sft), len(raf_sft), len(lpt_sft), len(rpt_sft)) save_tractogram(cc_sft, 'cc_1000.trk') save_tractogram(laf_sft, 'laf_1000.trk') save_tractogram(raf_sft, 'raf_1000.trk') save_tractogram(lpt_sft, 'lpt_1000.trk') save_tractogram(rpt_sft, 'rpt_1000.trk') nib.save(nib.Nifti1Image(cc_density, affine, nifti_header), 'cc_density.nii.gz') nib.save(nib.Nifti1Image(laf_density, affine, nifti_header),
def load_data_tmp_saving(filename, reference, init_only=False, disable_centroids=False): # Since data is often re-use when comparing multiple bundles, anything # that can be computed once is saved temporarily and simply loaded on demand if not os.path.isfile(filename): if init_only: logging.warning('%s does not exist', filename) return None hash_tmp = hashlib.md5(filename.encode()).hexdigest() tmp_density_filename = os.path.join('tmp_measures/', '{0}_density.nii.gz'.format(hash_tmp)) tmp_endpoints_filename = os.path.join('tmp_measures/', '{0}_endpoints.nii.gz'.format(hash_tmp)) tmp_centroids_filename = os.path.join('tmp_measures/', '{0}_centroids.trk'.format(hash_tmp)) sft = load_tractogram(filename, reference) sft.to_vox() sft.to_corner() streamlines = sft.get_streamlines_copy() if not streamlines: if init_only: logging.warning('%s is empty', filename) return None if os.path.isfile(tmp_density_filename) \ and os.path.isfile(tmp_endpoints_filename) \ and os.path.isfile(tmp_centroids_filename): # If initilization, loading the data is useless if init_only: return None density = nib.load(tmp_density_filename).get_data() endpoints_density = nib.load(tmp_endpoints_filename).get_data() sft_centroids = load_tractogram(tmp_centroids_filename, reference) sft_centroids.to_vox() sft_centroids.to_corner() centroids = sft_centroids.get_streamlines_copy() else: transformation, dimensions, _, _ = sft.space_attributes density = compute_tract_counts_map(streamlines, dimensions) endpoints_density = get_endpoints_density_map(streamlines, dimensions, point_to_select=3) thresholds = [32, 24, 12, 6] if disable_centroids: centroids = [] else: centroids = qbx_and_merge(streamlines, thresholds, rng=RandomState(0), verbose=False).centroids # Saving tmp files to save on future computation nib.save(nib.Nifti1Image(density.astype(np.float32), transformation), tmp_density_filename) nib.save(nib.Nifti1Image(endpoints_density.astype(np.int16), transformation), tmp_endpoints_filename) # Saving in vox space and corner. centroids_sft = StatefulTractogram.from_sft(centroids, sft) save_tractogram(centroids_sft, tmp_centroids_filename) return density, endpoints_density, streamlines, centroids
# Cleaning # -------- # Each fiber group is cleaned to exclude streamlines that are outliers in terms # of their trajector and/or length. print("Cleaning fiber groups...") for bundle in bundles: print(f"Cleaning {bundle}") print(f"Before cleaning: {len(fiber_groups[bundle]['sl'])} streamlines") new_fibers, idx_in_bundle = seg.clean_bundle(fiber_groups[bundle]['sl'], return_idx=True) print(f"Afer cleaning: {len(new_fibers)} streamlines") idx_in_global = fiber_groups[bundle]['idx'][idx_in_bundle] np.save(f'{bundle}_idx.npy', idx_in_global) sft = StatefulTractogram(new_fibers.streamlines, img, Space.VOX) sft.to_rasmm() save_tractogram(sft, f'./{bundle}_afq.trk', bbox_valid_check=False) ########################################################################## # Bundle profiles # --------------- # Streamlines are represented in the original diffusion space (`Space.VOX`) and # scalar properties along the length of each bundle are queried from this scalar # data. Here, the contribution of each streamline is weighted according to how # representative this streamline is of the bundle overall. print("Extracting tract profiles...") for bundle in bundles: sft = load_tractogram(f'./{bundle}_afq.trk', img, to_space=Space.VOX) fig, ax = plt.subplots(1)
def main(): # Callback required for FURY def keypress_callback(obj, _): key = obj.GetKeySym().lower() nonlocal clusters_linewidth, background_linewidth nonlocal curr_streamlines_actor, concat_streamlines_actor, show_curr_actor iterator = len(accepted_streamlines) + len(rejected_streamlines) renwin = interactor_style.GetInteractor().GetRenderWindow() renderer = interactor_style.GetCurrentRenderer() if key == 'c' and iterator < len(sft_accepted_on_size): if show_curr_actor: renderer.rm(concat_streamlines_actor) renwin.Render() show_curr_actor = False logging.info('Streamlines rendering OFF') else: renderer.add(concat_streamlines_actor) renderer.rm(curr_streamlines_actor) renderer.add(curr_streamlines_actor) renwin.Render() show_curr_actor = True logging.info('Streamlines rendering ON') return if key == 'q': show_manager.exit() if iterator < len(sft_accepted_on_size): logging.warning( 'Early exit, everything remaining to be rejected.') return if key in ['a', 'r'] and iterator < len(sft_accepted_on_size): if key == 'a': accepted_streamlines.append(iterator) choices.append('a') logging.info('Accepted file {}'.format( filename_accepted_on_size[iterator])) elif key == 'r': rejected_streamlines.append(iterator) choices.append('r') logging.info('Rejected file {}'.format( filename_accepted_on_size[iterator])) iterator += 1 if key == 'z': if iterator > 0: last_choice = choices.pop() if last_choice == 'r': rejected_streamlines.pop() else: accepted_streamlines.pop() logging.info('Rewind on step.') iterator -= 1 else: logging.warning('Cannot rewind, first element.') if key in ['a', 'r', 'z'] and iterator < len(sft_accepted_on_size): renderer.rm(curr_streamlines_actor) curr_streamlines = sft_accepted_on_size[iterator].streamlines curr_streamlines_actor = actor.line(curr_streamlines, opacity=0.8, linewidth=clusters_linewidth) renderer.add(curr_streamlines_actor) if iterator == len(sft_accepted_on_size): print('No more cluster, press q to exit') renderer.rm(curr_streamlines_actor) renwin.Render() parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_bundles) assert_outputs_exist(parser, args, [args.out_accepted, args.out_rejected]) if args.out_accepted_dir: assert_output_dirs_exist_and_empty(parser, args, args.out_accepted_dir, create_dir=True) if args.out_rejected_dir: assert_output_dirs_exist_and_empty(parser, args, args.out_rejected_dir, create_dir=True) if args.verbose: logging.basicConfig(level=logging.INFO) if args.min_cluster_size < 1: parser.error('Minimum cluster size must be at least 1.') clusters_linewidth = args.clusters_linewidth background_linewidth = args.background_linewidth # To accelerate procedure, clusters can be discarded based on size # Concatenation is to give spatial context sft_accepted_on_size, filename_accepted_on_size = [], [] sft_rejected_on_size, filename_rejected_on_size = [], [] concat_streamlines = [] for filename in args.in_bundles: if not is_header_compatible(args.in_bundles[0], filename): return basename = os.path.basename(filename) sft = load_tractogram_with_reference(parser, args, filename, bbox_check=False) if len(sft) >= args.min_cluster_size: sft_accepted_on_size.append(sft) filename_accepted_on_size.append(basename) concat_streamlines.extend(sft.streamlines) else: logging.info('File {} has {} streamlines,' 'automatically rejected.'.format(filename, len(sft))) sft_rejected_on_size.append(sft) filename_rejected_on_size.append(basename) if not filename_accepted_on_size: parser.error('No cluster survived the cluster_size threshold.') logging.info('{} clusters to be classified.'.format( len(sft_accepted_on_size))) # The clusters are sorted by size for simplicity/efficiency tuple_accepted = zip( *sorted(zip(sft_accepted_on_size, filename_accepted_on_size), key=lambda x: len(x[0]), reverse=True)) sft_accepted_on_size, filename_accepted_on_size = tuple_accepted # Initialize the actors, scene, window, observer concat_streamlines_actor = actor.line(concat_streamlines, colors=(1, 1, 1), opacity=args.background_opacity, linewidth=background_linewidth) curr_streamlines_actor = actor.line(sft_accepted_on_size[0].streamlines, opacity=0.8, linewidth=clusters_linewidth) scene = window.Scene() interactor_style = interactor.CustomInteractorStyle() show_manager = window.ShowManager(scene, size=(800, 800), reset_camera=False, interactor_style=interactor_style) scene.add(concat_streamlines_actor) scene.add(curr_streamlines_actor) interactor_style.AddObserver('KeyPressEvent', keypress_callback) # Lauch rendering and selection procedure choices, accepted_streamlines, rejected_streamlines = [], [], [] show_curr_actor = True show_manager.start() # Early exit means everything else is rejected missing = len(args.in_bundles) - len(choices) - len(sft_rejected_on_size) len_accepted = len(sft_accepted_on_size) rejected_streamlines.extend(range(len_accepted - missing, len_accepted)) if missing > 0: logging.info('{} clusters automatically rejected' 'from early exit'.format(missing)) # Save accepted clusters (by GUI) accepted_streamlines = save_clusters(sft_accepted_on_size, accepted_streamlines, args.out_accepted_dir, filename_accepted_on_size) accepted_sft = StatefulTractogram(accepted_streamlines, sft_accepted_on_size[0], Space.RASMM) save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False) # Save rejected clusters (by GUI) rejected_streamlines = save_clusters(sft_accepted_on_size, rejected_streamlines, args.out_rejected_dir, filename_accepted_on_size) # Save rejected clusters (by size) rejected_streamlines.extend( save_clusters(sft_rejected_on_size, range(len(sft_rejected_on_size)), args.out_rejected_dir, filename_rejected_on_size)) rejected_sft = StatefulTractogram(rejected_streamlines, sft_accepted_on_size[0], Space.RASMM) save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
from dipy.data import small_sphere from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import save_trk fod = csd_fit.odf(small_sphere) pmf = fod.clip(min=0) prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf, max_angle=30., sphere=small_sphere) streamline_generator = LocalTracking(prob_dg, stopping_criterion, seeds, affine, step_size=.5) streamlines = Streamlines(streamline_generator) sft = StatefulTractogram(streamlines, hardi_img, Space.RASMM) save_trk(sft, "tractogram_probabilistic_dg_pmf.trk") if has_fury: scene = window.Scene() scene.add(actor.line(streamlines, colormap.line_colors(streamlines))) window.record(scene, out_path='tractogram_probabilistic_dg_pmf.png', size=(800, 800)) if interactive: window.show(scene) """ .. figure:: tractogram_probabilistic_dg_pmf.png :align: center **Corpus Callosum using probabilistic direction getter from PMF**
def test_track_ensemble_particle(): """ Test for ensemble tractography functionality """ import tempfile from pynets.dmri import track from dipy.core.gradients import gradient_table from dipy.data import get_sphere from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin from dipy.io.streamline import save_tractogram from nibabel.streamlines.array_sequence import ArraySequence base_dir = str(Path(__file__).parent / "examples") B0_mask = f"{base_dir}/003/anat/mean_B0_bet_mask_tmp.nii.gz" gm_in_dwi = f"{base_dir}/003/anat/t1w_gm_in_dwi.nii.gz" vent_csf_in_dwi = f"{base_dir}/003/anat/t1w_vent_csf_in_dwi.nii.gz" wm_in_dwi = f"{base_dir}/003/anat/t1w_wm_in_dwi.nii.gz" dir_path = f"{base_dir}/003/dmri" bvals = f"{dir_path}/sub-003_dwi.bval" bvecs = f"{base_dir}/003/test_out/003/dwi/bvecs_reor.bvec" gtab = gradient_table(bvals, bvecs) dwi_file = f"{base_dir}/003/test_out/003/dwi/sub-003_dwi_reor-RAS_res-2mm.nii.gz" atlas_data_wm_gm_int = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track_wmgm_int.nii.gz" labels_im_file = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track.nii.gz" conn_model = 'csd' tiss_class = 'cmc' min_length = 10 maxcrossing = 2 roi_neighborhood_tol = 6 waymask = None curv_thr_list = [40, 30] step_list = [0.1, 0.2, 0.3, 0.4, 0.5] sphere = get_sphere('repulsion724') directget = 'prob' track_type = 'particle' target_samples = 1000 dwi_img = nib.load(dwi_file) dwi_data = dwi_img.get_fdata() model, _ = track.reconstruction(conn_model, gtab, dwi_data, wm_in_dwi) temp_dir = tempfile.TemporaryDirectory() recon_path = temp_dir.name + '/model_file.hdf5' with h5py.File(recon_path, 'w') as hf: hf.create_dataset("reconstruction", data=model.astype('float32')) hf.close() streamlines = track.track_ensemble( target_samples, atlas_data_wm_gm_int, labels_im_file, recon_path, sphere, directget, curv_thr_list, step_list, track_type, maxcrossing, roi_neighborhood_tol, min_length, waymask, B0_mask, gm_in_dwi, gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, temp_dir.name) streams = f"{base_dir}/miscellaneous/streamlines_model-csd_nodetype-parc_samples-1000streams_tracktype-particle_directget-prob_minlength-10.trk" save_tractogram(StatefulTractogram(streamlines, reference=dwi_img, space=Space.VOXMM, origin=Origin.NIFTI), streams, bbox_valid_check=False) assert isinstance(streamlines, ArraySequence)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_hdf5, args.in_target_file, args.in_transfo], args.in_deformation) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) with h5py.File(args.in_hdf5, 'r') as in_hdf5_file: with h5py.File(args.out_hdf5, 'a') as out_hdf5_file: transfo = load_matrix_in_any_format(args.in_transfo) deformation_data = None if args.in_deformation is not None: deformation_data = np.squeeze( nib.load(args.in_deformation).get_fdata(dtype=np.float32)) target_img = nib.load(args.in_target_file) for key in in_hdf5_file.keys(): group = out_hdf5_file.create_group(key) affine = in_hdf5_file.attrs['affine'] dimensions = in_hdf5_file.attrs['dimensions'] voxel_sizes = in_hdf5_file.attrs['voxel_sizes'] streamlines = reconstruct_streamlines_from_hdf5( in_hdf5_file, key) if len(streamlines) == 0: continue header = create_nifti_header(affine, dimensions, voxel_sizes) moving_sft = StatefulTractogram(streamlines, header, Space.VOX, origin=Origin.TRACKVIS) for dps_key in in_hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: print(type(in_hdf5_file[key][dps_key].value)) if in_hdf5_file[key][dps_key].value.shape \ == in_hdf5_file[key]['offsets']: moving_sft.data_per_streamline[dps_key] \ = in_hdf5_file[key][dps_key] new_sft = transform_warp_streamlines( moving_sft, transfo, target_img, inverse=args.inverse, deformation_data=deformation_data, remove_invalid=not args.cut_invalid, cut_invalid=args.cut_invalid) new_sft.to_vox() new_sft.to_corner() affine, dimensions, voxel_sizes, voxel_order = get_reference_info( target_img) out_hdf5_file.attrs['affine'] = affine out_hdf5_file.attrs['dimensions'] = dimensions out_hdf5_file.attrs['voxel_sizes'] = voxel_sizes out_hdf5_file.attrs['voxel_order'] = voxel_order group = out_hdf5_file[key] group.create_dataset('data', data=new_sft.streamlines._data.astype( np.float32)) group.create_dataset('offsets', data=new_sft.streamlines._offsets) group.create_dataset('lengths', data=new_sft.streamlines._lengths) for dps_key in in_hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: if in_hdf5_file[key][dps_key].value.shape \ == in_hdf5_file[key]['offsets']: group.create_dataset( dps_key, data=new_sft.data_per_streamline[dps_key]) else: group.create_dataset( dps_key, data=in_hdf5_file[key][dps_key].value)