示例#1
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.transformation])
    assert_outputs_exist(parser, args, [args.output_name])

    wb_file = load_tractogram_with_reference(parser, args, args.in_tractogram)
    wb_streamlines = wb_file.streamlines
    model_file = load_tractogram_with_reference(parser, args, args.in_model)

    # Default transformation source is expected to be ANTs
    transfo = np.loadtxt(args.transformation)
    if args.inverse:
        transfo = np.linalg.inv(np.loadtxt(args.transformation))

    model_streamlines = ArraySequence(
        transform_streamlines(model_file.streamlines, transfo))

    rng = np.random.RandomState(args.seed)
    if args.input_pickle:
        with open(args.input_pickle, 'rb') as infile:
            cluster_map = pickle.load(infile)
        reco_obj = RecoBundles(wb_streamlines,
                               cluster_map=cluster_map,
                               rng=rng,
                               verbose=args.verbose)
    else:
        reco_obj = RecoBundles(wb_streamlines,
                               clust_thr=args.wb_clustering_thr,
                               rng=rng,
                               verbose=args.verbose)

    if args.output_pickle:
        with open(args.output_pickle, 'wb') as outfile:
            pickle.dump(reco_obj.cluster_map, outfile)
    _, indices = reco_obj.recognize(model_streamlines,
                                    args.model_clustering_thr,
                                    pruning_thr=args.pruning_thr,
                                    slr_num_threads=args.slr_threads)
    new_streamlines = wb_streamlines[indices]
    new_data_per_streamlines = wb_file.data_per_streamline[indices]
    new_data_per_points = wb_file.data_per_point[indices]

    if not args.no_empty or new_streamlines:
        sft = StatefulTractogram(new_streamlines,
                                 wb_file,
                                 Space.RASMM,
                                 data_per_streamline=new_data_per_streamlines,
                                 data_per_point=new_data_per_points)
        save_tractogram(sft, args.output_name)
示例#2
0
def load_data(parser, args, path):
    logging.info('Loading streamlines from {0}.'.format(path))
    sft = load_tractogram_with_reference(parser, args, path)
    streamlines = list(sft.streamlines)
    data_per_streamline = sft.data_per_streamline
    data_per_point = sft.data_per_point

    return streamlines, data_per_streamline, data_per_point
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram], [args.reference])

    in_extension = os.path.splitext(args.in_tractogram)[1]
    out_extension = os.path.splitext(args.output_name)[1]

    if in_extension == out_extension:
        parser.error('Input and output cannot be of the same file format')

    assert_outputs_exist(parser, args, args.output_name)

    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    save_tractogram(sft, args.output_name, bbox_valid_check=False)
