def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_transfo])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if args.verbose:
        log_level = logging.INFO
        logging.basicConfig(level=log_level)

    wb_file = load_tractogram_with_reference(parser, args, args.in_tractogram)
    wb_streamlines = wb_file.streamlines
    model_file = load_tractogram_with_reference(parser, args, args.in_model)

    transfo = load_matrix_in_any_format(args.in_transfo)
    if args.inverse:
        transfo = np.linalg.inv(load_matrix_in_any_format(args.in_transfo))

    before, after = compute_distance_barycenters(wb_file, model_file, transfo)
    if after > before:
        logging.warning('The distance between volumes barycenter should be '
                        'lower after registration. Maybe try using/removing '
                        '--inverse.')
        logging.info('Distance before: {}, Distance after: {}'.format(
            np.round(before, 3), np.round(after, 3)))
    model_streamlines = transform_streamlines(model_file.streamlines, transfo)

    rng = np.random.RandomState(args.seed)
    if args.in_pickle:
        with open(args.in_pickle, 'rb') as infile:
            cluster_map = pickle.load(infile)
        reco_obj = RecoBundles(wb_streamlines,
                               cluster_map=cluster_map,
                               rng=rng,
                               verbose=args.verbose)
    else:
        reco_obj = RecoBundles(wb_streamlines,
                               clust_thr=args.tractogram_clustering_thr,
                               rng=rng,
                               verbose=args.verbose)

    if args.out_pickle:
        with open(args.out_pickle, 'wb') as outfile:
            pickle.dump(reco_obj.cluster_map, outfile)
    _, indices = reco_obj.recognize(ArraySequence(model_streamlines),
                                    args.model_clustering_thr,
                                    pruning_thr=args.pruning_thr,
                                    slr_num_threads=args.slr_threads)
    new_streamlines = wb_streamlines[indices]
    new_data_per_streamlines = wb_file.data_per_streamline[indices]
    new_data_per_points = wb_file.data_per_point[indices]

    if not args.no_empty or new_streamlines:
        sft = StatefulTractogram(new_streamlines,
                                 wb_file.space_attributes,
                                 Space.RASMM,
                                 data_per_streamline=new_data_per_streamlines,
                                 data_per_point=new_data_per_points)
        save_tractogram(sft, args.out_tractogram)
예제 #2
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.moving_tractogram,
                                 args.static_tractogram])

    if args.only_rigid:
        matrix_filename = os.path.splitext(args.out_name)[0] + '_rigid.txt'
    else:
        matrix_filename = os.path.splitext(args.out_name)[0] + '_affine.txt'

    assert_outputs_exist(parser, args, matrix_filename, args.out_name)

    sft_moving = load_tractogram_with_reference(parser, args,
                                                args.moving_tractogram,
                                                arg_name='moving_tractogram')

    sft_static = load_tractogram_with_reference(parser, args,
                                                args.static_tractogram,
                                                arg_name='static_tractogram')

    if args.only_rigid:
        transformation_type = 'rigid'
    else:
        transformation_type = 'affine'

    ret = whole_brain_slr(sft_moving.streamlines,
                          sft_static.streamlines,
                          x0=transformation_type,
                          maxiter=150,
                          verbose=args.verbose)
    _, transfo, _, _ = ret

    np.savetxt(matrix_filename, transfo)
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser,
                        [args.moving_tractogram, args.static_tractogram])

    if args.only_rigid:
        matrix_filename = os.path.splitext(args.out_name)[0] + '_rigid.npy'
    else:
        matrix_filename = os.path.splitext(args.out_name)[0] + '_affine.npy'

    assert_outputs_exist(parser, args, matrix_filename, args.out_name)

    sft_moving = load_tractogram_with_reference(parser,
                                                args,
                                                args.moving_tractogram,
                                                bbox_check=True,
                                                argName='moving_tractogram')

    sft_static = load_tractogram_with_reference(parser,
                                                args,
                                                args.static_tractogram,
                                                bbox_check=True,
                                                argName='static_tractogram')

    register_tractogram(sft_moving, sft_static, args.only_rigid,
                        args.amount_to_load, matrix_filename, args.verbose)
