def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram, optional=args.looping_tractogram) check_tracts_same_format(parser, [args.in_tractogram, args.out_tractogram, args.looping_tractogram]) if args.threshold <= 0: parser.error('Threshold "{}" '.format(args.threshold) + 'must be greater than 0') if args.angle <= 0: parser.error('Angle "{}" '.format(args.angle) + 'must be greater than 0') tractogram = load_tractogram_with_reference( parser, args, args.in_tractogram) streamlines = tractogram.streamlines ids_c = [] ids_l = [] if len(streamlines) > 1: ids_c = remove_loops_and_sharp_turns( streamlines, args.angle, use_qb=args.qb, qb_threshold=args.threshold) ids_l = np.setdiff1d(np.arange(len(streamlines)), ids_c) else: parser.error( 'Zero or one streamline in {}'.format(args.in_tractogram) + '. The file must have more than one streamline.') if len(ids_c) > 0: sft_c = filter_tractogram_data(tractogram, ids_c) save_tractogram(sft_c, args.out_tractogram) else: logging.warning( 'No clean streamlines in {}'.format(args.in_tractogram)) if args.display_counts: sc_bf = len(tractogram.streamlines) sc_af = len(sft_c.streamlines) print(json.dumps({'streamline_count_before_filtering': int(sc_bf), 'streamline_count_after_filtering': int(sc_af)}, indent=args.indent)) if len(ids_l) == 0: logging.warning('No loops in {}'.format(args.in_tractogram)) elif args.looping_tractogram: sft_l = filter_tractogram_data(tractogram, ids_l) save_tractogram(sft_l, args.looping_tractogram)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram, optional=args.remaining_tractogram) check_tracts_same_format( parser, [args.in_tractogram, args.out_tractogram, args.remaining_tractogram]) if args.threshold <= 0: parser.error('Threshold "{}" '.format(args.threshold) + 'must be greater than 0') if args.angle <= 0: parser.error('Angle "{}" '.format(args.angle) + 'must be greater than 0') tractogram = nib.streamlines.load(args.in_tractogram) streamlines = tractogram.streamlines streamlines_c = [] loops = [] if len(streamlines) > 1: streamlines_c, loops = remove_loops_and_sharp_turns( streamlines, args.angle, args.qb, args.threshold) else: parser.error( 'Zero or one streamline in {}'.format(args.in_tractogram) + '. The file must have more than one streamline.') if len(streamlines_c) > 0: tractogram_c = nib.streamlines.Tractogram(streamlines_c, affine_to_rasmm=np.eye(4)) nib.streamlines.save(tractogram_c, args.out_tractogram, header=tractogram.header) else: logging.warning('No clean streamlines in {}'.format( args.in_tractogram)) if len(loops) == 0: logging.warning('No loops in {}'.format(args.in_tractogram)) elif args.remaining_tractogram: tractogram_l = nib.streamlines.Tractogram(loops, affine_to_rasmm=np.eye(4)) nib.streamlines.save(tractogram_l, args.remaining_tractogram, header=tractogram.header)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_labels], args.reference) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) if (args.save_raw_connections or args.save_intermediate or args.save_discarded) and not args.out_dir: parser.error('To save outputs in the streamlines form, provide the ' 'output directory using --out_dir.') if args.out_dir: if os.path.abspath(args.out_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.in_labels) data_labels = get_data_as_label(img_labels) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03): parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) sft.remove_invalid_streamlines() time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) if not is_header_compatible(sft, img_labels): raise IOError('{} and {}do not have a compatible header'.format( args.in_tractogram, args.in_labels)) sft.to_vox() sft.to_corner() # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 with h5py.File(args.out_hdf5, 'w') as hdf5_file: affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft) hdf5_file.attrs['affine'] = affine hdf5_file.attrs['dimensions'] = dimensions hdf5_file.attrs['voxel_sizes'] = voxel_sizes hdf5_file.attrs['voxel_order'] = voxel_order # Each connections is processed independently. Multiprocessing would be # a burden on the I/O of most SSD/HD for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format( iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] connecting_ids = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) connecting_ids.append(strl_idx) # Each step is processed from the previous 'success' # 1. raw -> length pass/fail # 2. length pass -> loops pass/fail # 3. loops pass -> outlier detection pass/fail # 4. outlier detection pass -> qb curvature pass/fail # 5. qb curvature pass == final connections connecting_streamlines = ArraySequence(connecting_streamlines) raw_dps = sft.data_per_streamline[connecting_ids] raw_sft = StatefulTractogram.from_sft(connecting_streamlines, sft, data_per_streamline=raw_dps, data_per_point={}) _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length_ids, invalid_length_ids = _prune_segments( raw_sft.streamlines, args.min_length, args.max_length, vox_sizes[0]) invalid_length_sft = raw_sft[invalid_length_ids] valid_length = connecting_streamlines[valid_length_ids] _save_if_needed(invalid_length_sft, hdf5_file, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines valid_length_ids = range(len(connecting_streamlines)) if not len(valid_length): continue valid_length_sft = raw_sft[valid_length_ids] _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns( valid_length, args.loop_max_angle) loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops_sft = valid_length_sft[loop_ids] no_loops = valid_length[no_loop_ids] _save_if_needed(loops_sft, hdf5_file, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length no_loop_ids = range(len(valid_length)) if not len(no_loops): continue no_loops_sft = valid_length_sft[no_loop_ids] _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: outliers_ids, inliers_ids = remove_outliers( no_loops, args.outlier_threshold, nb_samplings=10, fast_approx=True) outliers_sft = no_loops_sft[outliers_ids] inliers = no_loops[inliers_ids] _save_if_needed(outliers_sft, hdf5_file, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops inliers_ids = range(len(no_loops)) if not len(inliers): continue inliers_sft = no_loops_sft[inliers_ids] _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) qb_curv_ids = np.setdiff1d(np.arange(len(inliers)), no_qb_curv_ids) qb_curv_sft = inliers_sft[qb_curv_ids] _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv_ids = range(len(inliers)) no_qb_curv_sft = inliers_sft[no_qb_curv_ids] _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))
def extract_true_connections( sft, mask_1_filename, mask_2_filename, gt_config, length_dict, gt_bundle, gt_bundle_inv_mask, dilate_endpoints, wrong_path_as_separate ): """ Extract true connections based on two regions from a tractogram. May extract false and no connections if the config is passed. Parameters ---------- sft: StatefulTractogram Tractogram containing the streamlines to be extracted. mask_1_filename: str Filename of the "head" of the bundle. mask_2_filename: str Filename of the "tail" of the bundle. gt_config: dict or None Dictionary containing the bundle's parameters. length_dict: dict or None Dictionary containing the bundle's length parameters. gt_bundle: str Bundle's name. gt_bundle_inv_mask: np.ndarray Inverse mask of the bundle. dilate_endpoints: int or None If set, dilate the masks for n iterations. wrong_path_as_separate: bool If true, save the WPCs as separate from TCs. Returns ------- tc_sft: StatefulTractogram SFT of true connections. wpc_sft: StatefulTractogram SFT of wrong-path-connections. fc_sft: StatefulTractogram SFT of false connections (streamlines that are too long). nc_streamlines: StatefulTractogram SFT of no connections (streamlines that loop) sft: StatefulTractogram SFT of remaining streamlines. """ mask_1_img = nib.load(mask_1_filename) mask_2_img = nib.load(mask_2_filename) mask_1 = get_data_as_mask(mask_1_img) mask_2 = get_data_as_mask(mask_2_img) if dilate_endpoints: mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints) mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints) # TODO: Handle streamline IDs instead of streamlines tmp_sft, sft = extract_streamlines(mask_1, mask_2, sft) streamlines = tmp_sft.streamlines tc_streamlines = streamlines wpc_streamlines = [] fc_streamlines = [] nc_streamlines = [] # Config file for each 'bundle' # Loops => no connection (nc) # TODO Is this legit ? # Length => false connection (fc) # TODO Is this legit ? if gt_config: min_len, max_len = \ length_dict[gt_bundle]['length'] # Bring streamlines to world coordinates so proper length # is calculated tmp_sft.to_rasmm() streamlines = tmp_sft.streamlines lengths = np.array(list(length(streamlines))) tmp_sft.to_vox() streamlines = tmp_sft.streamlines valid_min_length_mask = lengths > min_len valid_max_length_mask = lengths < max_len valid_length_mask = np.logical_and(valid_min_length_mask, valid_max_length_mask) streamlines = ArraySequence(streamlines) val_len_streamlines = streamlines[valid_length_mask] fc_streamlines = streamlines[~valid_length_mask] angle = length_dict[gt_bundle]['angle'] tc_streamlines_ids = remove_loops_and_sharp_turns( val_len_streamlines, angle) loop_ids = np.setdiff1d( range(len(val_len_streamlines)), tc_streamlines_ids) loops = val_len_streamlines[list(loop_ids)] tc_streamlines = val_len_streamlines[list(tc_streamlines_ids)] if loops: nc_streamlines = loops # Streamlines getting out of the bundle mask can be considered # separately as wrong path connection (wpc) # TODO: Maybe only consider if they cross another GT bundle ? if wrong_path_as_separate: tmp_sft = StatefulTractogram.from_sft(tc_streamlines, sft) _, wp_ids = filter_grid_roi( tmp_sft, gt_bundle_inv_mask, 'any', False) wpc_streamlines = tmp_sft.streamlines[list(wp_ids)] tc_ids = np.setdiff1d(range(len(tmp_sft)), wp_ids) tc_streamlines = tmp_sft.streamlines[list(tc_ids)] tc_sft = StatefulTractogram.from_sft(tc_streamlines, sft) wpc_sft = StatefulTractogram.from_sft([], sft) fc_sft = StatefulTractogram.from_sft(fc_streamlines, sft) if wrong_path_as_separate and len(wpc_streamlines): wpc_sft = StatefulTractogram.from_sft(wpc_streamlines, sft) return tc_sft, wpc_sft, fc_sft, nc_streamlines, sft
def extract_vb_vs(sft, head_filename, tail_filename, limits_length, angle, orientation_length, abs_orientation_length, inclusion_inv_mask, dilate_endpoints): """ Extract valid bundle (and valid streamline ids) from a tractogram, based on two regions of interest for the endpoints, one region of interest for the inclusion of streamlines, and maximum length, maximum angle, maximum length per orientation. Parameters ---------- sft: StatefulTractogram Tractogram containing the streamlines to be extracted. head_filename: str Filename of the "head" of the bundle. tail_filename: str Filename of the "tail" of the bundle. limits_length: list or None Bundle's length parameters: [min max]. angle: int or None Bundle's max angle. orientation_length: list or None Bundle's length parameters in each direction: [[min_x, max_x], [min_y, max_y], [min_z, max_z]] abs_orientation_length: idem, computed in absolute values. inclusion_inv_mask: np.ndarray or None Inverse mask of the bundle. dilate_endpoints: int or None If set, dilate the masks for n iterations. Returns ------- tc_sft: StatefulTractogram SFT of true connections. wpc_sft: StatefulTractogram SFT of wrong-path-connections. fc_sft: StatefulTractogram SFT of false connections (streamlines that are too long). nc_streamlines: StatefulTractogram SFT of no connections (streamlines that loop) sft: StatefulTractogram SFT of remaining streamlines. """ mask_1_img = nib.load(head_filename) mask_2_img = nib.load(tail_filename) mask_1 = get_data_as_mask(mask_1_img) mask_2 = get_data_as_mask(mask_2_img) if dilate_endpoints: mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints) mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints) _, vs_ids = filter_grid_roi_both(sft, mask_1, mask_2) wpc_ids = [] bundle_stats = {"Initial count head to tail": len(vs_ids)} # Remove out of inclusion mask (limits_mask) if len(vs_ids) > 0 and inclusion_inv_mask is not None: tmp_sft = StatefulTractogram.from_sft(sft.streamlines[vs_ids], sft) _, out_of_mask_ids_from_vs = filter_grid_roi(tmp_sft, inclusion_inv_mask, 'any', False) out_of_mask_ids = vs_ids[out_of_mask_ids_from_vs] bundle_stats.update({"WPC_out_of_mask": len(out_of_mask_ids)}) # Update ids wpc_ids.extend(out_of_mask_ids) vs_ids = np.setdiff1d(vs_ids, wpc_ids) # Remove invalid lengths if len(vs_ids) > 0 and limits_length is not None: min_len, max_len = limits_length # Bring streamlines to world coordinates so proper length # is calculated sft.to_rasmm() lengths = np.array(list(length(sft.streamlines[vs_ids]))) sft.to_vox() # Compute valid lengths valid_length_ids_mask_from_vs = np.logical_and(lengths > min_len, lengths < max_len) bundle_stats.update( {"WPC_invalid_length": int(sum(~valid_length_ids_mask_from_vs))}) # Update ids wpc_ids.extend(vs_ids[~valid_length_ids_mask_from_vs]) vs_ids = vs_ids[valid_length_ids_mask_from_vs] # Remove invalid lengths per orientation if len(vs_ids) > 0 and orientation_length is not None: # Compute valid lengths limits_x, limits_y, limits_z = orientation_length _, valid_orientation_ids_from_vs, _ = \ filter_streamlines_by_total_length_per_dim( sft[vs_ids], limits_x, limits_y, limits_z, use_abs=False, save_rejected=False) # Update ids valid_orientation_ids = vs_ids[valid_orientation_ids_from_vs] invalid_orientation_ids = np.setdiff1d(vs_ids, valid_orientation_ids) bundle_stats.update( {"WPC_invalid_orientation": len(invalid_orientation_ids)}) wpc_ids.extend(invalid_orientation_ids) vs_ids = valid_orientation_ids # Idem in abs if len(vs_ids) > 0 and abs_orientation_length is not None: # Compute valid lengths limits_x, limits_y, limits_z = abs_orientation_length _, valid_orientation_ids_from_vs, _ = \ filter_streamlines_by_total_length_per_dim( sft[vs_ids], limits_x, limits_y, limits_z, use_abs=True, save_rejected=False) # Update ids valid_orientation_ids = vs_ids[valid_orientation_ids_from_vs] invalid_orientation_ids = np.setdiff1d(vs_ids, valid_orientation_ids) bundle_stats.update( {"WPC_invalid_orientation_abs": len(invalid_orientation_ids)}) wpc_ids.extend(invalid_orientation_ids) vs_ids = valid_orientation_ids # Remove loops from tc if len(vs_ids) > 0 and angle is not None: # Compute valid angles valid_angle_ids_from_vs = remove_loops_and_sharp_turns( sft.streamlines[vs_ids], angle) # Update ids valid_angle_ids = vs_ids[valid_angle_ids_from_vs] invalid_angle_ids = np.setdiff1d(vs_ids, valid_angle_ids) bundle_stats.update({"WPC_invalid_length": len(invalid_angle_ids)}) wpc_ids.extend(invalid_angle_ids) vs_ids = valid_angle_ids bundle_stats.update({"VS": len(vs_ids)}) return list(vs_ids), list(wpc_ids), bundle_stats
def main(): parser = build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.tracks, args.labels]) if os.path.abspath(args.output) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.output) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) img_labels = nb.load(args.labels) if not np.issubdtype(img_labels.get_data_dtype().type, np.integer): parser.error("Label image should contain integers for labels.") # Ensure that voxel size is isotropic. Currently, for speed considerations, # we take the length in voxel space and multiply by the voxel size. For # this to work correctly, voxel size must be isotropic. vox_sizes = img_labels.header.get_zooms() if not np.mean(vox_sizes) == vox_sizes[0]: parser.error('Labels must be isotropic') if np.min(img_labels.get_data()) < 0 or \ np.max(img_labels.get_data()) > args.max_labels: parser.error('Invalid labels in labels image') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram) sft.to_vox() sft.to_corner() time2 = time.time() logging.info(' Number of streamlines to process: {}'.format( len(sft.streamlines))) logging.info(' Loading streamlines took %0.3f ms', (time2 - time1) * 1000.0) # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took %0.3f ms', (time2 - time1) * 1000.0) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, img_labels.get_data(), extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took %0.3f ms', (time2 - time1) * 1000.0) # Symmetrize matrix final_con_info = _symmetrize_con_info(con_info) # Prepare directories and information needed to save. saving_opts = _get_saving_options(args) out_paths = _get_output_paths(args.output) _create_required_output_dirs(out_paths, args) # Here, we use nb_labels + 1 since we want the direct mapping from image # label to matrix element. We will remove the first row and column before # saving. # TODO for other metrics # dtype should be adjusted depending on the type of elements # stored in the con_mat nb_labels = args.max_labels con_mat = np.zeros((nb_labels + 1, nb_labels + 1), dtype=np.uint32) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() for in_label in list(final_con_info.keys()): for out_label in list(final_con_info[in_label].keys()): pair_info = final_con_info[in_label][out_label] if not len(pair_info): continue final_strl = [] for connection in pair_info: strl_idx = connection['strl_idx'] final_strl.append( compute_streamline_segment(sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx])) _save_if_needed(final_strl, args, saving_opts, out_paths, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: pruned_strl, invalid_strl = _prune_segments( final_strl, args.min_length, args.max_length, vox_sizes[0]) _save_if_needed(invalid_strl, args, saving_opts, out_paths, 'discarded', 'removed_length', in_label, out_label) else: pruned_strl = final_strl if not len(pruned_strl): continue _save_if_needed(pruned_strl, args, saving_opts, out_paths, 'intermediate', 'pruned', in_label, out_label) if not args.no_remove_loops: no_loops, loops = remove_loops_and_sharp_turns( pruned_strl, args.loop_max_angle) _save_if_needed(loops, args, saving_opts, out_paths, 'discarded', 'loops', in_label, out_label) else: no_loops = pruned_strl if not len(no_loops): continue _save_if_needed(no_loops, args, saving_opts, out_paths, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: no_outliers, outliers = remove_outliers( no_loops, args.outlier_threshold) _save_if_needed(outliers, args, saving_opts, out_paths, 'discarded', 'outliers', in_label, out_label) else: no_outliers = no_loops if not len(no_outliers): continue _save_if_needed(no_outliers, args, saving_opts, out_paths, 'intermediate', 'no_outliers', in_label, out_label) if not args.no_remove_loops_again: no_qb_loops_strl, loops2 = remove_loops_and_sharp_turns( no_outliers, args.loop_max_angle, True, args.loop_qb_distance) _save_if_needed(loops2, args, saving_opts, out_paths, 'discarded', 'qb_loops', in_label, out_label) else: no_qb_loops_strl = no_outliers _save_if_needed(no_qb_loops_strl, args, saving_opts, out_paths, 'final', 'final', in_label, out_label) # TODO for other metrics # This would be where this is modified and the value # is computed (eg: mean FA in the connection. con_mat[in_label, out_label] += len(no_qb_loops_strl) time2 = time.time() logging.info(' Connection post-processing and saving took %0.3f ms', (time2 - time1) * 1000.0) # Remove first line and column, since they are index 0 and # would represent a connection to non-label voxels. Only used when # post-processing to avoid unnecessary -1 on labels for each access. con_mat = con_mat[1:, 1:] np.save(os.path.join(args.output, 'final_matrix.npy'), con_mat)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_inputs_exist(parser, args.in_wmparc) assert_output_dirs_exist_and_empty(parser, args, args.out_path, create_dir=True) if args.verbose: logging.basicConfig(level=logging.DEBUG) if args.angle <= 0: parser.error('Angle "{}" '.format(args.angle) + 'must be greater than or equal to 0') if args.ctx_dilation_radius < 0: parser.error( 'Cortex dilation radius "{}" '.format(args.ctx_dilation_radius) + 'must be greater than 0') sft = load_tractogram_with_reference(parser, args, args.in_tractogram) img_wmparc = nib.load(args.in_wmparc) if not is_header_compatible(img_wmparc, sft): parser.error('Headers from the tractogram and the wmparc are ' 'not compatible.') if args.csf_bin: img_csf = nib.load(args.csf_bin) if not is_header_compatible(img_csf, sft): parser.error('Headers from the tractogram and the CSF mask are ' 'not compatible.') if args.minL == 0 and np.isinf(args.maxL): logging.debug("You have not specified minL nor maxL. Output will " "not be filtered according to length!") if np.isinf(args.angle): logging.debug("You have not specified the angle. Loops will " "not be filtered!") if args.ctx_dilation_radius == 0: logging.debug("You have not specified the cortex dilation radius. " "The wmparc atlas will not be dilated!") o_dict = {} step_dict = ['length', 'no_loops', 'no_end_csf', 'end_in_atlas'] wm_labels = load_wmparc_labels() in_sft_name = os.path.splitext(os.path.basename(args.in_tractogram))[0] out_sft_rootname = in_sft_name + "_filtered" _, ext = os.path.splitext(args.in_tractogram) out_sft_name = os.path.join(args.out_path, out_sft_rootname + "_filtered" + ext) # STEP 1 - Filter length step = step_dict[0] steps_combined = step new_sft = filter_streamlines_by_length(sft, args.minL, args.maxL) # Streamline count before and after filtering lengths o_dict[in_sft_name + ext] =\ dict({'streamline_count': len(sft.streamlines)}) o_dict[in_sft_name + '_' + steps_combined + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') save_tractogram(new_sft, out_sft_name) if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return sft = new_sft # STEP 2 - Filter loops step = step_dict[1] steps_combined += "_" + step ids_c = remove_loops_and_sharp_turns(sft.streamlines, args.angle) new_sft = filter_tractogram_data(sft, ids_c) # Streamline count after filtering loops o_dict[in_sft_name + '_' + steps_combined + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') save_tractogram(new_sft, out_sft_name) if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return sft = new_sft # STEP 3 - Filter CSF endings step = step_dict[2] steps_combined += "_" + step # Mask creation if args.csf_bin: mask = get_data_as_mask(img_csf) else: atlas = get_data_as_label(img_wmparc) mask = binarize_labels(atlas, wm_labels["csf_labels"]) # Filter tractogram new_sft, _ = filter_grid_roi(sft, mask, 'any', True) # Streamline count after filtering CSF endings o_dict[in_sft_name + '_' + steps_combined + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_volumes: new_path = create_dir(args.out_path, step) if not args.csf_bin: nib.save( nib.Nifti1Image(mask, img_wmparc.affine, img_wmparc.header), os.path.join(new_path, 'csf_bin' + '.nii.gz')) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') save_tractogram(new_sft, out_sft_name) if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) return sft = new_sft # STEP 4 - Filter WM endings step = step_dict[3] steps_combined += "_" + step # Mask creation ctx_fs_labels = wm_labels["ctx_lh_fs_labels"] + \ wm_labels["ctx_rh_fs_labels"] vox_size = np.reshape(img_wmparc.header.get_zooms(), (1, 3)) atlas_wm = get_data_as_label(img_wmparc) atlas_shape = atlas_wm.shape wmparc_ctx = binarize_labels(atlas_wm, ctx_fs_labels) wmparc_nuclei = binarize_labels(atlas_wm, wm_labels["nuclei_fs_labels"]) # Dilation of cortex if args.ctx_dilation_radius: ctx_mask = dilate_mask(wmparc_ctx, atlas_shape, vox_size, args.ctx_dilation_radius) else: ctx_mask = wmparc_ctx freesurfer_mask = np.zeros(atlas_shape, dtype=np.uint16) freesurfer_mask[np.logical_or(wmparc_nuclei, ctx_mask)] = 1 # Filter tractogram new_sft, _ = filter_grid_roi(sft, freesurfer_mask, 'both_ends', False) # Streamline count after final filtering o_dict[out_sft_rootname + ext] =\ dict({'streamline_count': len(new_sft.streamlines)}) if args.save_volumes: new_path = create_dir(args.out_path, step) nib.save( nib.Nifti1Image(freesurfer_mask, img_wmparc.affine, img_wmparc.header), os.path.join(new_path, 'atlas_bin' + '.nii.gz')) if args.save_intermediate_tractograms: outliers_sft = compute_outliers(sft, new_sft) new_path = create_dir(args.out_path, step) save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name, step, steps_combined, ext, args.no_empty) o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\ dict({'streamline_count': len(outliers_sft.streamlines)}) # Finish filtering if args.verbose: display_count(o_dict, args.indent, args.sort_keys) if args.save_counts: save_count(o_dict, args.out_path, args.indent, args.sort_keys) if len(new_sft.streamlines) == 0: if args.no_empty: logging.debug("The file {} won't be written".format(out_sft_name) + "(0 streamlines after " + step + " filtering).") return logging.debug( 'The file {} contains 0 streamlines after '.format(out_sft_name) + step + ' filtering') sft = new_sft save_tractogram(sft, out_sft_name)
def main(): parser = build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.labels]) if os.path.abspath(args.output_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.output_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.labels) data_labels = img_labels.get_fdata().astype(np.int16) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') if not np.issubdtype(img_labels.get_data_dtype().type, np.integer): parser.error("Label image should contain integers for labels.") # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.mean(vox_sizes) == vox_sizes[0]: parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram) time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) logging.info('*** Filtering streamlines ***') data_mask = np.zeros(data_labels.shape) data_mask[data_labels > 0] = 1 original_len = len(sft) time1 = time.time() sft.to_vox() sft.to_corner() sft.remove_invalid_streamlines() time2 = time.time() logging.info( ' Discarded {} streamlines from filtering in {} sec.'.format( original_len - len(sft), round(time2 - time1, 2))) logging.info(' Number of streamlines to process: {}'.format(len(sft))) # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format(iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) _save_if_needed(connecting_streamlines, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length, invalid_length = _prune_segments( connecting_streamlines, args.min_length, args.max_length, vox_sizes[0]) _save_if_needed(invalid_length, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines if not len(valid_length): continue _save_if_needed(valid_length, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns(valid_length, args.loop_max_angle) no_loops = [valid_length[i] for i in no_loop_ids] loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops = [valid_length[i] for i in loop_ids] _save_if_needed(loops, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length if not len(no_loops): continue _save_if_needed(no_loops, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: inliers, outliers = remove_outliers(no_loops, args.outlier_threshold) _save_if_needed(outliers, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops if not len(inliers): continue _save_if_needed(inliers, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) no_qb_curv = [inliers[i] for i in no_qb_curv_ids] qb_curv_ids = np.setdiff1d( np.arange(len(inliers)), no_qb_curv_ids) qb_curv = [inliers[i] for i in qb_curv_ids] _save_if_needed(qb_curv, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv = inliers _save_if_needed(no_qb_curv, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))