예제 #1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
예제 #2
0
print("Segmenting fiber groups...")

segmentation = seg.Segmentation(return_idx=True, filter_by_endpoints=False)

segmentation.segment(bundles,
                     tractogram,
                     fdata=hardi_fdata,
                     fbval=hardi_fbval,
                     fbvec=hardi_fbvec,
                     mapping=mapping,
                     reg_template=MNI_T2_img)

fiber_groups = segmentation.fiber_groups

for bundle in bundles:
    tractogram = StatefulTractogram(fiber_groups[bundle]['sl'].streamlines,
                                    img, Space.VOX)
    tractogram.to_rasmm()
    save_tractogram(tractogram,
                    op.join(working_dir, f'afq_{bundle}_seg.trk'),
                    bbox_valid_check=False)

    tractogram_img = density_map(tractogram, n_sls=1000, to_vox=True)
    nib.save(tractogram_img,
             op.join(working_dir, f'afq_{bundle}_seg_density_map.nii.gz'))
    show_anatomical_slices(tractogram_img.get_fdata(),
                           f'Segmented {bundle} Density Map')

##########################################################################
# Cleaning:
# ---------
# Each fiber group is cleaned to exclude streamlines that are outliers in terms
예제 #3
0
def load_tractogram(filename,
                    reference,
                    to_space=Space.RASMM,
                    shifted_origin=False,
                    bbox_valid_check=True,
                    trk_header_check=True):
    """ Load the stateful tractogram from any format (trk, tck, fib, dpy)

    Parameters
    ----------
    filename : string
        Filename with valid extension
    reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or
        trk.header (dict), or 'same' if the input is a trk file.
        Reference that provides the spatial attribute.
        Typically a nifti-related object from the native diffusion used for
        streamlines generation
    space : string
        Space in which the streamlines will be transformed after loading
        (vox, voxmm or rasmm)
    shifted_origin : bool
        Information on the position of the origin,
        False is Trackvis standard, default (center of the voxel)
        True is NIFTI standard (corner of the voxel)

    Returns
    -------
    output : StatefulTractogram
        The tractogram to load (must have been saved properly)
    """
    _, extension = os.path.splitext(filename)
    if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']:
        logging.error('Output filename is not one of the supported format')
        return False

    if to_space not in Space:
        logging.error('Space MUST be one of the 3 choices (Enum)')
        return False

    if reference == 'same':
        if extension == '.trk':
            reference = filename
        else:
            logging.error('Reference must be provided, "same" is only ' +
                          'available for Trk file.')
            return False

    if trk_header_check and extension == '.trk':
        if not is_header_compatible(filename, reference):
            logging.error('Trk file header does not match the provided ' +
                          'reference')
            return False

    timer = time.time()
    data_per_point = None
    data_per_streamline = None
    if extension in ['.trk', '.tck']:
        tractogram_obj = nib.streamlines.load(filename).tractogram
        streamlines = tractogram_obj.streamlines
        if extension == '.trk':
            data_per_point = tractogram_obj.data_per_point
            data_per_streamline = tractogram_obj.data_per_streamline

    elif extension in ['.vtk', '.fib']:
        streamlines = load_vtk_streamlines(filename)
    elif extension in ['.dpy']:
        dpy_obj = Dpy(filename, mode='r')
        streamlines = list(dpy_obj.read_tracks())
        dpy_obj.close()
    logging.debug('Load %s with %s streamlines in %s seconds', filename,
                  len(streamlines), round(time.time() - timer, 3))

    sft = StatefulTractogram(streamlines,
                             reference,
                             Space.RASMM,
                             shifted_origin=shifted_origin,
                             data_per_point=data_per_point,
                             data_per_streamline=data_per_streamline)

    if to_space == Space.VOX:
        sft.to_vox()
    elif to_space == Space.VOXMM:
        sft.to_voxmm()

    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError('Bounding box is not valid in voxel space, cannot ' +
                         'load a valid file if some coordinates are invalid')

    return sft