예제 #4
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.transformation])
    assert_outputs_exist(parser, args, args.output_name)

    wb_file = load_tractogram_with_reference(parser, args, args.in_tractogram)
    wb_streamlines = wb_file.streamlines
    model_file = load_tractogram_with_reference(parser, args, args.in_model)

    # Default transformation source is expected to be ANTs
    transfo = np.loadtxt(args.transformation)
    if args.inverse:
        transfo = np.linalg.inv(np.loadtxt(args.transformation))

    model_streamlines = ArraySequence(
        transform_streamlines(model_file.streamlines, transfo))

    rng = np.random.RandomState(args.seed)
    if args.input_pickle:
        with open(args.input_pickle, 'rb') as infile:
            cluster_map = pickle.load(infile)
        reco_obj = RecoBundles(wb_streamlines,
                               cluster_map=cluster_map,
                               rng=rng,
                               verbose=args.verbose)
    else:
        reco_obj = RecoBundles(wb_streamlines,
                               clust_thr=args.wb_clustering_thr,
                               rng=rng,
                               verbose=args.verbose)

    if args.output_pickle:
        with open(args.output_pickle, 'wb') as outfile:
            pickle.dump(reco_obj.cluster_map, outfile)
    _, indices = reco_obj.recognize(model_streamlines,
                                    args.model_clustering_thr,
                                    pruning_thr=args.pruning_thr,
                                    slr_num_threads=args.slr_threads)
    new_streamlines = wb_streamlines[indices]
    new_data_per_streamlines = wb_file.data_per_streamline[indices]
    new_data_per_points = wb_file.data_per_point[indices]

    if not args.no_empty or new_streamlines:
        sft = StatefulTractogram(new_streamlines,
                                 wb_file,
                                 Space.RASMM,
                                 data_per_streamline=new_data_per_streamlines,
                                 data_per_point=new_data_per_points)
        save_tractogram(sft, args.output_name)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser,
                        [args.in_bundle, args.in_centroid],
                        optional=args.reference)
    assert_outputs_exist(parser, args, args.out_map)

    sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft_centroid = load_tractogram_with_reference(parser, args,
                                                  args.in_centroid)

    if not len(sft_bundle.streamlines):
        logging.error('Empty bundle file {}. '
                      'Skipping'.format(args.in_bundle))
        raise ValueError

    if not len(sft_centroid.streamlines):
        logging.error('Centroid file {} should contain one streamline. '
                      'Skipping'.format(args.in_centroid))
        raise ValueError

    sft_bundle.to_vox()
    bundle_streamlines_vox = sft_bundle.streamlines
    bundle_streamlines_vox._data *= args.upsample

    sft_centroid.to_vox()
    centroid_streamlines_vox = sft_centroid.streamlines
    centroid_streamlines_vox._data *= args.upsample

    upsampled_shape = [s * args.upsample for s in sft_bundle.dimensions]
    tdi_mask = compute_tract_counts_map(bundle_streamlines_vox,
                                        upsampled_shape) > 0

    tdi_mask_nzr = np.nonzero(tdi_mask)
    tdi_mask_nzr_ind = np.transpose(tdi_mask_nzr)

    min_dist_ind, _ = min_dist_to_centroid(tdi_mask_nzr_ind,
                                           centroid_streamlines_vox[0])

    # Save the (upscaled) labels mask
    labels_mask = np.zeros(tdi_mask.shape)
    labels_mask[tdi_mask_nzr] = min_dist_ind + 1  # 0 is background value
    rescaled_affine = sft_bundle.affine
    rescaled_affine[:3, :3] /= args.upsample
    labels_img = nib.Nifti1Image(labels_mask, rescaled_affine)
    upsampled_spacing = sft_bundle.voxel_sizes / args.upsample
    labels_img.header.set_zooms(upsampled_spacing)
    nib.save(labels_img, args.out_map)
