def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, [], optional=args.output_centroids) if args.output_clusters_dir: assert_output_dirs_exist_and_empty(parser, args, args.output_clusters_dir, create_dir=True) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) streamlines = sft.streamlines thresholds = [40, 30, 20, args.dist_thresh] clusters = qbx_and_merge(streamlines, thresholds, nb_pts=args.nb_points, verbose=False) for i, cluster in enumerate(clusters): if len(cluster.indices) > 1: cluster_streamlines = itemgetter(*cluster.indices)(streamlines) else: cluster_streamlines = streamlines[cluster.indices] new_sft = StatefulTractogram(cluster_streamlines, sft, Space.RASMM) save_tractogram( new_sft, os.path.join(args.output_clusters_dir, 'cluster_{}.trk'.format(i))) if args.output_centroids: new_sft = StatefulTractogram(clusters.centroids, sft, Space.RASMM) save_tractogram(new_sft, args.output_centroids)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_hdf5) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) keys = [] for filename in args.in_hdf5: curr_file = h5py.File(filename, 'r') keys.extend(curr_file.keys()) curr_file.close() nbr_cpu = validate_nbr_processes(parser, args, args.nbr_processes) if nbr_cpu == 1: for key in keys: _average_wrapper([args.in_hdf5, key, args.binary, args.out_dir]) else: pool = multiprocessing.Pool(nbr_cpu) _ = pool.map(_average_wrapper, zip(itertools.repeat(args.in_hdf5), keys, itertools.repeat(args.binary), itertools.repeat(args.out_dir))) pool.close() pool.join()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_hdf5) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) hdf5_file = h5py.File(args.in_hdf5, 'r') for key in hdf5_file.keys(): affine = hdf5_file.attrs['affine'] dimensions = hdf5_file.attrs['dimensions'] voxel_sizes = hdf5_file.attrs['voxel_sizes'] streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) header = create_nifti_header(affine, dimensions, voxel_sizes) sft = StatefulTractogram(streamlines, header, Space.VOX, origin=Origin.TRACKVIS) if args.include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: sft.data_per_streamline[dps_key] = hdf5_file[key][dps_key] save_tractogram(sft, '{}.trk'.format(os.path.join(args.out_dir, key))) hdf5_file.close()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_hdf5) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) if args.save_empty and args.labels_list is None: parser.error("The option --save_empty requires --labels_list.") with h5py.File(args.in_hdf5, 'r') as hdf5_file: if args.save_empty: all_labels = np.loadtxt(args.labels_list, dtype='str') comb_list = list(itertools.combinations(all_labels, r=2)) comb_list.extend(zip(all_labels, all_labels)) keys = [i[0] + '_' + i[1] for i in comb_list] else: keys = hdf5_file.keys() if args.edge_keys is not None: selected_keys = [key for key in keys if key in args.edge_keys] elif args.node_keys is not None: selected_keys = [] for node in args.node_keys: selected_keys.extend([ key for key in keys if key.startswith(node + '_') or key.endswith('_' + node) ]) else: selected_keys = keys affine = hdf5_file.attrs['affine'] dimensions = hdf5_file.attrs['dimensions'] voxel_sizes = hdf5_file.attrs['voxel_sizes'] header = create_nifti_header(affine, dimensions, voxel_sizes) for key in selected_keys: streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) if len(streamlines) == 0 and not args.save_empty: continue sft = StatefulTractogram(streamlines, header, Space.VOX, origin=Origin.TRACKVIS) if args.include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: sft.data_per_streamline[dps_key] = hdf5_file[key][ dps_key] save_tractogram(sft, '{}.trk'.format(os.path.join(args.out_dir, key)))
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_matrices, [args.labels_list, args.in_ordering]) assert_output_dirs_exist_and_empty(parser, args, [], args.out_dir) if args.out_dir is None: args.out_dir = './' if args.out_suffix is None: args.out_suffix = "" out_filenames = [] for filename in args.in_matrices: basename, _ = os.path.splitext(filename) basename = os.path.basename(basename) out_filenames.append('{}/{}{}.npy'.format(args.out_dir, basename, args.out_suffix)) assert_outputs_exist(parser, args, out_filenames) with open(args.in_ordering, 'r') as my_file: lines = my_file.readlines() ordering = [[int(val) for val in lines[0].split()], [int(val) for val in lines[1].split()]] for filename in args.in_matrices: basename, _ = os.path.splitext(filename) basename = os.path.basename(basename) matrix = load_matrix_in_any_format(filename) if args.labels_list: labels_list = np.loadtxt(args.labels_list, dtype=np.int16).tolist() indices_1, indices_2 = [], [] for j in ordering[0]: indices_1.append(labels_list.index(j)) for j in ordering[1]: indices_2.append(labels_list.index(j)) else: indices_1 = ordering[0] indices_2 = ordering[1] if (np.array(indices_1) > matrix.shape[0]).any() \ or (indices_2 > np.array(matrix.shape[1])).any(): raise ValueError('Indices from config higher than matrix size, ' 'maybe you need a labels list?') tmp_matrix = matrix[tuple(indices_1), :] tmp_matrix = tmp_matrix[:, tuple(indices_2)] save_matrix_in_any_format('{}/{}{}.npy'.format(args.out_dir, basename, args.out_suffix), tmp_matrix)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_bundle] + args.in_metrics) assert_output_dirs_exist_and_empty(parser, args, args.out_folder, create_dir=True) assert_same_resolution(args.in_metrics) sft = load_tractogram_with_reference(parser, args, args.in_bundle) sft.to_vox() sft.to_corner() if len(sft.streamlines) == 0: logging.warning('Empty bundle file {}. Skipping'.format(args.bundle)) return mins, maxs, indices = _process_streamlines(sft.streamlines) metrics = [nib.load(metric) for metric in args.in_metrics] for metric in metrics: data = metric.get_fdata(dtype=np.float32) endpoint_metric_map = np.zeros(metric.shape) count = np.zeros(metric.shape) for cur_min, cur_max, cur_ind, orig_s in zip(mins, maxs, indices, sft.streamlines): streamline_mean = _compute_streamline_mean(cur_ind, cur_min, cur_max, data) xyz = orig_s[0, :].astype(int) endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean count[xyz[0], xyz[1], xyz[2]] += 1 xyz = orig_s[-1, :].astype(int) endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean count[xyz[0], xyz[1], xyz[2]] += 1 endpoint_metric_map[count != 0] /= count[count != 0] metric_fname, ext = split_name_with_nii( os.path.basename(metric.get_filename())) nib.save(nib.Nifti1Image(endpoint_metric_map, metric.affine, metric.header), os.path.join(args.out_folder, '{}_endpoints_metric{}'.format(metric_fname, ext)))
def main(): parser = _build_arg_parser() args = parser.parse_args() required = args.in_label assert_inputs_exist(parser, required) label_img = nib.load(args.in_label) label_img_data = get_data_as_label(label_img) if args.scilpy_lut: with open(os.path.join(get_lut_dir(), args.scilpy_lut + '.json')) as f: label_dict = json.load(f) (label_indices, label_names) = zip(*label_dict.items()) else: with open(args.custom_lut) as f: label_dict = json.load(f) (label_indices, label_names) = zip(*label_dict.items()) output_filenames = [] for label, name in zip(label_indices, label_names): if int(label) != 0: if args.out_prefix: output_filenames.append( os.path.join( args.out_dir, '{0}_{1}.nii.gz'.format(args.out_prefix, name))) else: output_filenames.append( os.path.join(args.out_dir, '{0}.nii.gz'.format(name))) assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir) assert_outputs_exist(parser, args, output_filenames) # Extract the voxels that match the label and save them to a file. cnt_filename = 0 for label in label_indices: if int(label) != 0: split_label = np.zeros(label_img.shape, dtype=np.uint16) split_label[np.where(label_img_data == int(label))] = label split_image = nib.Nifti1Image(split_label, label_img.affine, header=label_img.header) nib.save(split_image, output_filenames[cnt_filename]) cnt_filename += 1
def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist( parser, [args.in_tractogram, args.config_file, args.transformation]) for directory in args.models_directories: if not os.path.isdir(directory): parser.error('Input folder {0} does not exist'.format(directory)) assert_output_dirs_exist_and_empty(parser, args, args.output) logging.basicConfig( filename=os.path.join(args.output, 'logfile.txt'), filemode='w', format='%(asctime)s, %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level=args.log_level) coloredlogs.install(level=args.log_level) transfo = np.loadtxt(args.transformation) if args.inverse: transfo = np.linalg.inv(np.loadtxt(args.transformation)) with open(args.config_file) as json_data: config = json.load(json_data) voting = VotingScheme( config, args.models_directories, transfo, args.output, tractogram_clustering_thr=args.tractogram_clustering_thr, minimal_vote_ratio=args.minimal_vote_ratio, multi_parameters=args.multi_parameters) if args.seeds is None: seeds = [random.randint(1, 1000)] else: seeds = args.seeds voting(args.in_tractogram, nbr_processes=args.processes, seeds=seeds)
def main(): parser = _build_arg_parser() args = parser.parse_args() required = args.in_labels assert_inputs_exist(parser, required) label_img = nib.load(args.in_labels) label_img_data = get_data_as_label(label_img) if args.range: label_indices = [item for sublist in args.range for item in sublist] else: label_indices = np.unique(label_img_data) label_names = [str(i) for i in label_indices] output_filenames = [] for label, name in zip(label_indices, label_names): if int(label) != 0: if args.out_prefix: output_filenames.append( os.path.join( args.out_dir, '{0}_{1}.nii.gz'.format(args.out_prefix, name))) else: output_filenames.append( os.path.join(args.out_dir, '{0}.nii.gz'.format(name))) assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir) assert_outputs_exist(parser, args, output_filenames) # Extract the voxels that match the label and save them to a file. cnt_filename = 0 for label in label_indices: if int(label) != 0: split_label = np.zeros(label_img.shape, dtype=np.uint16) split_label[np.where(label_img_data == int(label))] = label split_image = nib.Nifti1Image(split_label, label_img.affine, header=label_img.header) nib.save(split_image, output_filenames[cnt_filename]) cnt_filename += 1
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) _, out_extension = os.path.splitext(args.in_tractogram) assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir) # Check only the first potential output filename assert_outputs_exist(parser, args, os.path.join(args.out_dir, '{}_0{}'.format(args.out_prefix, out_extension))) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) streamlines_count = len(sft.streamlines) if args.nb_chunk: chunk_size = int(streamlines_count/args.nb_chunk) nb_chunk = args.nb_chunk else: chunk_size = args.chunk_size nb_chunk = int(streamlines_count/chunk_size)+1 # All chunks will be equal except the last one chunk_sizes = np.ones((nb_chunk,), dtype=np.int16) * chunk_size chunk_sizes[-1] += (streamlines_count - chunk_size * nb_chunk) curr_count = 0 for i in range(nb_chunk): streamlines = sft.streamlines[curr_count:curr_count + chunk_sizes[i]] data_per_streamline = sft.data_per_streamline[curr_count:curr_count + chunk_sizes[i]] data_per_point = sft.data_per_point[curr_count:curr_count + chunk_sizes[i]] curr_count += chunk_sizes[i] new_sft = StatefulTractogram.from_sft(streamlines, sft, data_per_point=data_per_point, data_per_streamline=data_per_streamline) out_name = os.path.join(args.out_dir, '{0}_{1}{2}'.format(args.out_prefix, i, out_extension)) save_tractogram(new_sft, out_name)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_json) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) if args.fill_color and len(args.fill_color) != 8: parser.error('Hexadecimal RGB color should be formatted as 0xRRGGBB') with open(args.in_json, 'r+') as f: mean_std_per_point = json.load(f) for bundle_name, bundle_stats in mean_std_per_point.items(): for metric, metric_stats in bundle_stats.items(): nb_points = len(metric_stats) num_digits_labels = len(str(nb_points)) means = [] stds = [] for label_int in range(1, nb_points + 1): label = str(label_int).zfill(num_digits_labels) mean = metric_stats.get(label, {'mean': np.nan})['mean'] mean = mean if mean else np.nan std = metric_stats.get(label, {'std': np.nan})['std'] std = std if std else np.nan means += [mean] stds += [std] fig = plot_metrics_stats( np.array(means), np.array(stds), title=bundle_name, xlabel='Location along the streamline', ylabel=metric, fill_color=(args.fill_color.replace("0x", "#") if args.fill_color else None)) fig.savefig(os.path.join(args.out_dir, '{}_{}.png'.format(bundle_name, metric)), bbox_inches='tight')
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) if args.threshold_weights == 'None' or args.threshold_weights == 'none': args.threshold_weights = None if not args.keep_whole_tractogram and ext != '.h5': logging.warning('Not thresholding weigth with trk file without ' 'the --keep_whole_tractogram will not save a ' 'tractogram') else: args.threshold_weights = float(args.threshold_weights) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. hdf5_keys = list(hdf5_file.keys()) streamlines = [] offsets_list = [0] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) offsets_list.append(len(tmp_streamlines)) streamlines.extend(tmp_streamlines) offsets_list = np.cumsum(offsets_list) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, gen_trk=False, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() mit.load_dictionary(tmp_dir.name, use_mask=args.in_tracking_mask is not None) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=tmp_dir.name) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0) mit.save_results() # Simplifying output for streamlines and cleaning output directory commit_results_dir = os.path.join(tmp_dir.name, 'Results_StickZeppelinBall') pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb') commit_output_dict = pickle.load(pk_file) nbr_streamlines = lazy_streamlines_count(args.in_tractogram) commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines]) np.savetxt(os.path.join(commit_results_dir, 'commit_weights.txt'), commit_weights) if ext == '.h5': new_filename = os.path.join(commit_results_dir, 'decompose_commit.h5') with h5py.File(new_filename, 'w') as new_hdf5_file: new_hdf5_file.attrs['affine'] = sft.affine new_hdf5_file.attrs['dimensions'] = sft.dimensions new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes new_hdf5_file.attrs['voxel_order'] = sft.voxel_order # Assign the weights into the hdf5, while respecting the ordering of # connections/streamlines logging.debug('Adding commit weights to {}.'.format(new_filename)) for i, key in enumerate(hdf5_keys): new_group = new_hdf5_file.create_group(key) old_group = hdf5_file[key] tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]] if args.threshold_weights is not None: essential_ind = np.where( tmp_commit_weights > args.threshold_weights)[0] tmp_streamlines = reconstruct_streamlines(old_group['data'], old_group['offsets'], old_group['lengths'], indices=essential_ind) # Replacing the data with the one above the threshold # Safe since this hdf5 was a copy in the first place new_group.create_dataset('data', data=tmp_streamlines.get_data(), dtype=np.float32) new_group.create_dataset('offsets', data=tmp_streamlines._offsets, dtype=np.int64) new_group.create_dataset('lengths', data=tmp_streamlines._lengths, dtype=np.int32) for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: new_group.create_dataset(key, data=hdf5_file[key][dps_key]) new_group.create_dataset('commit_weights', data=tmp_commit_weights) files = os.listdir(commit_results_dir) for f in files: shutil.move(os.path.join(commit_results_dir, f), args.out_dir) # Save split tractogram (essential/nonessential) and/or saving the # tractogram with data_per_streamline updated if args.keep_whole_tractogram or args.threshold_weights is not None: # Reload is needed because of COMMIT handling its file by itself tractogram_file = nib.streamlines.load(args.in_tractogram) tractogram = tractogram_file.tractogram tractogram.data_per_streamline['commit_weights'] = commit_weights if args.threshold_weights is not None: essential_ind = np.where( commit_weights > args.threshold_weights)[0] nonessential_ind = np.where( commit_weights <= args.threshold_weights)[0] logging.debug('{} essential streamlines were kept at ' 'threshold {}'.format(len(essential_ind), args.threshold_weights)) logging.debug('{} nonessential streamlines were kept at ' 'threshold {}'.format(len(nonessential_ind), args.threshold_weights)) # TODO PR when Dipy 1.2 is out with sft slicing essential_streamlines = tractogram.streamlines[essential_ind] essential_dps = tractogram.data_per_streamline[essential_ind] essential_dpp = tractogram.data_per_point[essential_ind] essential_tractogram = Tractogram(essential_streamlines, data_per_point=essential_dpp, data_per_streamline=essential_dps, affine_to_rasmm=np.eye(4)) nonessential_streamlines = tractogram.streamlines[nonessential_ind] nonessential_dps = tractogram.data_per_streamline[nonessential_ind] nonessential_dpp = tractogram.data_per_point[nonessential_ind] nonessential_tractogram = Tractogram(nonessential_streamlines, data_per_point=nonessential_dpp, data_per_streamline=nonessential_dps, affine_to_rasmm=np.eye(4)) nib.streamlines.save(essential_tractogram, os.path.join(args.out_dir, 'essential_tractogram.trk'), header=tractogram_file.header) nib.streamlines.save(nonessential_tractogram, os.path.join(args.out_dir, 'nonessential_tractogram.trk'), header=tractogram_file.header,) if args.keep_whole_tractogram: output_filename = os.path.join(args.out_dir, 'tractogram.trk') logging.debug('Saving tractogram with weights as {}'.format( output_filename)) nib.streamlines.save(tractogram_file, output_filename) tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_volume) assert_outputs_exist(parser, args, args.out_image) output_names = [ 'axial_superior', 'axial_inferior', 'coronal_posterior', 'coronal_anterior', 'sagittal_left', 'sagittal_right' ] for filename in args.in_bundles: _, ext = os.path.splitext(filename) if ext == '.tck': tractogram = load_tractogram_with_reference(parser, args, filename) else: tractogram = filename if not is_header_compatible(args.in_volume, tractogram): parser.error('{} does not have a compatible header with {}'.format( filename, args.in_volume)) # Delete temporary tractogram else: del tractogram output_dir = os.path.dirname(args.out_image) if output_dir: assert_output_dirs_exist_and_empty(parser, args, output_dir, create_dir=True) _, extension = os.path.splitext(args.out_image) # ----------------------------------------------------------------------- # # Mosaic, column 0: orientation names and data description # ----------------------------------------------------------------------- # width = args.resolution_of_thumbnails height = args.resolution_of_thumbnails rows = 6 cols = len(args.in_bundles) text_pos_x = 50 text_pos_y = 50 # Creates a new empty image, RGB mode mosaic = Image.new('RGB', ((cols + 1) * width, (rows + 1) * height)) # Prepare draw and font objects to render text draw = ImageDraw.Draw(mosaic) font = get_font(args) # Data of the volume used as background ref_img = nib.load(args.in_volume) data = ref_img.get_fdata(dtype=np.float32) affine = ref_img.affine mean, std = data[data > 0].mean(), data[data > 0].std() value_range = (mean - 0.5 * std, mean + 1.5 * std) # First column with rows description draw_column_with_names(draw, output_names, text_pos_x, text_pos_y, height, font) # ----------------------------------------------------------------------- # # Columns with bundles # ----------------------------------------------------------------------- # random.seed(args.random_coloring) for idx_bundle, bundle_file in enumerate(args.in_bundles): bundle_file_name = os.path.basename(bundle_file) bundle_name, bundle_ext = split_name_with_nii(bundle_file_name) i = (idx_bundle + 1) * width if not os.path.isfile(bundle_file): print('\nInput file {} doesn\'t exist.'.format(bundle_file)) number_streamlines = 0 view_number = 6 j = height * view_number draw_bundle_information(draw, bundle_file_name, number_streamlines, i + text_pos_x, j + text_pos_y, font) else: if args.uniform_coloring: colors = args.uniform_coloring elif args.random_coloring is not None: colors = random_rgb() # Select the streamlines to plot if bundle_ext in ['.tck', '.trk']: if (args.random_coloring is None and args.uniform_coloring is None): colors = None bundle_tractogram_file = nib.streamlines.load(bundle_file) streamlines = bundle_tractogram_file.streamlines bundle_actor = actor.line(streamlines, colors) nbr_of_elem = len(streamlines) # Select the volume to plot elif bundle_ext in ['.nii.gz', '.nii']: if not args.random_coloring and not args.uniform_coloring: colors = [1.0, 1.0, 1.0] bundle_img_file = nib.load(bundle_file) roi = get_data_as_mask(bundle_img_file) bundle_actor = actor.contour_from_roi(roi, bundle_img_file.affine, colors) nbr_of_elem = np.count_nonzero(roi) # Render ren = window.Scene() zoom = args.zoom opacity = args.opacity_background # Structural data slice_actor = actor.slicer(data, affine, value_range) slice_actor.opacity(opacity) ren.add(slice_actor) # Streamlines ren.add(bundle_actor) ren.reset_camera() ren.zoom(zoom) view_number = 0 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.pitch(180) ren.reset_camera() ren.zoom(zoom) view_number = 1 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.rm(slice_actor) slice_actor2 = slice_actor.copy() slice_actor2.display(None, slice_actor2.shape[1] // 2, None) slice_actor2.opacity(opacity) ren.add(slice_actor2) ren.pitch(90) ren.set_camera(view_up=(0, 0, 1)) ren.reset_camera() ren.zoom(zoom) view_number = 2 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.pitch(180) ren.set_camera(view_up=(0, 0, 1)) ren.reset_camera() ren.zoom(zoom) view_number = 3 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.rm(slice_actor2) slice_actor3 = slice_actor.copy() slice_actor3.display(slice_actor3.shape[0] // 2, None, None) slice_actor3.opacity(opacity) ren.add(slice_actor3) ren.yaw(90) ren.reset_camera() ren.zoom(zoom) view_number = 4 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.yaw(180) ren.reset_camera() ren.zoom(zoom) view_number = 5 set_img_in_cell(mosaic, ren, view_number, width, height, i) view_number = 6 j = height * view_number draw_bundle_information(draw, bundle_file_name, nbr_of_elem, i + text_pos_x, j + text_pos_y, font) # Save image to file mosaic.save(args.out_image)
def main(): parser = _build_arg_parser() args = parser.parse_args() # The number of labels maps must be equal to the number of bundles tmp = args.in_bundles + args.in_labels args.in_labels = args.in_bundles[(len(tmp) // 2):] + args.in_labels args.in_bundles = args.in_bundles[0:len(tmp) // 2] assert_inputs_exist(parser, args.in_bundles + args.in_labels) assert_output_dirs_exist_and_empty(parser, args, [], optional=args.save_rendering) stats = {} num_digits_labels = 3 scene = window.Scene() scene.background(tuple(map(int, args.background))) for i, filename in enumerate(args.in_bundles): sft = load_tractogram_with_reference(parser, args, filename) sft.to_vox() sft.to_corner() img_labels = nib.load(args.in_labels[i]) # same subject: same header or coregistered subjects: same header if not is_header_compatible(sft, args.in_bundles[0]) \ or not is_header_compatible(img_labels, args.in_bundles[0]): parser.error('All headers must be identical.') data_labels = img_labels.get_fdata() bundle_name, _ = os.path.splitext(os.path.basename(filename)) unique_labels = np.unique(data_labels)[1:].astype(int) # Empty bundle should at least return a json if not len(sft): tmp_dict = {} for label in unique_labels: tmp_dict['{}'.format(label).zfill(num_digits_labels)] \ = {'mean': 0.0, 'std': 0.0} stats[bundle_name] = {'diameter': tmp_dict} continue counter = 0 labels_dict = {label: ([], []) for label in unique_labels} pts_labels = map_coordinates(data_labels, sft.streamlines._data.T - 0.5, order=0) # For each label, all positions and directions are needed to get # a tube estimation per label. for streamline in sft.streamlines: direction = np.gradient(streamline, axis=0).tolist() curr_labels = pts_labels[counter:counter + len(streamline)].tolist() for i, label in enumerate(curr_labels): if label > 0: labels_dict[label][0].append(streamline[i]) labels_dict[label][1].append(direction[i]) counter += len(streamline) centroid = np.zeros((len(unique_labels), 3)) radius = np.zeros((len(unique_labels), 1)) error = np.zeros((len(unique_labels), 1)) for key in unique_labels: key = int(key) c, d, e = fit_circle_in_space(labels_dict[key][0], labels_dict[key][1], args.fitting_func) centroid[key - 1], radius[key - 1], error[key - 1] = c, d, e # Spatial smoothing to avoid degenerate estimation centroid_smooth = gaussian_filter(centroid, sigma=[1, 0], mode='nearest') centroid_smooth[::len(centroid) - 1] = centroid[::len(centroid) - 1] radius = gaussian_filter(radius, sigma=1, mode='nearest') error = gaussian_filter(error, sigma=1, mode='nearest') tmp_dict = {} for label in unique_labels: tmp_dict['{}'.format(label).zfill(num_digits_labels)] \ = {'mean': float(radius[label-1])*2, 'std': float(error[label-1])} stats[bundle_name] = {'diameter': tmp_dict} if args.show_rendering or args.save_rendering: tube_actor = create_tube_with_radii( centroid_smooth, radius, error, wireframe=args.wireframe, error_coloring=args.error_coloring) scene.add(tube_actor) cmap = plt.get_cmap('jet') coloring = cmap(pts_labels / np.max(pts_labels))[:, 0:3] streamlines_actor = actor.streamtube(sft.streamlines, linewidth=args.width, opacity=args.opacity, colors=coloring) scene.add(streamlines_actor) slice_actor = actor.slicer(data_labels, np.eye(4)) slice_actor.opacity(0.0) scene.add(slice_actor) # If there's actually streamlines to display if args.show_rendering: showm = window.ShowManager(scene, reset_camera=True) showm.initialize() showm.start() elif args.save_rendering: scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'superior.png'), size=(1920, 1080), offscreen=True) scene.pitch(180) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'inferior.png'), size=(1920, 1080), offscreen=True) scene.pitch(90) scene.set_camera(view_up=(0, 0, 1)) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'posterior.png'), size=(1920, 1080), offscreen=True) scene.pitch(180) scene.set_camera(view_up=(0, 0, 1)) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'anterior.png'), size=(1920, 1080), offscreen=True) scene.yaw(90) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'right.png'), size=(1920, 1080), offscreen=True) scene.yaw(180) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'left.png'), size=(1920, 1080), offscreen=True) print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_json) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) if args.fill_color and len(args.fill_color) != 8: parser.error('Hexadecimal RGB color should be formatted as 0xRRGGBB') with open(args.in_json, 'r+') as f: if args.stats_over_population: mean_std_per_point = json.load(f) else: mean_std_per_point = list(json.load(f).values())[0] for bundle_name, bundle_stats in mean_std_per_point.items(): for metric, metric_stats in bundle_stats.items(): nb_points = args.nb_pts if args.nb_pts is not None \ else len(metric_stats) num_digits_labels = len(list(metric_stats.keys())[0]) means = [] stds = [] for label_int in range(1, nb_points + 1): label = str(label_int).zfill(num_digits_labels) mean = metric_stats.get(label, {'mean': 0})['mean'] std = metric_stats.get(label, {'std': 0})['std'] if not isinstance(mean, list): mean = [mean] std = [std] means += [mean] stds += [std] color = None if args.dict_colors: with open(args.dict_colors, 'r') as data: dict_colors = json.load(data) # Supports variation from rbx-flow for key in dict_colors.keys(): if key in bundle_name: color = dict_colors[key] elif args.fill_color is not None: color = args.fill_color if color is None: color = '0x000000' # Robustify for missing data means = np.array( list(itertools.zip_longest(*means, fillvalue=np.nan))).T stds = np.array( list(itertools.zip_longest(*stds, fillvalue=np.nan))).T for i in range(len(means)): _nan = np.isnan(means[i, :]) if np.count_nonzero(_nan) > 0: if np.count_nonzero(_nan) < len(means[i, :]): means[i, _nan] = np.average(means[i, ~_nan]) stds[i, _nan] = np.average(stds[i, ~_nan]) else: means[i, _nan] = -1 stds[i, _nan] = -1 if not args.stats_over_population: means = np.squeeze(means) stds = np.squeeze(stds) fig = plot_metrics_stats(means, stds, title=bundle_name, xlabel='Location along the streamline', ylabel=metric, fill_color=(color.replace("0x", "#")), display_means=args.display_means) fig.savefig(os.path.join(args.out_dir, '{}_{}.png'.format(bundle_name, metric)), bbox_inches='tight')
def load_and_verify_everything(parser, args): """ - Reads the config file - Loads the masks / sft - If endpoints were given instead of head + tail, separate into two sub-rois. - Verifies compatibility """ assert_inputs_exist(parser, args.gt_config) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) os.makedirs(os.path.join(args.out_dir, 'segmented_VB')) if args.compute_ic: os.makedirs(os.path.join(args.out_dir, 'segmented_IB')) if args.save_wpc_separately: os.makedirs(os.path.join(args.out_dir, 'segmented_WPC')) # Read the config file (bundle_names, gt_masks_files, limits_masks_files, roi_options, lengths, angles, orientation_lengths, abs_orientation_lengths) = read_config_file(args) # Find all masks to be loaded. all_mask_files = list( itertools.chain( *[list(roi_option.values()) for roi_option in roi_options])) all_mask_files = list(dict.fromkeys(all_mask_files)) # Removes duplicates # Verify options assert_inputs_exist(parser, all_mask_files + [args.in_tractogram], gt_masks_files + limits_masks_files) if args.verbose: logging.basicConfig(level=logging.INFO) logging.info("Loading tractogram.") sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) if args.remove_invalid: sft.remove_invalid_streamlines() logging.info("Verifying compatibility of tractogram with gt_masks and " "limits_masks.") all_masks = gt_masks_files + limits_masks_files all_masks = list(dict.fromkeys(all_masks)) # Removes duplicates verify_compatibility_with_reference_sft(sft, all_masks, parser, args) logging.info("Loading and/or computing ground-truth masks and limits " "masks.") gt_masks, _, affine, dimensions, = \ compute_masks(gt_masks_files, parser, args) limits_masks, limits_inv_masks, _, _, = \ compute_masks(limits_masks_files, parser, args) logging.info("Extracting ground-truth head and tail masks.") gt_tails, gt_heads = compute_endpoint_masks(roi_options, affine, dimensions, args.out_dir) # Update all_rois, remove duplicates all_rois = gt_tails + gt_heads all_rois = list(dict.fromkeys(all_rois)) # Removes duplicates logging.info("Verifying tractogram compatibility with endpoint ROIs.") for file in all_rois: compatible = is_header_compatible(sft, file) if not compatible: parser.error("Input tractogram incompatible with {}".format(file)) return (gt_tails, gt_heads, sft, bundle_names, all_rois, lengths, angles, orientation_lengths, abs_orientation_lengths, limits_inv_masks, gt_masks, dimensions)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram] + args.gt_bundles) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) if (args.gt_tails and not args.gt_heads) \ or (args.gt_heads and not args.gt_tails): parser.error("Both --gt_heads and --gt_tails are needed.") if args.gt_endpoints and (args.gt_tails or args.gt_heads): parser.error("Can only provide --gt_endpoints or --gt_tails/gt_heads") if not args.gt_endpoints and (not args.gt_tails and not args.gt_heads): parser.error( "Either input --gt_endpoints or --gt_heads and --gt_tails.") if args.verbose: logging.basicConfig(level=logging.INFO) _, ext = os.path.splitext(args.in_tractogram) sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) if args.remove_invalid: sft.remove_invalid_streamlines() initial_count = len(sft) logging.info("Verifying compatibility with ground-truth") for gt in args.gt_bundles: compatible = is_header_compatible(sft, gt) if not compatible: parser.error("Input tractogram incompatible with" " {}".format(gt)) logging.info("Computing ground-truth masks") gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions, = \ compute_gt_masks(args.gt_bundles, parser, args) # If endpoints without heads/tails are loaded, split them and continue # normally after. Q/C of the output is important if args.gt_endpoints: logging.info("Extracting ground-truth end and tail masks") gt_tails, gt_heads, affine, dimensions = \ extract_tails_heads_from_endpoints( args.gt_endpoints, args.out_dir) else: gt_tails, gt_heads = args.gt_tails, args.gt_heads logging.info("Verifying compatibility with endpoints") for gt in gt_tails + gt_heads: compatible = is_header_compatible(sft, gt) if not compatible: parser.error("Input tractogram incompatible with" " {}".format(gt)) # Load the endpoints heads/tails, keep the correct combinations # separately from all the possible combinations tc_filenames = list(zip(gt_tails, gt_heads)) length_dict = {} if args.gt_config: with open(args.gt_config, "r") as json_file: length_dict = json.load(json_file) tc_streamlines_list = [] wpc_streamlines_list = [] fc_streamlines_list = [] nc_streamlines = [] logging.info("Scoring true connections") for i, (mask_1_filename, mask_2_filename) in enumerate(tc_filenames): # Automatically generate filename for Q/C prefix_1 = extract_prefix(mask_1_filename) prefix_2 = extract_prefix(mask_2_filename) tc_sft, wpc_sft, fc_sft, nc, sft = extract_true_connections( sft, mask_1_filename, mask_2_filename, args.gt_config, length_dict, extract_prefix(args.gt_bundles[i]), gt_bundle_inv_masks[i], args.dilate_endpoints, args.wrong_path_as_separate) nc_streamlines.extend(nc) if len(tc_sft) > 0: save_tractogram(tc_sft, os.path.join( args.out_dir, "{}_{}_tc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) if len(wpc_sft) > 0: save_tractogram(wpc_sft, os.path.join( args.out_dir, "{}_{}_wpc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) if len(fc_sft) > 0: save_tractogram(fc_sft, os.path.join( args.out_dir, "{}_{}_fc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) tc_streamlines_list.append(tc_sft.streamlines) wpc_streamlines_list.append(wpc_sft.streamlines) fc_streamlines_list.append(fc_sft.streamlines) logging.info("Recognized {} streamlines between {} and {}".format( len(tc_sft.streamlines) + len(wpc_sft.streamlines) + len(fc_sft.streamlines) + len(nc), prefix_1, prefix_2)) # Again keep the keep the correct combinations comb_filename = list( itertools.combinations(itertools.chain(*zip(gt_tails, gt_heads)), r=2)) # Remove the true connections from all combinations, leaving only # false connections for tc_f in tc_filenames: comb_filename.remove(tc_f) logging.info("Scoring false connections") # Go through all the possible combinations of endpoints masks for i, roi in enumerate(comb_filename): mask_1_filename, mask_2_filename = roi # That would be done here. # Automatically generate filename for Q/C prefix_1 = extract_prefix(mask_1_filename) prefix_2 = extract_prefix(mask_2_filename) _, ext = os.path.splitext(args.in_tractogram) fc_sft, sft = extract_false_connections(sft, mask_1_filename, mask_2_filename, args.dilate_endpoints) if len(fc_sft) > 0: save_tractogram(fc_sft, os.path.join( args.out_dir, "{}_{}_fc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) logging.info("Recognized {} streamlines between {} and {}".format( len(fc_sft.streamlines), prefix_1, prefix_2)) fc_streamlines_list.append(fc_sft.streamlines) nc_streamlines.extend(sft.streamlines) final_results = {} no_conn_sft = StatefulTractogram.from_sft(nc_streamlines, sft) save_tractogram(no_conn_sft, os.path.join(args.out_dir, "nc{}".format(ext)), bbox_valid_check=False) # Total number of streamlines for each category # and statistic that are not "bundle-wise" tc_streamlines_count = len(list(itertools.chain(*tc_streamlines_list))) fc_streamlines_count = len(list(itertools.chain(*fc_streamlines_list))) if args.wrong_path_as_separate: wpc_streamlines_count = len( list(itertools.chain(*wpc_streamlines_list))) else: wpc_streamlines_count = 0 nc_streamlines_count = len(nc_streamlines) total_count = tc_streamlines_count + fc_streamlines_count + \ wpc_streamlines_count + nc_streamlines_count assert total_count == initial_count final_results["tractogram_filename"] = str(args.in_tractogram) final_results["tractogram_overlap"] = 0.0 final_results["tc_streamlines"] = tc_streamlines_count final_results["fc_streamlines"] = fc_streamlines_count final_results["nc_streamlines"] = nc_streamlines_count final_results["tc_bundle"] = len([x for x in tc_streamlines_list if x]) final_results["fc_bundle"] = len([x for x in fc_streamlines_list if x]) final_results["tc_streamlines_ratio"] = tc_streamlines_count / total_count final_results["fc_streamlines_ratio"] = fc_streamlines_count / total_count final_results["nc_streamlines_ratio"] = nc_streamlines_count / total_count if args.wrong_path_as_separate: final_results["wpc_streamlines"] = wpc_streamlines_count final_results["wpc_streamlines_ratio"] = \ wpc_streamlines_count / total_count final_results["wpc_bundle"] = len( [x for x in wpc_streamlines_list if x]) final_results["total_streamlines"] = total_count final_results["bundle_wise"] = {} final_results["bundle_wise"]["true_connections"] = {} final_results["bundle_wise"]["false_connections"] = {} tractogram_overlap = 0.0 for i, filename in enumerate(tc_filenames): current_tc_streamlines = tc_streamlines_list[i] current_tc_voxels, current_tc_endpoints_voxels = get_binary_maps( current_tc_streamlines, sft) if args.wrong_path_as_separate: current_wpc_streamlines = wpc_streamlines_list[i] current_wpc_voxels, _ = get_binary_maps(current_wpc_streamlines, sft) tmp_dict = {} tmp_dict["tc_streamlines"] = len(current_tc_streamlines) tmp_dict["tc_dice"] = compute_dice_voxel(gt_bundle_masks[i], current_tc_voxels)[0] bundle_overlap = gt_bundle_masks[i] * current_tc_voxels bundle_overreach = np.zeros(dimensions) bundle_overreach[np.where((gt_bundle_masks[i] == 0) & (current_tc_voxels >= 1))] = 1 bundle_lacking = np.zeros(dimensions) bundle_lacking[np.where((gt_bundle_masks[i] == 1) & (current_tc_voxels == 0))] = 1 if args.wrong_path_as_separate: tmp_dict["wpc_streamlines"] = len(current_wpc_streamlines) tmp_dict["wpc_dice"] = \ compute_dice_voxel(gt_bundle_masks[i], current_wpc_voxels)[0] # Add wrong path to overreach bundle_overreach[np.where((gt_bundle_masks[i] == 0) & (current_wpc_voxels >= 1))] = 1 tmp_dict["tc_bundle_overlap"] = np.count_nonzero(bundle_overlap) tmp_dict["tc_bundle_overreach"] = \ np.count_nonzero(bundle_overreach) tmp_dict["tc_bundle_lacking"] = np.count_nonzero(bundle_lacking) tmp_dict["tc_bundle_overlap_PCT"] = \ tmp_dict["tc_bundle_overlap"] / \ (tmp_dict["tc_bundle_overlap"] + tmp_dict["tc_bundle_lacking"]) tractogram_overlap += tmp_dict["tc_bundle_overlap_PCT"] endpoints_overlap = \ gt_bundle_masks[i] * current_tc_endpoints_voxels endpoints_overreach = np.zeros(dimensions) endpoints_overreach[np.where((gt_bundle_masks[i] == 0) & (current_tc_endpoints_voxels >= 1))] = 1 tmp_dict["tc_endpoints_overlap"] = np.count_nonzero(endpoints_overlap) tmp_dict["tc_endpoints_overreach"] = np.count_nonzero( endpoints_overreach) final_results["bundle_wise"]["true_connections"][str(filename)] = \ tmp_dict # Bundle-wise statistics, useful for more complex phantom for i, filename in enumerate(comb_filename): current_fc_streamlines = fc_streamlines_list[i] current_fc_voxels, _ = get_binary_maps(current_fc_streamlines, sft) tmp_dict = {} if len(current_fc_streamlines): tmp_dict["fc_streamlines"] = len(current_fc_streamlines) tmp_dict["fc_voxels"] = np.count_nonzero(current_fc_voxels) final_results["bundle_wise"]["false_connections"][str(filename)] =\ tmp_dict final_results["tractogram_overlap"] = \ tractogram_overlap / len(gt_bundle_masks) with open(os.path.join(args.out_dir, "results.json"), "w") as f: json.dump(final_results, f, indent=args.indent, sort_keys=args.sort_keys)
def main(): # Callback required for FURY def keypress_callback(obj, _): key = obj.GetKeySym().lower() nonlocal clusters_linewidth, background_linewidth nonlocal curr_streamlines_actor, concat_streamlines_actor, show_curr_actor iterator = len(accepted_streamlines) + len(rejected_streamlines) renwin = interactor_style.GetInteractor().GetRenderWindow() renderer = interactor_style.GetCurrentRenderer() if key == 'c' and iterator < len(sft_accepted_on_size): if show_curr_actor: renderer.rm(concat_streamlines_actor) renwin.Render() show_curr_actor = False logging.info('Streamlines rendering OFF') else: renderer.add(concat_streamlines_actor) renderer.rm(curr_streamlines_actor) renderer.add(curr_streamlines_actor) renwin.Render() show_curr_actor = True logging.info('Streamlines rendering ON') return if key == 'q': show_manager.exit() if iterator < len(sft_accepted_on_size): logging.warning( 'Early exit, everything remaining to be rejected.') return if key in ['a', 'r'] and iterator < len(sft_accepted_on_size): if key == 'a': accepted_streamlines.append(iterator) choices.append('a') logging.info('Accepted file %s', filename_accepted_on_size[iterator]) elif key == 'r': rejected_streamlines.append(iterator) choices.append('r') logging.info('Rejected file %s', filename_accepted_on_size[iterator]) iterator += 1 if key == 'z': if iterator > 0: last_choice = choices.pop() if last_choice == 'r': rejected_streamlines.pop() else: accepted_streamlines.pop() logging.info('Rewind on step.') iterator -= 1 else: logging.warning('Cannot rewind, first element.') if key in ['a', 'r', 'z'] and iterator < len(sft_accepted_on_size): renderer.rm(curr_streamlines_actor) curr_streamlines = sft_accepted_on_size[iterator].streamlines curr_streamlines_actor = actor.line(curr_streamlines, opacity=0.8, linewidth=clusters_linewidth) renderer.add(curr_streamlines_actor) if iterator == len(sft_accepted_on_size): print('No more cluster, press q to exit') renderer.rm(curr_streamlines_actor) renwin.Render() parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_bundles) assert_outputs_exist(parser, args, [args.out_accepted, args.out_rejected]) if args.out_accepted_dir: assert_output_dirs_exist_and_empty(parser, args, args.out_accepted_dir, create_dir=True) if args.out_rejected_dir: assert_output_dirs_exist_and_empty(parser, args, args.out_rejected_dir, create_dir=True) if args.verbose: logging.basicConfig(level=logging.INFO) if args.min_cluster_size < 1: parser.error('Minimum cluster size must be at least 1.') clusters_linewidth = args.clusters_linewidth background_linewidth = args.background_linewidth # To accelerate procedure, clusters can be discarded based on size # Concatenation is to give spatial context sft_accepted_on_size, filename_accepted_on_size = [], [] sft_rejected_on_size, filename_rejected_on_size = [], [] concat_streamlines = [] for filename in args.in_bundles: if not is_header_compatible(args.in_bundles[0], filename): return basename = os.path.basename(filename) sft = load_tractogram_with_reference(parser, args, filename, bbox_check=False) if len(sft) >= args.min_cluster_size: sft_accepted_on_size.append(sft) filename_accepted_on_size.append(basename) concat_streamlines.extend(sft.streamlines) else: logging.info('File %s has %s streamlines, automatically rejected.', filename, len(sft)) sft_rejected_on_size.append(sft) filename_rejected_on_size.append(basename) if not filename_accepted_on_size: parser.error('No cluster survived the cluster_size threshold.') logging.info('%s clusters to be classified.', len(sft_accepted_on_size)) # The clusters are sorted by size for simplicity/efficiency tuple_accepted = zip( *sorted(zip(sft_accepted_on_size, filename_accepted_on_size), key=lambda x: len(x[0]), reverse=True)) sft_accepted_on_size, filename_accepted_on_size = tuple_accepted # Initialize the actors, scene, window, observer concat_streamlines_actor = actor.line(concat_streamlines, colors=(1, 1, 1), opacity=args.background_opacity, linewidth=background_linewidth) curr_streamlines_actor = actor.line(sft_accepted_on_size[0].streamlines, opacity=0.8, linewidth=clusters_linewidth) scene = window.Scene() interactor_style = interactor.CustomInteractorStyle() show_manager = window.ShowManager(scene, size=(800, 800), reset_camera=False, interactor_style=interactor_style) scene.add(concat_streamlines_actor) scene.add(curr_streamlines_actor) interactor_style.AddObserver('KeyPressEvent', keypress_callback) # Lauch rendering and selection procedure choices, accepted_streamlines, rejected_streamlines = [], [], [] show_curr_actor = True show_manager.start() # Early exit means everything else is rejected missing = len(args.in_bundles) - len(choices) - len(sft_rejected_on_size) len_accepted = len(sft_accepted_on_size) rejected_streamlines.extend(range(len_accepted - missing, len_accepted)) if missing > 0: logging.info('%s clusters automatically rejected from early exit', missing) # Save accepted clusters (by GUI) accepted_streamlines = save_clusters(sft_accepted_on_size, accepted_streamlines, args.out_accepted_dir, filename_accepted_on_size) accepted_sft = StatefulTractogram(accepted_streamlines, sft_accepted_on_size[0], Space.RASMM) save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False) # Save rejected clusters (by GUI) rejected_streamlines = save_clusters(sft_accepted_on_size, rejected_streamlines, args.out_rejected_dir, filename_accepted_on_size) # Save rejected clusters (by size) rejected_streamlines.extend( save_clusters(sft_rejected_on_size, range(len(sft_rejected_on_size)), args.out_rejected_dir, filename_rejected_on_size)) rejected_sft = StatefulTractogram(rejected_streamlines, sft_accepted_on_size[0], Space.RASMM) save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.DEBUG) required_args = [args.in_json, args.in_participants] assert_inputs_exist(parser, required_args) req_folder = os.path.join(args.out_dir, 'Graph') assert_output_dirs_exist_and_empty(parser, args, req_folder) # We generated the stats object my_data = data_for_stat(args.in_json, args.in_participants) bundles = args.bundles metrics = args.metrics values = args.values if args.bundles == 'all': bundles = my_data.get_bundles_list() if args.metrics == 'all': metrics = my_data.get_metrics_list() if args.values == 'all': values = my_data.get_values_list() alpha_error = float(args.alpha_error) my_group_dict = my_data.get_groups_dictionnary(args.group_by) # Initialise the result dictionnary result_dict = {} # We do the comparison for every single combianaison of metric-bundle-value for b, m, v in product(bundles, metrics, values): # First we extract the basic information on that comparison curr_comparison_measure = ('_').join([b, m, v]) logging.debug('______________________') logging.debug('Measure to compare: {}'.format(curr_comparison_measure)) # Check normality of that metric across all groups current_normality = {} overall_normality = True groups_array = [] for group in my_group_dict: curr_sample = get_group_data_sample(my_group_dict, group, b, m, v) logging.debug('Group {}'.format(group)) current_normality[group] = verify_normality( curr_sample, alpha_error) if not current_normality[group][0]: overall_normality = False groups_array.append(curr_sample) logging.debug('Normality result:') logging.debug(current_normality) logging.debug('Overall Normality:') logging.debug(overall_normality) logging.debug('Groups array:') logging.debug(groups_array) # Generate graph of the metric if args.generate_graph: visualise_distribution(groups_array, my_data.get_participants_list(), b, m, v, args.out_dir, my_data.get_groups_list(args.group_by)) # Quit if we didnt separate by group if len(my_data.get_groups_list(args.group_by)) == 1: logging.error('There is only 1 group generated. ' 'We cannot continue the groups comparison') raise BaseException('Only 1 group generated from ' '{}'.format(args.group_by)) # Check homoscedasticity variance_equality = verify_homoscedasticity( groups_array, normality=overall_normality, alpha=alpha_error) logging.debug('Equality of variance result:') logging.debug(variance_equality) # Now we compare the groups population difference = verify_group_difference( groups_array, normality=overall_normality, homoscedasticity=variance_equality[1], alpha=alpha_error) logging.debug('Main test result:') logging.debug(difference) # Finally if we have more than 2 groups and found a difference # We do a post hoc analysis to explore where is this difference if difference[1] and difference[0] == 'ANOVA': diff_2_by_2 = verify_post_hoc(groups_array, my_data.get_groups_list( args.group_by), test='Student', alpha=alpha_error) elif difference[1] and difference[0] == 'Kruskalwallis': diff_2_by_2 = verify_post_hoc(groups_array, my_data.get_groups_list( args.group_by), test='Mannwhitneyu', alpha=alpha_error) elif difference[1] and difference[0] == 'Friedmann': diff_2_by_2 = verify_post_hoc(groups_array, my_data.get_groups_list( args.group_by), test='Wilcoxon', alpha=alpha_error) else: diff_2_by_2 = [] logging.debug('Summary of difference 2 by 2:') logging.debug(diff_2_by_2) # Write the current metric in the report curr_dict = write_current_dictionnary(curr_comparison_measure, current_normality, variance_equality, difference, diff_2_by_2) result_dict[curr_comparison_measure] = curr_dict # Saving the result dictionnary into a json file and csv if necessary if args.out_json: with open(os.path.join(args.out_dir, args.out_json), 'w') as outfile: json.dump(result_dict, outfile, indent=args.indent, sort_keys=args.sort_keys) else: print( json.dumps(result_dict, indent=args.indent, sort_keys=args.sort_keys))
def main(): parser = build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.tracks, args.labels]) if os.path.abspath(args.output) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.output) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) img_labels = nb.load(args.labels) if not np.issubdtype(img_labels.get_data_dtype().type, np.integer): parser.error("Label image should contain integers for labels.") # Ensure that voxel size is isotropic. Currently, for speed considerations, # we take the length in voxel space and multiply by the voxel size. For # this to work correctly, voxel size must be isotropic. vox_sizes = img_labels.header.get_zooms() if not np.mean(vox_sizes) == vox_sizes[0]: parser.error('Labels must be isotropic') if np.min(img_labels.get_data()) < 0 or \ np.max(img_labels.get_data()) > args.max_labels: parser.error('Invalid labels in labels image') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram) sft.to_vox() sft.to_corner() time2 = time.time() logging.info(' Number of streamlines to process: {}'.format( len(sft.streamlines))) logging.info(' Loading streamlines took %0.3f ms', (time2 - time1) * 1000.0) # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took %0.3f ms', (time2 - time1) * 1000.0) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, img_labels.get_data(), extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took %0.3f ms', (time2 - time1) * 1000.0) # Symmetrize matrix final_con_info = _symmetrize_con_info(con_info) # Prepare directories and information needed to save. saving_opts = _get_saving_options(args) out_paths = _get_output_paths(args.output) _create_required_output_dirs(out_paths, args) # Here, we use nb_labels + 1 since we want the direct mapping from image # label to matrix element. We will remove the first row and column before # saving. # TODO for other metrics # dtype should be adjusted depending on the type of elements # stored in the con_mat nb_labels = args.max_labels con_mat = np.zeros((nb_labels + 1, nb_labels + 1), dtype=np.uint32) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() for in_label in list(final_con_info.keys()): for out_label in list(final_con_info[in_label].keys()): pair_info = final_con_info[in_label][out_label] if not len(pair_info): continue final_strl = [] for connection in pair_info: strl_idx = connection['strl_idx'] final_strl.append( compute_streamline_segment(sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx])) _save_if_needed(final_strl, args, saving_opts, out_paths, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: pruned_strl, invalid_strl = _prune_segments( final_strl, args.min_length, args.max_length, vox_sizes[0]) _save_if_needed(invalid_strl, args, saving_opts, out_paths, 'discarded', 'removed_length', in_label, out_label) else: pruned_strl = final_strl if not len(pruned_strl): continue _save_if_needed(pruned_strl, args, saving_opts, out_paths, 'intermediate', 'pruned', in_label, out_label) if not args.no_remove_loops: no_loops, loops = remove_loops_and_sharp_turns( pruned_strl, args.loop_max_angle) _save_if_needed(loops, args, saving_opts, out_paths, 'discarded', 'loops', in_label, out_label) else: no_loops = pruned_strl if not len(no_loops): continue _save_if_needed(no_loops, args, saving_opts, out_paths, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: no_outliers, outliers = remove_outliers( no_loops, args.outlier_threshold) _save_if_needed(outliers, args, saving_opts, out_paths, 'discarded', 'outliers', in_label, out_label) else: no_outliers = no_loops if not len(no_outliers): continue _save_if_needed(no_outliers, args, saving_opts, out_paths, 'intermediate', 'no_outliers', in_label, out_label) if not args.no_remove_loops_again: no_qb_loops_strl, loops2 = remove_loops_and_sharp_turns( no_outliers, args.loop_max_angle, True, args.loop_qb_distance) _save_if_needed(loops2, args, saving_opts, out_paths, 'discarded', 'qb_loops', in_label, out_label) else: no_qb_loops_strl = no_outliers _save_if_needed(no_qb_loops_strl, args, saving_opts, out_paths, 'final', 'final', in_label, out_label) # TODO for other metrics # This would be where this is modified and the value # is computed (eg: mean FA in the connection. con_mat[in_label, out_label] += len(no_qb_loops_strl) time2 = time.time() logging.info(' Connection post-processing and saving took %0.3f ms', (time2 - time1) * 1000.0) # Remove first line and column, since they are index 0 and # would represent a connection to non-label voxels. Only used when # post-processing to avoid unnecessary -1 on labels for each access. con_mat = con_mat[1:, 1:] np.save(os.path.join(args.output, 'final_matrix.npy'), con_mat)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist( parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec], [args.in_peaks, args.in_tracking_mask]) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) if args.commit2: if os.path.splitext(args.in_tractogram)[1] != '.h5': parser.error('COMMIT2 requires .h5 file for connectomics.') args.ball_stick = True if args.load_kernels and not os.path.isdir(args.load_kernels): parser.error('Kernels directory does not exist.') if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') if args.load_kernels and args.save_kernels: parser.error('Cannot load and save kernels at the same time.') if args.ball_stick and args.perp_diff: parser.error('Cannot use --perp_diff with ball&stick.') if not args.ball_stick and not args.in_peaks: parser.error('Stick Zeppelin Ball model requires --in_peaks') if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1: parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.') # If it is a trk, check compatibility of header since COMMIT does not do it dwi_img = nib.load(args.in_dwi) _, ext = os.path.splitext(args.in_tractogram) if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() tmp_dir = tempfile.TemporaryDirectory() hdf5_file = None offsets_list = None if ext == '.h5': logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format( args.in_tractogram)) hdf5_file = h5py.File(args.in_tractogram, 'r') if not (np.allclose( hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dwi_img.shape[0:3])): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) # Keep track of the order of connections/streamlines in relation to the # tractogram as well as the number of streamlines for each connection. bundle_groups_len = [] hdf5_keys = list(hdf5_file.keys()) streamlines = [] for key in hdf5_keys: tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) streamlines.extend(tmp_streamlines) bundle_groups_len.append(len(tmp_streamlines)) offsets_list = np.cumsum([0] + bundle_groups_len) sft = StatefulTractogram(streamlines, args.in_dwi, Space.VOX, origin=Origin.TRACKVIS) tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk') # Keeping the input variable, saving trk file for COMMIT internal use save_tractogram(sft, tmp_tractogram_filename) args.in_tractogram = tmp_tractogram_filename # Writing the scheme file with proper shells tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Lauching COMMIT on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) if len(shells_centroids) == 2 and not args.ball_stick: parser.error('The DWI data appears to be single-shell.\n' 'Use --ball_stick for single-shell.') with redirected_stdout: # Setting up the tractogram and nifti files trk2dictionary.run(filename_tractogram=args.in_tractogram, filename_peaks=args.in_peaks, peaks_use_affine=False, filename_mask=args.in_tracking_mask, ndirs=args.nbr_dir, path_out=tmp_dir.name) # Preparation for fitting commit.core.setup(ndirs=args.nbr_dir) mit = commit.Evaluation('.', '.') # FIX for very small values during HCP processing # (based on order of magnitude of signal) img = nib.load(args.in_dwi) data = img.get_fdata(dtype=np.float32) data[data < (0.001 * 10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0 nib.save(nib.Nifti1Image(data, img.affine), os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz')) mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'), tmp_scheme_filename) mit.set_model('StickZeppelinBall') if args.ball_stick: logging.debug('Disabled zeppelin, using the Ball & Stick model.') para_diff = args.para_diff or 1.7E-3 perp_diff = [] isotropc_diff = args.iso_diff or [2.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) else: logging.debug('Using the Stick Zeppelin Ball model.') para_diff = args.para_diff or 1.7E-3 perp_diff = args.perp_diff or [0.85E-3, 0.51E-3] isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3] mit.model.set(para_diff, perp_diff, isotropc_diff) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id) regenerate_kernels = True mit.set_config('ATOMS_path', kernels_dir) mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels) if args.compute_only: return mit.load_kernels() use_mask = args.in_tracking_mask is not None mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask) mit.set_threads(args.nbr_processes) mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/')) tol_fun = 1e-2 if args.commit2 else 1e-3 mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_1/', False) if args.commit2: tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0) group_idx = np.array( [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)]) group_w = np.empty_like(bundle_groups_len, dtype=np.float64) for k in range(len(bundle_groups_len)): group_w[k] = np.sqrt(bundle_groups_len[k]) / \ (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12) prior_on_bundles = commit.solvers.init_regularisation( mit, structureIC=group_idx, weightsIC=group_w, regnorms=[ commit.solvers.group_sparsity, commit.solvers.non_negative, commit.solvers.non_negative ], lambdas=[args.lambda_commit_2, 0.0, 0.0]) mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, regularisation=prior_on_bundles, verbose=False) mit.save_results() _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, 'commit_2/', True) tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_inputs_exist(parser, args.in_wmparc) assert_output_dirs_exist_and_empty(parser, args, args.out_path, create_dir=True) if args.verbose: logging.basicConfig(level=logging.DEBUG) if args.angle <= 0: parser.error('Angle "{}" '.format(args.angle) + 'must be greater than or equal to 0') if args.ctx_dilation_radius < 0: parser.error( 'Cortex dilation radius "{}" '.format(args.ctx_dilation_radius) + 'must be greater than 0') sft = load_tractogram_with_reference(parser, args, args.in_tractogram) img_wmparc = nib.load(args.in_wmparc) if not is_header_compatible(img_wmparc, sft): parser.error('Headers from the tractogram and the wmparc are ' 'not compatible.') if args.csf_bin: img_csf = nib.load(args.csf_bin) if not is_header_compatible(img_csf, sft): parser.error('Headers from the tractogram and the CSF mask are ' 'not compatible.') if args.minL == 0 and np.isinf(args.maxL): logging.debug("You have not specified minL nor maxL. Output will " "not be filtered according to length!") if np.isinf(args.angle): logging.debug("You have not specified the angle. Loops will " "not be filtered!") if args.ctx_dilation_radius == 0: logging.debug("You have not specified the cortex dilation radius. " "The wmparc atlas will not be dilated!") o_dict = {} step_dict = ['length', 'no_loops', 'no_end_csf', 'end_in_atlas'] wm_labels = load_wmparc_labels() in_sft_name = os.path.splitext(os.path.basename(args.in_tractogram))[0] out_sft_rootname = in_sft_name + "_filtered" _, ext = os.path.splitext(args.in_tractogram) out_sft_name = os.path.join(args.out_path, out_sft_rootname + "_filtered" + ext) # STEP 1 - Filter length step = step_dict[0] steps_combined = step new_sft = filter_streamlines_by_length(sft, args.minL, args.maxL) # Streamline count before and after filtering lengths o_dict[in_sft_name + ext] =\ dict({'streamline_count': len(sft.streamlines)}) o_dict[in_sft_name + '_' + steps_combined + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') save_tractogram(new_sft, out_sft_name) if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return sft = new_sft # STEP 2 - Filter loops step = step_dict[1] steps_combined += "_" + step ids_c = remove_loops_and_sharp_turns(sft.streamlines, args.angle) new_sft = filter_tractogram_data(sft, ids_c) # Streamline count after filtering loops o_dict[in_sft_name + '_' + steps_combined + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') save_tractogram(new_sft, out_sft_name) if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return sft = new_sft # STEP 3 - Filter CSF endings step = step_dict[2] steps_combined += "_" + step # Mask creation if args.csf_bin: mask = get_data_as_mask(img_csf) else: atlas = get_data_as_label(img_wmparc) mask = binarize_labels(atlas, wm_labels["csf_labels"]) # Filter tractogram new_sft, _ = filter_grid_roi(sft, mask, 'any', True) # Streamline count after filtering CSF endings o_dict[in_sft_name + '_' + steps_combined + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_volumes: new_path = create_dir(args.out_path, step) if not args.csf_bin: nib.save( nib.Nifti1Image(mask, img_wmparc.affine, img_wmparc.header), os.path.join(new_path, 'csf_bin' + '.nii.gz')) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') save_tractogram(new_sft, out_sft_name) if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return sft = new_sft # STEP 4 - Filter WM endings step = step_dict[3] steps_combined += "_" + step # Mask creation ctx_fs_labels = wm_labels["ctx_lh_fs_labels"] + \ wm_labels["ctx_rh_fs_labels"] vox_size = np.reshape(img_wmparc.header.get_zooms(), (1, 3)) atlas_wm = get_data_as_label(img_wmparc) atlas_shape = atlas_wm.shape wmparc_ctx = binarize_labels(atlas_wm, ctx_fs_labels) wmparc_nuclei = binarize_labels(atlas_wm, wm_labels["nuclei_fs_labels"]) # Dilation of cortex if args.ctx_dilation_radius: ctx_mask = dilate_mask(wmparc_ctx, atlas_shape, vox_size, args.ctx_dilation_radius) else: ctx_mask = wmparc_ctx freesurfer_mask = np.zeros(atlas_shape, dtype=np.uint16) freesurfer_mask[np.logical_or(wmparc_nuclei, ctx_mask)] = 1 # Filter tractogram new_sft, _ = filter_grid_roi(sft, freesurfer_mask, 'both_ends', False) # Streamline count after final filtering o_dict[out_sft_rootname + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_volumes: new_path = create_dir(args.out_path, step) nib.save( nib.Nifti1Image(freesurfer_mask, img_wmparc.affine, img_wmparc.header), os.path.join(new_path, 'atlas_bin' + '.nii.gz')) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) # Finish filtering if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') sft = new_sft save_tractogram(sft, out_sft_name)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_labels], args.reference) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) if (args.save_raw_connections or args.save_intermediate or args.save_discarded) and not args.out_dir: parser.error('To save outputs in the streamlines form, provide the ' 'output directory using --out_dir.') if args.out_dir: if os.path.abspath(args.out_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.in_labels) data_labels = get_data_as_label(img_labels) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03): parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) sft.remove_invalid_streamlines() time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) if not is_header_compatible(sft, img_labels): raise IOError('{} and {}do not have a compatible header'.format( args.in_tractogram, args.in_labels)) sft.to_vox() sft.to_corner() # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 with h5py.File(args.out_hdf5, 'w') as hdf5_file: affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft) hdf5_file.attrs['affine'] = affine hdf5_file.attrs['dimensions'] = dimensions hdf5_file.attrs['voxel_sizes'] = voxel_sizes hdf5_file.attrs['voxel_order'] = voxel_order # Each connections is processed independently. Multiprocessing would be # a burden on the I/O of most SSD/HD for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format( iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] connecting_ids = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) connecting_ids.append(strl_idx) # Each step is processed from the previous 'success' # 1. raw -> length pass/fail # 2. length pass -> loops pass/fail # 3. loops pass -> outlier detection pass/fail # 4. outlier detection pass -> qb curvature pass/fail # 5. qb curvature pass == final connections connecting_streamlines = ArraySequence(connecting_streamlines) raw_dps = sft.data_per_streamline[connecting_ids] raw_sft = StatefulTractogram.from_sft(connecting_streamlines, sft, data_per_streamline=raw_dps, data_per_point={}) _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length_ids, invalid_length_ids = _prune_segments( raw_sft.streamlines, args.min_length, args.max_length, vox_sizes[0]) invalid_length_sft = raw_sft[invalid_length_ids] valid_length = connecting_streamlines[valid_length_ids] _save_if_needed(invalid_length_sft, hdf5_file, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines valid_length_ids = range(len(connecting_streamlines)) if not len(valid_length): continue valid_length_sft = raw_sft[valid_length_ids] _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns( valid_length, args.loop_max_angle) loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops_sft = valid_length_sft[loop_ids] no_loops = valid_length[no_loop_ids] _save_if_needed(loops_sft, hdf5_file, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length no_loop_ids = range(len(valid_length)) if not len(no_loops): continue no_loops_sft = valid_length_sft[no_loop_ids] _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: outliers_ids, inliers_ids = remove_outliers( no_loops, args.outlier_threshold, nb_samplings=10, fast_approx=True) outliers_sft = no_loops_sft[outliers_ids] inliers = no_loops[inliers_ids] _save_if_needed(outliers_sft, hdf5_file, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops inliers_ids = range(len(no_loops)) if not len(inliers): continue inliers_sft = no_loops_sft[inliers_ids] _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) qb_curv_ids = np.setdiff1d(np.arange(len(inliers)), no_qb_curv_ids) qb_curv_sft = inliers_sft[qb_curv_ids] _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv_ids = range(len(inliers)) no_qb_curv_sft = inliers_sft[no_qb_curv_ids] _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.compute_only and not args.save_kernels: parser.error('--compute_only must be used with --save_kernels.') assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], args.mask) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, optional=args.save_kernels) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() # Generage a scheme file from the bvals and bvecs files tmp_dir = tempfile.TemporaryDirectory() tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug( 'Compute FreeWater with AMICO on {} shells at found at {}.'.format( len(shells_centroids), shells_centroids)) with redirected_stdout: amico.core.setup() # Load the data ae = amico.Evaluation('.', '.') # Load the data ae.load_data(args.in_dwi, scheme_filename=tmp_scheme_filename, mask_filename=args.mask) # Compute the response functions ae.set_model("FreeWater") model_type = 'Human' if args.mouse: model_type = 'Mouse' ae.model.set(args.para_diff, np.linspace(args.perp_diff_min, args.perp_diff_max, 10), [args.iso_diff], model_type) ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id) regenerate_kernels = True ae.set_config('ATOMS_path', kernels_dir) ae.set_config('OUTPUT_path', args.out_dir) ae.generate_kernels(regenerate=regenerate_kernels) if args.compute_only: return ae.load_kernels() # Set number of processes solver_params = ae.get_config('solver_params') solver_params['numThreads'] = args.nbr_processes ae.set_config('solver_params', solver_params) ae.set_config('doNormalizeSignal', True) ae.set_config('doKeepb0Intact', False) ae.set_config('doComputeNRMSE', True) ae.set_config('doSaveCorrectedDWI', True) # Model fit ae.fit() # Save the results ae.save_results() tmp_dir.cleanup()
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_output_dirs_exist_and_empty(parser, args, os.path.join(args.out_dir, 'Contrasts_MT_maps'), os.path.join(args.out_dir, 'MT_native_maps'), create_dir=True) # Merge all echos path into a list maps = [args.in_mtoff, args.in_mton, args.in_t1w] maps_flat = (args.in_mtoff + args.in_mton + args.in_t1w) jsons = [curr_map.replace('.nii.gz', '.json') for curr_map in maps_flat] # check data assert_inputs_exist(parser, jsons + maps_flat) for curr_map in maps[1:]: if len(curr_map) != len(maps[0]): parser.error('Not the same number of echoes per contrast') # Set TR and FlipAngle parameters for MT (mtoff contrast) # and T1w images parameters = [ set_acq_parameters(maps[0][0].replace('.nii.gz', '.json')), set_acq_parameters(maps[2][0].replace('.nii.gz', '.json')) ] # Fix issue from the presence of invalide value and division by zero np.seterr(divide='ignore', invalid='ignore') # Define reference image for saving maps ref_img = nib.load(maps[0][0]) # Define contrasts maps names contrasts_name = ['mt_off', 'mt_on', 'T1w'] if args.out_prefix: contrasts_name = [ args.out_prefix + '_' + curr_name for curr_name in contrasts_name ] # Compute contrasts maps computed_contrasts = [] for idx, curr_map in enumerate(maps): computed_contrasts.append(compute_contrasts_maps(curr_map)) nib.save( nib.Nifti1Image(computed_contrasts[idx].astype(np.float32), ref_img.affine), os.path.join(args.out_dir, 'Contrasts_MT_maps', contrasts_name[idx] + '.nii.gz')) # Compute and thresold MT maps MTR, MTsat = compute_MT_maps(computed_contrasts, parameters) for curr_map in MTR, MTsat: curr_map = threshold_MT_maps(curr_map, args.in_mask, 0, 100) if args.in_B1_map: curr_map = apply_B1_correction(curr_map, args.in_B1_map) # Save MT maps img_name = ['MTR', 'MTsat'] if args.in_B1_map: img_name = [curr_name + '_B1_corrected' for curr_name in img_name] if args.out_prefix: img_name = [ args.out_prefix + '_' + curr_name for curr_name in img_name ] img_data = MTR, MTsat for img_to_save, name in zip(img_data, img_name): nib.save( nib.Nifti1Image(img_to_save.astype(np.float32), ref_img.affine), os.path.join(args.out_dir, 'MT_native_maps', name + '.nii.gz'))
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_output_dirs_exist_and_empty(parser, args, os.path.join(args.out_dir, 'Contrasts_ihMT_maps'), os.path.join(args.out_dir, 'ihMT_native_maps'), create_dir=True) # Merge all echos path into a list maps = [ args.in_altnp, args.in_altpn, args.in_mtoff, args.in_negative, args.in_positive, args.in_t1w ] maps_flat = (args.in_altnp + args.in_altpn + args.in_mtoff + args.in_negative + args.in_positive + args.in_t1w) jsons = [curr_map.replace('.nii.gz', '.json') for curr_map in maps_flat] # check echoes number and jsons assert_inputs_exist(parser, jsons + maps_flat) for curr_map in maps[1:]: if len(curr_map) != len(maps[0]): parser.error('Not the same number of echoes per contrast') # Set TR and FlipAngle parameters for ihMT (positive contrast) # and T1w images parameters = [ set_acq_parameters(maps[4][0].replace('.nii.gz', '.json')), set_acq_parameters(maps[5][0].replace('.nii.gz', '.json')) ] # Fix issue from the presence of invalide value and division by zero np.seterr(divide='ignore', invalid='ignore') # Define reference image for saving maps ref_img = nib.load(maps[4][0]) # Define contrasts maps names contrasts_name = [ 'altnp', 'altpn', 'reference', 'negative', 'positive', 'T1w' ] if args.filtering: contrasts_name = [ curr_name + '_filter' for curr_name in contrasts_name ] if args.out_prefix: contrasts_name = [ args.out_prefix + '_' + curr_name for curr_name in contrasts_name ] # Compute contrasts maps computed_contrasts = [] for idx, curr_map in enumerate(maps): computed_contrasts.append( compute_contrasts_maps(curr_map, filtering=args.filtering)) nib.save( nib.Nifti1Image(computed_contrasts[idx].astype(np.float32), ref_img.affine), os.path.join(args.out_dir, 'Contrasts_ihMT_maps', contrasts_name[idx] + '.nii.gz')) # Compute and thresold ihMT maps ihMTR, ihMTsat = compute_ihMT_maps(computed_contrasts, parameters) ihMTR = threshold_ihMT_maps(ihMTR, computed_contrasts, args.in_mask, 0, 100, [4, 3, 1, 0, 2]) ihMTsat = threshold_ihMT_maps(ihMTsat, computed_contrasts, args.in_mask, 0, 10, [4, 3, 1, 0]) # Compute and thresold non-ihMT maps MTR, MTsat = compute_MT_maps(computed_contrasts, parameters) for curr_map in MTR, MTsat: curr_map = threshold_ihMT_maps(curr_map, computed_contrasts, args.in_mask, 0, 100, [4, 2]) # Save ihMT and MT images img_name = ['ihMTR', 'ihMTsat', 'MTR', 'MTsat'] if args.filtering: img_name = [curr_name + '_filter' for curr_name in img_name] if args.out_prefix: img_name = [ args.out_prefix + '_' + curr_name for curr_name in img_name ] img_data = ihMTR, ihMTsat, MTR, MTsat for img_to_save, name in zip(img_data, img_name): nib.save( nib.Nifti1Image(img_to_save.astype(np.float32), ref_img.affine), os.path.join(args.out_dir, 'ihMT_native_maps', name + '.nii.gz'))
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], args.in_mask) assert_output_dirs_exist_and_empty(parser, args, os.path.join(args.out_dir, 'NODDI'), optional=args.save_kernels) # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: logging.basicConfig(level=logging.DEBUG) redirected_stdout = redirect_stdout(sys.stdout) else: f = io.StringIO() redirected_stdout = redirect_stdout(f) redirect_stdout_c() # Generage a scheme file from the bvals and bvecs files tmp_dir = tempfile.TemporaryDirectory() tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) shells_centroids, indices_shells = identify_shells(bvals, args.b_thr, roundCentroids=True) np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) logging.debug('Compute NODDI with AMICO on {} shells at found ' 'at {}.'.format(len(shells_centroids), shells_centroids)) with redirected_stdout: # Load the data amico.core.setup() ae = amico.Evaluation('.', '.') ae.load_data(args.in_dwi, tmp_scheme_filename, mask_filename=args.in_mask) # Compute the response functions ae.set_model("NODDI") intra_vol_frac = np.linspace(0.1, 0.99, 12) intra_orient_distr = np.hstack((np.array([0.03, 0.06]), np.linspace(0.09, 0.99, 10))) ae.model.set(args.para_diff, args.iso_diff, intra_vol_frac, intra_orient_distr, False) ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2) # The kernels are, by default, set to be in the current directory # Depending on the choice, manually change the saving location if args.save_kernels: kernels_dir = os.path.join(args.save_kernels) regenerate_kernels = True elif args.load_kernels: kernels_dir = os.path.join(args.load_kernels) regenerate_kernels = False else: kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id) regenerate_kernels = True ae.set_config('ATOMS_path', kernels_dir) out_model_dir = os.path.join(args.out_dir, ae.model.id) ae.set_config('OUTPUT_path', out_model_dir) ae.generate_kernels(regenerate=regenerate_kernels) ae.load_kernels() # Set number of processes solver_params = ae.get_config('solver_params') solver_params['numThreads'] = args.nbr_processes ae.set_config('solver_params', solver_params) # Model fit ae.fit() # Save the results ae.save_results() tmp_dir.cleanup()
def main(): parser = build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.labels]) if os.path.abspath(args.output_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.output_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.labels) data_labels = img_labels.get_fdata().astype(np.int16) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') if not np.issubdtype(img_labels.get_data_dtype().type, np.integer): parser.error("Label image should contain integers for labels.") # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.mean(vox_sizes) == vox_sizes[0]: parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram) time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) logging.info('*** Filtering streamlines ***') data_mask = np.zeros(data_labels.shape) data_mask[data_labels > 0] = 1 original_len = len(sft) time1 = time.time() sft.to_vox() sft.to_corner() sft.remove_invalid_streamlines() time2 = time.time() logging.info( ' Discarded {} streamlines from filtering in {} sec.'.format( original_len - len(sft), round(time2 - time1, 2))) logging.info(' Number of streamlines to process: {}'.format(len(sft))) # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format(iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) _save_if_needed(connecting_streamlines, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length, invalid_length = _prune_segments( connecting_streamlines, args.min_length, args.max_length, vox_sizes[0]) _save_if_needed(invalid_length, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines if not len(valid_length): continue _save_if_needed(valid_length, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns(valid_length, args.loop_max_angle) no_loops = [valid_length[i] for i in no_loop_ids] loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops = [valid_length[i] for i in loop_ids] _save_if_needed(loops, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length if not len(no_loops): continue _save_if_needed(no_loops, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: inliers, outliers = remove_outliers(no_loops, args.outlier_threshold) _save_if_needed(outliers, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops if not len(inliers): continue _save_if_needed(inliers, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) no_qb_curv = [inliers[i] for i in no_qb_curv_ids] qb_curv_ids = np.setdiff1d( np.arange(len(inliers)), no_qb_curv_ids) qb_curv = [inliers[i] for i in qb_curv_ids] _save_if_needed(qb_curv, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv = inliers _save_if_needed(no_qb_curv, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))