Ejemplo n.º 1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_matrices)
    assert_outputs_exist(parser, args, args.out_json)

    all_matrices = []
    for filename in args.in_matrices:
        tmp_mat = load_matrix_in_any_format(filename)
        tmp_mat = tmp_mat.astype(float)
        tmp_mat -= np.min(tmp_mat)
        if args.normalize:
            all_matrices.append(tmp_mat / np.max(tmp_mat))
        else:
            all_matrices.append(tmp_mat)

    if args.single_compare:
        tmp_mat = load_matrix_in_any_format(args.single_compare)
        tmp_mat = tmp_mat.astype(float)
        tmp_mat -= np.min(tmp_mat)
        if args.normalize:
            all_matrices.append(tmp_mat / np.max(tmp_mat))
        else:
            all_matrices.append(tmp_mat)

    output_measures_dict = {
        'SSD': [],
        'correlation': [],
        'w_dice_voxels': [],
        'dice_voxels': []
    }

    if args.single_compare:
        if args.single_compare in args.in_matrices:
            id = args.in_matrices.index(args.single_compare)
            all_matrices.pop(id)
        pairs = list(itertools.product(all_matrices[:-1], [all_matrices[-1]]))
    else:
        pairs = list(itertools.combinations(all_matrices, r=2))

    for i in pairs:
        ssd = np.sum((i[0] - i[1])**2)
        output_measures_dict['SSD'].append(ssd)
        corrcoef = np.corrcoef(i[0].ravel(), i[1].ravel())
        output_measures_dict['correlation'].append(corrcoef[0][1])
        dice, w_dice = compute_dice_voxel(i[0], i[1])
        output_measures_dict['dice_voxels'].append(dice)
        output_measures_dict['w_dice_voxels'].append(w_dice)

    with open(args.out_json, 'w') as outfile:
        json.dump(output_measures_dict,
                  outfile,
                  indent=args.indent,
                  sort_keys=args.sort_keys)
Ejemplo n.º 2
0
def compute_dice_overlap_overreach(current_vb_voxels, gt_mask, dimensions):
    """
    Compute dice, OL and OR based on a ground truth mask.

    Parameters
    ------
    current_vb_voxels: 3D array
        The voxels touched by at least one streamlines for a given bundle.
    gt_mask: 3D array
        The ground truth mask.
    dimensions: array
        The nibabel dimensions of the data (3D).

    Returns
    -------
    dice: float
        The dice score
    overlap: int
        The overlap (in number of voxels, not as percentages).
    overreach: int
        The overreach (in number of voxels).
    lacking: int
        The number of voxels from the gt_mask that have not been recovered.
    """
    # Dice
    dice = compute_dice_voxel(gt_mask, current_vb_voxels)[0]

    # Overlap and overreach
    overlap_mask = gt_mask * current_vb_voxels
    overreach_mask = np.zeros(dimensions)
    overreach_mask[np.where((gt_mask == 0) & (current_vb_voxels >= 1))] = 1

    bundle_lacking = np.zeros(dimensions)
    bundle_lacking[np.where((gt_mask == 1) & (current_vb_voxels == 0))] = 1

    overlap = np.count_nonzero(overlap_mask)
    overreach = np.count_nonzero(overreach_mask)
    lacking = np.count_nonzero(bundle_lacking)

    return dice, overlap, overreach, lacking