예제 #6
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]
    sft_list = []
    for name in args.in_bundles:
        tmp_sft = load_tractogram_with_reference(parser, args, name)
        tmp_sft.to_vox()
        tmp_sft.to_corner()

        if not is_header_compatible(reference_file, tmp_sft):
            raise ValueError('Headers are not compatible.')
        sft_list.append(tmp_sft)
        fusion_streamlines.append(tmp_sft.streamlines)

    fusion_streamlines, _ = union_robust(fusion_streamlines)

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i in range(len(args.in_bundles)):
        sft = sft_list[i]
        binary = compute_tract_counts_map(sft.streamlines, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = intersection_robust(
                [fusion_streamlines, sft.streamlines])
            streamlines_vote[list(indices), [i]] += 1

    if args.same_tractogram:
        real_indices = []
        ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
        real_indices = np.where(
            np.sum(streamlines_vote, axis=1) >= ratio_value)[0]
        new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices],
                                              sft_list[0])
        save_tractogram(new_sft, output_streamlines_filename)

    volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram, args.reference)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    ori_len = len(sft)
    sft.remove_invalid_streamlines()

    indices = []
    if args.remove_single_point:
        # Will try to do a PR in Dipy
        indices = [i for i in range(len(sft)) if len(sft.streamlines[i]) <= 1]

    if args.remove_overlapping_points:
        for i in range(len(sft)):
            norm = np.linalg.norm(np.gradient(sft.streamlines[i], axis=0),
                                  axis=1)
            if (norm < 0.001).any():
                indices.append(i)

    indices = np.setdiff1d(range(len(sft)), indices)
    new_sft = StatefulTractogram.from_sft(
        sft.streamlines[indices],
        sft,
        data_per_point=sft.data_per_point[indices],
        data_per_streamline=sft.data_per_streamline[indices])
    logging.warning('Removed {} invalid streamlines.'.format(ori_len -
                                                             len(new_sft)))
    save_tractogram(new_sft, args.out_tractogram)
예제 #8
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_bundle] + args.metrics,
                        optional=args.reference)

    assert_same_resolution(args.metrics)
    metrics = [nib.load(metric) for metric in args.metrics]

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft.to_vox()
    sft.to_corner()

    bundle_stats = get_bundle_metrics_mean_std(sft.streamlines,
                                               metrics,
                                               args.density_weighting)

    bundle_name, _ = os.path.splitext(os.path.basename(args.in_bundle))

    stats = {bundle_name: {}}
    for metric, (mean, std) in zip(metrics, bundle_stats):
        metric_name = split_name_with_nii(
            os.path.basename(metric.get_filename()))[0]
        stats[bundle_name][metric_name] = {
            'mean': mean,
            'std': std
        }

    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
예제 #9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser,
                        [args.in_bundle, args.in_bingham, args.in_lobe_metric])
    assert_outputs_exist(parser, args, [args.out_mean_map])

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    bingham_img = nib.load(args.in_bingham)
    metric_img = nib.load(args.in_lobe_metric)

    if bingham_img.shape[-2] != metric_img.shape[-1]:
        parser.error('Dimension mismatch between Bingham coefficients '
                     'and lobe-specific metric image.')

    metric_mean_map =\
        lobe_specific_metric_map_along_streamlines(sft,
                                                   bingham_img.get_fdata(),
                                                   metric_img.get_fdata(),
                                                   args.max_theta,
                                                   args.length_weighting)

    nib.Nifti1Image(metric_mean_map.astype(np.float32),
                    bingham_img.affine).to_filename(args.out_mean_map)
예제 #10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, [], optional=args.output_centroids)
    if args.output_clusters_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.output_clusters_dir,
                                           create_dir=True)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    streamlines = sft.streamlines
    thresholds = [40, 30, 20, args.dist_thresh]
    clusters = qbx_and_merge(streamlines,
                             thresholds,
                             nb_pts=args.nb_points,
                             verbose=False)

    for i, cluster in enumerate(clusters):
        if len(cluster.indices) > 1:
            cluster_streamlines = itemgetter(*cluster.indices)(streamlines)
        else:
            cluster_streamlines = streamlines[cluster.indices]

        new_sft = StatefulTractogram(cluster_streamlines, sft, Space.RASMM)
        save_tractogram(
            new_sft,
            os.path.join(args.output_clusters_dir, 'cluster_{}.trk'.format(i)))

    if args.output_centroids:
        new_sft = StatefulTractogram(clusters.centroids, sft, Space.RASMM)
        save_tractogram(new_sft, args.output_centroids)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser,
        [args.in_moving_tractogram, args.in_target_file, args.in_transfo],
        args.in_deformation)
    assert_outputs_exist(parser, args, args.out_tractogram)

    moving_sft = load_tractogram_with_reference(parser,
                                                args,
                                                args.in_moving_tractogram,
                                                bbox_check=False)

    transfo = load_matrix_in_any_format(args.in_transfo)
    deformation_data = None
    if args.in_deformation is not None:
        deformation_data = np.squeeze(
            nib.load(args.in_deformation).get_fdata(dtype=np.float32))

    new_sft = transform_warp_streamlines(moving_sft,
                                         transfo,
                                         args.in_target_file,
                                         inverse=args.inverse,
                                         deformation_data=deformation_data,
                                         remove_invalid=args.remove_invalid,
                                         cut_invalid=args.cut_invalid)

    if args.keep_invalid:
        if not new_sft.is_bbox_in_vox_valid():
            logging.warning('Saving tractogram with invalid streamlines.')
        save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False)
    else:
        save_tractogram(new_sft, args.out_tractogram)
