def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram,
                         optional=args.looping_tractogram)
    check_tracts_same_format(parser, [args.in_tractogram, args.out_tractogram,
                                      args.looping_tractogram])

    if args.threshold <= 0:
        parser.error('Threshold "{}" '.format(args.threshold) +
                     'must be greater than 0')

    if args.angle <= 0:
        parser.error('Angle "{}" '.format(args.angle) +
                     'must be greater than 0')

    tractogram = load_tractogram_with_reference(
        parser, args, args.in_tractogram)

    streamlines = tractogram.streamlines

    ids_c = []

    ids_l = []

    if len(streamlines) > 1:
        ids_c = remove_loops_and_sharp_turns(
            streamlines, args.angle, use_qb=args.qb,
            qb_threshold=args.threshold)
        ids_l = np.setdiff1d(np.arange(len(streamlines)), ids_c)
    else:
        parser.error(
            'Zero or one streamline in {}'.format(args.in_tractogram) +
            '. The file must have more than one streamline.')

    if len(ids_c) > 0:
        sft_c = filter_tractogram_data(tractogram, ids_c)
        save_tractogram(sft_c, args.out_tractogram)
    else:
        logging.warning(
            'No clean streamlines in {}'.format(args.in_tractogram))

    if args.display_counts:
        sc_bf = len(tractogram.streamlines)
        sc_af = len(sft_c.streamlines)
        print(json.dumps({'streamline_count_before_filtering': int(sc_bf),
                         'streamline_count_after_filtering': int(sc_af)},
                         indent=args.indent))

    if len(ids_l) == 0:
        logging.warning('No loops in {}'.format(args.in_tractogram))
    elif args.looping_tractogram:
        sft_l = filter_tractogram_data(tractogram, ids_l)
        save_tractogram(sft_l, args.looping_tractogram)
Esempio n. 2
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser,
                         args,
                         args.out_tractogram,
                         optional=args.remaining_tractogram)
    check_tracts_same_format(
        parser,
        [args.in_tractogram, args.out_tractogram, args.remaining_tractogram])

    if args.threshold <= 0:
        parser.error('Threshold "{}" '.format(args.threshold) +
                     'must be greater than 0')

    if args.angle <= 0:
        parser.error('Angle "{}" '.format(args.angle) +
                     'must be greater than 0')

    tractogram = nib.streamlines.load(args.in_tractogram)
    streamlines = tractogram.streamlines

    streamlines_c = []
    loops = []
    if len(streamlines) > 1:
        streamlines_c, loops = remove_loops_and_sharp_turns(
            streamlines, args.angle, args.qb, args.threshold)
    else:
        parser.error(
            'Zero or one streamline in {}'.format(args.in_tractogram) +
            '. The file must have more than one streamline.')

    if len(streamlines_c) > 0:
        tractogram_c = nib.streamlines.Tractogram(streamlines_c,
                                                  affine_to_rasmm=np.eye(4))
        nib.streamlines.save(tractogram_c,
                             args.out_tractogram,
                             header=tractogram.header)
    else:
        logging.warning('No clean streamlines in {}'.format(
            args.in_tractogram))

    if len(loops) == 0:
        logging.warning('No loops in {}'.format(args.in_tractogram))
    elif args.remaining_tractogram:
        tractogram_l = nib.streamlines.Tractogram(loops,
                                                  affine_to_rasmm=np.eye(4))
        nib.streamlines.save(tractogram_l,
                             args.remaining_tractogram,
                             header=tractogram.header)
