Example #1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.bundle1, args.bundle2, args.reference])

    ref_img = nib.load(args.reference)
    ref_shape = ref_img.header.get_data_shape()
    bundle1_streamlines_vox = load_in_voxel_space(args.bundle1, ref_img)
    bundle2_streamlines_vox = load_in_voxel_space(args.bundle2, ref_img)

    tract_count_map1 = compute_robust_tract_counts_map(bundle1_streamlines_vox,
                                                       ref_shape)
    tract_count_map2 = compute_robust_tract_counts_map(bundle2_streamlines_vox,
                                                       ref_shape)

    if not args.weighted:
        tract_count_map1 = tract_count_map1 > 0
        tract_count_map2 = tract_count_map2 > 0

    dice_coef =\
        _compute_dice(tract_count_map1,
                      tract_count_map2) if np.any(tract_count_map2) else 0.0

    if dice_coef > 1.0:
        dice_coef = 1.0

    print(json.dumps({'dice': dice_coef}, indent=args.indent))
Example #2
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser, [args.bundle, args.centroid_streamline, args.reference])
    assert_outputs_exists(parser, args, [args.output_map])

    bundle_tractogram_file = nib.streamlines.load(args.bundle)
    centroid_tractogram_file = nib.streamlines.load(args.centroid_streamline)
    if int(bundle_tractogram_file.header['nb_streamlines']) == 0:
        logger.warning('Empty bundle file {}. Skipping'.format(args.bundle))
        return

    if int(centroid_tractogram_file.header['nb_streamlines']) != 1:
        logger.warning('Centroid file {} should contain one streamline. '
                       'Skipping'.format(args.centroid_streamline))
        return

    ref_img = nib.load(args.reference)
    bundle_streamlines_vox = load_in_voxel_space(bundle_tractogram_file,
                                                 ref_img)
    bundle_streamlines_vox._data *= args.upsample

    number_of_centroid_points = len(centroid_tractogram_file.streamlines[0])
    if number_of_centroid_points > 99:
        raise Exception('Invalid number of points in the centroid. You should '
                        'have a maximum of 99 points in your centroid '
                        'streamline. '
                        'Current is {}'.format(number_of_centroid_points))

    centroid_streamlines_vox = load_in_voxel_space(centroid_tractogram_file,
                                                   ref_img)
    centroid_streamlines_vox._data *= args.upsample

    upsampled_shape = [s * args.upsample for s in ref_img.shape]
    tdi_mask = compute_robust_tract_counts_map(bundle_streamlines_vox,
                                               upsampled_shape) > 0

    tdi_mask_nzr = np.nonzero(tdi_mask)
    tdi_mask_nzr_ind = np.transpose(tdi_mask_nzr)

    min_dist_ind, _ = min_dist_to_centroid(tdi_mask_nzr_ind,
                                           centroid_streamlines_vox[0])

    # Save the (upscaled) labels mask
    labels_mask = np.zeros(tdi_mask.shape)
    labels_mask[tdi_mask_nzr] = min_dist_ind + 1  # 0 is background value
    rescaled_affine = ref_img.affine
    rescaled_affine[:3, :3] /= args.upsample
    labels_img = nib.Nifti1Image(labels_mask, rescaled_affine)
    upsampled_spacing = ref_img.header['pixdim'][1:4] / args.upsample
    labels_img.header.set_zooms(upsampled_spacing)
    nib.save(labels_img, args.output_map)