def main():

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    new_streamlines, new_per_point, new_per_streamline = filter_streamlines_by_length(
        sft, args.minL, args.maxL)

    new_sft = StatefulTractogram(new_streamlines,
                                 sft,
                                 Space.RASMM,
                                 data_per_streamline=new_per_streamline,
                                 data_per_point=new_per_point)

    if not new_streamlines:
        if args.no_empty:
            logging.debug("The file {} won't be written "
                          "(0 streamline).".format(args.out_tractogram))

            return

        logging.debug('The file {} contains 0 streamline'.format(
            args.out_tractogram))

    save_tractogram(new_sft, args.out_tractogram)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not args.out_tractogram.endswith('.trk'):
        parser.error('Output file needs to end with .trk.')

    if len(args.color) == 7:
        args.color = '0x' + args.color.lstrip('#')

    if len(args.color) == 8:
        color_int = int(args.color, 0)
        red = color_int >> 16
        green = (color_int & 0x00FF00) >> 8
        blue = color_int & 0x0000FF
    else:
        parser.error('Hexadecimal RGB color should be formatted as "#RRGGBB"'
                     ' or 0xRRGGBB.')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    sft.data_per_point["color"] = [np.tile([red, green, blue],
                                           (len(i), 1)) for i in sft.streamlines]

    sft = StatefulTractogram.from_sft(sft.streamlines, sft,
                                      data_per_point=sft.data_per_point)

    save_tractogram(sft, args.out_tractogram)
예제 #14
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.DEBUG
    logging.basicConfig(level=log_level)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    smoothed_streamlines = []
    for streamline in sft.streamlines:
        if args.gaussian:
            tmp_streamlines = smooth_line_gaussian(streamline, args.gaussian)
        else:
            tmp_streamlines = smooth_line_spline(streamline, args.spline[0],
                                                 args.spline[1])

        if args.error_rate:
            smoothed_streamlines.append(
                compress_streamlines(tmp_streamlines, args.error_rate))

    smoothed_sft = StatefulTractogram.from_sft(
        smoothed_streamlines, sft, data_per_streamline=sft.data_per_streamline)
    save_tractogram(smoothed_sft, args.out_tractogram)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundle, optional=args.reference)
    assert_outputs_exist(parser, args, args.out_img)

    max_ = np.iinfo(np.int16).max
    if args.binary is not None and (args.binary <= 0 or args.binary > max_):
        parser.error(
            'The value of --binary ({}) '
            'must be greater than 0 and smaller or equal to {}'.format(
                args.binary, max_))

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft.to_vox()
    sft.to_corner()
    streamlines = sft.streamlines
    transformation, dimensions, _, _ = sft.space_attributes

    streamline_count = compute_tract_counts_map(streamlines, dimensions)

    if args.binary is not None:
        streamline_count[streamline_count > 0] = args.binary

    nib.save(
        nib.Nifti1Image(streamline_count.astype(np.int16), transformation),
        args.out_img)
예제 #16
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundle)
    assert_outputs_exist(parser, args, args.out_bundle, args.remaining_bundle)
    if args.alpha <= 0 or args.alpha > 1:
        parser.error('--alpha should be ]0, 1]')

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    if len(sft) == 0:
        logging.warning("Bundle file contains no streamline")
        return

    check_tracts_same_format(
        parser, [args.in_bundle, args.out_bundle, args.remaining_bundle])
    outliers, inliers = remove_outliers(sft.streamlines, args.alpha)

    inliers_sft = sft[inliers]
    outliers_sfts = sft[outliers]

    if len(inliers) == 0:
        logging.warning("All streamlines are considered outliers."
                        "Please lower the --alpha parameter")
    else:
        save_tractogram(inliers_sft, args.out_bundle)

    if len(outliers) == 0:
        logging.warning("No outlier found. Please raise the --alpha parameter")
    elif args.remaining_bundle:
        save_tractogram(outliers_sfts, args.remaining_bundle)
