def _afd_rd_wrapper(args): in_hdf5_filename = args[0] key = args[1] fodf_img = args[2] sh_basis = args[3] length_weighting = args[4] with h5py.File(in_hdf5_filename, 'r') as in_hdf5_file: affine = in_hdf5_file.attrs['affine'] dimensions = in_hdf5_file.attrs['dimensions'] voxel_sizes = in_hdf5_file.attrs['voxel_sizes'] streamlines = reconstruct_streamlines_from_hdf5(in_hdf5_file, key) if len(streamlines) == 0: return key, 0 header = create_nifti_header(affine, dimensions, voxel_sizes) sft = StatefulTractogram(streamlines, header, Space.VOX, origin=Origin.TRACKVIS) afd_mean_map, rd_mean_map = afd_map_along_streamlines( sft, fodf_img, sh_basis, length_weighting) afd_mean = np.average(afd_mean_map[afd_mean_map > 0]) return key, afd_mean
def _average_wrapper(args): hdf5_filenames = args[0] key = args[1] binary = args[2] out_dir = args[3] hdf5_file_ref = h5py.File(hdf5_filenames[0], 'r') affine = hdf5_file_ref.attrs['affine'] dimensions = hdf5_file_ref.attrs['dimensions'] density_data = np.zeros(dimensions, dtype=np.float32) for hdf5_filename in hdf5_filenames: hdf5_file = h5py.File(hdf5_filename, 'r') if not (np.allclose(hdf5_file.attrs['affine'], affine) and np.allclose(hdf5_file.attrs['dimensions'], dimensions)): raise IOError('{} do not have a compatible header'.format( hdf5_filename)) # scil_decompose_connectivity.py saves the streamlines in VOX/CORNER streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) density = compute_tract_counts_map(streamlines, dimensions) hdf5_file.close() if binary: density_data[density > 0] += 1 elif np.max(density) > 0: density_data += density / np.max(density) if np.max(density_data) > 0: density_data /= len(hdf5_filenames) nib.save(nib.Nifti1Image(density_data, affine), os.path.join(out_dir, '{}.nii.gz'.format(key)))
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_hdf5) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) hdf5_file = h5py.File(args.in_hdf5, 'r') for key in hdf5_file.keys(): affine = hdf5_file.attrs['affine'] dimensions = hdf5_file.attrs['dimensions'] voxel_sizes = hdf5_file.attrs['voxel_sizes'] streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) header = create_nifti_header(affine, dimensions, voxel_sizes) sft = StatefulTractogram(streamlines, header, Space.VOX, origin=Origin.TRACKVIS) if args.include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: sft.data_per_streamline[dps_key] = hdf5_file[key][dps_key] save_tractogram(sft, '{}.trk'.format(os.path.join(args.out_dir, key))) hdf5_file.close()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_hdf5) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) if args.save_empty and args.labels_list is None: parser.error("The option --save_empty requires --labels_list.") with h5py.File(args.in_hdf5, 'r') as hdf5_file: if args.save_empty: all_labels = np.loadtxt(args.labels_list, dtype='str') comb_list = list(itertools.combinations(all_labels, r=2)) comb_list.extend(zip(all_labels, all_labels)) keys = [i[0] + '_' + i[1] for i in comb_list] else: keys = hdf5_file.keys() if args.edge_keys is not None: selected_keys = [key for key in keys if key in args.edge_keys] elif args.node_keys is not None: selected_keys = [] for node in args.node_keys: selected_keys.extend([ key for key in keys if key.startswith(node + '_') or key.endswith('_' + node) ]) else: selected_keys = keys affine = hdf5_file.attrs['affine'] dimensions = hdf5_file.attrs['dimensions'] voxel_sizes = hdf5_file.attrs['voxel_sizes'] header = create_nifti_header(affine, dimensions, voxel_sizes) for key in selected_keys: streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) if len(streamlines) == 0 and not args.save_empty: continue sft = StatefulTractogram(streamlines, header, Space.VOX, origin=Origin.TRACKVIS) if args.include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: sft.data_per_streamline[dps_key] = hdf5_file[key][ dps_key] save_tractogram(sft, '{}.trk'.format(os.path.join(args.out_dir, key)))
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_hdf5, args.in_target_file, args.in_transfo], args.in_deformation) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) with h5py.File(args.in_hdf5, 'r') as in_hdf5_file: shutil.copy(args.in_hdf5, args.out_hdf5) with h5py.File(args.out_hdf5, 'a') as out_hdf5_file: transfo = load_matrix_in_any_format(args.in_transfo) deformation_data = None if args.in_deformation is not None: deformation_data = np.squeeze(nib.load( args.in_deformation).get_fdata(dtype=np.float32)) target_img = nib.load(args.in_target_file) for key in in_hdf5_file.keys(): affine = in_hdf5_file.attrs['affine'] dimensions = in_hdf5_file.attrs['dimensions'] voxel_sizes = in_hdf5_file.attrs['voxel_sizes'] streamlines = reconstruct_streamlines_from_hdf5( in_hdf5_file, key) if len(streamlines) == 0: continue header = create_nifti_header(affine, dimensions, voxel_sizes) moving_sft = StatefulTractogram(streamlines, header, Space.VOX, origin=Origin.TRACKVIS) new_sft = transform_warp_streamlines( moving_sft, transfo, target_img, inverse=args.inverse, deformation_data=deformation_data, remove_invalid=not args.cut_invalid, cut_invalid=args.cut_invalid) new_sft.to_vox() new_sft.to_corner() affine, dimensions, voxel_sizes, voxel_order = get_reference_info( target_img) out_hdf5_file.attrs['affine'] = affine out_hdf5_file.attrs['dimensions'] = dimensions out_hdf5_file.attrs['voxel_sizes'] = voxel_sizes out_hdf5_file.attrs['voxel_order'] = voxel_order group = out_hdf5_file[key] del group['data'] group.create_dataset('data', data=new_sft.streamlines.get_data()) del group['offsets'] group.create_dataset('offsets', data=new_sft.streamlines._offsets) del group['lengths'] group.create_dataset('lengths', data=new_sft.streamlines._lengths)
def _processing_wrapper(args): hdf5_filename = args[0] labels_img = args[1] in_label, out_label = args[2] measures_to_compute = copy.copy(args[3]) if args[4] is not None: similarity_directory = args[4][0] weighted = args[5] include_dps = args[6] min_lesion_vol = args[7] hdf5_file = h5py.File(hdf5_filename, 'r') key = '{}_{}'.format(in_label, out_label) if key not in hdf5_file: return streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) if len(streamlines) == 0: return affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img) measures_to_return = {} if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)): raise ValueError('Provided hdf5 have incompatible headers.') # Precompute to save one transformation, insert later if 'length' in measures_to_compute: streamlines_copy = list(streamlines) # scil_decompose_connectivity.py requires isotropic voxels mean_length = np.average(length(streamlines_copy))*voxel_sizes[0] # If density is not required, do not compute it # Only required for volume, similarity and any metrics if not ((len(measures_to_compute) == 1 and ('length' in measures_to_compute or 'streamline_count' in measures_to_compute)) or (len(measures_to_compute) == 2 and ('length' in measures_to_compute and 'streamline_count' in measures_to_compute))): density = compute_tract_counts_map(streamlines, dimensions) if 'volume' in measures_to_compute: measures_to_return['volume'] = np.count_nonzero(density) * \ np.prod(voxel_sizes) measures_to_compute.remove('volume') if 'streamline_count' in measures_to_compute: measures_to_return['streamline_count'] = len(streamlines) measures_to_compute.remove('streamline_count') if 'length' in measures_to_compute: measures_to_return['length'] = mean_length measures_to_compute.remove('length') if 'similarity' in measures_to_compute and similarity_directory: density_sim = load_node_nifti(similarity_directory, in_label, out_label, labels_img) if density_sim is None: ba_vox = 0 else: ba_vox = compute_bundle_adjacency_voxel(density, density_sim) measures_to_return['similarity'] = ba_vox measures_to_compute.remove('similarity') for measure in measures_to_compute: # Maps if isinstance(measure, str) and os.path.isdir(measure): map_dirname = measure map_data = load_node_nifti(map_dirname, in_label, out_label, labels_img) measures_to_return[map_dirname] = np.average( map_data[map_data > 0]) elif isinstance(measure, tuple): if not isinstance(measure[0], tuple) \ and os.path.isfile(measure[0]): metric_filename = measure[0] metric_img = measure[1] if not is_header_compatible(metric_img, labels_img): logging.error('{} do not have a compatible header'.format( metric_filename)) raise IOError metric_data = metric_img.get_fdata(dtype=np.float64) if weighted: avg_value = np.average(metric_data, weights=density) else: avg_value = np.average(metric_data[density > 0]) measures_to_return[metric_filename] = avg_value # lesion else: lesion_filename = measure[0][0] computed_lesion_labels = measure[0][1] lesion_img = measure[1] if not is_header_compatible(lesion_img, labels_img): logging.error('{} do not have a compatible header'.format( lesion_filename)) raise IOError voxel_sizes = lesion_img.header.get_zooms()[0:3] lesion_img.set_filename('tmp.nii.gz') lesion_atlas = get_data_as_label(lesion_img) tmp_dict = compute_lesion_stats( density.astype(bool), lesion_atlas, voxel_sizes=voxel_sizes, single_label=True, min_lesion_vol=min_lesion_vol, precomputed_lesion_labels=computed_lesion_labels) tmp_ind = _streamlines_in_mask(list(streamlines), lesion_atlas.astype(np.uint8), np.eye(3), [0, 0, 0]) streamlines_count = len( np.where(tmp_ind == [0, 1][True])[0].tolist()) if tmp_dict: measures_to_return[lesion_filename+'vol'] = \ tmp_dict['lesion_total_volume'] measures_to_return[lesion_filename+'count'] = \ tmp_dict['lesion_count'] measures_to_return[lesion_filename+'sc'] = \ streamlines_count else: measures_to_return[lesion_filename+'vol'] = 0 measures_to_return[lesion_filename+'count'] = 0 measures_to_return[lesion_filename+'sc'] = 0 if include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: out_file = os.path.join(include_dps, dps_key) if 'commit' in dps_key: measures_to_return[out_file] = np.sum( hdf5_file[key][dps_key]) else: measures_to_return[out_file] = np.average( hdf5_file[key][dps_key]) return {(in_label, out_label): measures_to_return}
def _processing_wrapper(args): hdf5_filename = args[0] labels_img = args[1] in_label, out_label = args[2] measures_to_compute = copy.copy(args[3]) if args[4] is not None: similarity_directory = args[4][0] weighted = args[5] include_dps = args[6] hdf5_file = h5py.File(hdf5_filename, 'r') key = '{}_{}'.format(in_label, out_label) if key not in hdf5_file: return streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img) measures_to_return = {} if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)): raise ValueError('Provided hdf5 have incompatible headers.') # Precompute to save one transformation, insert later if 'length' in measures_to_compute: streamlines_copy = list(streamlines) # scil_decompose_connectivity.py requires isotropic voxels mean_length = np.average(length(streamlines_copy)) * voxel_sizes[0] # If density is not required, do not compute it # Only required for volume, similarity and any metrics if not ((len(measures_to_compute) == 1 and ('length' in measures_to_compute or 'streamline_count' in measures_to_compute)) or (len(measures_to_compute) == 2 and ('length' in measures_to_compute and 'streamline_count' in measures_to_compute))): density = compute_tract_counts_map(streamlines, dimensions) if 'volume' in measures_to_compute: measures_to_return['volume'] = np.count_nonzero(density) * \ np.prod(voxel_sizes) measures_to_compute.remove('volume') if 'streamline_count' in measures_to_compute: measures_to_return['streamline_count'] = len(streamlines) measures_to_compute.remove('streamline_count') if 'length' in measures_to_compute: measures_to_return['length'] = mean_length measures_to_compute.remove('length') if 'similarity' in measures_to_compute and similarity_directory: density_sim = load_node_nifti(similarity_directory, in_label, out_label, labels_img) if density_sim is None: ba_vox = 0 else: ba_vox = compute_bundle_adjacency_voxel(density, density_sim) measures_to_return['similarity'] = ba_vox measures_to_compute.remove('similarity') for measure in measures_to_compute: if isinstance(measure, str) and os.path.isdir(measure): map_dirname = measure map_data = load_node_nifti(map_dirname, in_label, out_label, labels_img) measures_to_return[map_dirname] = np.average( map_data[map_data > 0]) elif isinstance(measure, tuple) and os.path.isfile(measure[0]): metric_filename = measure[0] metric_img = measure[1] if not is_header_compatible(metric_img, labels_img): logging.error('{} do not have a compatible header'.format( metric_filename)) raise IOError metric_data = metric_img.get_fdata(dtype=np.float64) if weighted: density = density / np.max(density) voxels_value = metric_data * density voxels_value = voxels_value[voxels_value > 0] else: voxels_value = metric_data[density > 0] measures_to_return[metric_filename] = np.average(voxels_value) if include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: out_file = os.path.join(include_dps, dps_key) measures_to_return[out_file] = np.average( hdf5_file[key][dps_key]) return {(in_label, out_label): measures_to_return}
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist( parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.commit2: if os.path.splitext(args.in_tractogram)[1] != '.h5': parser.error('COMMIT2 requires .h5 file for connectomics.') args.ball_stick = True if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() hdf5_file = None offsets_list = None if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose( hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. bundle_groups_len = [] hdf5_keys = list(hdf5_file.keys()) streamlines = [] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) streamlines.extend(tmp_streamlines) bundle_groups_len.append(len(tmp_streamlines)) offsets_list = np.cumsum([0] + bundle_groups_len) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001 * 10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() use_mask = args.in_tracking_mask is not None mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/')) tol_fun = 1e-2 if args.commit2 else 1e-3 mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_1/', False) if args.commit2: tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0) group_idx = np.array( [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)]) group_w = np.empty_like(bundle_groups_len, dtype=np.float64) for k in range(len(bundle_groups_len)): group_w[k] = np.sqrt(bundle_groups_len[k]) / \ (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12) prior_on_bundles = commit.solvers.init_regularisation( mit, structureIC=group_idx, weightsIC=group_w, regnorms=[ commit.solvers.group_sparsity, commit.solvers.non_negative, commit.solvers.non_negative ], lambdas=[args.lambda_commit_2, 0.0, 0.0]) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, regularisation=prior_on_bundles, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_2/', True) tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) if args.threshold_weights == 'None' or args.threshold_weights == 'none': args.threshold_weights = None if not args.keep_whole_tractogram and ext != '.h5': logging.warning('Not thresholding weigth with trk file without ' 'the --keep_whole_tractogram will not save a ' 'tractogram') else: args.threshold_weights = float(args.threshold_weights) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. hdf5_keys = list(hdf5_file.keys()) streamlines = [] offsets_list = [0] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) offsets_list.append(len(tmp_streamlines)) streamlines.extend(tmp_streamlines) offsets_list = np.cumsum(offsets_list) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, gen_trk=False, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() mit.load_dictionary(tmp_dir.name, use_mask=args.in_tracking_mask is not None) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=tmp_dir.name) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0) mit.save_results() # Simplifying output for streamlines and cleaning output directory commit_results_dir = os.path.join(tmp_dir.name, 'Results_StickZeppelinBall') pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb') commit_output_dict = pickle.load(pk_file) nbr_streamlines = lazy_streamlines_count(args.in_tractogram) commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines]) np.savetxt(os.path.join(commit_results_dir, 'commit_weights.txt'), commit_weights) if ext == '.h5': new_filename = os.path.join(commit_results_dir, 'decompose_commit.h5') with h5py.File(new_filename, 'w') as new_hdf5_file: new_hdf5_file.attrs['affine'] = sft.affine new_hdf5_file.attrs['dimensions'] = sft.dimensions new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes new_hdf5_file.attrs['voxel_order'] = sft.voxel_order # Assign the weights into the hdf5, while respecting the ordering of # connections/streamlines logging.debug('Adding commit weights to {}.'.format(new_filename)) for i, key in enumerate(hdf5_keys): new_group = new_hdf5_file.create_group(key) old_group = hdf5_file[key] tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]] if args.threshold_weights is not None: essential_ind = np.where( tmp_commit_weights > args.threshold_weights)[0] tmp_streamlines = reconstruct_streamlines(old_group['data'], old_group['offsets'], old_group['lengths'], indices=essential_ind) # Replacing the data with the one above the threshold # Safe since this hdf5 was a copy in the first place new_group.create_dataset('data', data=tmp_streamlines.get_data(), dtype=np.float32) new_group.create_dataset('offsets', data=tmp_streamlines._offsets, dtype=np.int64) new_group.create_dataset('lengths', data=tmp_streamlines._lengths, dtype=np.int32) for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: new_group.create_dataset(key, data=hdf5_file[key][dps_key]) new_group.create_dataset('commit_weights', data=tmp_commit_weights) files = os.listdir(commit_results_dir) for f in files: shutil.move(os.path.join(commit_results_dir, f), args.out_dir) # Save split tractogram (essential/nonessential) and/or saving the # tractogram with data_per_streamline updated if args.keep_whole_tractogram or args.threshold_weights is not None: # Reload is needed because of COMMIT handling its file by itself tractogram_file = nib.streamlines.load(args.in_tractogram) tractogram = tractogram_file.tractogram tractogram.data_per_streamline['commit_weights'] = commit_weights if args.threshold_weights is not None: essential_ind = np.where( commit_weights > args.threshold_weights)[0] nonessential_ind = np.where( commit_weights <= args.threshold_weights)[0] logging.debug('{} essential streamlines were kept at ' 'threshold {}'.format(len(essential_ind), args.threshold_weights)) logging.debug('{} nonessential streamlines were kept at ' 'threshold {}'.format(len(nonessential_ind), args.threshold_weights)) # TODO PR when Dipy 1.2 is out with sft slicing essential_streamlines = tractogram.streamlines[essential_ind] essential_dps = tractogram.data_per_streamline[essential_ind] essential_dpp = tractogram.data_per_point[essential_ind] essential_tractogram = Tractogram(essential_streamlines, data_per_point=essential_dpp, data_per_streamline=essential_dps, affine_to_rasmm=np.eye(4)) nonessential_streamlines = tractogram.streamlines[nonessential_ind] nonessential_dps = tractogram.data_per_streamline[nonessential_ind] nonessential_dpp = tractogram.data_per_point[nonessential_ind] nonessential_tractogram = Tractogram(nonessential_streamlines, data_per_point=nonessential_dpp, data_per_streamline=nonessential_dps, affine_to_rasmm=np.eye(4)) nib.streamlines.save(essential_tractogram, os.path.join(args.out_dir, 'essential_tractogram.trk'), header=tractogram_file.header) nib.streamlines.save(nonessential_tractogram, os.path.join(args.out_dir, 'nonessential_tractogram.trk'), header=tractogram_file.header,) if args.keep_whole_tractogram: output_filename = os.path.join(args.out_dir, 'tractogram.trk') logging.debug('Saving tractogram with weights as {}'.format( output_filename)) nib.streamlines.save(tractogram_file, output_filename) tmp_dir.cleanup()