Esempio n. 3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Esempio n. 4
0
def extract_true_connections(
    sft, mask_1_filename, mask_2_filename, gt_config, length_dict,
    gt_bundle, gt_bundle_inv_mask, dilate_endpoints, wrong_path_as_separate
):
    """
    Extract true connections based on two regions from a tractogram.
    May extract false and no connections if the config is passed.

    Parameters
    ----------
    sft: StatefulTractogram
        Tractogram containing the streamlines to be extracted.
    mask_1_filename: str
        Filename of the "head" of the bundle.
    mask_2_filename: str
        Filename of the "tail" of the bundle.
    gt_config: dict or None
        Dictionary containing the bundle's parameters.
    length_dict: dict or None
        Dictionary containing the bundle's length parameters.
    gt_bundle: str
        Bundle's name.
    gt_bundle_inv_mask: np.ndarray
        Inverse mask of the bundle.
    dilate_endpoints: int or None
        If set, dilate the masks for n iterations.
    wrong_path_as_separate: bool
        If true, save the WPCs as separate from TCs.

    Returns
    -------
    tc_sft: StatefulTractogram
        SFT of true connections.
    wpc_sft: StatefulTractogram
        SFT of wrong-path-connections.
    fc_sft: StatefulTractogram
        SFT of false connections (streamlines that are too long).
    nc_streamlines: StatefulTractogram
        SFT of no connections (streamlines that loop)
    sft: StatefulTractogram
        SFT of remaining streamlines.
    """

    mask_1_img = nib.load(mask_1_filename)
    mask_2_img = nib.load(mask_2_filename)
    mask_1 = get_data_as_mask(mask_1_img)
    mask_2 = get_data_as_mask(mask_2_img)

    if dilate_endpoints:
        mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
        mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

    # TODO: Handle streamline IDs instead of streamlines
    tmp_sft, sft = extract_streamlines(mask_1, mask_2, sft)

    streamlines = tmp_sft.streamlines
    tc_streamlines = streamlines
    wpc_streamlines = []
    fc_streamlines = []
    nc_streamlines = []

    # Config file for each 'bundle'
    # Loops => no connection (nc) # TODO Is this legit ?
    # Length => false connection (fc) # TODO Is this legit ?
    if gt_config:
        min_len, max_len = \
            length_dict[gt_bundle]['length']

        # Bring streamlines to world coordinates so proper length
        # is calculated
        tmp_sft.to_rasmm()
        streamlines = tmp_sft.streamlines
        lengths = np.array(list(length(streamlines)))
        tmp_sft.to_vox()
        streamlines = tmp_sft.streamlines

        valid_min_length_mask = lengths > min_len
        valid_max_length_mask = lengths < max_len
        valid_length_mask = np.logical_and(valid_min_length_mask,
                                           valid_max_length_mask)
        streamlines = ArraySequence(streamlines)

        val_len_streamlines = streamlines[valid_length_mask]
        fc_streamlines = streamlines[~valid_length_mask]

        angle = length_dict[gt_bundle]['angle']
        tc_streamlines_ids = remove_loops_and_sharp_turns(
            val_len_streamlines, angle)

        loop_ids = np.setdiff1d(
            range(len(val_len_streamlines)), tc_streamlines_ids)

        loops = val_len_streamlines[list(loop_ids)]
        tc_streamlines = val_len_streamlines[list(tc_streamlines_ids)]

        if loops:
            nc_streamlines = loops

    # Streamlines getting out of the bundle mask can be considered
    # separately as wrong path connection (wpc)
    # TODO: Maybe only consider if they cross another GT bundle ?
    if wrong_path_as_separate:
        tmp_sft = StatefulTractogram.from_sft(tc_streamlines, sft)
        _, wp_ids = filter_grid_roi(
            tmp_sft, gt_bundle_inv_mask, 'any', False)
        wpc_streamlines = tmp_sft.streamlines[list(wp_ids)]
        tc_ids = np.setdiff1d(range(len(tmp_sft)), wp_ids)
        tc_streamlines = tmp_sft.streamlines[list(tc_ids)]

    tc_sft = StatefulTractogram.from_sft(tc_streamlines, sft)
    wpc_sft = StatefulTractogram.from_sft([], sft)
    fc_sft = StatefulTractogram.from_sft(fc_streamlines, sft)
    if wrong_path_as_separate and len(wpc_streamlines):
        wpc_sft = StatefulTractogram.from_sft(wpc_streamlines, sft)

    return tc_sft, wpc_sft, fc_sft, nc_streamlines, sft