예제 #17
0
def load_data(parser, args, path):
    logging.info('Loading streamlines from {0}.'.format(path))
    sft = load_tractogram_with_reference(parser, args, path)
    streamlines = list(sft.streamlines)
    data_per_streamline = sft.data_per_streamline
    data_per_point = sft.data_per_point

    return streamlines, data_per_streamline, data_per_point
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram,
                         optional=args.looping_tractogram)
    check_tracts_same_format(parser, [args.in_tractogram, args.out_tractogram,
                                      args.looping_tractogram])

    if args.threshold <= 0:
        parser.error('Threshold "{}" '.format(args.threshold) +
                     'must be greater than 0')

    if args.angle <= 0:
        parser.error('Angle "{}" '.format(args.angle) +
                     'must be greater than 0')

    tractogram = load_tractogram_with_reference(
        parser, args, args.in_tractogram)

    streamlines = tractogram.streamlines

    ids_c = []

    ids_l = []

    if len(streamlines) > 1:
        ids_c = remove_loops_and_sharp_turns(
            streamlines, args.angle, use_qb=args.qb,
            qb_threshold=args.threshold)
        ids_l = np.setdiff1d(np.arange(len(streamlines)), ids_c)
    else:
        parser.error(
            'Zero or one streamline in {}'.format(args.in_tractogram) +
            '. The file must have more than one streamline.')

    if len(ids_c) > 0:
        sft_c = filter_tractogram_data(tractogram, ids_c)
        save_tractogram(sft_c, args.out_tractogram)
    else:
        logging.warning(
            'No clean streamlines in {}'.format(args.in_tractogram))

    if args.display_counts:
        sc_bf = len(tractogram.streamlines)
        sc_af = len(sft_c.streamlines)
        print(json.dumps({'streamline_count_before_filtering': int(sc_bf),
                         'streamline_count_after_filtering': int(sc_af)},
                         indent=args.indent))

    if len(ids_l) == 0:
        logging.warning('No loops in {}'.format(args.in_tractogram))
    elif args.looping_tractogram:
        sft_l = filter_tractogram_data(tractogram, ids_l)
        save_tractogram(sft_l, args.looping_tractogram)
예제 #19
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    swap = args.swap

    assert_inputs_exist(parser, args.in_bundle, args.reference)
    assert_outputs_exist(parser, args,
                         [args.endpoints_map_head, args.endpoints_map_tail])

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft.to_vox()
    if len(sft.streamlines) == 0:
        logging.warning('Empty bundle file {}. Skipping'.format(args.bundle))
        return

    transfo, dim, _, _ = sft.space_attributes

    endpoints_map_head = np.zeros(dim)
    endpoints_map_tail = np.zeros(dim)

    head_name = args.endpoints_map_head
    tail_name = args.endpoints_map_tail

    if swap:
        head_name = args.endpoints_map_tail
        tail_name = args.endpoints_map_head

    for streamline in sft.streamlines:
        xyz = streamline[0, :].astype(int)
        endpoints_map_head[xyz[0], xyz[1], xyz[2]] += 1

        xyz = streamline[-1, :].astype(int)
        endpoints_map_tail[xyz[0], xyz[1], xyz[2]] += 1

    nib.save(nib.Nifti1Image(endpoints_map_head, transfo), head_name)
    nib.save(nib.Nifti1Image(endpoints_map_tail, transfo), tail_name)

    bundle_name, _ = os.path.splitext(os.path.basename(args.in_bundle))
    bundle_name_head = bundle_name + '_head'
    bundle_name_tail = bundle_name + '_tail'

    if swap:
        bundle_name_head = bundle_name + '_tail'
        bundle_name_tail = bundle_name + '_head'

    stats = {
        bundle_name_head: {
            'count': np.count_nonzero(endpoints_map_head)
        },
        bundle_name_tail: {
            'count': np.count_nonzero(endpoints_map_tail)
        }
    }

    print(json.dumps(stats, indent=args.indent))
