def cut_streamlines(streamlines, roi_anat_1, roi_anat_2): roi_img_1 = nb.load(roi_anat_1) affine1 = roi_img_1.get_affine() roi_data_1 = roi_img_1.get_data() non_zero_1 = np.transpose(np.nonzero(roi_data_1)) non_zero_1_set = set(map(tuple, non_zero_1)) roi_img_2 = nb.load(roi_anat_2) affine2 = roi_img_2.get_affine() if not np.allclose(affine1, affine2): raise ValueError("The affines of both ROIs do not match.") roi_data_2 = roi_img_2.get_data() non_zero_2 = np.transpose(np.nonzero(roi_data_2)) non_zero_2_set = set(map(tuple, non_zero_2)) overlap = non_zero_1_set & non_zero_2_set if len(overlap) > 0: logging.warning('Parts of the ROIs may overlap.\n' + 'Behavior might be unexpected.') final_streamlines = [] (indices, points_to_idx) = uncompress(streamlines, return_mapping=True) for strl_idx, strl in enumerate(streamlines): logging.debug("Starting streamline") strl_indices = indices[strl_idx] logging.debug(strl_indices) in_strl_idx, out_strl_idx = intersects_two_rois( roi_data_1, roi_data_2, strl_indices) if in_strl_idx is not None and out_strl_idx is not None: points_to_indices = points_to_idx[strl_idx] logging.debug(points_to_indices) final_streamlines.append( compute_streamline_segment(strl, strl_indices, in_strl_idx, out_strl_idx, points_to_indices)) return final_streamlines
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_labels], args.reference) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) if (args.save_raw_connections or args.save_intermediate or args.save_discarded) and not args.out_dir: parser.error('To save outputs in the streamlines form, provide the ' 'output directory using --out_dir.') if args.out_dir: if os.path.abspath(args.out_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.in_labels) data_labels = get_data_as_label(img_labels) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03): parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) sft.remove_invalid_streamlines() time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) if not is_header_compatible(sft, img_labels): raise IOError('{} and {}do not have a compatible header'.format( args.in_tractogram, args.in_labels)) sft.to_vox() sft.to_corner() # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 with h5py.File(args.out_hdf5, 'w') as hdf5_file: affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft) hdf5_file.attrs['affine'] = affine hdf5_file.attrs['dimensions'] = dimensions hdf5_file.attrs['voxel_sizes'] = voxel_sizes hdf5_file.attrs['voxel_order'] = voxel_order # Each connections is processed independently. Multiprocessing would be # a burden on the I/O of most SSD/HD for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format( iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] connecting_ids = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) connecting_ids.append(strl_idx) # Each step is processed from the previous 'success' # 1. raw -> length pass/fail # 2. length pass -> loops pass/fail # 3. loops pass -> outlier detection pass/fail # 4. outlier detection pass -> qb curvature pass/fail # 5. qb curvature pass == final connections connecting_streamlines = ArraySequence(connecting_streamlines) raw_dps = sft.data_per_streamline[connecting_ids] raw_sft = StatefulTractogram.from_sft(connecting_streamlines, sft, data_per_streamline=raw_dps, data_per_point={}) _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length_ids, invalid_length_ids = _prune_segments( raw_sft.streamlines, args.min_length, args.max_length, vox_sizes[0]) invalid_length_sft = raw_sft[invalid_length_ids] valid_length = connecting_streamlines[valid_length_ids] _save_if_needed(invalid_length_sft, hdf5_file, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines valid_length_ids = range(len(connecting_streamlines)) if not len(valid_length): continue valid_length_sft = raw_sft[valid_length_ids] _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns( valid_length, args.loop_max_angle) loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops_sft = valid_length_sft[loop_ids] no_loops = valid_length[no_loop_ids] _save_if_needed(loops_sft, hdf5_file, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length no_loop_ids = range(len(valid_length)) if not len(no_loops): continue no_loops_sft = valid_length_sft[no_loop_ids] _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: outliers_ids, inliers_ids = remove_outliers( no_loops, args.outlier_threshold, nb_samplings=10, fast_approx=True) outliers_sft = no_loops_sft[outliers_ids] inliers = no_loops[inliers_ids] _save_if_needed(outliers_sft, hdf5_file, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops inliers_ids = range(len(no_loops)) if not len(inliers): continue inliers_sft = no_loops_sft[inliers_ids] _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) qb_curv_ids = np.setdiff1d(np.arange(len(inliers)), no_qb_curv_ids) qb_curv_sft = inliers_sft[qb_curv_ids] _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv_ids = range(len(inliers)) no_qb_curv_sft = inliers_sft[no_qb_curv_ids] _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))
def main(): parser = build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.tracks, args.labels]) if os.path.abspath(args.output) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.output) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) img_labels = nb.load(args.labels) if not np.issubdtype(img_labels.get_data_dtype().type, np.integer): parser.error("Label image should contain integers for labels.") # Ensure that voxel size is isotropic. Currently, for speed considerations, # we take the length in voxel space and multiply by the voxel size. For # this to work correctly, voxel size must be isotropic. vox_sizes = img_labels.header.get_zooms() if not np.mean(vox_sizes) == vox_sizes[0]: parser.error('Labels must be isotropic') if np.min(img_labels.get_data()) < 0 or \ np.max(img_labels.get_data()) > args.max_labels: parser.error('Invalid labels in labels image') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram) sft.to_vox() sft.to_corner() time2 = time.time() logging.info(' Number of streamlines to process: {}'.format( len(sft.streamlines))) logging.info(' Loading streamlines took %0.3f ms', (time2 - time1) * 1000.0) # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took %0.3f ms', (time2 - time1) * 1000.0) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, img_labels.get_data(), extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took %0.3f ms', (time2 - time1) * 1000.0) # Symmetrize matrix final_con_info = _symmetrize_con_info(con_info) # Prepare directories and information needed to save. saving_opts = _get_saving_options(args) out_paths = _get_output_paths(args.output) _create_required_output_dirs(out_paths, args) # Here, we use nb_labels + 1 since we want the direct mapping from image # label to matrix element. We will remove the first row and column before # saving. # TODO for other metrics # dtype should be adjusted depending on the type of elements # stored in the con_mat nb_labels = args.max_labels con_mat = np.zeros((nb_labels + 1, nb_labels + 1), dtype=np.uint32) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() for in_label in list(final_con_info.keys()): for out_label in list(final_con_info[in_label].keys()): pair_info = final_con_info[in_label][out_label] if not len(pair_info): continue final_strl = [] for connection in pair_info: strl_idx = connection['strl_idx'] final_strl.append( compute_streamline_segment(sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx])) _save_if_needed(final_strl, args, saving_opts, out_paths, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: pruned_strl, invalid_strl = _prune_segments( final_strl, args.min_length, args.max_length, vox_sizes[0]) _save_if_needed(invalid_strl, args, saving_opts, out_paths, 'discarded', 'removed_length', in_label, out_label) else: pruned_strl = final_strl if not len(pruned_strl): continue _save_if_needed(pruned_strl, args, saving_opts, out_paths, 'intermediate', 'pruned', in_label, out_label) if not args.no_remove_loops: no_loops, loops = remove_loops_and_sharp_turns( pruned_strl, args.loop_max_angle) _save_if_needed(loops, args, saving_opts, out_paths, 'discarded', 'loops', in_label, out_label) else: no_loops = pruned_strl if not len(no_loops): continue _save_if_needed(no_loops, args, saving_opts, out_paths, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: no_outliers, outliers = remove_outliers( no_loops, args.outlier_threshold) _save_if_needed(outliers, args, saving_opts, out_paths, 'discarded', 'outliers', in_label, out_label) else: no_outliers = no_loops if not len(no_outliers): continue _save_if_needed(no_outliers, args, saving_opts, out_paths, 'intermediate', 'no_outliers', in_label, out_label) if not args.no_remove_loops_again: no_qb_loops_strl, loops2 = remove_loops_and_sharp_turns( no_outliers, args.loop_max_angle, True, args.loop_qb_distance) _save_if_needed(loops2, args, saving_opts, out_paths, 'discarded', 'qb_loops', in_label, out_label) else: no_qb_loops_strl = no_outliers _save_if_needed(no_qb_loops_strl, args, saving_opts, out_paths, 'final', 'final', in_label, out_label) # TODO for other metrics # This would be where this is modified and the value # is computed (eg: mean FA in the connection. con_mat[in_label, out_label] += len(no_qb_loops_strl) time2 = time.time() logging.info(' Connection post-processing and saving took %0.3f ms', (time2 - time1) * 1000.0) # Remove first line and column, since they are index 0 and # would represent a connection to non-label voxels. Only used when # post-processing to avoid unnecessary -1 on labels for each access. con_mat = con_mat[1:, 1:] np.save(os.path.join(args.output, 'final_matrix.npy'), con_mat)
def main(): parser = build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.labels]) if os.path.abspath(args.output_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.output_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.labels) data_labels = img_labels.get_fdata().astype(np.int16) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') if not np.issubdtype(img_labels.get_data_dtype().type, np.integer): parser.error("Label image should contain integers for labels.") # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.mean(vox_sizes) == vox_sizes[0]: parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram) time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) logging.info('*** Filtering streamlines ***') data_mask = np.zeros(data_labels.shape) data_mask[data_labels > 0] = 1 original_len = len(sft) time1 = time.time() sft.to_vox() sft.to_corner() sft.remove_invalid_streamlines() time2 = time.time() logging.info( ' Discarded {} streamlines from filtering in {} sec.'.format( original_len - len(sft), round(time2 - time1, 2))) logging.info(' Number of streamlines to process: {}'.format(len(sft))) # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format(iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) _save_if_needed(connecting_streamlines, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length, invalid_length = _prune_segments( connecting_streamlines, args.min_length, args.max_length, vox_sizes[0]) _save_if_needed(invalid_length, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines if not len(valid_length): continue _save_if_needed(valid_length, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns(valid_length, args.loop_max_angle) no_loops = [valid_length[i] for i in no_loop_ids] loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops = [valid_length[i] for i in loop_ids] _save_if_needed(loops, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length if not len(no_loops): continue _save_if_needed(no_loops, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: inliers, outliers = remove_outliers(no_loops, args.outlier_threshold) _save_if_needed(outliers, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops if not len(inliers): continue _save_if_needed(inliers, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) no_qb_curv = [inliers[i] for i in no_qb_curv_ids] qb_curv_ids = np.setdiff1d( np.arange(len(inliers)), no_qb_curv_ids) qb_curv = [inliers[i] for i in qb_curv_ids] _save_if_needed(qb_curv, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv = inliers _save_if_needed(no_qb_curv, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))