Esempio n. 5
0
def extract_vb_vs(sft, head_filename, tail_filename, limits_length, angle,
                  orientation_length, abs_orientation_length,
                  inclusion_inv_mask, dilate_endpoints):
    """
    Extract valid bundle (and valid streamline ids) from a tractogram, based
    on two regions of interest for the endpoints, one region of interest for
    the inclusion of streamlines, and maximum length, maximum angle,
    maximum length per orientation.

    Parameters
    ----------
    sft: StatefulTractogram
        Tractogram containing the streamlines to be extracted.
    head_filename: str
        Filename of the "head" of the bundle.
    tail_filename: str
        Filename of the "tail" of the bundle.
    limits_length: list or None
        Bundle's length parameters: [min max].
    angle: int or None
        Bundle's max angle.
    orientation_length: list or None
        Bundle's length parameters in each direction:
        [[min_x, max_x], [min_y, max_y], [min_z, max_z]]
    abs_orientation_length: idem, computed in absolute values.
    inclusion_inv_mask: np.ndarray or None
        Inverse mask of the bundle.
    dilate_endpoints: int or None
        If set, dilate the masks for n iterations.

    Returns
    -------
    tc_sft: StatefulTractogram
        SFT of true connections.
    wpc_sft: StatefulTractogram
        SFT of wrong-path-connections.
    fc_sft: StatefulTractogram
        SFT of false connections (streamlines that are too long).
    nc_streamlines: StatefulTractogram
        SFT of no connections (streamlines that loop)
    sft: StatefulTractogram
        SFT of remaining streamlines.
    """
    mask_1_img = nib.load(head_filename)
    mask_2_img = nib.load(tail_filename)
    mask_1 = get_data_as_mask(mask_1_img)
    mask_2 = get_data_as_mask(mask_2_img)

    if dilate_endpoints:
        mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
        mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

    _, vs_ids = filter_grid_roi_both(sft, mask_1, mask_2)

    wpc_ids = []
    bundle_stats = {"Initial count head to tail": len(vs_ids)}

    # Remove out of inclusion mask (limits_mask)
    if len(vs_ids) > 0 and inclusion_inv_mask is not None:
        tmp_sft = StatefulTractogram.from_sft(sft.streamlines[vs_ids], sft)
        _, out_of_mask_ids_from_vs = filter_grid_roi(tmp_sft,
                                                     inclusion_inv_mask, 'any',
                                                     False)
        out_of_mask_ids = vs_ids[out_of_mask_ids_from_vs]

        bundle_stats.update({"WPC_out_of_mask": len(out_of_mask_ids)})

        # Update ids
        wpc_ids.extend(out_of_mask_ids)
        vs_ids = np.setdiff1d(vs_ids, wpc_ids)

    # Remove invalid lengths
    if len(vs_ids) > 0 and limits_length is not None:
        min_len, max_len = limits_length

        # Bring streamlines to world coordinates so proper length
        # is calculated
        sft.to_rasmm()
        lengths = np.array(list(length(sft.streamlines[vs_ids])))
        sft.to_vox()

        # Compute valid lengths
        valid_length_ids_mask_from_vs = np.logical_and(lengths > min_len,
                                                       lengths < max_len)

        bundle_stats.update(
            {"WPC_invalid_length": int(sum(~valid_length_ids_mask_from_vs))})

        # Update ids
        wpc_ids.extend(vs_ids[~valid_length_ids_mask_from_vs])
        vs_ids = vs_ids[valid_length_ids_mask_from_vs]

    # Remove invalid lengths per orientation
    if len(vs_ids) > 0 and orientation_length is not None:
        # Compute valid lengths
        limits_x, limits_y, limits_z = orientation_length

        _, valid_orientation_ids_from_vs, _ = \
            filter_streamlines_by_total_length_per_dim(
                sft[vs_ids], limits_x, limits_y, limits_z,
                use_abs=False, save_rejected=False)

        # Update ids
        valid_orientation_ids = vs_ids[valid_orientation_ids_from_vs]
        invalid_orientation_ids = np.setdiff1d(vs_ids, valid_orientation_ids)

        bundle_stats.update(
            {"WPC_invalid_orientation": len(invalid_orientation_ids)})

        wpc_ids.extend(invalid_orientation_ids)
        vs_ids = valid_orientation_ids

    # Idem in abs
    if len(vs_ids) > 0 and abs_orientation_length is not None:
        # Compute valid lengths
        limits_x, limits_y, limits_z = abs_orientation_length

        _, valid_orientation_ids_from_vs, _ = \
            filter_streamlines_by_total_length_per_dim(
                sft[vs_ids], limits_x, limits_y,
                limits_z,
                use_abs=True, save_rejected=False)

        # Update ids
        valid_orientation_ids = vs_ids[valid_orientation_ids_from_vs]
        invalid_orientation_ids = np.setdiff1d(vs_ids, valid_orientation_ids)

        bundle_stats.update(
            {"WPC_invalid_orientation_abs": len(invalid_orientation_ids)})

        wpc_ids.extend(invalid_orientation_ids)
        vs_ids = valid_orientation_ids

    # Remove loops from tc
    if len(vs_ids) > 0 and angle is not None:
        # Compute valid angles
        valid_angle_ids_from_vs = remove_loops_and_sharp_turns(
            sft.streamlines[vs_ids], angle)

        # Update ids
        valid_angle_ids = vs_ids[valid_angle_ids_from_vs]
        invalid_angle_ids = np.setdiff1d(vs_ids, valid_angle_ids)

        bundle_stats.update({"WPC_invalid_length": len(invalid_angle_ids)})

        wpc_ids.extend(invalid_angle_ids)
        vs_ids = valid_angle_ids

    bundle_stats.update({"VS": len(vs_ids)})

    return list(vs_ids), list(wpc_ids), bundle_stats