Example #3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.bundle, args.reference])
    assert_outputs_exists(parser, args, [args.endpoints_map])

    bundle_tractogram_file = nib.streamlines.load(args.bundle)
    if int(bundle_tractogram_file.header['nb_streamlines']) == 0:
        logging.warning('Empty bundle file {}. Skipping'.format(args.bundle))
        return

    reference = nib.load(args.reference)
    bundle_streamlines_vox = load_in_voxel_space(bundle_tractogram_file,
                                                 reference)
    endpoints_map = np.zeros(reference.shape)

    for streamline in bundle_streamlines_vox:
        xyz = streamline[0, :].astype(int)
        endpoints_map[xyz[0], xyz[1], xyz[2]] += 1
        xyz = streamline[-1, :].astype(int)
        endpoints_map[xyz[0], xyz[1], xyz[2]] += 1

    nib.save(
        nib.Nifti1Image(endpoints_map, reference.affine, reference.header),
        args.endpoints_map)

    bundle_name, _ = os.path.splitext(os.path.basename(args.bundle))
    stats = {bundle_name: {'count': np.count_nonzero(endpoints_map)}}

    print(json.dumps(stats, indent=args.indent))
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.bundle] + args.metrics)
    if args.num_points <= 1:
        parser.error('--num_points {} needs to be greater than '
                     '1'.format(args.num_points))

    metrics = [nib.load(m) for m in args.metrics]
    assert_same_resolution(*metrics)

    bundle_tractogram_file = nib.streamlines.load(args.bundle)

    bundle_name, _ = os.path.splitext(os.path.basename(args.bundle))
    stats = {}
    if len(bundle_tractogram_file.streamlines) == 0:
        stats[bundle_name] = None
        print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
        return

    bundle_streamlines_vox = load_in_voxel_space(bundle_tractogram_file,
                                                 metrics[0])
    bundle_subsampled = subsample_streamlines(bundle_streamlines_vox,
                                              num_points=args.num_points,
                                              arc_length=True)

    # Make sure all streamlines go in the same direction. We want to make
    # sure point #1 / 20 of streamline A is matched with point #1 / 20 of
    # streamline B and so on
    num_streamlines = len(bundle_subsampled)
    reference = bundle_subsampled[0]
    for s in np.arange(num_streamlines):
        streamline = bundle_subsampled[s]
        direct = average_euclidean(reference, streamline)
        flipped = average_euclidean(reference, streamline[::-1])

        if flipped < direct:
            bundle_subsampled[s] = streamline[::-1]

    profiles = get_metrics_profile_over_streamlines(bundle_subsampled, metrics)
    t_profiles = np.expand_dims(profiles, axis=1)
    t_profiles = np.rollaxis(t_profiles, 3, 2)

    stats[bundle_name] = {}
    for metric, profile, t_profile in zip(metrics, profiles, t_profiles):
        metric_name, _ = split_name_with_nii(
            os.path.basename(metric.get_filename()))
        stats[bundle_name][metric_name] = {
            'mean': np.mean(profile, axis=0).tolist(),
            'std': np.std(profile, axis=0).tolist(),
            'tractprofile': t_profile.tolist()
        }

    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
Example #5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.bundle] + args.metrics)

    metrics = [nib.load(metric) for metric in args.metrics]
    assert_same_resolution(*metrics)
    streamlines_vox = load_in_voxel_space(args.bundle, metrics[0])
    bundle_stats = get_metrics_stats_over_streamlines_robust(
        streamlines_vox, metrics, args.density_weighting)

    bundle_name, _ = os.path.splitext(os.path.basename(args.bundle))

    stats = {bundle_name: {}}
    for metric, (mean, std) in zip(metrics, bundle_stats):
        metric_name = split_name_with_nii(
            os.path.basename(metric.get_filename()))[0]
        stats[bundle_name][metric_name] = {'mean': mean, 'std': std}

    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.bundle] + args.metrics)
    assert_outputs_dir_exists_and_empty(parser, args, args.output_folder)

    metrics = [nib.load(metric) for metric in args.metrics]
    assert_same_resolution(*metrics)

    bundle_tractogram_file = nib.streamlines.load(args.bundle)
    if int(bundle_tractogram_file.header['nb_streamlines']) == 0:
        logging.warning('Empty bundle file {}. Skipping'.format(args.bundle))
        return
    bundle_streamlines_vox = load_in_voxel_space(bundle_tractogram_file,
                                                 metrics[0])

    for metric in metrics:
        data = metric.get_data()
        endpoint_metric_map = np.zeros(metric.shape)
        count = np.zeros(metric.shape)
        for streamline in bundle_streamlines_vox:
            streamline_mean = _compute_streamline_mean(streamline, data)

            xyz = streamline[0, :].astype(int)
            endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean
            count[xyz[0], xyz[1], xyz[2]] += 1

            xyz = streamline[-1, :].astype(int)
            endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean
            count[xyz[0], xyz[1], xyz[2]] += 1

        endpoint_metric_map[count != 0] /= count[count != 0]
        metric_fname, ext = split_name_with_nii(
            os.path.basename(metric.get_filename()))
        nib.save(
            nib.Nifti1Image(endpoint_metric_map, metric.affine, metric.header),
            os.path.join(args.output_folder,
                         '{}_endpoints_metric{}'.format(metric_fname, ext)))