Ejemplo n.º 3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram] + args.gt_bundles)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    if (args.gt_tails and not args.gt_heads) \
            or (args.gt_heads and not args.gt_tails):
        parser.error("Both --gt_heads and --gt_tails are needed.")
    if args.gt_endpoints and (args.gt_tails or args.gt_heads):
        parser.error("Can only provide --gt_endpoints or --gt_tails/gt_heads")
    if not args.gt_endpoints and (not args.gt_tails and not args.gt_heads):
        parser.error(
            "Either input --gt_endpoints or --gt_heads and --gt_tails.")

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    _, ext = os.path.splitext(args.in_tractogram)
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)

    if args.remove_invalid:
        sft.remove_invalid_streamlines()

    initial_count = len(sft)

    logging.info("Verifying compatibility with ground-truth")
    for gt in args.gt_bundles:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    logging.info("Computing ground-truth masks")
    gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions,  = \
        compute_gt_masks(args.gt_bundles, parser, args)

    # If endpoints without heads/tails are loaded, split them and continue
    # normally after. Q/C of the output is important
    if args.gt_endpoints:
        logging.info("Extracting ground-truth end and tail masks")
        gt_tails, gt_heads, affine, dimensions = \
            extract_tails_heads_from_endpoints(
                args.gt_endpoints, args.out_dir)
    else:
        gt_tails, gt_heads = args.gt_tails, args.gt_heads

    logging.info("Verifying compatibility with endpoints")
    for gt in gt_tails + gt_heads:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    # Load the endpoints heads/tails, keep the correct combinations
    # separately from all the possible combinations
    tc_filenames = list(zip(gt_tails, gt_heads))

    length_dict = {}
    if args.gt_config:
        with open(args.gt_config, "r") as json_file:
            length_dict = json.load(json_file)

    tc_streamlines_list = []
    wpc_streamlines_list = []
    fc_streamlines_list = []
    nc_streamlines = []

    logging.info("Scoring true connections")
    for i, (mask_1_filename, mask_2_filename) in enumerate(tc_filenames):

        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)

        tc_sft, wpc_sft, fc_sft, nc, sft = extract_true_connections(
            sft, mask_1_filename, mask_2_filename, args.gt_config, length_dict,
            extract_prefix(args.gt_bundles[i]), gt_bundle_inv_masks[i],
            args.dilate_endpoints, args.wrong_path_as_separate)
        nc_streamlines.extend(nc)
        if len(tc_sft) > 0:
            save_tractogram(tc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_tc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(wpc_sft) > 0:
            save_tractogram(wpc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_wpc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        tc_streamlines_list.append(tc_sft.streamlines)
        wpc_streamlines_list.append(wpc_sft.streamlines)
        fc_streamlines_list.append(fc_sft.streamlines)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(tc_sft.streamlines) + len(wpc_sft.streamlines) +
            len(fc_sft.streamlines) + len(nc), prefix_1, prefix_2))

    # Again keep the keep the correct combinations
    comb_filename = list(
        itertools.combinations(itertools.chain(*zip(gt_tails, gt_heads)), r=2))

    # Remove the true connections from all combinations, leaving only
    # false connections
    for tc_f in tc_filenames:
        comb_filename.remove(tc_f)

    logging.info("Scoring false connections")
    # Go through all the possible combinations of endpoints masks
    for i, roi in enumerate(comb_filename):
        mask_1_filename, mask_2_filename = roi

        # That would be done here.
        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)
        _, ext = os.path.splitext(args.in_tractogram)

        fc_sft, sft = extract_false_connections(sft, mask_1_filename,
                                                mask_2_filename,
                                                args.dilate_endpoints)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(fc_sft.streamlines), prefix_1, prefix_2))

        fc_streamlines_list.append(fc_sft.streamlines)

    nc_streamlines.extend(sft.streamlines)

    final_results = {}
    no_conn_sft = StatefulTractogram.from_sft(nc_streamlines, sft)
    save_tractogram(no_conn_sft,
                    os.path.join(args.out_dir, "nc{}".format(ext)),
                    bbox_valid_check=False)

    # Total number of streamlines for each category
    # and statistic that are not "bundle-wise"
    tc_streamlines_count = len(list(itertools.chain(*tc_streamlines_list)))
    fc_streamlines_count = len(list(itertools.chain(*fc_streamlines_list)))

    if args.wrong_path_as_separate:
        wpc_streamlines_count = len(
            list(itertools.chain(*wpc_streamlines_list)))
    else:
        wpc_streamlines_count = 0

    nc_streamlines_count = len(nc_streamlines)
    total_count = tc_streamlines_count + fc_streamlines_count + \
        wpc_streamlines_count + nc_streamlines_count

    assert total_count == initial_count

    final_results["tractogram_filename"] = str(args.in_tractogram)
    final_results["tractogram_overlap"] = 0.0
    final_results["tc_streamlines"] = tc_streamlines_count
    final_results["fc_streamlines"] = fc_streamlines_count
    final_results["nc_streamlines"] = nc_streamlines_count

    final_results["tc_bundle"] = len([x for x in tc_streamlines_list if x])
    final_results["fc_bundle"] = len([x for x in fc_streamlines_list if x])

    final_results["tc_streamlines_ratio"] = tc_streamlines_count / total_count
    final_results["fc_streamlines_ratio"] = fc_streamlines_count / total_count
    final_results["nc_streamlines_ratio"] = nc_streamlines_count / total_count

    if args.wrong_path_as_separate:
        final_results["wpc_streamlines"] = wpc_streamlines_count
        final_results["wpc_streamlines_ratio"] = \
            wpc_streamlines_count / total_count
        final_results["wpc_bundle"] = len(
            [x for x in wpc_streamlines_list if x])

    final_results["total_streamlines"] = total_count
    final_results["bundle_wise"] = {}
    final_results["bundle_wise"]["true_connections"] = {}
    final_results["bundle_wise"]["false_connections"] = {}
    tractogram_overlap = 0.0

    for i, filename in enumerate(tc_filenames):
        current_tc_streamlines = tc_streamlines_list[i]
        current_tc_voxels, current_tc_endpoints_voxels = get_binary_maps(
            current_tc_streamlines, sft)

        if args.wrong_path_as_separate:
            current_wpc_streamlines = wpc_streamlines_list[i]
            current_wpc_voxels, _ = get_binary_maps(current_wpc_streamlines,
                                                    sft)

        tmp_dict = {}
        tmp_dict["tc_streamlines"] = len(current_tc_streamlines)

        tmp_dict["tc_dice"] = compute_dice_voxel(gt_bundle_masks[i],
                                                 current_tc_voxels)[0]

        bundle_overlap = gt_bundle_masks[i] * current_tc_voxels
        bundle_overreach = np.zeros(dimensions)
        bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                  & (current_tc_voxels >= 1))] = 1
        bundle_lacking = np.zeros(dimensions)
        bundle_lacking[np.where((gt_bundle_masks[i] == 1)
                                & (current_tc_voxels == 0))] = 1

        if args.wrong_path_as_separate:
            tmp_dict["wpc_streamlines"] = len(current_wpc_streamlines)
            tmp_dict["wpc_dice"] = \
                compute_dice_voxel(gt_bundle_masks[i],
                                   current_wpc_voxels)[0]
            # Add wrong path to overreach
            bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                      & (current_wpc_voxels >= 1))] = 1

        tmp_dict["tc_bundle_overlap"] = np.count_nonzero(bundle_overlap)
        tmp_dict["tc_bundle_overreach"] = \
            np.count_nonzero(bundle_overreach)
        tmp_dict["tc_bundle_lacking"] = np.count_nonzero(bundle_lacking)
        tmp_dict["tc_bundle_overlap_PCT"] = \
            tmp_dict["tc_bundle_overlap"] / \
            (tmp_dict["tc_bundle_overlap"] +
                tmp_dict["tc_bundle_lacking"])
        tractogram_overlap += tmp_dict["tc_bundle_overlap_PCT"]

        endpoints_overlap = \
            gt_bundle_masks[i] * current_tc_endpoints_voxels
        endpoints_overreach = np.zeros(dimensions)
        endpoints_overreach[np.where((gt_bundle_masks[i] == 0)
                                     & (current_tc_endpoints_voxels >= 1))] = 1
        tmp_dict["tc_endpoints_overlap"] = np.count_nonzero(endpoints_overlap)
        tmp_dict["tc_endpoints_overreach"] = np.count_nonzero(
            endpoints_overreach)

        final_results["bundle_wise"]["true_connections"][str(filename)] = \
            tmp_dict

    # Bundle-wise statistics, useful for more complex phantom
    for i, filename in enumerate(comb_filename):
        current_fc_streamlines = fc_streamlines_list[i]
        current_fc_voxels, _ = get_binary_maps(current_fc_streamlines, sft)

        tmp_dict = {}

        if len(current_fc_streamlines):
            tmp_dict["fc_streamlines"] = len(current_fc_streamlines)
            tmp_dict["fc_voxels"] = np.count_nonzero(current_fc_voxels)

            final_results["bundle_wise"]["false_connections"][str(filename)] =\
                tmp_dict

    final_results["tractogram_overlap"] = \
        tractogram_overlap / len(gt_bundle_masks)

    with open(os.path.join(args.out_dir, "results.json"), "w") as f:
        json.dump(final_results,
                  f,
                  indent=args.indent,
                  sort_keys=args.sort_keys)