예제 #20
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser,
                         args,
                         args.out_tractogram,
                         optional=args.remaining_tractogram)
    check_tracts_same_format(
        parser,
        [args.in_tractogram, args.out_tractogram, args.remaining_tractogram])

    if not (-1 <= args.minU <= 1 and -1 <= args.maxU <= 1):
        parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) +
                     'must be between -1 and 1.')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    ids_c = detect_ushape(sft, args.minU, args.maxU)
    ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c)

    if len(ids_c) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written "
                          "(0 streamline).".format(args.out_tractogram))
            return

        logging.debug('The file {} contains 0 streamline.'.format(
            args.out_tractogram))

    save_tractogram(sft[ids_c], args.out_tractogram)

    if args.display_counts:
        sc_bf = len(sft.streamlines)
        sc_af = len(ids_c)
        print(
            json.dumps(
                {
                    'streamline_count_before_filtering': int(sc_bf),
                    'streamline_count_after_filtering': int(sc_af)
                },
                indent=args.indent))

    if args.remaining_tractogram:
        if len(ids_l) == 0:
            if args.no_empty:
                logging.debug("The file {} won't be written (0 streamline"
                              ").".format(args.remaining_tractogram))
                return

            logging.warning('No remaining streamlines.')

        save_tractogram(sft[ids_l], args.remaining_tractogram)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    swap = args.swap

    assert_inputs_exist(parser, args.in_bundle, args.reference)
    assert_outputs_exist(parser, args,
                         [args.endpoints_map_head, args.endpoints_map_tail])

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft.to_vox()
    sft.to_corner()
    if len(sft.streamlines) == 0:
        logging.warning('Empty bundle file {}. Skipping'.format(args.bundle))
        return

    transfo, dim, _, _ = sft.space_attributes

    head_name = args.endpoints_map_head
    tail_name = args.endpoints_map_tail

    if swap:
        head_name = args.endpoints_map_tail
        tail_name = args.endpoints_map_head

    endpoints_map_head, endpoints_map_tail = \
        get_head_tail_density_maps(sft.streamlines, dim)

    if args.binary:
        endpoints_map_head = (endpoints_map_head > 0).astype(np.int16)
        endpoints_map_tail = (endpoints_map_tail > 0).astype(np.int16)

    nib.save(nib.Nifti1Image(endpoints_map_head, transfo), head_name)
    nib.save(nib.Nifti1Image(endpoints_map_tail, transfo), tail_name)

    bundle_name, _ = os.path.splitext(os.path.basename(args.in_bundle))
    bundle_name_head = bundle_name + '_head'
    bundle_name_tail = bundle_name + '_tail'

    if swap:
        bundle_name_head = bundle_name + '_tail'
        bundle_name_tail = bundle_name + '_head'

    stats = {
        bundle_name_head: {
            'count': np.count_nonzero(endpoints_map_head)
        },
        bundle_name_tail: {
            'count': np.count_nonzero(endpoints_map_tail)
        }
    }

    print(json.dumps(stats, indent=args.indent))
예제 #22
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser,
                         args,
                         args.out_tractogram,
                         optional=args.remaining_tractogram)
    check_tracts_same_format(
        parser,
        [args.in_tractogram, args.out_tractogram, args.remaining_tractogram])

    if not (-1 <= args.minU <= 1 and -1 <= args.maxU <= 1):
        parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) +
                     'must be between -1 and 1.')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    ids_c = []
    ids_l = []

    if len(sft.streamlines) > 1:
        ids_c = detect_ushape(sft, args.minU, args.maxU)
        ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c)
    else:
        parser.error(
            'Zero or one streamline in {}'.format(args.in_tractogram) +
            '. The file must have more than one streamline.')

    if len(ids_c) > 0:
        save_tractogram(sft[ids_c], args.out_tractogram)
    else:
        logging.warning('No u-shape streamlines in {}'.format(
            args.in_tractogram))

    if args.display_counts:
        sc_bf = len(sft.streamlines)
        sc_af = len(ids_c)
        print(
            json.dumps(
                {
                    'streamline_count_before_filtering': int(sc_bf),
                    'streamline_count_after_filtering': int(sc_af)
                },
                indent=args.indent))

    if len(ids_l) == 0:
        logging.warning('No remaining streamlines '
                        'in {}'.format(args.remaining_tractogram))
    elif args.remaining_tractogram:
        save_tractogram(sft[ids_l], args.remaining_tractogram)