Esempio n. 6
0
def main():
    parser = build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.tracks, args.labels])

    if os.path.abspath(args.output) == os.getcwd():
        parser.error('Do not use the current path as output directory.')

    assert_output_dirs_exist_and_empty(parser, args, args.output)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)

    img_labels = nb.load(args.labels)
    if not np.issubdtype(img_labels.get_data_dtype().type, np.integer):
        parser.error("Label image should contain integers for labels.")

    # Ensure that voxel size is isotropic. Currently, for speed considerations,
    # we take the length in voxel space and multiply by the voxel size. For
    # this to work correctly, voxel size must be isotropic.
    vox_sizes = img_labels.header.get_zooms()
    if not np.mean(vox_sizes) == vox_sizes[0]:
        parser.error('Labels must be isotropic')

    if np.min(img_labels.get_data()) < 0 or \
            np.max(img_labels.get_data()) > args.max_labels:
        parser.error('Invalid labels in labels image')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    sft.to_vox()
    sft.to_corner()
    time2 = time.time()

    logging.info('    Number of streamlines to process: {}'.format(
        len(sft.streamlines)))
    logging.info('    Loading streamlines took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, img_labels.get_data(),
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Symmetrize matrix
    final_con_info = _symmetrize_con_info(con_info)

    # Prepare directories and information needed to save.
    saving_opts = _get_saving_options(args)
    out_paths = _get_output_paths(args.output)
    _create_required_output_dirs(out_paths, args)

    # Here, we use nb_labels + 1 since we want the direct mapping from image
    # label to matrix element. We will remove the first row and column before
    # saving.
    # TODO for other metrics
    # dtype should be adjusted depending on the type of elements
    # stored in the con_mat
    nb_labels = args.max_labels
    con_mat = np.zeros((nb_labels + 1, nb_labels + 1), dtype=np.uint32)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()
    for in_label in list(final_con_info.keys()):
        for out_label in list(final_con_info[in_label].keys()):
            pair_info = final_con_info[in_label][out_label]

            if not len(pair_info):
                continue

            final_strl = []

            for connection in pair_info:
                strl_idx = connection['strl_idx']
                final_strl.append(
                    compute_streamline_segment(sft.streamlines[strl_idx],
                                               indices[strl_idx],
                                               connection['in_idx'],
                                               connection['out_idx'],
                                               points_to_idx[strl_idx]))

            _save_if_needed(final_strl, args, saving_opts, out_paths, 'raw',
                            'raw', in_label, out_label)

            # Doing all post-processing
            if not args.no_pruning:
                pruned_strl, invalid_strl = _prune_segments(
                    final_strl, args.min_length, args.max_length, vox_sizes[0])

                _save_if_needed(invalid_strl, args, saving_opts, out_paths,
                                'discarded', 'removed_length', in_label,
                                out_label)
            else:
                pruned_strl = final_strl

            if not len(pruned_strl):
                continue

            _save_if_needed(pruned_strl, args, saving_opts, out_paths,
                            'intermediate', 'pruned', in_label, out_label)

            if not args.no_remove_loops:
                no_loops, loops = remove_loops_and_sharp_turns(
                    pruned_strl, args.loop_max_angle)
                _save_if_needed(loops, args, saving_opts, out_paths,
                                'discarded', 'loops', in_label, out_label)
            else:
                no_loops = pruned_strl

            if not len(no_loops):
                continue

            _save_if_needed(no_loops, args, saving_opts, out_paths,
                            'intermediate', 'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                no_outliers, outliers = remove_outliers(
                    no_loops, args.outlier_threshold)
                _save_if_needed(outliers, args, saving_opts, out_paths,
                                'discarded', 'outliers', in_label, out_label)
            else:
                no_outliers = no_loops

            if not len(no_outliers):
                continue

            _save_if_needed(no_outliers, args, saving_opts, out_paths,
                            'intermediate', 'no_outliers', in_label, out_label)

            if not args.no_remove_loops_again:
                no_qb_loops_strl, loops2 = remove_loops_and_sharp_turns(
                    no_outliers, args.loop_max_angle, True,
                    args.loop_qb_distance)
                _save_if_needed(loops2, args, saving_opts, out_paths,
                                'discarded', 'qb_loops', in_label, out_label)
            else:
                no_qb_loops_strl = no_outliers

            _save_if_needed(no_qb_loops_strl, args, saving_opts, out_paths,
                            'final', 'final', in_label, out_label)

            # TODO for other metrics
            # This would be where this is modified and the value
            # is computed (eg: mean FA in the connection.
            con_mat[in_label, out_label] += len(no_qb_loops_strl)

    time2 = time.time()
    logging.info('    Connection post-processing and saving took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Remove first line and column, since they are index 0 and
    # would represent a connection to non-label voxels. Only used when
    # post-processing to avoid unnecessary -1 on labels for each access.
    con_mat = con_mat[1:, 1:]
    np.save(os.path.join(args.output, 'final_matrix.npy'), con_mat)
Esempio n. 7
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_inputs_exist(parser, args.in_wmparc)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_path,
                                       create_dir=True)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    if args.angle <= 0:
        parser.error('Angle "{}" '.format(args.angle) +
                     'must be greater than or equal to 0')
    if args.ctx_dilation_radius < 0:
        parser.error(
            'Cortex dilation radius "{}" '.format(args.ctx_dilation_radius) +
            'must be greater than 0')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    img_wmparc = nib.load(args.in_wmparc)
    if not is_header_compatible(img_wmparc, sft):
        parser.error('Headers from the tractogram and the wmparc are '
                     'not compatible.')
    if args.csf_bin:
        img_csf = nib.load(args.csf_bin)
        if not is_header_compatible(img_csf, sft):
            parser.error('Headers from the tractogram and the CSF mask are '
                         'not compatible.')

    if args.minL == 0 and np.isinf(args.maxL):
        logging.debug("You have not specified minL nor maxL. Output will "
                      "not be filtered according to length!")
    if np.isinf(args.angle):
        logging.debug("You have not specified the angle. Loops will "
                      "not be filtered!")
    if args.ctx_dilation_radius == 0:
        logging.debug("You have not specified the cortex dilation radius. "
                      "The wmparc atlas will not be dilated!")

    o_dict = {}
    step_dict = ['length', 'no_loops', 'no_end_csf', 'end_in_atlas']
    wm_labels = load_wmparc_labels()

    in_sft_name = os.path.splitext(os.path.basename(args.in_tractogram))[0]
    out_sft_rootname = in_sft_name + "_filtered"
    _, ext = os.path.splitext(args.in_tractogram)
    out_sft_name = os.path.join(args.out_path,
                                out_sft_rootname + "_filtered" + ext)

    # STEP 1 - Filter length
    step = step_dict[0]
    steps_combined = step

    new_sft = filter_streamlines_by_length(sft, args.minL, args.maxL)

    # Streamline count before and after filtering lengths
    o_dict[in_sft_name + ext] =\
        dict({'streamline_count': len(sft.streamlines)})
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')
        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 2 - Filter loops
    step = step_dict[1]
    steps_combined += "_" + step

    ids_c = remove_loops_and_sharp_turns(sft.streamlines, args.angle)
    new_sft = filter_tractogram_data(sft, ids_c)

    # Streamline count after filtering loops
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 3 - Filter CSF endings
    step = step_dict[2]
    steps_combined += "_" + step

    # Mask creation
    if args.csf_bin:
        mask = get_data_as_mask(img_csf)
    else:
        atlas = get_data_as_label(img_wmparc)
        mask = binarize_labels(atlas, wm_labels["csf_labels"])

    # Filter tractogram
    new_sft, _ = filter_grid_roi(sft, mask, 'any', True)

    # Streamline count after filtering CSF endings
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_volumes:
        new_path = create_dir(args.out_path, step)
        if not args.csf_bin:
            nib.save(
                nib.Nifti1Image(mask, img_wmparc.affine, img_wmparc.header),
                os.path.join(new_path, 'csf_bin' + '.nii.gz'))

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 4 - Filter WM endings
    step = step_dict[3]
    steps_combined += "_" + step

    # Mask creation
    ctx_fs_labels = wm_labels["ctx_lh_fs_labels"] + \
        wm_labels["ctx_rh_fs_labels"]
    vox_size = np.reshape(img_wmparc.header.get_zooms(), (1, 3))
    atlas_wm = get_data_as_label(img_wmparc)
    atlas_shape = atlas_wm.shape
    wmparc_ctx = binarize_labels(atlas_wm, ctx_fs_labels)
    wmparc_nuclei = binarize_labels(atlas_wm, wm_labels["nuclei_fs_labels"])

    # Dilation of cortex
    if args.ctx_dilation_radius:
        ctx_mask = dilate_mask(wmparc_ctx, atlas_shape, vox_size,
                               args.ctx_dilation_radius)
    else:
        ctx_mask = wmparc_ctx

    freesurfer_mask = np.zeros(atlas_shape, dtype=np.uint16)
    freesurfer_mask[np.logical_or(wmparc_nuclei, ctx_mask)] = 1

    # Filter tractogram
    new_sft, _ = filter_grid_roi(sft, freesurfer_mask, 'both_ends', False)

    # Streamline count after final filtering
    o_dict[out_sft_rootname + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_volumes:
        new_path = create_dir(args.out_path, step)
        nib.save(
            nib.Nifti1Image(freesurfer_mask, img_wmparc.affine,
                            img_wmparc.header),
            os.path.join(new_path, 'atlas_bin' + '.nii.gz'))

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    # Finish filtering
    if args.verbose:
        display_count(o_dict, args.indent, args.sort_keys)

    if args.save_counts:
        save_count(o_dict, args.out_path, args.indent, args.sort_keys)

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")
            return
        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

    sft = new_sft
    save_tractogram(sft, out_sft_name)
def main():
    parser = build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.labels])

    if os.path.abspath(args.output_dir) == os.getcwd():
        parser.error('Do not use the current path as output directory.')

    assert_output_dirs_exist_and_empty(parser, args, args.output_dir,
                                       create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.labels)
    data_labels = img_labels.get_fdata().astype(np.int16)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    if not np.issubdtype(img_labels.get_data_dtype().type, np.integer):
        parser.error("Label image should contain integers for labels.")

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.mean(vox_sizes) == vox_sizes[0]:
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    logging.info('*** Filtering streamlines ***')
    data_mask = np.zeros(data_labels.shape)
    data_mask[data_labels > 0] = 1

    original_len = len(sft)
    time1 = time.time()

    sft.to_vox()
    sft.to_corner()
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info(
        '    Discarded {} streamlines from filtering in {} sec.'.format(
            original_len - len(sft), round(time2 - time1, 2)))
    logging.info('    Number of streamlines to process: {}'.format(len(sft)))

    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices,
                                    data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    for in_label, out_label in comb_list:
        if iteration_counter > 0 and iteration_counter % 100 == 0:
            logging.info('Split {} nodes out of {}'.format(iteration_counter,
                                                           len(comb_list)))
        iteration_counter += 1

        pair_info = []
        if in_label not in con_info:
            continue
        elif out_label in con_info[in_label]:
            pair_info.extend(con_info[in_label][out_label])

        if out_label not in con_info:
            continue
        elif in_label in con_info[out_label]:
            pair_info.extend(con_info[out_label][in_label])

        if not len(pair_info):
            continue

        connecting_streamlines = []
        for connection in pair_info:
            strl_idx = connection['strl_idx']
            curr_streamlines = compute_streamline_segment(
                sft.streamlines[strl_idx],
                indices[strl_idx],
                connection['in_idx'],
                connection['out_idx'],
                points_to_idx[strl_idx])
            connecting_streamlines.append(curr_streamlines)

        _save_if_needed(connecting_streamlines, args, 'raw',
                        'raw', in_label, out_label)

        # Doing all post-processing
        if not args.no_pruning:
            valid_length, invalid_length = _prune_segments(
                connecting_streamlines,
                args.min_length,
                args.max_length,
                vox_sizes[0])

            _save_if_needed(invalid_length, args,
                            'discarded', 'invalid_length',
                            in_label, out_label)
        else:
            valid_length = connecting_streamlines

        if not len(valid_length):
            continue

        _save_if_needed(valid_length, args,
                        'intermediate', 'valid_length', in_label, out_label)

        if not args.no_remove_loops:
            no_loop_ids = remove_loops_and_sharp_turns(valid_length,
                                                       args.loop_max_angle)
            no_loops = [valid_length[i] for i in no_loop_ids]

            loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids)
            loops = [valid_length[i] for i in loop_ids]

            _save_if_needed(loops, args,
                            'discarded', 'loops', in_label, out_label)
        else:
            no_loops = valid_length

        if not len(no_loops):
            continue

        _save_if_needed(no_loops, args,
                        'intermediate', 'no_loops', in_label, out_label)

        if not args.no_remove_outliers:
            inliers, outliers = remove_outliers(no_loops,
                                                args.outlier_threshold)
            _save_if_needed(outliers, args,
                            'discarded', 'outliers', in_label, out_label)
        else:
            inliers = no_loops

        if not len(inliers):
            continue

        _save_if_needed(inliers, args,
                        'intermediate', 'inliers', in_label, out_label)

        if not args.no_remove_curv_dev:
            no_qb_curv_ids = remove_loops_and_sharp_turns(
                inliers,
                args.loop_max_angle,
                use_qb=True,
                qb_threshold=args.curv_qb_distance)
            no_qb_curv = [inliers[i] for i in no_qb_curv_ids]

            qb_curv_ids = np.setdiff1d(
                np.arange(len(inliers)), no_qb_curv_ids)
            qb_curv = [inliers[i] for i in qb_curv_ids]

            _save_if_needed(qb_curv, args,
                            'discarded', 'qb_curv', in_label, out_label)
        else:
            no_qb_curv = inliers

        _save_if_needed(no_qb_curv, args,
                        'final', 'final', in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))