Ejemplo n.º 1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundle)
    assert_outputs_exist(parser, args, args.out_bundle, args.remaining_bundle)
    if args.alpha <= 0 or args.alpha > 1:
        parser.error('--alpha should be ]0, 1]')

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    if len(sft) == 0:
        logging.warning("Bundle file contains no streamline")
        return

    check_tracts_same_format(
        parser, [args.in_bundle, args.out_bundle, args.remaining_bundle])
    outliers, inliers = remove_outliers(sft.streamlines, args.alpha)

    inliers_sft = sft[inliers]
    outliers_sfts = sft[outliers]

    if len(inliers) == 0:
        logging.warning("All streamlines are considered outliers."
                        "Please lower the --alpha parameter")
    else:
        save_tractogram(inliers_sft, args.out_bundle)

    if len(outliers) == 0:
        logging.warning("No outlier found. Please raise the --alpha parameter")
    elif args.remaining_bundle:
        save_tractogram(outliers_sfts, args.remaining_bundle)
Ejemplo n.º 2
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Ejemplo n.º 3
0
def main():
    parser = build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.tracks, args.labels])

    if os.path.abspath(args.output) == os.getcwd():
        parser.error('Do not use the current path as output directory.')

    assert_output_dirs_exist_and_empty(parser, args, args.output)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)

    img_labels = nb.load(args.labels)
    if not np.issubdtype(img_labels.get_data_dtype().type, np.integer):
        parser.error("Label image should contain integers for labels.")

    # Ensure that voxel size is isotropic. Currently, for speed considerations,
    # we take the length in voxel space and multiply by the voxel size. For
    # this to work correctly, voxel size must be isotropic.
    vox_sizes = img_labels.header.get_zooms()
    if not np.mean(vox_sizes) == vox_sizes[0]:
        parser.error('Labels must be isotropic')

    if np.min(img_labels.get_data()) < 0 or \
            np.max(img_labels.get_data()) > args.max_labels:
        parser.error('Invalid labels in labels image')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    sft.to_vox()
    sft.to_corner()
    time2 = time.time()

    logging.info('    Number of streamlines to process: {}'.format(
        len(sft.streamlines)))
    logging.info('    Loading streamlines took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, img_labels.get_data(),
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Symmetrize matrix
    final_con_info = _symmetrize_con_info(con_info)

    # Prepare directories and information needed to save.
    saving_opts = _get_saving_options(args)
    out_paths = _get_output_paths(args.output)
    _create_required_output_dirs(out_paths, args)

    # Here, we use nb_labels + 1 since we want the direct mapping from image
    # label to matrix element. We will remove the first row and column before
    # saving.
    # TODO for other metrics
    # dtype should be adjusted depending on the type of elements
    # stored in the con_mat
    nb_labels = args.max_labels
    con_mat = np.zeros((nb_labels + 1, nb_labels + 1), dtype=np.uint32)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()
    for in_label in list(final_con_info.keys()):
        for out_label in list(final_con_info[in_label].keys()):
            pair_info = final_con_info[in_label][out_label]

            if not len(pair_info):
                continue

            final_strl = []

            for connection in pair_info:
                strl_idx = connection['strl_idx']
                final_strl.append(
                    compute_streamline_segment(sft.streamlines[strl_idx],
                                               indices[strl_idx],
                                               connection['in_idx'],
                                               connection['out_idx'],
                                               points_to_idx[strl_idx]))

            _save_if_needed(final_strl, args, saving_opts, out_paths, 'raw',
                            'raw', in_label, out_label)

            # Doing all post-processing
            if not args.no_pruning:
                pruned_strl, invalid_strl = _prune_segments(
                    final_strl, args.min_length, args.max_length, vox_sizes[0])

                _save_if_needed(invalid_strl, args, saving_opts, out_paths,
                                'discarded', 'removed_length', in_label,
                                out_label)
            else:
                pruned_strl = final_strl

            if not len(pruned_strl):
                continue

            _save_if_needed(pruned_strl, args, saving_opts, out_paths,
                            'intermediate', 'pruned', in_label, out_label)

            if not args.no_remove_loops:
                no_loops, loops = remove_loops_and_sharp_turns(
                    pruned_strl, args.loop_max_angle)
                _save_if_needed(loops, args, saving_opts, out_paths,
                                'discarded', 'loops', in_label, out_label)
            else:
                no_loops = pruned_strl

            if not len(no_loops):
                continue

            _save_if_needed(no_loops, args, saving_opts, out_paths,
                            'intermediate', 'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                no_outliers, outliers = remove_outliers(
                    no_loops, args.outlier_threshold)
                _save_if_needed(outliers, args, saving_opts, out_paths,
                                'discarded', 'outliers', in_label, out_label)
            else:
                no_outliers = no_loops

            if not len(no_outliers):
                continue

            _save_if_needed(no_outliers, args, saving_opts, out_paths,
                            'intermediate', 'no_outliers', in_label, out_label)

            if not args.no_remove_loops_again:
                no_qb_loops_strl, loops2 = remove_loops_and_sharp_turns(
                    no_outliers, args.loop_max_angle, True,
                    args.loop_qb_distance)
                _save_if_needed(loops2, args, saving_opts, out_paths,
                                'discarded', 'qb_loops', in_label, out_label)
            else:
                no_qb_loops_strl = no_outliers

            _save_if_needed(no_qb_loops_strl, args, saving_opts, out_paths,
                            'final', 'final', in_label, out_label)

            # TODO for other metrics
            # This would be where this is modified and the value
            # is computed (eg: mean FA in the connection.
            con_mat[in_label, out_label] += len(no_qb_loops_strl)

    time2 = time.time()
    logging.info('    Connection post-processing and saving took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Remove first line and column, since they are index 0 and
    # would represent a connection to non-label voxels. Only used when
    # post-processing to avoid unnecessary -1 on labels for each access.
    con_mat = con_mat[1:, 1:]
    np.save(os.path.join(args.output, 'final_matrix.npy'), con_mat)
Ejemplo n.º 4
0
def main():
    parser = build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.labels])

    if os.path.abspath(args.output_dir) == os.getcwd():
        parser.error('Do not use the current path as output directory.')

    assert_output_dirs_exist_and_empty(parser, args, args.output_dir,
                                       create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.labels)
    data_labels = img_labels.get_fdata().astype(np.int16)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    if not np.issubdtype(img_labels.get_data_dtype().type, np.integer):
        parser.error("Label image should contain integers for labels.")

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.mean(vox_sizes) == vox_sizes[0]:
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    logging.info('*** Filtering streamlines ***')
    data_mask = np.zeros(data_labels.shape)
    data_mask[data_labels > 0] = 1

    original_len = len(sft)
    time1 = time.time()

    sft.to_vox()
    sft.to_corner()
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info(
        '    Discarded {} streamlines from filtering in {} sec.'.format(
            original_len - len(sft), round(time2 - time1, 2)))
    logging.info('    Number of streamlines to process: {}'.format(len(sft)))

    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices,
                                    data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    for in_label, out_label in comb_list:
        if iteration_counter > 0 and iteration_counter % 100 == 0:
            logging.info('Split {} nodes out of {}'.format(iteration_counter,
                                                           len(comb_list)))
        iteration_counter += 1

        pair_info = []
        if in_label not in con_info:
            continue
        elif out_label in con_info[in_label]:
            pair_info.extend(con_info[in_label][out_label])

        if out_label not in con_info:
            continue
        elif in_label in con_info[out_label]:
            pair_info.extend(con_info[out_label][in_label])

        if not len(pair_info):
            continue

        connecting_streamlines = []
        for connection in pair_info:
            strl_idx = connection['strl_idx']
            curr_streamlines = compute_streamline_segment(
                sft.streamlines[strl_idx],
                indices[strl_idx],
                connection['in_idx'],
                connection['out_idx'],
                points_to_idx[strl_idx])
            connecting_streamlines.append(curr_streamlines)

        _save_if_needed(connecting_streamlines, args, 'raw',
                        'raw', in_label, out_label)

        # Doing all post-processing
        if not args.no_pruning:
            valid_length, invalid_length = _prune_segments(
                connecting_streamlines,
                args.min_length,
                args.max_length,
                vox_sizes[0])

            _save_if_needed(invalid_length, args,
                            'discarded', 'invalid_length',
                            in_label, out_label)
        else:
            valid_length = connecting_streamlines

        if not len(valid_length):
            continue

        _save_if_needed(valid_length, args,
                        'intermediate', 'valid_length', in_label, out_label)

        if not args.no_remove_loops:
            no_loop_ids = remove_loops_and_sharp_turns(valid_length,
                                                       args.loop_max_angle)
            no_loops = [valid_length[i] for i in no_loop_ids]

            loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids)
            loops = [valid_length[i] for i in loop_ids]

            _save_if_needed(loops, args,
                            'discarded', 'loops', in_label, out_label)
        else:
            no_loops = valid_length

        if not len(no_loops):
            continue

        _save_if_needed(no_loops, args,
                        'intermediate', 'no_loops', in_label, out_label)

        if not args.no_remove_outliers:
            inliers, outliers = remove_outliers(no_loops,
                                                args.outlier_threshold)
            _save_if_needed(outliers, args,
                            'discarded', 'outliers', in_label, out_label)
        else:
            inliers = no_loops

        if not len(inliers):
            continue

        _save_if_needed(inliers, args,
                        'intermediate', 'inliers', in_label, out_label)

        if not args.no_remove_curv_dev:
            no_qb_curv_ids = remove_loops_and_sharp_turns(
                inliers,
                args.loop_max_angle,
                use_qb=True,
                qb_threshold=args.curv_qb_distance)
            no_qb_curv = [inliers[i] for i in no_qb_curv_ids]

            qb_curv_ids = np.setdiff1d(
                np.arange(len(inliers)), no_qb_curv_ids)
            qb_curv = [inliers[i] for i in qb_curv_ids]

            _save_if_needed(qb_curv, args,
                            'discarded', 'qb_curv', in_label, out_label)
        else:
            no_qb_curv = inliers

        _save_if_needed(no_qb_curv, args,
                        'final', 'final', in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))