예제 #23
0
def compute_gt_masks(gt_bundles, parser, args):
    """
    Compute ground-truth masks. If the ground-truth is
    already a mask, load it. If the ground-truth is a
    bundle, compute the mask.

    Parameters
    ----------
    gt_bundles: list
        List of either StatefulTractograms or niftis.
    parser: ArgumentParser
        Argument parser which handles the script's arguments.
    args: Namespace
        List of arguments passed to the script.

    Returns
    -------
    mask_1: numpy.ndarray
        "Head" of the mask.
    mask_2: numpy.ndarray
        "Tail" of the mask.
    """

    gt_bundle_masks = []
    gt_bundle_inv_masks = []

    for gt_bundle in args.gt_bundles:
        # Support ground truth as streamlines or masks
        # Will be converted to binary masks immediately
        _, ext = split_name_with_nii(gt_bundle)
        if ext in ['.gz', '.nii.gz']:
            gt_img = nib.load(gt_bundle)
            gt_mask = get_data_as_mask(gt_img)
            affine = gt_img.affine
            dimensions = gt_mask.shape
        else:
            gt_sft = load_tractogram_with_reference(parser,
                                                    args,
                                                    gt_bundle,
                                                    bbox_check=False)
            gt_sft.to_vox()
            gt_sft.to_corner()
            affine, dimensions, _, _ = gt_sft.space_attributes
            gt_mask = compute_tract_counts_map(gt_sft.streamlines,
                                               dimensions).astype(np.int16)
        gt_inv_mask = np.zeros(dimensions, dtype=np.int16)
        gt_inv_mask[gt_mask == 0] = 1
        gt_mask[gt_mask > 0] = 1
        gt_bundle_masks.append(gt_mask)
        gt_bundle_inv_masks.append(gt_inv_mask)

    return gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions
예제 #24
0
def main():

    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    new_sft = get_subset_streamlines(sft, args.max_num_streamlines, args.seed)

    save_tractogram(new_sft, args.out_tractogram)
예제 #25
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    sft.to_vox()
    sft.to_corner()

    new_sft = flip_sft(sft, args.axes)
    save_tractogram(new_sft, args.out_tractogram)
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram, args.reference)
    assert_outputs_exist(parser, args, args.output_name)

    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    save_tractogram(sft, args.output_name)
예제 #27
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_bundle, args.in_centroid])
    assert_outputs_exist(parser, args,
                         [args.output_label, args.output_distance])

    is_header_compatible(args.in_bundle, args.in_centroid)

    sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle)

    sft_centroid = load_tractogram_with_reference(parser, args,
                                                  args.in_centroid)

    if not len(sft_bundle.streamlines):
        logging.error('Empty bundle file {}. Skipping'.format(args.in_bundle))
        raise ValueError

    if not len(sft_centroid.streamlines):
        logging.error('Empty centroid streamline file {}. Skipping'.format(
            args.centroid_streamline))
        raise ValueError

    min_dist_label, min_dist = min_dist_to_centroid(
        sft_bundle.streamlines.data, sft_centroid.streamlines.data)
    min_dist_label += 1

    # Save assignment in a compressed numpy file
    # You can load this file and access its data using
    # f = np.load('someFile.npz')
    # assignment = f['arr_0']
    np.savez_compressed(args.output_label, min_dist_label)

    # Save distance in a compressed numpy file
    # You can load this file and access its data using
    # f = np.load('someFile.npz')
    # distance = f['arr_0']
    np.savez_compressed(args.output_distance, min_dist)
