def main(): parser = _build_arg_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) assert_inputs_exist(parser, args.in_bval) assert_outputs_exist(parser, args, args.out_bval) bvals, bvecs = read_bvals_bvecs(args.in_bval, None) # Find the volume indices that correspond to the shells to extract. tol = args.tolerance sorted_centroids, sorted_indices = identify_shells(bvals, tol, sort=True) bvals_to_extract = np.sort(args.bvals_to_extract) n_shells = np.shape(bvals_to_extract)[0] logging.info("number of shells: {}".format(n_shells)) logging.info("bvals to extract: {}".format(bvals_to_extract)) logging.info("estimated centroids: {}".format(sorted_centroids)) logging.info("original bvals: {}".format(bvals)) logging.info("selected indices: {}".format(sorted_indices)) new_bvals = bvals for i in range(n_shells): if np.abs(sorted_centroids[i] - bvals_to_extract[i]) <= tol: new_bvals[np.where(sorted_indices == i)] = bvals_to_extract[i] else: parser.error("No bvals to resample: tolerance is too low.") logging.info("new bvals: {}".format(new_bvals)) new_bvals.shape = (1, len(new_bvals)) np.savetxt(args.out_bval, new_bvals, '%d')
def dwi_protocol(bvals, tol=20): """ Return dwi protocol for each subject Parameters ---------- bvals : List List of bvals tol: int tolerance threshold to check if the current bval is in the list Returns ------- """ stats_per_subjects = {} values_stats = [] column_names = ["Nbr shells", "Nbr directions"] shells = {} for i, filename in enumerate(bvals): values = [] bval = np.loadtxt(bvals[i]) centroids, shells_indices = identify_shells(bval, threshold=tol) s_centroids = sorted(centroids) values.append(', '.join(str(x) for x in s_centroids)) values.append(len(shells_indices)) columns = ["bvals"] columns.append("Nbr directions") for centroid in s_centroids: nearest_centroid = get_nearest_bval(list(shells.keys()), centroid) if np.int(nearest_centroid) not in shells: shells[np.int(nearest_centroid)] = {} nb_directions = len(shells_indices[shells_indices == np.where( centroids == centroid)[0]]) print(centroid, nb_directions) if filename not in shells[np.int(nearest_centroid)]: shells[np.int(nearest_centroid)][filename] = 0 shells[np.int(nearest_centroid)][filename] += nb_directions values.append(nb_directions) columns.append("Nbr bval {}".format(centroid)) values_stats.append([len(centroids) - 1, len(shells_indices)]) stats_per_subjects[filename] = pd.DataFrame([values], index=[bvals[i]], columns=columns) stats = pd.DataFrame(values_stats, index=[bvals], columns=column_names) stats_across_subjects = pd.DataFrame( [stats.mean(), stats.std(), stats.min(), stats.max()], index=['mean', 'std', 'min', 'max'], columns=column_names) return stats_per_subjects, stats, stats_across_subjects, shells
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], args.in_mask) assert_output_dirs_exist_and_empty(parser, args, os.path.join(args.out_dir, 'NODDI'), optional=args.save_kernels) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() # Generage a scheme file from the bvals and bvecs files tmp_dir = tempfile.TemporaryDirectory() tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Compute NODDI with AMICO on {} shells at found ' 'at {}.'.format(len(shells_centroids), shells_centroids)) with redirected_stdout: # Load the data amico.core.setup() ae = amico.Evaluation('.', '.') ae.load_data(args.in_dwi, tmp_scheme_filename, mask_filename=args.in_mask) # Compute the response functions ae.set_model("NODDI") intra_vol_frac = np.linspace(0.1, 0.99, 12) intra_orient_distr = np.hstack((np.array([0.03, 0.06]), np.linspace(0.09, 0.99, 10))) ae.model.set(args.para_diff, args.iso_diff, intra_vol_frac, intra_orient_distr, False) ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id) regenerate_kernels = True ae.set_config('ATOMS_path', kernels_dir) out_model_dir = os.path.join(args.out_dir, ae.model.id) ae.set_config('OUTPUT_path', out_model_dir) ae.generate_kernels(regenerate=regenerate_kernels) ae.load_kernels() # Set number of processes solver_params = ae.get_config('solver_params') solver_params['numThreads'] = args.nbr_processes ae.set_config('solver_params', solver_params) # Model fit ae.fit() # Save the results ae.save_results() tmp_dir.cleanup()
def main(): logger = logging.getLogger("Compute_DKI_Metrics") logger.setLevel(logging.INFO) parser = _build_args_parser() args = parser.parse_args() if not args.not_all: args.dki_fa = args.dki_fa or 'dki_fa.nii.gz' args.dki_md = args.dki_md or 'dki_md.nii.gz' args.dki_ad = args.dki_ad or 'dki_ad.nii.gz' args.dki_rd = args.dki_rd or 'dki_rd.nii.gz' args.mk = args.mk or 'mk.nii.gz' args.rk = args.rk or 'rk.nii.gz' args.ak = args.ak or 'ak.nii.gz' args.dki_residual = args.dki_residual or 'dki_residual.nii.gz' args.msk = args.msk or 'msk.nii.gz' args.msd = args.msd or 'msd.nii.gz' outputs = [args.dki_fa, args.dki_md, args.dki_ad, args.dki_rd, args.mk, args.rk, args.ak, args.dki_residual, args.msk, args.msd] if args.not_all and not any(outputs): parser.error('When using --not_all, you need to specify at least ' + 'one metric to output.') assert_inputs_exist( parser, [args.input, args.bvals, args.bvecs], args.mask) assert_outputs_exist(parser, args, outputs) img = nib.load(args.input) data = img.get_fdata() affine = img.affine if args.mask is None: mask = None else: mask = nib.load(args.mask).get_fdata().astype(np.bool) # Validate bvals and bvecs bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs) if not is_normalized_bvecs(bvecs): logging.warning('Your b-vectors do not seem normalized...') bvecs = normalize_bvecs(bvecs) # Find the volume indices that correspond to the shells to extract. tol = args.tolerance shells, _ = identify_shells(bvals, tol) if not len(shells) >= 3: parser.error('Data is not multi-shell. You need at least 2 non-zero' + ' b-values') if (shells > 2500).any(): logging.warning('You seem to be using b > 2500 s/mm2 DWI data. ' + 'In theory, this is beyond the optimal range for DKI') check_b0_threshold(args, bvals.min()) gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min()) fwhm = args.smooth if fwhm > 0: # converting fwhm to Gaussian std gauss_std = fwhm / np.sqrt(8 * np.log(2)) data_smooth = np.zeros(data.shape) for v in range(data.shape[-1]): data_smooth[..., v] = gaussian_filter(data[..., v], sigma=gauss_std) data = data_smooth # Compute DKI dkimodel = dki.DiffusionKurtosisModel(gtab) dkifit = dkimodel.fit(data, mask=mask) min_k = args.min_k max_k = args.max_k if args.dki_fa: FA = dkifit.fa FA[np.isnan(FA)] = 0 FA = np.clip(FA, 0, 1) fa_img = nib.Nifti1Image(FA.astype(np.float32), affine) nib.save(fa_img, args.dki_fa) if args.dki_md: MD = dkifit.md md_img = nib.Nifti1Image(MD.astype(np.float32), affine) nib.save(md_img, args.dki_md) if args.dki_ad: AD = dkifit.ad ad_img = nib.Nifti1Image(AD.astype(np.float32), affine) nib.save(ad_img, args.dki_ad) if args.dki_rd: RD = dkifit.rd rd_img = nib.Nifti1Image(RD.astype(np.float32), affine) nib.save(rd_img, args.dki_rd) if args.mk: MK = dkifit.mk(min_k, max_k) mk_img = nib.Nifti1Image(MK.astype(np.float32), affine) nib.save(mk_img, args.mk) if args.ak: AK = dkifit.ak(min_k, max_k) ak_img = nib.Nifti1Image(AK.astype(np.float32), affine) nib.save(ak_img, args.ak) if args.rk: RK = dkifit.rk(min_k, max_k) rk_img = nib.Nifti1Image(RK.astype(np.float32), affine) nib.save(rk_img, args.rk) if args.msk or args.msd: # Compute MSDKI msdki_model = msdki.MeanDiffusionKurtosisModel(gtab) msdki_fit = msdki_model.fit(data, mask=mask) if args.msk: MSK = msdki_fit.msk MSK[np.isnan(MSK)] = 0 MSK = np.clip(MSK, min_k, max_k) msk_img = nib.Nifti1Image(MSK.astype(np.float32), affine) nib.save(msk_img, args.msk) if args.msd: MSD = msdki_fit.msd msd_img = nib.Nifti1Image(MSD.astype(np.float32), affine) nib.save(msd_img, args.msd) if args.dki_residual: S0 = np.mean(data[..., gtab.b0s_mask], axis=-1) data_p = dkifit.predict(gtab, S0) R = np.mean(np.abs(data_p[..., ~gtab.b0s_mask] - data[..., ~gtab.b0s_mask]), axis=-1) norm = np.linalg.norm(R) if norm != 0: R = R / norm if args.mask is not None: R *= mask R_img = nib.Nifti1Image(R.astype(np.float32), affine) nib.save(R_img, args.dki_residual)
def compute_sh_coefficients(dwi, gradient_table, sh_order=4, basis_type='descoteaux07', smooth=0.006, use_attenuation=False, force_b0_threshold=False, mask=None, sphere=None): """Fit a diffusion signal with spherical harmonics coefficients. Parameters ---------- dwi : nib.Nifti1Image object Diffusion signal as weighted images (4D). gradient_table : GradientTable Dipy object that contains all bvals and bvecs. sh_order : int, optional SH order to fit, by default 4. smooth : float, optional Lambda-regularization coefficient in the SH fit, by default 0.006. basis_type: str Either 'tournier07' or 'descoteaux07' use_attenuation: bool, optional If true, we will use DWI attenuation. [False] force_b0_threshold : bool, optional If set, will continue even if the minimum bvalue is suspiciously high. mask: nib.Nifti1Image object, optional Binary mask. Only data inside the mask will be used for computations and reconstruction. sphere: Sphere Dipy object. If not provided, will use Sphere(xyz=bvecs). Returns ------- sh_coeffs : np.ndarray with shape (X, Y, Z, #coeffs) Spherical harmonics coefficients at every voxel. The actual number of coefficients depends on `sh_order`. """ # Extracting infos b0_mask = gradient_table.b0s_mask bvecs = gradient_table.bvecs bvals = gradient_table.bvals # Checks if not is_normalized_bvecs(bvecs): logging.warning("Your b-vectors do not seem normalized...") bvecs = normalize_bvecs(bvecs) check_b0_threshold(force_b0_threshold, bvals.min()) # Ensure that this is on a single shell. shell_values, _ = identify_shells(bvals) shell_values.sort() if force_b0_threshold: b0_threshold = bvals.min() else: b0_threshold = DEFAULT_B0_THRESHOLD if shell_values.shape[0] != 2 or shell_values[0] > b0_threshold: raise ValueError("Can only work on single shell signals.") # Keeping b0-based infos bvecs = bvecs[np.logical_not(b0_mask)] weights = dwi[..., np.logical_not(b0_mask)] # Compute attenuation using the b0. if use_attenuation: b0 = dwi[..., b0_mask].mean(axis=3) weights = compute_dwi_attenuation(weights, b0) # Get cartesian coords from bvecs if sphere is None: sphere = Sphere(xyz=bvecs) # Fit SH sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth) # Apply mask if mask is not None: sh *= mask[..., None] return sh
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], args.mask) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() # Generage a scheme file from the bvals and bvecs files tmp_dir = tempfile.TemporaryDirectory() tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug( 'Compute FreeWater with AMICO on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) with redirected_stdout: amico.core.setup() # Load the data ae = amico.Evaluation('.', '.') # Load the data ae.load_data(args.in_dwi, scheme_filename=tmp_scheme_filename, mask_filename=args.mask) # Compute the response functions ae.set_model("FreeWater") model_type = 'Human' if args.mouse: model_type = 'Mouse' ae.model.set(args.para_diff, np.linspace(args.perp_diff_min, args.perp_diff_max, 10), [args.iso_diff], model_type) ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id) regenerate_kernels = True ae.set_config('ATOMS_path', kernels_dir) ae.set_config('OUTPUT_path', args.out_dir) ae.generate_kernels(regenerate=regenerate_kernels) if args.compute_only: return ae.load_kernels() # Set number of processes solver_params = ae.get_config('solver_params') solver_params['numThreads'] = args.nbr_processes ae.set_config('solver_params', solver_params) ae.set_config('doNormalizeSignal', True) ae.set_config('doKeepb0Intact', False) ae.set_config('doComputeNRMSE', True) ae.set_config('doSaveCorrectedDWI', True) # Model fit ae.fit() # Save the results ae.save_results() tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.gradient_sampling_file) if args.verbose: logging.basicConfig(level=logging.INFO) if len(args.gradient_sampling_file) == 2: assert_gradients_filenames_valid(parser, args.gradient_sampling_file, 'fsl') elif len(args.gradient_sampling_file) == 1: basename, ext = os.path.splitext(args.gradient_sampling_file[0]) if ext in ['.bvec', '.bvecs', '.bvals', '.bval']: parser.error('You should input two files for fsl format (.bvec ' 'and .bval).') else: assert_gradients_filenames_valid(parser, args.gradient_sampling_file, 'mrtrix') else: parser.error('Depending on the gradient format you should have ' 'two files for FSL format and one file for MRtrix') out_basename = None proj = args.enable_proj each = args.plot_shells if not (proj or each): parser.error('Select at least one type of rendering (proj or each).') if len(args.gradient_sampling_file) == 2: gradient_sampling_files = args.gradient_sampling_file gradient_sampling_files.sort() # [bval, bvec] # bvecs/bvals (FSL) format, X Y Z AND b (or transpose) points = np.genfromtxt(gradient_sampling_files[1]) if points.shape[0] == 3: points = points.T bvals = np.genfromtxt(gradient_sampling_files[0]) centroids, shell_idx = identify_shells(bvals) else: # MRtrix format X, Y, Z, b gradient_sampling_file = args.gradient_sampling_file[0] tmp = np.genfromtxt(gradient_sampling_file, delimiter=' ') points = tmp[:, :3] bvals = tmp[:, 3] centroids, shell_idx = identify_shells(bvals) if args.out_basename: out_basename, ext = os.path.splitext(args.out_basename) possible_output_paths = [ out_basename + '_shell_' + str(i) + '.png' for i in centroids ] possible_output_paths.append(out_basename + '.png') assert_outputs_exist(parser, args, possible_output_paths) for idx, b0 in enumerate(centroids[centroids < 40]): shell_idx[shell_idx == idx] = -1 centroids = np.delete(centroids, np.where(centroids == b0)) if len(shell_idx[shell_idx == -1]) > 0: shell_idx[shell_idx != -1] -= 1 sym = args.enable_sym sph = args.enable_sph same = args.same_color ms = build_ms_from_shell_idx(points, shell_idx) if proj: plot_proj_shell(ms, use_sym=sym, use_sphere=sph, same_color=same, rad=0.025, opacity=args.opacity, ofile=out_basename, ores=(args.res, args.res)) if each: plot_each_shell(ms, centroids, plot_sym_vecs=sym, use_sphere=sph, same_color=same, rad=0.025, opacity=args.opacity, ofile=out_basename, ores=(args.res, args.res))
def compute_snr(dwi, bval, bvec, b0_thr, mask, noise_mask=None, noise_map=None, split_shells=False, basename=None, verbose=False): """ Compute snr Parameters ---------- dwi: string Path to the dwi file bvec: string Path to the bvec file bval: string Path to the bval file b0_thr: int Threshold to define b0 minimum value mask: string Path to the mask noise_mask: string Path to the noise mask noise_map: string Path to the noise map basename: string Basename used for naming all output files verbose: boolean Set to use logging """ if verbose: logging.basicConfig(level=logging.INFO) img = nib.load(dwi) data = img.get_fdata(dtype=np.float32) affine = img.affine mask = get_data_as_mask(nib.load(mask), dtype=bool) bvals, bvecs = read_bvals_bvecs(bval, bvec) if split_shells: centroids, shell_indices = identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False) bvals = centroids[shell_indices] b0s_location = bvals <= b0_thr if not np.any(b0s_location): raise ValueError('You should ajust --b0_thr={} ' 'since no b0s where find.'.format(b0_thr)) if noise_mask is None and noise_map is None: b0_mask, noise_mask = median_otsu(data, vol_idx=b0s_location) # we inflate the mask, then invert it to recover only the noise noise_mask = binary_dilation(noise_mask, iterations=10).squeeze() # Add the upper half in order to delete the neck and shoulder # when inverting the mask noise_mask[..., :noise_mask.shape[-1] // 2] = 1 # Reverse the mask to get only noise noise_mask = (~noise_mask).astype('float32') logging.info('Number of voxels found ' 'in noise mask : {}'.format(np.count_nonzero(noise_mask))) logging.info('Total number of voxel ' 'in volume : {}'.format(np.size(noise_mask))) nib.save(nib.Nifti1Image(noise_mask, affine), basename + '_noise_mask.nii.gz') elif noise_mask: noise_mask = get_data_as_mask(nib.load(noise_mask), dtype=bool).squeeze() elif noise_map: img_noisemap = nib.load(noise_map) data_noisemap = img_noisemap.get_fdata(dtype=np.float32) # Val = np array (mean_signal, std_noise) val = {0: {'bvec': [0, 0, 0], 'bval': 0, 'mean': 0, 'std': 0}} for idx in range(data.shape[-1]): val[idx] = {} val[idx]['bvec'] = bvecs[idx] val[idx]['bval'] = bvals[idx] val[idx]['mean'] = np.mean(data[..., idx:idx + 1][mask > 0]) if noise_map: val[idx]['std'] = np.std(data_noisemap[mask > 0]) else: val[idx]['std'] = np.std(data[..., idx:idx + 1][noise_mask > 0]) if val[idx]['std'] == 0: raise ValueError('Your noise mask does not capture any data' '(std=0). Please check your noise mask.') val[idx]['snr'] = val[idx]['mean'] / val[idx]['std'] return val
def main(): parser = _build_arg_parser() args = parser.parse_args() if not len(args.bval) == len(args.bvec): parser.error("Not the same number of images in input.") all_data = np.concatenate([args.bval, args.bvec]) assert_inputs_exist(parser, all_data) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) if os.path.exists("data"): shutil.rmtree("data") os.makedirs("data") if os.path.exists("libs"): shutil.rmtree("libs") name = "DWI Protocol" summary, stats_for_graph, stats_all, shells = dwi_protocol(args.bval) warning_dict = {} warning_dict[name] = analyse_qa(stats_for_graph, stats_all, ["Nbr shells", "Nbr directions"]) warning_images = [filenames for filenames in warning_dict[name].values()] warning_list = np.concatenate(warning_images) warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) stats_html = dataframe_to_html(stats_all) summary_dict = {} summary_dict[name] = stats_html graphs = [] graphs.append( graph_directions_per_shells("Nbr directions per shell", shells)) graphs.append(graph_subjects_per_shells("Nbr subjects per shell", shells)) for c in ["Nbr shells", "Nbr directions"]: graph = graph_dwi_protocol(c, c, stats_for_graph) graphs.append(graph) subjects_dict = {} for bval, bvec in zip(args.bval, args.bvec): filename = os.path.basename(bval) subjects_dict[bval] = {} points = np.genfromtxt(bvec) if points.shape[0] == 3: points = points.T bvals = np.genfromtxt(bval) centroids, shell_idx = identify_shells(bvals) ms = build_ms_from_shell_idx(points, shell_idx) plot_proj_shell(ms, centroids, use_sym=True, use_sphere=True, same_color=False, rad=0.025, opacity=0.2, ofile=os.path.join("data", name + filename), ores=(800, 800)) subjects_dict[bval]['screenshot'] = os.path.join( "data", name + filename + '.png') metrics_dict = {} for subj in args.bval: summary_html = dataframe_to_html(summary[subj]) subjects_dict[subj]['stats'] = summary_html metrics_dict[name] = subjects_dict nb_subjects = len(args.bval) report = Report(args.output_report) report.generate(title="Quality Assurance DWI protocol", nb_subjects=nb_subjects, metrics_dict=metrics_dict, summary_dict=summary_dict, graph_array=graphs, warning_dict=warning_dict)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist( parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.commit2: if os.path.splitext(args.in_tractogram)[1] != '.h5': parser.error('COMMIT2 requires .h5 file for connectomics.') args.ball_stick = True if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() hdf5_file = None offsets_list = None if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose( hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. bundle_groups_len = [] hdf5_keys = list(hdf5_file.keys()) streamlines = [] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) streamlines.extend(tmp_streamlines) bundle_groups_len.append(len(tmp_streamlines)) offsets_list = np.cumsum([0] + bundle_groups_len) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001 * 10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() use_mask = args.in_tracking_mask is not None mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/')) tol_fun = 1e-2 if args.commit2 else 1e-3 mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_1/', False) if args.commit2: tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0) group_idx = np.array( [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)]) group_w = np.empty_like(bundle_groups_len, dtype=np.float64) for k in range(len(bundle_groups_len)): group_w[k] = np.sqrt(bundle_groups_len[k]) / \ (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12) prior_on_bundles = commit.solvers.init_regularisation( mit, structureIC=group_idx, weightsIC=group_w, regnorms=[ commit.solvers.group_sparsity, commit.solvers.non_negative, commit.solvers.non_negative ], lambdas=[args.lambda_commit_2, 0.0, 0.0]) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, regularisation=prior_on_bundles, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_2/', True) tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) if args.threshold_weights == 'None' or args.threshold_weights == 'none': args.threshold_weights = None if not args.keep_whole_tractogram and ext != '.h5': logging.warning('Not thresholding weigth with trk file without ' 'the --keep_whole_tractogram will not save a ' 'tractogram') else: args.threshold_weights = float(args.threshold_weights) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. hdf5_keys = list(hdf5_file.keys()) streamlines = [] offsets_list = [0] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) offsets_list.append(len(tmp_streamlines)) streamlines.extend(tmp_streamlines) offsets_list = np.cumsum(offsets_list) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, gen_trk=False, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() mit.load_dictionary(tmp_dir.name, use_mask=args.in_tracking_mask is not None) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=tmp_dir.name) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0) mit.save_results() # Simplifying output for streamlines and cleaning output directory commit_results_dir = os.path.join(tmp_dir.name, 'Results_StickZeppelinBall') pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb') commit_output_dict = pickle.load(pk_file) nbr_streamlines = lazy_streamlines_count(args.in_tractogram) commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines]) np.savetxt(os.path.join(commit_results_dir, 'commit_weights.txt'), commit_weights) if ext == '.h5': new_filename = os.path.join(commit_results_dir, 'decompose_commit.h5') with h5py.File(new_filename, 'w') as new_hdf5_file: new_hdf5_file.attrs['affine'] = sft.affine new_hdf5_file.attrs['dimensions'] = sft.dimensions new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes new_hdf5_file.attrs['voxel_order'] = sft.voxel_order # Assign the weights into the hdf5, while respecting the ordering of # connections/streamlines logging.debug('Adding commit weights to {}.'.format(new_filename)) for i, key in enumerate(hdf5_keys): new_group = new_hdf5_file.create_group(key) old_group = hdf5_file[key] tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]] if args.threshold_weights is not None: essential_ind = np.where( tmp_commit_weights > args.threshold_weights)[0] tmp_streamlines = reconstruct_streamlines(old_group['data'], old_group['offsets'], old_group['lengths'], indices=essential_ind) # Replacing the data with the one above the threshold # Safe since this hdf5 was a copy in the first place new_group.create_dataset('data', data=tmp_streamlines.get_data(), dtype=np.float32) new_group.create_dataset('offsets', data=tmp_streamlines._offsets, dtype=np.int64) new_group.create_dataset('lengths', data=tmp_streamlines._lengths, dtype=np.int32) for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: new_group.create_dataset(key, data=hdf5_file[key][dps_key]) new_group.create_dataset('commit_weights', data=tmp_commit_weights) files = os.listdir(commit_results_dir) for f in files: shutil.move(os.path.join(commit_results_dir, f), args.out_dir) # Save split tractogram (essential/nonessential) and/or saving the # tractogram with data_per_streamline updated if args.keep_whole_tractogram or args.threshold_weights is not None: # Reload is needed because of COMMIT handling its file by itself tractogram_file = nib.streamlines.load(args.in_tractogram) tractogram = tractogram_file.tractogram tractogram.data_per_streamline['commit_weights'] = commit_weights if args.threshold_weights is not None: essential_ind = np.where( commit_weights > args.threshold_weights)[0] nonessential_ind = np.where( commit_weights <= args.threshold_weights)[0] logging.debug('{} essential streamlines were kept at ' 'threshold {}'.format(len(essential_ind), args.threshold_weights)) logging.debug('{} nonessential streamlines were kept at ' 'threshold {}'.format(len(nonessential_ind), args.threshold_weights)) # TODO PR when Dipy 1.2 is out with sft slicing essential_streamlines = tractogram.streamlines[essential_ind] essential_dps = tractogram.data_per_streamline[essential_ind] essential_dpp = tractogram.data_per_point[essential_ind] essential_tractogram = Tractogram(essential_streamlines, data_per_point=essential_dpp, data_per_streamline=essential_dps, affine_to_rasmm=np.eye(4)) nonessential_streamlines = tractogram.streamlines[nonessential_ind] nonessential_dps = tractogram.data_per_streamline[nonessential_ind] nonessential_dpp = tractogram.data_per_point[nonessential_ind] nonessential_tractogram = Tractogram(nonessential_streamlines, data_per_point=nonessential_dpp, data_per_streamline=nonessential_dps, affine_to_rasmm=np.eye(4)) nib.streamlines.save(essential_tractogram, os.path.join(args.out_dir, 'essential_tractogram.trk'), header=tractogram_file.header) nib.streamlines.save(nonessential_tractogram, os.path.join(args.out_dir, 'nonessential_tractogram.trk'), header=tractogram_file.header,) if args.keep_whole_tractogram: output_filename = os.path.join(args.out_dir, 'tractogram.trk') logging.debug('Saving tractogram with weights as {}'.format( output_filename)) nib.streamlines.save(tractogram_file, output_filename) tmp_dir.cleanup()