예제 #4
0
It is recommended to re-create a new StatefulTractogram object and
explicitly specify in which space the streamlines are. Be careful to follow
the order of operations.

If the tractogram was from a Trk file with metadata, this will be lost.
If you wish to keep metadata while manipulating the number or the order
look at the function StatefulTractogram.remove_invalid_streamlines() for more
details

It is important to mention that once the object is created in a consistent state
the ``save_tractogram`` function will save a valid file. And then the function
``load_tractogram`` will load them in a valid state.
"""

cc_sft = StatefulTractogram(cc_streamlines_vox, reference_anatomy, Space.VOX)
laf_sft = StatefulTractogram(laf_streamlines_vox, reference_anatomy, Space.VOX)
raf_sft = StatefulTractogram(raf_streamlines_vox, reference_anatomy, Space.VOX)
lpt_sft = StatefulTractogram(lpt_streamlines_vox, reference_anatomy, Space.VOX)
rpt_sft = StatefulTractogram(rpt_streamlines_vox, reference_anatomy, Space.VOX)

print(len(cc_sft), len(laf_sft), len(raf_sft), len(lpt_sft), len(rpt_sft))
save_tractogram(cc_sft, 'cc_1000.trk')
save_tractogram(laf_sft, 'laf_1000.trk')
save_tractogram(raf_sft, 'raf_1000.trk')
save_tractogram(lpt_sft, 'lpt_1000.trk')
save_tractogram(rpt_sft, 'rpt_1000.trk')

nib.save(nib.Nifti1Image(cc_density, affine, nifti_header),
         'cc_density.nii.gz')
nib.save(nib.Nifti1Image(laf_density, affine, nifti_header),
def load_data_tmp_saving(filename, reference, init_only=False,
                         disable_centroids=False):
    # Since data is often re-use when comparing multiple bundles, anything
    # that can be computed once is saved temporarily and simply loaded on demand
    if not os.path.isfile(filename):
        if init_only:
            logging.warning('%s does not exist', filename)
        return None

    hash_tmp = hashlib.md5(filename.encode()).hexdigest()
    tmp_density_filename = os.path.join('tmp_measures/',
                                        '{0}_density.nii.gz'.format(hash_tmp))
    tmp_endpoints_filename = os.path.join('tmp_measures/',
                                          '{0}_endpoints.nii.gz'.format(hash_tmp))
    tmp_centroids_filename = os.path.join('tmp_measures/',
                                          '{0}_centroids.trk'.format(hash_tmp))

    sft = load_tractogram(filename, reference)
    sft.to_vox()
    sft.to_corner()
    streamlines = sft.get_streamlines_copy()
    if not streamlines:
        if init_only:
            logging.warning('%s is empty', filename)
        return None

    if os.path.isfile(tmp_density_filename) \
            and os.path.isfile(tmp_endpoints_filename) \
            and os.path.isfile(tmp_centroids_filename):
        # If initilization, loading the data is useless
        if init_only:
            return None
        density = nib.load(tmp_density_filename).get_data()
        endpoints_density = nib.load(tmp_endpoints_filename).get_data()
        sft_centroids = load_tractogram(tmp_centroids_filename, reference)
        sft_centroids.to_vox()
        sft_centroids.to_corner()
        centroids = sft_centroids.get_streamlines_copy()
    else:
        transformation, dimensions, _, _ = sft.space_attributes
        density = compute_tract_counts_map(streamlines, dimensions)
        endpoints_density = get_endpoints_density_map(streamlines, dimensions,
                                                      point_to_select=3)
        thresholds = [32, 24, 12, 6]
        if disable_centroids:
            centroids = []
        else:
            centroids = qbx_and_merge(streamlines, thresholds,
                                      rng=RandomState(0),
                                      verbose=False).centroids

        # Saving tmp files to save on future computation
        nib.save(nib.Nifti1Image(density.astype(np.float32), transformation),
                 tmp_density_filename)
        nib.save(nib.Nifti1Image(endpoints_density.astype(np.int16),
                                 transformation),
                 tmp_endpoints_filename)

        # Saving in vox space and corner.
        centroids_sft = StatefulTractogram.from_sft(centroids, sft)
        save_tractogram(centroids_sft, tmp_centroids_filename)

    return density, endpoints_density, streamlines, centroids
예제 #6
0
# Cleaning
# --------
# Each fiber group is cleaned to exclude streamlines that are outliers in terms
# of their trajector and/or length.

print("Cleaning fiber groups...")
for bundle in bundles:
    print(f"Cleaning {bundle}")
    print(f"Before cleaning: {len(fiber_groups[bundle]['sl'])} streamlines")
    new_fibers, idx_in_bundle = seg.clean_bundle(fiber_groups[bundle]['sl'],
                                                 return_idx=True)
    print(f"Afer cleaning: {len(new_fibers)} streamlines")

    idx_in_global = fiber_groups[bundle]['idx'][idx_in_bundle]
    np.save(f'{bundle}_idx.npy', idx_in_global)
    sft = StatefulTractogram(new_fibers.streamlines, img, Space.VOX)
    sft.to_rasmm()
    save_tractogram(sft, f'./{bundle}_afq.trk', bbox_valid_check=False)

##########################################################################
# Bundle profiles
# ---------------
# Streamlines are represented in the original diffusion space (`Space.VOX`) and
# scalar properties along the length of each bundle are queried from this scalar
# data. Here, the contribution of each streamline is weighted according to how
# representative this streamline is of the bundle overall.

print("Extracting tract profiles...")
for bundle in bundles:
    sft = load_tractogram(f'./{bundle}_afq.trk', img, to_space=Space.VOX)
    fig, ax = plt.subplots(1)
예제 #7
0
def main():
    # Callback required for FURY
    def keypress_callback(obj, _):
        key = obj.GetKeySym().lower()
        nonlocal clusters_linewidth, background_linewidth
        nonlocal curr_streamlines_actor, concat_streamlines_actor, show_curr_actor
        iterator = len(accepted_streamlines) + len(rejected_streamlines)
        renwin = interactor_style.GetInteractor().GetRenderWindow()
        renderer = interactor_style.GetCurrentRenderer()

        if key == 'c' and iterator < len(sft_accepted_on_size):
            if show_curr_actor:
                renderer.rm(concat_streamlines_actor)
                renwin.Render()
                show_curr_actor = False
                logging.info('Streamlines rendering OFF')
            else:
                renderer.add(concat_streamlines_actor)
                renderer.rm(curr_streamlines_actor)
                renderer.add(curr_streamlines_actor)
                renwin.Render()
                show_curr_actor = True
                logging.info('Streamlines rendering ON')
            return

        if key == 'q':
            show_manager.exit()
            if iterator < len(sft_accepted_on_size):
                logging.warning(
                    'Early exit, everything remaining to be rejected.')
            return

        if key in ['a', 'r'] and iterator < len(sft_accepted_on_size):
            if key == 'a':
                accepted_streamlines.append(iterator)
                choices.append('a')
                logging.info('Accepted file {}'.format(
                    filename_accepted_on_size[iterator]))
            elif key == 'r':
                rejected_streamlines.append(iterator)
                choices.append('r')
                logging.info('Rejected file {}'.format(
                    filename_accepted_on_size[iterator]))
            iterator += 1

        if key == 'z':
            if iterator > 0:
                last_choice = choices.pop()
                if last_choice == 'r':
                    rejected_streamlines.pop()
                else:
                    accepted_streamlines.pop()
                logging.info('Rewind on step.')

                iterator -= 1
            else:
                logging.warning('Cannot rewind, first element.')

        if key in ['a', 'r', 'z'] and iterator < len(sft_accepted_on_size):
            renderer.rm(curr_streamlines_actor)
            curr_streamlines = sft_accepted_on_size[iterator].streamlines
            curr_streamlines_actor = actor.line(curr_streamlines,
                                                opacity=0.8,
                                                linewidth=clusters_linewidth)
            renderer.add(curr_streamlines_actor)

        if iterator == len(sft_accepted_on_size):
            print('No more cluster, press q to exit')
            renderer.rm(curr_streamlines_actor)

        renwin.Render()

    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, [args.out_accepted, args.out_rejected])

    if args.out_accepted_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_accepted_dir,
                                           create_dir=True)
    if args.out_rejected_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_rejected_dir,
                                           create_dir=True)

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    if args.min_cluster_size < 1:
        parser.error('Minimum cluster size must be at least 1.')

    clusters_linewidth = args.clusters_linewidth
    background_linewidth = args.background_linewidth

    # To accelerate procedure, clusters can be discarded based on size
    # Concatenation is to give spatial context
    sft_accepted_on_size, filename_accepted_on_size = [], []
    sft_rejected_on_size, filename_rejected_on_size = [], []
    concat_streamlines = []
    for filename in args.in_bundles:
        if not is_header_compatible(args.in_bundles[0], filename):
            return
        basename = os.path.basename(filename)
        sft = load_tractogram_with_reference(parser,
                                             args,
                                             filename,
                                             bbox_check=False)
        if len(sft) >= args.min_cluster_size:
            sft_accepted_on_size.append(sft)
            filename_accepted_on_size.append(basename)
            concat_streamlines.extend(sft.streamlines)
        else:
            logging.info('File {} has {} streamlines,'
                         'automatically rejected.'.format(filename, len(sft)))
            sft_rejected_on_size.append(sft)
            filename_rejected_on_size.append(basename)

    if not filename_accepted_on_size:
        parser.error('No cluster survived the cluster_size threshold.')

    logging.info('{} clusters to be classified.'.format(
        len(sft_accepted_on_size)))
    # The clusters are sorted by size for simplicity/efficiency
    tuple_accepted = zip(
        *sorted(zip(sft_accepted_on_size, filename_accepted_on_size),
                key=lambda x: len(x[0]),
                reverse=True))
    sft_accepted_on_size, filename_accepted_on_size = tuple_accepted

    # Initialize the actors, scene, window, observer
    concat_streamlines_actor = actor.line(concat_streamlines,
                                          colors=(1, 1, 1),
                                          opacity=args.background_opacity,
                                          linewidth=background_linewidth)
    curr_streamlines_actor = actor.line(sft_accepted_on_size[0].streamlines,
                                        opacity=0.8,
                                        linewidth=clusters_linewidth)

    scene = window.Scene()
    interactor_style = interactor.CustomInteractorStyle()
    show_manager = window.ShowManager(scene,
                                      size=(800, 800),
                                      reset_camera=False,
                                      interactor_style=interactor_style)
    scene.add(concat_streamlines_actor)
    scene.add(curr_streamlines_actor)
    interactor_style.AddObserver('KeyPressEvent', keypress_callback)

    # Lauch rendering and selection procedure
    choices, accepted_streamlines, rejected_streamlines = [], [], []
    show_curr_actor = True
    show_manager.start()

    # Early exit means everything else is rejected
    missing = len(args.in_bundles) - len(choices) - len(sft_rejected_on_size)
    len_accepted = len(sft_accepted_on_size)
    rejected_streamlines.extend(range(len_accepted - missing, len_accepted))
    if missing > 0:
        logging.info('{} clusters automatically rejected'
                     'from early exit'.format(missing))

    # Save accepted clusters (by GUI)
    accepted_streamlines = save_clusters(sft_accepted_on_size,
                                         accepted_streamlines,
                                         args.out_accepted_dir,
                                         filename_accepted_on_size)

    accepted_sft = StatefulTractogram(accepted_streamlines,
                                      sft_accepted_on_size[0], Space.RASMM)
    save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False)

    # Save rejected clusters (by GUI)
    rejected_streamlines = save_clusters(sft_accepted_on_size,
                                         rejected_streamlines,
                                         args.out_rejected_dir,
                                         filename_accepted_on_size)

    # Save rejected clusters (by size)
    rejected_streamlines.extend(
        save_clusters(sft_rejected_on_size, range(len(sft_rejected_on_size)),
                      args.out_rejected_dir, filename_rejected_on_size))

    rejected_sft = StatefulTractogram(rejected_streamlines,
                                      sft_accepted_on_size[0], Space.RASMM)
    save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
예제 #8
0
from dipy.data import small_sphere
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_trk

fod = csd_fit.odf(small_sphere)
pmf = fod.clip(min=0)
prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf,
                                                max_angle=30.,
                                                sphere=small_sphere)
streamline_generator = LocalTracking(prob_dg,
                                     stopping_criterion,
                                     seeds,
                                     affine,
                                     step_size=.5)
streamlines = Streamlines(streamline_generator)
sft = StatefulTractogram(streamlines, hardi_img, Space.RASMM)
save_trk(sft, "tractogram_probabilistic_dg_pmf.trk")

if has_fury:
    scene = window.Scene()
    scene.add(actor.line(streamlines, colormap.line_colors(streamlines)))
    window.record(scene,
                  out_path='tractogram_probabilistic_dg_pmf.png',
                  size=(800, 800))
    if interactive:
        window.show(scene)
"""
.. figure:: tractogram_probabilistic_dg_pmf.png
   :align: center

   **Corpus Callosum using probabilistic direction getter from PMF**