示例#4
0
def main():
    parser = _buildArgsParser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram])
    assert_outputs_exists(parser, args, [args.out_tractogram])
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    roi_opt_list = prepare_filtering_list(parser, args)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    for i, roi_opt in enumerate(roi_opt_list):
        # Atlas needs an extra argument (value in the LUT)
        if roi_opt[0] == 'atlas_roi':
            filter_type, filter_arg_1, filter_arg_2, \
                filter_mode, filter_criteria = roi_opt
        else:
            filter_type, filter_arg, filter_mode, filter_criteria = roi_opt
        is_not = False if filter_criteria == 'include' else True

        if filter_type == 'drawn_roi':
            img = nib.load(filter_arg)
            if not is_header_compatible(img, sft):
                parser.error('Headers from the tractogram and the mask are '
                             'not compatible.')

            mask = img.get_data()
            filtered_streamlines, indexes = filter_grid_roi(
                sft, mask, filter_mode, is_not)

        elif filter_type == 'atlas_roi':
            img = nib.load(filter_arg_1)
            if not is_header_compatible(img, sft):
                parser.error('Headers from the tractogram and the mask are '
                             'not compatible.')

            atlas = img.get_data().astype(np.uint16)
            mask = np.zeros(atlas.shape, dtype=np.uint16)
            mask[atlas == int(filter_arg_2)] = 1

            filtered_streamlines, indexes = filter_grid_roi(
                sft, mask, filter_mode, is_not)

        # For every case, the input number must be greater or equal to 0 and
        # below the dimension, since this is a voxel space operation
        elif filter_type in ['x_plane', 'y_plane', 'z_plane']:
            filter_arg = int(filter_arg)
            _, dim, _, _ = sft.space_attribute
            mask = np.zeros(dim, dtype=np.int16)
            error_msg = None
            if filter_type == 'x_plane':
                if 0 <= filter_arg < dim[0]:
                    mask[filter_arg, :, :] = 1
                else:
                    error_msg = 'X plane ' + str(filter_arg)

            elif filter_type == 'y_plane':
                if 0 <= filter_arg < dim[1]:
                    mask[:, filter_arg, :] = 1
                else:
                    error_msg = 'Y plane ' + str(filter_arg)

            elif filter_type == 'z_plane':
                if 0 <= filter_arg < dim[2]:
                    mask[:, :, filter_arg] = 1
                else:
                    error_msg = 'Z plane ' + str(filter_arg)

            if error_msg:
                parser.error('{} is not valid according to the '
                             'tractogram header.'.format(error_msg))

            filtered_streamlines, indexes = filter_grid_roi(
                sft, mask, filter_mode, is_not)

        elif filter_type == 'bdo':
            geometry, radius, center = read_info_from_mb_bdo(filter_arg)
            if geometry == 'Ellipsoid':
                filtered_streamlines, indexes = filter_ellipsoid(
                    sft, radius, center, filter_mode, is_not)
            elif geometry == 'Cuboid':
                filtered_streamlines, indexes = filter_cuboid(
                    sft, radius, center, filter_mode, is_not)

        logging.debug('The filtering options {0} resulted in '
                      '{1} streamlines'.format(roi_opt,
                                               len(filtered_streamlines)))

        data_per_streamline = sft.data_per_streamline[indexes]
        data_per_point = sft.data_per_point[indexes]

        sft = StatefulTractogram(filtered_streamlines,
                                 sft,
                                 Space.RASMM,
                                 data_per_streamline=data_per_streamline,
                                 data_per_point=data_per_point)

    if not filtered_streamlines:
        if args.no_empty:
            logging.debug("The file {} won't be written (0 streamline)".format(
                args.out_tractogram))

            return

        logging.debug('The file {} contains 0 streamline'.format(
            args.out_tractogram))

    save_tractogram(sft, args.out_tractogram)
示例#5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser,
                        required=[args.tract_filename],
                        optional=[args.mask, args.reference])
    output_file_list = []
    if args.out_mask:
        output_file_list.append(args.out_mask)
    if args.out_lw_tdi:
        output_file_list.append(args.out_lw_tdi)
    if args.out_lw_todi:
        output_file_list.append(args.out_lw_todi)
    if args.out_lw_todi_sh:
        output_file_list.append(args.out_lw_todi_sh)

    if not output_file_list:
        parser.error('No output to be done')
    else:
        assert_outputs_exist(parser, args, output_file_list)

    sft = load_tractogram_with_reference(parser, args, args.tract_filename)
    affine, data_shape, _, _ = sft.space_attribute
    sft.to_vox()

    logging.info('Computing length-weighted TODI ...')
    todi_obj = TrackOrientationDensityImaging(tuple(data_shape), args.sphere)
    todi_obj.compute_todi(sft.streamlines, length_weights=True)

    if args.smooth:
        logging.info('Smoothing ...')
        todi_obj.smooth_todi_dir()
        todi_obj.smooth_todi_spatial()

    if args.mask:
        mask = nib.load(args.mask).get_data()
        todi_obj.mask_todi(mask)

    logging.info('Saving Outputs ...')
    if args.out_mask:
        data = todi_obj.get_mask().astype(np.int16)
        img = todi_obj.reshape_to_3d(data)
        img = nib.Nifti1Image(img, affine)
        img.to_filename(args.out_mask)

    if args.out_lw_todi_sh:
        img = todi_obj.get_todi().astype(np.float32)
        img = todi_obj.get_sh(img, args.sh_basis, args.sh_order,
                              args.sh_normed)
        img = todi_obj.reshape_to_3d(img)
        img = nib.Nifti1Image(img, affine)
        img.to_filename(args.out_lw_todi_sh)

    if args.out_lw_tdi:
        img = todi_obj.get_tdi().astype(np.float32)
        img = todi_obj.reshape_to_3d(img)
        img = nib.Nifti1Image(img, affine)
        img.to_filename(args.out_lw_tdi)

    if args.out_lw_todi:
        img = todi_obj.get_todi().astype(np.float32)
        img = todi_obj.reshape_to_3d(img)
        img = nib.Nifti1Image(img, affine)
        img.to_filename(args.out_lw_todi)