예제 #28
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_tractogram, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    if args.step_size is not None:
        sft = resample_streamlines_step_size(sft, args.step_size)

    mask_img = nib.load(args.in_mask)
    binary_mask = get_data_as_mask(mask_img)

    if not is_header_compatible(sft, mask_img):
        parser.error('Incompatible header between the tractogram and mask.')

    bundle_disjoint, _ = ndi.label(binary_mask)
    unique, count = np.unique(bundle_disjoint, return_counts=True)
    if args.biggest_blob:
        val = unique[np.argmax(count[1:]) + 1]
        binary_mask[bundle_disjoint != val] = 0
        unique = [0, val]
    if len(unique) == 2:
        logging.info('The provided mask has 1 entity '
                     'cut_outside_of_mask_streamlines function selected.')
        new_sft = cut_outside_of_mask_streamlines(sft, binary_mask)
    elif len(unique) == 3:
        logging.info('The provided mask has 2 entity '
                     'cut_between_masks_streamlines function selected.')
        new_sft = cut_between_masks_streamlines(sft, binary_mask)

    else:
        logging.error('The provided mask has more than 2 entities. Cannot cut '
                      'between >2.')
        return

    if len(new_sft) == 0:
        logging.warning('No streamline intersected the provided mask. '
                        'Saving empty tractogram.')
    elif args.error_rate is not None:
        compressed_strs = [
            compress_streamlines(s, args.error_rate)
            for s in new_sft.streamlines
        ]
        new_sft = StatefulTractogram.from_sft(compressed_strs, sft)

    save_tractogram(new_sft, args.out_tractogram)
예제 #29
0
def main():

    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram, args.save_rejected)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        # Silencing SFT's logger if our logging is in DEBUG mode, because it
        # typically produces a lot of outputs!
        set_sft_logger_level('WARNING')

    if args.min_x == 0 and np.isinf(args.max_x) and \
       args.min_y == 0 and np.isinf(args.max_y) and \
       args.min_z == 0 and np.isinf(args.max_z):
        logging.warning("You have not specified min or max in any direction. "
                        "Output will simply be a copy of your input!")

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    computed_rejected_sft = args.save_rejected is not None
    new_sft, indices, rejected_sft = \
        filter_streamlines_by_total_length_per_dim(
            sft, [args.min_x, args.max_x], [args.min_y, args.max_y],
            [args.min_z, args.max_z], args.use_abs, computed_rejected_sft)

    if args.display_counts:
        sc_bf = len(sft.streamlines)
        sc_af = len(new_sft.streamlines)
        print(
            json.dumps(
                {
                    'streamline_count_before_filtering': int(sc_bf),
                    'streamline_count_after_filtering': int(sc_af)
                },
                indent=args.indent))

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written "
                          "(0 streamline).".format(args.out_tractogram))
        else:
            logging.debug('The file {} contains 0 streamline'.format(
                args.out_tractogram))

    save_tractogram(new_sft, args.out_tractogram)

    if computed_rejected_sft:
        save_tractogram(rejected_sft, args.save_rejected)
예제 #30
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_bundle] + args.in_metrics)
    assert_output_dirs_exist_and_empty(parser, args,
                                       args.out_folder,
                                       create_dir=True)

    assert_same_resolution(args.in_metrics)

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft.to_vox()
    sft.to_corner()

    if len(sft.streamlines) == 0:
        logging.warning('Empty bundle file {}. Skipping'.format(args.bundle))
        return

    mins, maxs, indices = _process_streamlines(sft.streamlines)

    metrics = [nib.load(metric) for metric in args.in_metrics]
    for metric in metrics:
        data = metric.get_fdata(dtype=np.float32)
        endpoint_metric_map = np.zeros(metric.shape)
        count = np.zeros(metric.shape)
        for cur_min, cur_max, cur_ind, orig_s in zip(mins, maxs,
                                                     indices,
                                                     sft.streamlines):
            streamline_mean = _compute_streamline_mean(cur_ind,
                                                       cur_min,
                                                       cur_max,
                                                       data)

            xyz = orig_s[0, :].astype(int)
            endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean
            count[xyz[0], xyz[1], xyz[2]] += 1

            xyz = orig_s[-1, :].astype(int)
            endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean
            count[xyz[0], xyz[1], xyz[2]] += 1

        endpoint_metric_map[count != 0] /= count[count != 0]
        metric_fname, ext = split_name_with_nii(
            os.path.basename(metric.get_filename()))
        nib.save(nib.Nifti1Image(endpoint_metric_map, metric.affine,
                                 metric.header),
                 os.path.join(args.out_folder,
                              '{}_endpoints_metric{}'.format(metric_fname,
                                                             ext)))