def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist( parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.commit2: if os.path.splitext(args.in_tractogram)[1] != '.h5': parser.error('COMMIT2 requires .h5 file for connectomics.') args.ball_stick = True if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() hdf5_file = None offsets_list = None if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose( hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. bundle_groups_len = [] hdf5_keys = list(hdf5_file.keys()) streamlines = [] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) streamlines.extend(tmp_streamlines) bundle_groups_len.append(len(tmp_streamlines)) offsets_list = np.cumsum([0] + bundle_groups_len) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001 * 10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() use_mask = args.in_tracking_mask is not None mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/')) tol_fun = 1e-2 if args.commit2 else 1e-3 mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_1/', False) if args.commit2: tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0) group_idx = np.array( [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)]) group_w = np.empty_like(bundle_groups_len, dtype=np.float64) for k in range(len(bundle_groups_len)): group_w[k] = np.sqrt(bundle_groups_len[k]) / \ (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12) prior_on_bundles = commit.solvers.init_regularisation( mit, structureIC=group_idx, weightsIC=group_w, regnorms=[ commit.solvers.group_sparsity, commit.solvers.non_negative, commit.solvers.non_negative ], lambdas=[args.lambda_commit_2, 0.0, 0.0]) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, regularisation=prior_on_bundles, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_2/', True) tmp_dir.cleanup()
def main(): parser = build_argparser() args = parser.parse_args() # check for errors in input try: dwi = nib.load(args.dwi) except: parser.error("Expecting DWI as first image") # and so on # Essential part of COMMIT filtering # Create a dictionary for a tractogram from commit import trk2dictionary trk2dictionary.run( filename_trk=args.tracks, # 'LausanneTwoShell/fibers.trk', path_out=args.output, #'LausanneTwoShell/CommitOutput', filename_peaks=args.peaks, #'LausanneTwoShell/peaks.nii.gz', filename_mask=args.wm, #'LausanneTwoShell/WM.nii.gz', fiber_shift=0.5, peaks_use_affine=True) # Precompute the rotation matrices used internally by COMMIT to create the lookup-tables for the response functions import commit commit.core.setup() # Load the data and fit selected model if (args.model_type == 'StickZeppelinBall_Model'): mit = commit.Evaluation('.', args.output) mit.CONFIG['doNormalizeSignal'] = False mit.CONFIG['doDemean'] = False mit.load_data(args.dwi, args.dwi_scheme) mit.set_model('StickZeppelinBall') mit.model.set(1.7E-3, [0.7], [1.7E-3, 3.0E-3]) mit.generate_kernels(regenerate=True) mit.load_kernels() if (args.model_type == 'LiFE_Model'): mit = commit.Evaluation('.', args.output) mit.CONFIG['doNormalizeSignal'] = False mit.CONFIG['doDemean'] = True mit.load_data(args.dwi, args.dwi_scheme) mit.set_model('StickZeppelinBall') mit.model.set(1.7E-3, [], []) mit.generate_kernels(regenerate=True) mit.load_kernels() # Load in memory the sparse data-structure previously created with trk2dicitonary.run(): mit.load_dictionary('.') # Now it's time to build the linear operator A to compute the matrix-vector multiplications for solving the linear system. This operator uses information from the segments loaded in the previous step and the lookup-tables for the response functions; it also needs to know the workload to be assigned to each thread durint the multiplications. To this aim, run the following commands: mit.set_threads(8) mit.build_operator() # fit the model (Stick-Zeppelin-Ball in this case) to the data mit.fit(tol_fun=1e-3, max_iter=500) # Save results #mit.save_results() if (args.model_type == 'StickZeppelinBall_Model'): mit.save_results( save_coeff=True) #make sure it's saved in proper location if (args.model_type == 'LiFE_Model'): mit.save_results(save_coeff=True)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) if args.threshold_weights == 'None' or args.threshold_weights == 'none': args.threshold_weights = None if not args.keep_whole_tractogram and ext != '.h5': logging.warning('Not thresholding weigth with trk file without ' 'the --keep_whole_tractogram will not save a ' 'tractogram') else: args.threshold_weights = float(args.threshold_weights) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. hdf5_keys = list(hdf5_file.keys()) streamlines = [] offsets_list = [0] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) offsets_list.append(len(tmp_streamlines)) streamlines.extend(tmp_streamlines) offsets_list = np.cumsum(offsets_list) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, gen_trk=False, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() mit.load_dictionary(tmp_dir.name, use_mask=args.in_tracking_mask is not None) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=tmp_dir.name) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0) mit.save_results() # Simplifying output for streamlines and cleaning output directory commit_results_dir = os.path.join(tmp_dir.name, 'Results_StickZeppelinBall') pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb') commit_output_dict = pickle.load(pk_file) nbr_streamlines = lazy_streamlines_count(args.in_tractogram) commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines]) np.savetxt(os.path.join(commit_results_dir, 'commit_weights.txt'), commit_weights) if ext == '.h5': new_filename = os.path.join(commit_results_dir, 'decompose_commit.h5') with h5py.File(new_filename, 'w') as new_hdf5_file: new_hdf5_file.attrs['affine'] = sft.affine new_hdf5_file.attrs['dimensions'] = sft.dimensions new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes new_hdf5_file.attrs['voxel_order'] = sft.voxel_order # Assign the weights into the hdf5, while respecting the ordering of # connections/streamlines logging.debug('Adding commit weights to {}.'.format(new_filename)) for i, key in enumerate(hdf5_keys): new_group = new_hdf5_file.create_group(key) old_group = hdf5_file[key] tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]] if args.threshold_weights is not None: essential_ind = np.where( tmp_commit_weights > args.threshold_weights)[0] tmp_streamlines = reconstruct_streamlines(old_group['data'], old_group['offsets'], old_group['lengths'], indices=essential_ind) # Replacing the data with the one above the threshold # Safe since this hdf5 was a copy in the first place new_group.create_dataset('data', data=tmp_streamlines.get_data(), dtype=np.float32) new_group.create_dataset('offsets', data=tmp_streamlines._offsets, dtype=np.int64) new_group.create_dataset('lengths', data=tmp_streamlines._lengths, dtype=np.int32) for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: new_group.create_dataset(key, data=hdf5_file[key][dps_key]) new_group.create_dataset('commit_weights', data=tmp_commit_weights) files = os.listdir(commit_results_dir) for f in files: shutil.move(os.path.join(commit_results_dir, f), args.out_dir) # Save split tractogram (essential/nonessential) and/or saving the # tractogram with data_per_streamline updated if args.keep_whole_tractogram or args.threshold_weights is not None: # Reload is needed because of COMMIT handling its file by itself tractogram_file = nib.streamlines.load(args.in_tractogram) tractogram = tractogram_file.tractogram tractogram.data_per_streamline['commit_weights'] = commit_weights if args.threshold_weights is not None: essential_ind = np.where( commit_weights > args.threshold_weights)[0] nonessential_ind = np.where( commit_weights <= args.threshold_weights)[0] logging.debug('{} essential streamlines were kept at ' 'threshold {}'.format(len(essential_ind), args.threshold_weights)) logging.debug('{} nonessential streamlines were kept at ' 'threshold {}'.format(len(nonessential_ind), args.threshold_weights)) # TODO PR when Dipy 1.2 is out with sft slicing essential_streamlines = tractogram.streamlines[essential_ind] essential_dps = tractogram.data_per_streamline[essential_ind] essential_dpp = tractogram.data_per_point[essential_ind] essential_tractogram = Tractogram(essential_streamlines, data_per_point=essential_dpp, data_per_streamline=essential_dps, affine_to_rasmm=np.eye(4)) nonessential_streamlines = tractogram.streamlines[nonessential_ind] nonessential_dps = tractogram.data_per_streamline[nonessential_ind] nonessential_dpp = tractogram.data_per_point[nonessential_ind] nonessential_tractogram = Tractogram(nonessential_streamlines, data_per_point=nonessential_dpp, data_per_streamline=nonessential_dps, affine_to_rasmm=np.eye(4)) nib.streamlines.save(essential_tractogram, os.path.join(args.out_dir, 'essential_tractogram.trk'), header=tractogram_file.header) nib.streamlines.save(nonessential_tractogram, os.path.join(args.out_dir, 'nonessential_tractogram.trk'), header=tractogram_file.header,) if args.keep_whole_tractogram: output_filename = os.path.join(args.out_dir, 'tractogram.trk') logging.debug('Saving tractogram with weights as {}'.format( output_filename)) nib.streamlines.save(tractogram_file, output_filename) tmp_dir.cleanup()