Exemplo n.º 1
0
def compute_dice_streamlines(bundle_1, bundle_2):
    """
    Compute the overlap (dice coefficient) between two bundles.
    Both bundles need to come from the exact same tractogram.
    Parameters
    ----------
    bundle_1: list of ndarray
        First set of streamlines.
    bundle_2: list of ndarray
        Second set of streamlines.
    Returns
    -------
    A tuple containing
        float: Value between 0 and 1 that represent the spatial aggrement
            between both bundles.
        list of ndarray: intersection_robust of streamlines in both bundle
        list of ndarray: union_robust of streamlines in both bundle
    """
    streamlines_intersect, _ = intersection_robust([bundle_1, bundle_2])
    streamlines_union_robust, _ = union_robust([bundle_1, bundle_2])

    numerator = 2 * len(streamlines_intersect)
    denominator = len(bundle_1) + len(bundle_2)
    if denominator > 0:
        dice = numerator / float(denominator)
    else:
        dice = np.nan

    return dice, streamlines_intersect, streamlines_union_robust
Exemplo n.º 2
0
def compute_streamlines_measures(args):
    bundle_filename, bundle_reference = args[0]
    wb_streamlines = args[1]
    gs_streamlines_indices = args[2]

    if not os.path.isfile(bundle_filename):
        logging.info('{} does not exist'.format(bundle_filename))
        return None

    bundle_sft = load_tractogram(bundle_filename, bundle_reference)
    bundle_sft.to_vox()
    bundle_sft.to_corner()
    bundle_streamlines = bundle_sft.streamlines
    _, bundle_dimensions, _, _ = bundle_sft.space_attributes

    if not bundle_streamlines:
        logging.info('{} is empty'.format(bundle_filename))
        return None

    _, streamlines_indices = intersection_robust(
        [wb_streamlines, bundle_streamlines])

    streamlines_binary = binary_classification(streamlines_indices,
                                               gs_streamlines_indices,
                                               len(wb_streamlines))

    return dict(
        zip([
            'sensitivity_streamlines', 'specificity_streamlines',
            'precision_streamlines', 'accuracy_streamlines',
            'dice_streamlines', 'kappa_streamlines', 'youden_streamlines'
        ], streamlines_binary))
Exemplo n.º 3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]
    sft_list = []
    for name in args.in_bundles:
        tmp_sft = load_tractogram_with_reference(parser, args, name)
        tmp_sft.to_vox()
        tmp_sft.to_corner()

        if not is_header_compatible(reference_file, tmp_sft):
            raise ValueError('Headers are not compatible.')
        sft_list.append(tmp_sft)
        fusion_streamlines.append(tmp_sft.streamlines)

    fusion_streamlines, _ = union_robust(fusion_streamlines)

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i in range(len(args.in_bundles)):
        sft = sft_list[i]
        binary = compute_tract_counts_map(sft.streamlines, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = intersection_robust(
                [fusion_streamlines, sft.streamlines])
            streamlines_vote[list(indices), [i]] += 1

    if args.same_tractogram:
        real_indices = []
        ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
        real_indices = np.where(
            np.sum(streamlines_vote, axis=1) >= ratio_value)[0]
        new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices],
                                              sft_list[0])
        save_tractogram(new_sft, output_streamlines_filename)

    volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Exemplo n.º 4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, args.out_json)

    if (not args.streamlines_measures) and (not args.voxels_measures):
        parser.error('At least one of the two modes is needed')

    nbr_cpu = validate_nbr_processes(parser, args)

    all_binary_metrics = []
    bundles_references_tuple_extended = link_bundles_and_reference(
        parser, args, args.in_bundles)

    if args.streamlines_measures:
        # Gold standard related indices are computed once
        wb_sft = load_tractogram_with_reference(parser, args,
                                                args.streamlines_measures[1])
        wb_sft.to_vox()
        wb_sft.to_corner()
        wb_streamlines = wb_sft.streamlines

        gs_sft = load_tractogram_with_reference(parser, args,
                                                args.streamlines_measures[0])
        gs_sft.to_vox()
        gs_sft.to_corner()
        gs_streamlines = gs_sft.streamlines
        _, gs_dimensions, _, _ = gs_sft.space_attributes

        # Prepare the gold standard only once
        _, gs_streamlines_indices = intersection_robust(
            [wb_streamlines, gs_streamlines])

        if nbr_cpu == 1:
            streamlines_dict = []
            for i in bundles_references_tuple_extended:
                streamlines_dict.append(
                    compute_streamlines_measures(
                        [i, wb_streamlines, gs_streamlines_indices]))
        else:
            pool = multiprocessing.Pool(nbr_cpu)
            streamlines_dict = pool.map(
                compute_streamlines_measures,
                zip(bundles_references_tuple_extended,
                    itertools.repeat(wb_streamlines),
                    itertools.repeat(gs_streamlines_indices)))
            pool.close()
            pool.join()
        all_binary_metrics.extend(streamlines_dict)

    if not args.voxels_measures:
        gs_binary_3d = compute_tract_counts_map(gs_streamlines, gs_dimensions)
        gs_binary_3d[gs_binary_3d > 0] = 1

        tracking_mask_data = compute_tract_counts_map(wb_streamlines,
                                                      gs_dimensions)
        tracking_mask_data[tracking_mask_data > 0] = 1
    else:
        gs_binary_3d = get_data_as_mask(nib.load(args.voxels_measures[0]))
        gs_binary_3d[gs_binary_3d > 0] = 1
        tracking_mask_data = get_data_as_mask(nib.load(
            args.voxels_measures[1]))
        tracking_mask_data[tracking_mask_data > 0] = 1

    if nbr_cpu == 1:
        voxels_dict = []
        for i in bundles_references_tuple_extended:
            voxels_dict.append(
                compute_voxel_measures([i, tracking_mask_data, gs_binary_3d]))
    else:
        pool = multiprocessing.Pool(nbr_cpu)
        voxels_dict = pool.map(
            compute_voxel_measures,
            zip(bundles_references_tuple_extended,
                itertools.repeat(tracking_mask_data),
                itertools.repeat(gs_binary_3d)))
        pool.close()
        pool.join()
    all_binary_metrics.extend(voxels_dict)

    # After all processing, write the json file and skip None value
    output_binary_dict = {}
    for binary_dict in all_binary_metrics:
        if binary_dict is not None:
            for measure_name in binary_dict.keys():
                if measure_name not in output_binary_dict:
                    output_binary_dict[measure_name] = []
                output_binary_dict[measure_name].append(
                    float(binary_dict[measure_name]))

    with open(args.out_json, 'w') as outfile:
        json.dump(output_binary_dict,
                  outfile,
                  indent=args.indent,
                  sort_keys=args.sort_keys)