Example #7
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser,
                        [args.bundle, args.label_map, args.distance_map] +
                        args.metrics)

    bundle_tractogram_file = nib.streamlines.load(args.bundle)

    stats = {}
    bundle_name, _ = os.path.splitext(os.path.basename(args.bundle))
    if len(bundle_tractogram_file.streamlines) == 0:
        stats[bundle_name] = None
        print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
        return

    metrics = [nib.load(m) for m in args.metrics]
    assert_same_resolution(*metrics)
    streamlines_vox = load_in_voxel_space(bundle_tractogram_file, metrics[0])

    if args.density_weighting:
        track_count = compute_robust_tract_counts_map(
            streamlines_vox, metrics[0].shape).astype(np.float64)
    else:
        track_count = np.ones(metrics[0].shape)

    label_file = np.load(args.label_map)
    labels = label_file['arr_0']

    distance_file = np.load(args.distance_map)
    distances_to_centroid_streamline = distance_file['arr_0']
    # Bigger weight near the centroid streamline
    distances_to_centroid_streamline = 1.0 / distances_to_centroid_streamline

    if len(labels) != len(distances_to_centroid_streamline):
        raise Exception('Label map doesn\'t contain the same number of '
                        'entries as the distance map. {} != {}'.format(
                            len(labels),
                            len(distances_to_centroid_streamline)))

    bundle_data_int = streamlines_vox.data.astype(np.int)
    stats[bundle_name] = {}
    for metric in metrics:
        metric_data = metric.get_data()
        current_metric_fname, _ = split_name_with_nii(
            os.path.basename(metric.get_filename()))
        stats[bundle_name][current_metric_fname] = {}

        for i in np.unique(labels):
            number_key = '{:02}'.format(i)
            label_stats = {}
            stats[bundle_name][current_metric_fname][number_key] = label_stats

            label_indices = bundle_data_int[labels == i]
            label_metric = metric_data[label_indices[:, 0],
                                       label_indices[:, 1], label_indices[:,
                                                                          2]]
            track_weight = track_count[label_indices[:, 0],
                                       label_indices[:, 1], label_indices[:,
                                                                          2]]
            label_weight = track_weight
            if args.distance_weighting:
                label_weight *= distances_to_centroid_streamline[labels == i]
            if np.sum(label_weight) == 0:
                logger.warning('Weights sum to zero, can\'t be normalized. '
                               'Disabling weighting')
                label_weight = None

            label_mean = np.average(label_metric, weights=label_weight)
            label_std = np.sqrt(
                np.average((label_metric - label_mean)**2,
                           weights=label_weight))
            label_stats['mean'] = float(label_mean)
            label_stats['std'] = float(label_std)

    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
Example #8
0
def load_tracts_over_grid_transition(tract_fname,
                                     ref_anat_fname,
                                     start_at_corner=True,
                                     tract_producer=None):
    tracts_format = tc.detect_format(tract_fname)
    tracts_file = tracts_format(tract_fname)

    # TODO
    if isinstance(tracts_file, tc.formats.vtk.VTK):
        raise IOError('VTK tracts not currently supported')

    # Load tracts
    if isinstance(tracts_file, tc.formats.tck.TCK):
        # Get information on the supporting anatomy
        ref_img = nb.load(ref_anat_fname)
        index_to_world_affine = ref_img.get_header().get_best_affine()

        # Transposed for efficient computations later on.
        index_to_world_affine = index_to_world_affine.T.astype('<f4')
        world_to_index_affine = linalg.inv(index_to_world_affine)

        if start_at_corner:
            shift = 0.5
        else:
            shift = 0.0

        strls = []

        for s in tracts_file:
            # We use c_ to easily transform the 3D streamline to a
            # 4D object to allow using the dot product with uniform coordinates.
            # Basically, this adds a 1 at the end of each point, to be able to
            # directly perform the dot product.
            transformed_s = np.dot(
                c_[s, np.ones([s.shape[0], 1], dtype='<f4')],
                world_to_index_affine)[:, :-1] + shift
            strls.append(transformed_s)

        return strls
    elif isinstance(tracts_file, tc.formats.trk.TRK):
        if tract_producer is None:
            raise ValueError('Cannot robustly load TRKs without the '
                             'tract_producer argument.')

        streamlines = load_in_voxel_space(tract_fname, ref_anat_fname)

        # The previous call returns the streamlines in voxel space,
        # corner-aligned. Check if we need to shift them back.

        # Producer: scilpy means that streamlines respect the nifti standard
        # Producer: trackvis means that (0,0,0) is the corner of the voxel
        if start_at_corner:
            if tract_producer == "scilpy":
                shift = 0.5
            elif tract_producer == "trackvis":
                shift = 0.0
        else:
            if tract_producer == "scilpy":
                shift = 0.0
            elif tract_producer == "trackvis":
                shift = -0.5

        streamlines._data += shift

        return streamlines