예제 #9
0
def test_track_ensemble_particle():
    """
    Test for ensemble tractography functionality
    """
    import tempfile
    from pynets.dmri import track
    from dipy.core.gradients import gradient_table
    from dipy.data import get_sphere
    from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin
    from dipy.io.streamline import save_tractogram
    from nibabel.streamlines.array_sequence import ArraySequence

    base_dir = str(Path(__file__).parent / "examples")
    B0_mask = f"{base_dir}/003/anat/mean_B0_bet_mask_tmp.nii.gz"
    gm_in_dwi = f"{base_dir}/003/anat/t1w_gm_in_dwi.nii.gz"
    vent_csf_in_dwi = f"{base_dir}/003/anat/t1w_vent_csf_in_dwi.nii.gz"
    wm_in_dwi = f"{base_dir}/003/anat/t1w_wm_in_dwi.nii.gz"
    dir_path = f"{base_dir}/003/dmri"
    bvals = f"{dir_path}/sub-003_dwi.bval"
    bvecs = f"{base_dir}/003/test_out/003/dwi/bvecs_reor.bvec"
    gtab = gradient_table(bvals, bvecs)
    dwi_file = f"{base_dir}/003/test_out/003/dwi/sub-003_dwi_reor-RAS_res-2mm.nii.gz"
    atlas_data_wm_gm_int = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track_wmgm_int.nii.gz"
    labels_im_file = f"{dir_path}/whole_brain_cluster_labels_PCA200_dwi_track.nii.gz"
    conn_model = 'csd'
    tiss_class = 'cmc'
    min_length = 10
    maxcrossing = 2
    roi_neighborhood_tol = 6
    waymask = None
    curv_thr_list = [40, 30]
    step_list = [0.1, 0.2, 0.3, 0.4, 0.5]
    sphere = get_sphere('repulsion724')
    directget = 'prob'
    track_type = 'particle'
    target_samples = 1000

    dwi_img = nib.load(dwi_file)
    dwi_data = dwi_img.get_fdata()

    model, _ = track.reconstruction(conn_model, gtab, dwi_data, wm_in_dwi)
    temp_dir = tempfile.TemporaryDirectory()
    recon_path = temp_dir.name + '/model_file.hdf5'

    with h5py.File(recon_path, 'w') as hf:
        hf.create_dataset("reconstruction", data=model.astype('float32'))
    hf.close()

    streamlines = track.track_ensemble(
        target_samples, atlas_data_wm_gm_int, labels_im_file, recon_path,
        sphere, directget, curv_thr_list, step_list, track_type, maxcrossing,
        roi_neighborhood_tol, min_length, waymask, B0_mask, gm_in_dwi,
        gm_in_dwi, vent_csf_in_dwi, wm_in_dwi, tiss_class, temp_dir.name)

    streams = f"{base_dir}/miscellaneous/streamlines_model-csd_nodetype-parc_samples-1000streams_tracktype-particle_directget-prob_minlength-10.trk"
    save_tractogram(StatefulTractogram(streamlines,
                                       reference=dwi_img,
                                       space=Space.VOXMM,
                                       origin=Origin.NIFTI),
                    streams,
                    bbox_valid_check=False)
    assert isinstance(streamlines, ArraySequence)
