def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.fsl_bval, args.fsl_bvec]) assert_outputs_exist(parser, args, [args.mrtrix_enc]) fsl2mrtrix(args.fsl_bval, args.fsl_bvec, args.mrtrix_enc)
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) assert_gradients_filenames_valid(parser, [args.fsl_bval, args.fsl_bvec], 'fsl') assert_gradients_filenames_valid(parser, args.mrtrix_enc, 'mrtrix') assert_inputs_exist(parser, [args.fsl_bval, args.fsl_bvec]) assert_outputs_exist(parser, args, args.mrtrix_enc) fsl2mrtrix(args.fsl_bval, args.fsl_bvec, args.mrtrix_enc)
def test_bvec_bval_tools(): file_path = "" f_original_fsl_bval = file_path + "data/bval" f_original_fsl_bvec = file_path + "data/bvec" f_original_mrtrix_encoding = file_path + "data/encoding.b" f_original_dmri_bval = file_path + "data/b.txt" f_original_dmri_bvec = file_path + "data/grad.txt" f_generated_fsl_bval = file_path + "data/gen-bval" f_generated_fsl_bvec = file_path + "data/gen-bvec" f_generated_mrtrix_encoding = file_path + "data/gen-encoding.b" f_generated_temp_file1 = file_path + "data/temp_file1" f_generated_temp_file2 = file_path + "data/temp_file2" original_fsl_bval = np.loadtxt(f_original_fsl_bval) original_fsl_bvec = np.loadtxt(f_original_fsl_bvec) original_mrtrix_encoding = np.loadtxt(f_original_mrtrix_encoding) original_dmri_bval = np.loadtxt(f_original_dmri_bval) original_dmri_bvec = np.loadtxt(f_original_dmri_bvec) #dmri2fsl(f_original_dmri_bval, f_original_dmri_bvec, f_generated_fsl_bval, f_generated_fsl_bvec) #generated_fsl_bval = np.loadtxt(f_generated_fsl_bval) #generated_fsl_bvec = np.loadtxt(f_generated_fsl_bvec) #assert_array_equal(original_fsl_bval, generated_fsl_bval) #assert_array_equal(original_fsl_bvec, generated_fsl_bvec) mrtrix2fsl(f_original_mrtrix_encoding, f_generated_fsl_bval, f_generated_fsl_bvec) fsl2mrtrix(f_generated_fsl_bval, f_generated_fsl_bvec, f_generated_mrtrix_encoding) generated_mrtrix_encoding = np.loadtxt(f_generated_mrtrix_encoding) assert_array_equal(original_mrtrix_encoding, generated_mrtrix_encoding) dmri2fsl(f_original_dmri_bval, f_original_dmri_bvec, f_generated_fsl_bval, f_generated_fsl_bvec) dmri2mrtrix(f_original_dmri_bval, f_original_dmri_bvec, f_generated_mrtrix_encoding) fsl2mrtrix(f_generated_fsl_bval, f_generated_fsl_bvec, f_generated_temp_file1) dmri_fsl_mrtrix = np.loadtxt(f_generated_mrtrix_encoding) dmri_mrtrix = np.loadtxt(f_generated_temp_file1) assert_array_equal(dmri_fsl_mrtrix, dmri_mrtrix) #generated_fsl_bval = np.loadtxt(f_generated_fsl_bval) #generated_fsl_bvec = np.loadtxt(f_generated_fsl_bvec) #assert_array_equal(original_fsl_bval, generated_fsl_bval) #assert_array_equal(original_fsl_bvec, generated_fsl_bvec) return
def main(): parser = buildArgsParser() args = parser.parse_args() if not os.path.exists(args.fsl_bval): parser.error('"{0}"'.format(args.fsl_bval) + " doesn't exist. Please enter an existing file.") if not os.path.exists(args.fsl_bvec): parser.error('"{0}"'.format(args.fsl_bvec) + " doesn't exist. Please enter an existing file.") if os.path.exists(args.mrtrix_enc): if args.isForce: print('Overwriting "{0}".'.format(args.mrtrix_enc)) else: parser.error('"{0}" already exist! Use -f to overwrite it.'.format( args.mrtrix_enc)) fsl2mrtrix(args.fsl_bval, args.fsl_bvec, args.mrtrix_enc)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], args.in_mask) assert_output_dirs_exist_and_empty(parser, args, os.path.join(args.out_dir, 'NODDI'), optional=args.save_kernels) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() # Generage a scheme file from the bvals and bvecs files tmp_dir = tempfile.TemporaryDirectory() tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Compute NODDI with AMICO on {} shells at found ' 'at {}.'.format(len(shells_centroids), shells_centroids)) with redirected_stdout: # Load the data amico.core.setup() ae = amico.Evaluation('.', '.') ae.load_data(args.in_dwi, tmp_scheme_filename, mask_filename=args.in_mask) # Compute the response functions ae.set_model("NODDI") intra_vol_frac = np.linspace(0.1, 0.99, 12) intra_orient_distr = np.hstack((np.array([0.03, 0.06]), np.linspace(0.09, 0.99, 10))) ae.model.set(args.para_diff, args.iso_diff, intra_vol_frac, intra_orient_distr, False) ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id) regenerate_kernels = True ae.set_config('ATOMS_path', kernels_dir) out_model_dir = os.path.join(args.out_dir, ae.model.id) ae.set_config('OUTPUT_path', out_model_dir) ae.generate_kernels(regenerate=regenerate_kernels) ae.load_kernels() # Set number of processes solver_params = ae.get_config('solver_params') solver_params['numThreads'] = args.nbr_processes ae.set_config('solver_params', solver_params) # Model fit ae.fit() # Save the results ae.save_results() tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], args.mask) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() # Generage a scheme file from the bvals and bvecs files tmp_dir = tempfile.TemporaryDirectory() tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug( 'Compute FreeWater with AMICO on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) with redirected_stdout: amico.core.setup() # Load the data ae = amico.Evaluation('.', '.') # Load the data ae.load_data(args.in_dwi, scheme_filename=tmp_scheme_filename, mask_filename=args.mask) # Compute the response functions ae.set_model("FreeWater") model_type = 'Human' if args.mouse: model_type = 'Mouse' ae.model.set(args.para_diff, np.linspace(args.perp_diff_min, args.perp_diff_max, 10), [args.iso_diff], model_type) ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id) regenerate_kernels = True ae.set_config('ATOMS_path', kernels_dir) ae.set_config('OUTPUT_path', args.out_dir) ae.generate_kernels(regenerate=regenerate_kernels) if args.compute_only: return ae.load_kernels() # Set number of processes solver_params = ae.get_config('solver_params') solver_params['numThreads'] = args.nbr_processes ae.set_config('solver_params', solver_params) ae.set_config('doNormalizeSignal', True) ae.set_config('doKeepb0Intact', False) ae.set_config('doComputeNRMSE', True) ae.set_config('doSaveCorrectedDWI', True) # Model fit ae.fit() # Save the results ae.save_results() tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist( parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.commit2: if os.path.splitext(args.in_tractogram)[1] != '.h5': parser.error('COMMIT2 requires .h5 file for connectomics.') args.ball_stick = True if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() hdf5_file = None offsets_list = None if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose( hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. bundle_groups_len = [] hdf5_keys = list(hdf5_file.keys()) streamlines = [] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) streamlines.extend(tmp_streamlines) bundle_groups_len.append(len(tmp_streamlines)) offsets_list = np.cumsum([0] + bundle_groups_len) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001 * 10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() use_mask = args.in_tracking_mask is not None mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/')) tol_fun = 1e-2 if args.commit2 else 1e-3 mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_1/', False) if args.commit2: tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0) group_idx = np.array( [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)]) group_w = np.empty_like(bundle_groups_len, dtype=np.float64) for k in range(len(bundle_groups_len)): group_w[k] = np.sqrt(bundle_groups_len[k]) / \ (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12) prior_on_bundles = commit.solvers.init_regularisation( mit, structureIC=group_idx, weightsIC=group_w, regnorms=[ commit.solvers.group_sparsity, commit.solvers.non_negative, commit.solvers.non_negative ], lambdas=[args.lambda_commit_2, 0.0, 0.0]) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, regularisation=prior_on_bundles, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_2/', True) tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) if args.threshold_weights == 'None' or args.threshold_weights == 'none': args.threshold_weights = None if not args.keep_whole_tractogram and ext != '.h5': logging.warning('Not thresholding weigth with trk file without ' 'the --keep_whole_tractogram will not save a ' 'tractogram') else: args.threshold_weights = float(args.threshold_weights) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. hdf5_keys = list(hdf5_file.keys()) streamlines = [] offsets_list = [0] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) offsets_list.append(len(tmp_streamlines)) streamlines.extend(tmp_streamlines) offsets_list = np.cumsum(offsets_list) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, gen_trk=False, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() mit.load_dictionary(tmp_dir.name, use_mask=args.in_tracking_mask is not None) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=tmp_dir.name) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0) mit.save_results() # Simplifying output for streamlines and cleaning output directory commit_results_dir = os.path.join(tmp_dir.name, 'Results_StickZeppelinBall') pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb') commit_output_dict = pickle.load(pk_file) nbr_streamlines = lazy_streamlines_count(args.in_tractogram) commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines]) np.savetxt(os.path.join(commit_results_dir, 'commit_weights.txt'), commit_weights) if ext == '.h5': new_filename = os.path.join(commit_results_dir, 'decompose_commit.h5') with h5py.File(new_filename, 'w') as new_hdf5_file: new_hdf5_file.attrs['affine'] = sft.affine new_hdf5_file.attrs['dimensions'] = sft.dimensions new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes new_hdf5_file.attrs['voxel_order'] = sft.voxel_order # Assign the weights into the hdf5, while respecting the ordering of # connections/streamlines logging.debug('Adding commit weights to {}.'.format(new_filename)) for i, key in enumerate(hdf5_keys): new_group = new_hdf5_file.create_group(key) old_group = hdf5_file[key] tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]] if args.threshold_weights is not None: essential_ind = np.where( tmp_commit_weights > args.threshold_weights)[0] tmp_streamlines = reconstruct_streamlines(old_group['data'], old_group['offsets'], old_group['lengths'], indices=essential_ind) # Replacing the data with the one above the threshold # Safe since this hdf5 was a copy in the first place new_group.create_dataset('data', data=tmp_streamlines.get_data(), dtype=np.float32) new_group.create_dataset('offsets', data=tmp_streamlines._offsets, dtype=np.int64) new_group.create_dataset('lengths', data=tmp_streamlines._lengths, dtype=np.int32) for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: new_group.create_dataset(key, data=hdf5_file[key][dps_key]) new_group.create_dataset('commit_weights', data=tmp_commit_weights) files = os.listdir(commit_results_dir) for f in files: shutil.move(os.path.join(commit_results_dir, f), args.out_dir) # Save split tractogram (essential/nonessential) and/or saving the # tractogram with data_per_streamline updated if args.keep_whole_tractogram or args.threshold_weights is not None: # Reload is needed because of COMMIT handling its file by itself tractogram_file = nib.streamlines.load(args.in_tractogram) tractogram = tractogram_file.tractogram tractogram.data_per_streamline['commit_weights'] = commit_weights if args.threshold_weights is not None: essential_ind = np.where( commit_weights > args.threshold_weights)[0] nonessential_ind = np.where( commit_weights <= args.threshold_weights)[0] logging.debug('{} essential streamlines were kept at ' 'threshold {}'.format(len(essential_ind), args.threshold_weights)) logging.debug('{} nonessential streamlines were kept at ' 'threshold {}'.format(len(nonessential_ind), args.threshold_weights)) # TODO PR when Dipy 1.2 is out with sft slicing essential_streamlines = tractogram.streamlines[essential_ind] essential_dps = tractogram.data_per_streamline[essential_ind] essential_dpp = tractogram.data_per_point[essential_ind] essential_tractogram = Tractogram(essential_streamlines, data_per_point=essential_dpp, data_per_streamline=essential_dps, affine_to_rasmm=np.eye(4)) nonessential_streamlines = tractogram.streamlines[nonessential_ind] nonessential_dps = tractogram.data_per_streamline[nonessential_ind] nonessential_dpp = tractogram.data_per_point[nonessential_ind] nonessential_tractogram = Tractogram(nonessential_streamlines, data_per_point=nonessential_dpp, data_per_streamline=nonessential_dps, affine_to_rasmm=np.eye(4)) nib.streamlines.save(essential_tractogram, os.path.join(args.out_dir, 'essential_tractogram.trk'), header=tractogram_file.header) nib.streamlines.save(nonessential_tractogram, os.path.join(args.out_dir, 'nonessential_tractogram.trk'), header=tractogram_file.header,) if args.keep_whole_tractogram: output_filename = os.path.join(args.out_dir, 'tractogram.trk') logging.debug('Saving tractogram with weights as {}'.format( output_filename)) nib.streamlines.save(tractogram_file, output_filename) tmp_dir.cleanup()