Ejemplo n.º 4
0
def compute_all_measures(args):
    tuple_1, tuple_2 = args[0]
    filename_1, reference_1 = tuple_1
    filename_2, reference_2 = tuple_2
    streamline_dice = args[1]
    disable_streamline_distance = args[2]

    if not is_header_compatible(reference_1, reference_2):
        raise ValueError('{} and {} have incompatible headers'.format(
            filename_1, filename_2))

    data_tuple_1 = load_data_tmp_saving([filename_1, reference_1, False,
                                         disable_streamline_distance])
    if data_tuple_1 is None:
        return None

    density_1, endpoints_density_1, bundle_1, \
        centroids_1 = data_tuple_1

    data_tuple_2 = load_data_tmp_saving([filename_2, reference_2, False,
                                         disable_streamline_distance])
    if data_tuple_2 is None:
        return None

    density_2, endpoints_density_2, bundle_2, \
        centroids_2 = data_tuple_2

    _, _, voxel_size, _ = get_reference_info(reference_1)
    voxel_size = np.product(voxel_size)

    # These measures are in mm^3
    binary_1 = copy.copy(density_1)
    binary_1[binary_1 > 0] = 1
    binary_2 = copy.copy(density_2)
    binary_2[binary_2 > 0] = 1
    volume_overlap = np.count_nonzero(binary_1 * binary_2)
    volume_overlap_endpoints = np.count_nonzero(
        endpoints_density_1 * endpoints_density_2)
    volume_overreach = np.abs(np.count_nonzero(
        binary_1 + binary_2) - volume_overlap)
    volume_overreach_endpoints = np.abs(np.count_nonzero(
        endpoints_density_1 + endpoints_density_2) - volume_overlap_endpoints)

    # These measures are in mm
    bundle_adjacency_voxel = compute_bundle_adjacency_voxel(density_1,
                                                            density_2,
                                                            non_overlap=True)
    if streamline_dice and not disable_streamline_distance:
        bundle_adjacency_streamlines = \
            compute_bundle_adjacency_streamlines(bundle_1,
                                                 bundle_2,
                                                 non_overlap=True)
    elif not disable_streamline_distance:
        bundle_adjacency_streamlines = \
            compute_bundle_adjacency_streamlines(bundle_1,
                                                 bundle_2,
                                                 centroids_1=centroids_1,
                                                 centroids_2=centroids_2,
                                                 non_overlap=True)
    # These measures are between 0 and 1
    dice_vox, w_dice_vox = compute_dice_voxel(density_1, density_2)

    dice_vox_endpoints, w_dice_vox_endpoints = compute_dice_voxel(
        endpoints_density_1,
        endpoints_density_2)
    density_correlation = compute_correlation(density_1, density_2)
    density_correlation_endpoints = compute_correlation(endpoints_density_1,
                                                        endpoints_density_2)

    measures_name = ['bundle_adjacency_voxels',
                     'dice_voxels', 'w_dice_voxels',
                     'volume_overlap',
                     'volume_overreach',
                     'dice_voxels_endpoints',
                     'w_dice_voxels_endpoints',
                     'volume_overlap_endpoints',
                     'volume_overreach_endpoints',
                     'density_correlation',
                     'density_correlation_endpoints']
    measures = [bundle_adjacency_voxel,
                dice_vox, w_dice_vox,
                volume_overlap * voxel_size,
                volume_overreach * voxel_size,
                dice_vox_endpoints,
                w_dice_vox_endpoints,
                volume_overlap_endpoints * voxel_size,
                volume_overreach_endpoints * voxel_size,
                density_correlation,
                density_correlation_endpoints]

    if not disable_streamline_distance:
        measures_name += ['bundle_adjacency_streamlines']
        measures += [bundle_adjacency_streamlines]

    # Only when the tractograms are exactly the same
    if streamline_dice:
        dice_streamlines, streamlines_intersect, streamlines_union = \
            compute_dice_streamlines(bundle_1, bundle_2)
        streamlines_count_overlap = len(streamlines_intersect)
        streamlines_count_overreach = len(
            streamlines_union) - len(streamlines_intersect)
        measures_name += ['dice_streamlines',
                          'streamlines_count_overlap',
                          'streamlines_count_overreach']
        measures += [dice_streamlines,
                     streamlines_count_overlap,
                     streamlines_count_overreach]

    return dict(zip(measures_name, measures))