예제 #10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser,
                        [args.in_hdf5, args.in_target_file, args.in_transfo],
                        args.in_deformation)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    with h5py.File(args.in_hdf5, 'r') as in_hdf5_file:
        with h5py.File(args.out_hdf5, 'a') as out_hdf5_file:
            transfo = load_matrix_in_any_format(args.in_transfo)

            deformation_data = None
            if args.in_deformation is not None:
                deformation_data = np.squeeze(
                    nib.load(args.in_deformation).get_fdata(dtype=np.float32))
            target_img = nib.load(args.in_target_file)

            for key in in_hdf5_file.keys():
                group = out_hdf5_file.create_group(key)
                affine = in_hdf5_file.attrs['affine']
                dimensions = in_hdf5_file.attrs['dimensions']
                voxel_sizes = in_hdf5_file.attrs['voxel_sizes']
                streamlines = reconstruct_streamlines_from_hdf5(
                    in_hdf5_file, key)

                if len(streamlines) == 0:
                    continue
                header = create_nifti_header(affine, dimensions, voxel_sizes)
                moving_sft = StatefulTractogram(streamlines,
                                                header,
                                                Space.VOX,
                                                origin=Origin.TRACKVIS)
                for dps_key in in_hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        print(type(in_hdf5_file[key][dps_key].value))
                        if in_hdf5_file[key][dps_key].value.shape \
                                == in_hdf5_file[key]['offsets']:
                            moving_sft.data_per_streamline[dps_key] \
                                = in_hdf5_file[key][dps_key]

                new_sft = transform_warp_streamlines(
                    moving_sft,
                    transfo,
                    target_img,
                    inverse=args.inverse,
                    deformation_data=deformation_data,
                    remove_invalid=not args.cut_invalid,
                    cut_invalid=args.cut_invalid)
                new_sft.to_vox()
                new_sft.to_corner()

                affine, dimensions, voxel_sizes, voxel_order = get_reference_info(
                    target_img)
                out_hdf5_file.attrs['affine'] = affine
                out_hdf5_file.attrs['dimensions'] = dimensions
                out_hdf5_file.attrs['voxel_sizes'] = voxel_sizes
                out_hdf5_file.attrs['voxel_order'] = voxel_order

                group = out_hdf5_file[key]
                group.create_dataset('data',
                                     data=new_sft.streamlines._data.astype(
                                         np.float32))
                group.create_dataset('offsets',
                                     data=new_sft.streamlines._offsets)
                group.create_dataset('lengths',
                                     data=new_sft.streamlines._lengths)
                for dps_key in in_hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        if in_hdf5_file[key][dps_key].value.shape \
                                == in_hdf5_file[key]['offsets']:
                            group.create_dataset(
                                dps_key,
                                data=new_sft.data_per_streamline[dps_key])
                        else:
                            group.create_dataset(
                                dps_key, data=in_hdf5_file[key][dps_key].value)