示例#1
0
def main():
    """Parse arguments, generate hdf5 dataset and save it on disk."""
    p = _parse_args()
    args = p.parse_args()

    # Initialize logger
    logging.basicConfig(level=str(args.logging).upper())

    # Silencing SFT's logger if our logging is in DEBUG mode, because it
    # typically produces a lot of outputs!
    set_sft_logger_level('WARNING')

    # Verify that dwi_ml_ready folder is found
    if not Path(args.dwi_ml_ready_folder).is_dir():
        raise ValueError('The dwi_ml_ready folder was not found: {}'.format(
            args.dwi_ml_ready_folder))
    assert_inputs_exist(
        p, [args.config_file],
        [args.training_subjs, args.validation_subjs, args.testing_subjs])
    # check hdf extension
    _, ext = os.path.splitext(args.out_hdf5_file)
    if ext == '':
        args.out_hdf5_file += '.hdf5'
    elif ext != '.hdf5':
        raise p.error("The hdf5 file's extension should be .hdf5, but "
                      "received {}".format(ext))
    assert_outputs_exist(p, args, args.out_hdf5_file)

    # Prepare creator and load config file.
    creator = prepare_hdf5_creator(args)

    # Create dataset from config and save
    with Timer("\nCreating database...", newline=True, color='green'):
        creator.create_database()
示例#2
0
def main():

    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram, args.save_rejected)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        # Silencing SFT's logger if our logging is in DEBUG mode, because it
        # typically produces a lot of outputs!
        set_sft_logger_level('WARNING')

    if args.min_x == 0 and np.isinf(args.max_x) and \
       args.min_y == 0 and np.isinf(args.max_y) and \
       args.min_z == 0 and np.isinf(args.max_z):
        logging.warning("You have not specified min or max in any direction. "
                        "Output will simply be a copy of your input!")

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    computed_rejected_sft = args.save_rejected is not None
    new_sft, indices, rejected_sft = \
        filter_streamlines_by_total_length_per_dim(
            sft, [args.min_x, args.max_x], [args.min_y, args.max_y],
            [args.min_z, args.max_z], args.use_abs, computed_rejected_sft)

    if args.display_counts:
        sc_bf = len(sft.streamlines)
        sc_af = len(new_sft.streamlines)
        print(
            json.dumps(
                {
                    'streamline_count_before_filtering': int(sc_bf),
                    'streamline_count_after_filtering': int(sc_af)
                },
                indent=args.indent))

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written "
                          "(0 streamline).".format(args.out_tractogram))
        else:
            logging.debug('The file {} contains 0 streamline'.format(
                args.out_tractogram))

    save_tractogram(new_sft, args.out_tractogram)

    if computed_rejected_sft:
        save_tractogram(rejected_sft, args.save_rejected)
    def as_sft(self, streamline_ids=None):
        """
        Returns chosen streamlines in a StatefulTractogram format.

        Params
        ------
        streamline_ids: list[int]
            List of chosen ids. If None, use all streamlines.
        """
        set_sft_logger_level('WARNING')
        streamlines = self._subset_streamlines(streamline_ids)

        sft = StatefulTractogram(streamlines, self.space_attributes,
                                 self.space)
        return sft
示例#4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
示例#5
0
def process_streamlines(bundles_dir: Path, bundles, header: nib.Nifti1Header,
                        step_size: float, space: Space):
    """Load and process a group of bundles and merge all streamlines
    together.

    Parameters
    ----------
    bundles_dir : Path
        Path to bundles folder.
    bundles: List[str]
        List of the bundles filenames to load. If none, all bundles will be
        used.
    header : nib.Nifti1Header
        Reference used to load and send the streamlines in voxel space and to
        create final merged SFT. If the file is a .trk, 'same' is used instead.
    step_size: float
        Step size to resample streamlines. If none, compress streamlines.
    space: Space
        Space to place the tractograms.

    Returns
    -------
    final_tractogram : StatefulTractogram
        All streamlines in voxel space.
    output_lengths : List[float]
        The euclidean length of each streamline
    """
    # Silencing SFT's logger if our logging is in DEBUG mode, because it
    # typically produces a lot of outputs!
    set_sft_logger_level('WARNING')

    # Initialize
    final_tractogram = None
    output_lengths = []

    # If not bundles in the config, taking everything in the subject's folder
    if bundles is None:
        # Take everything found in subject bundle folder
        bundles = [str(p) for p in bundles_dir.glob('*')]

    for bundle_name in bundles:
        # Find bundle name
        logging.debug('      *Loading bundle {}'.format(bundle_name))

        # Completing name, ex if no extension was given or to allow suffixes
        bundle_name = str(bundles_dir.joinpath(bundle_name + '*'))
        bundle_complete_name = glob.glob(bundle_name)
        if len(bundle_complete_name) == 0:
            logging.debug("      Skipping bundle {} because it was not found "
                          "in this subject's folder".format(bundle_name))
            # Note: if args.enforce_bundles_presence was set to true, this case
            # is not possible, already checked in create_hdf5_dataset.py.
        else:
            bundle_complete_name = bundle_complete_name[0]

            # Check bundle extension
            _, file_extension = os.path.splitext(bundle_complete_name)
            if file_extension not in ['.trk', '.tck']:
                raise ValueError("We do not support bundle's type: {}. We "
                                 "only support .trk and .tck files.".format(
                                     bundle_complete_name))
            if file_extension == '.trk':
                header = 'same'

            # Loading bundle and sending to wanted space
            bundle = load_tractogram(bundle_complete_name, header)
            bundle.to_center()

            # Resample or compress streamlines
            # Note. No matter the chosen space, resampling is done in mm.
            if step_size:
                logging.debug('      *Resampling')
                bundle = resample_streamlines_step_size(bundle, step_size)
                logging.debug(
                    "      *Resampled streamlines' step size to {}mm".format(
                        step_size))
            else:
                logging.debug('      *Compressing')
                bundle = compress_sft(bundle)

            # Compute euclidean lengths (rasmm space)
            bundle.to_space(Space.RASMM)
            output_lengths.extend(length(bundle.streamlines))

            # Sending to wanted space
            bundle.to_space(space)

            # Add processed bundle to output tractogram
            if final_tractogram is None:
                final_tractogram = bundle
            else:
                final_tractogram = concatenate_sft([final_tractogram, bundle],
                                                   erase_metadata=False)

    # Removing invalid streamlines
    logging.debug('      *Total: {:,.0f} streamlines. Now removing invalid '
                  'streamlines.'.format(len(final_tractogram)))
    final_tractogram.remove_invalid_streamlines()
    logging.debug("      *Remaining: {:,.0f} streamlines."
                  "".format(len(final_tractogram)))

    return final_tractogram, output_lengths
示例#6
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    set_sft_logger_level('ERROR')
    assert_inputs_exist(parser, [args.in_bundle, args.in_centroid],
                        optional=args.reference)
    assert_outputs_exist(parser,
                         args,
                         args.out_labels_map,
                         optional=[
                             args.out_labels_npz, args.out_distances_npz,
                             args.labels_color_dpp, args.distances_color_dpp
                         ])

    sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft_centroid = load_tractogram_with_reference(parser, args,
                                                  args.in_centroid)

    if not len(sft_bundle.streamlines):
        logging.error('Empty bundle file {}. '
                      'Skipping'.format(args.in_bundle))
        raise ValueError

    if len(sft_centroid.streamlines) < 1 \
            or len(sft_centroid.streamlines) > 1:
        logging.error('Centroid file {} should contain one streamline. '
                      'Skipping'.format(args.in_centroid))
        raise ValueError

    if not is_header_compatible(sft_centroid, sft_bundle):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_centroid, args.in_bundle))

    sft_bundle.to_vox()
    sft_bundle.to_corner()

    # Slightly cut the bundle at the edgge to clean up single streamline voxels
    # with no neighbor. Remove isolated voxels to keep a single 'blob'
    binary_bundle = compute_tract_counts_map(
        sft_bundle.streamlines, sft_bundle.dimensions).astype(bool)

    structure = ndi.generate_binary_structure(3, 1)
    if np.count_nonzero(binary_bundle) > args.min_voxel_count \
            and len(sft_bundle) > args.min_streamline_count:
        binary_bundle = ndi.binary_dilation(binary_bundle,
                                            structure=np.ones((3, 3, 3)))
        binary_bundle = ndi.binary_erosion(binary_bundle,
                                           structure=structure,
                                           iterations=2)

        bundle_disjoint, _ = ndi.label(binary_bundle)
        unique, count = np.unique(bundle_disjoint, return_counts=True)
        val = unique[np.argmax(count[1:]) + 1]
        binary_bundle[bundle_disjoint != val] = 0

        # Chop off some streamlines
        cut_sft = cut_outside_of_mask_streamlines(sft_bundle, binary_bundle)
    else:
        cut_sft = sft_bundle

    if args.nb_pts is not None:
        sft_centroid = resample_streamlines_num_points(sft_centroid,
                                                       args.nb_pts)
    else:
        args.nb_pts = len(sft_centroid.streamlines[0])

    # Generate a centroids labels mask for the centroid alone
    sft_centroid.to_vox()
    sft_centroid.to_corner()
    sft_centroid = _affine_slr(sft_bundle, sft_centroid)

    # Map every streamlines points to the centroids
    binary_centroid = compute_tract_counts_map(
        sft_centroid.streamlines, sft_centroid.dimensions).astype(bool)
    # TODO N^2 growth in RAM, should split it if we want to do nb_pts = 100
    min_dist_label, min_dist = min_dist_to_centroid(
        cut_sft.streamlines._data, sft_centroid.streamlines._data)
    min_dist_label += 1  # 0 means no labels

    # It is not allowed that labels jumps labels for consistency
    # Streamlines should have continous labels
    curr_ind = 0
    final_streamlines = []
    final_label = []
    final_dist = []
    for i, streamline in enumerate(cut_sft.streamlines):
        next_ind = curr_ind + len(streamline)
        curr_labels = min_dist_label[curr_ind:next_ind]
        curr_dist = min_dist[curr_ind:next_ind]
        curr_ind = next_ind

        # Flip streamlines so the labels increase (facilitate if/else)
        # Should always be ordered in nextflow pipeline
        gradient = np.gradient(curr_labels)
        if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)):
            streamline = streamline[::-1]
            curr_labels = curr_labels[::-1]
            curr_dist = curr_dist[::-1]

        # Find jumps, cut them and find the longest
        gradient = np.ediff1d(curr_labels)
        max_jump = max(args.nb_pts // 5, 1)
        if len(np.argwhere(np.abs(gradient) > max_jump)) > 0:
            pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1
            split_chunk = np.split(curr_labels, pos_jump)
            max_len = 0
            max_pos = 0
            for j, chunk in enumerate(split_chunk):
                if len(chunk) > max_len:
                    max_len = len(chunk)
                    max_pos = j

            curr_labels = split_chunk[max_pos]
            gradient_chunk = np.ediff1d(chunk)
            if len(np.unique(np.sign(gradient_chunk))) > 1:
                continue
            streamline = np.split(streamline, pos_jump)[max_pos]
            curr_dist = np.split(curr_dist, pos_jump)[max_pos]

        final_streamlines.append(streamline)
        final_label.append(curr_labels)
        final_dist.append(curr_dist)

    # Re-arrange the new cut streamlines and their metadata
    # Compute the voxels equivalent of the labels maps
    new_sft = StatefulTractogram.from_sft(final_streamlines, sft_bundle)

    tdi_mask_nzr = np.nonzero(binary_bundle)
    tdi_mask_nzr_ind = np.transpose(tdi_mask_nzr)
    min_dist_ind, _ = min_dist_to_centroid(tdi_mask_nzr_ind,
                                           sft_centroid.streamlines[0])
    img_labels = np.zeros(binary_centroid.shape, dtype=np.int16)
    img_labels[tdi_mask_nzr] = min_dist_ind + 1  # 0 is background value

    nib.save(nib.Nifti1Image(img_labels, sft_bundle.affine),
             args.out_labels_map)

    if args.labels_color_dpp or args.distances_color_dpp \
            or args.out_labels_npz or args.out_distances_npz:
        labels_array = ArraySequence(final_label)
        dist_array = ArraySequence(final_dist)
        # WARNING: WILL NOT WORK WITH THE INPUT TRK !
        # These will fit only with the TRK saved below.
        if args.out_labels_npz:
            np.savez_compressed(args.out_labels_npz, labels_array._data)
        if args.out_distances_npz:
            np.savez_compressed(args.out_distances_npz, dist_array._data)

        cmap = plt.get_cmap(args.colormap)
        new_sft.data_per_point['color'] = ArraySequence(new_sft.streamlines)

        # Nicer visualisation for MI-Brain
        if args.labels_color_dpp:
            new_sft.data_per_point['color']._data = cmap(
                labels_array._data / np.max(labels_array._data))[:, 0:3] * 255
            save_tractogram(new_sft, args.labels_color_dpp)

        if args.distances_color_dpp:
            new_sft.data_per_point['color']._data = cmap(
                dist_array._data / np.max(dist_array._data))[:, 0:3] * 255
            save_tractogram(new_sft, args.distances_color_dpp)
示例#7
0
    def _process_one_streamline_group(
            self, subj_dir: Path, group: str, subj_id: str,
            header: nib.Nifti1Header):
        """
        Loads and processes a group of tractograms and merges all streamlines
        together.

        Note. Wildcards will be replaced by the subject id. If the list is
        folder/ALL, all tractograms in the folder will be used.

        Parameters
        ----------
        subj_dir : Path
            Path to tractograms folder.
        group: str
            group name
        subj_id: str
            The subject's id.
        header : nib.Nifti1Header
            Reference used to load and send the streamlines in voxel space and
            to create final merged SFT. If the file is a .trk, 'same' is used
            instead.

        Returns
        -------
        final_tractogram : StatefulTractogram
            All streamlines in voxel space.
        output_lengths : List[float]
            The euclidean length of each streamline
        """
        tractograms = self.groups_config[group]['files']

        if self.step_size and self.compress:
            raise ValueError(
                "Only one option can be chosen: either resampling to "
                "step_size or compressing, not both.")

        # Silencing SFT's logger if our logging is in DEBUG mode, because it
        # typically produces a lot of outputs!
        set_sft_logger_level('WARNING')

        # Initialize
        final_sft = None
        output_lengths = []

        for instructions in tractograms:
            if instructions.endswith('/ALL'):
                # instructions is to get all tractograms in given folder.
                tractograms_dir = instructions.split('/ALL')
                tractograms_dir = ''.join(tractograms_dir[:-1])
                tractograms_sublist = [
                    instructions.replace('/ALL', '/' + os.path.basename(p))
                    for p in subj_dir.glob(tractograms_dir + '/*')]
            else:
                # instruction is to get one specific tractogram
                tractograms_sublist = [instructions]

            # Either a loop on "ALL" or a loop on only one file.
            for tractogram_name in tractograms_sublist:
                tractogram_name = tractogram_name.replace('*', subj_id)
                tractogram_file = subj_dir.joinpath(tractogram_name)

                sft = self._load_and_process_sft(
                    tractogram_file, tractogram_name, header)

                if sft is not None:
                    # Compute euclidean lengths (rasmm space)
                    sft.to_space(Space.RASMM)
                    output_lengths.extend(length(sft.streamlines))

                    # Sending to wanted space
                    sft.to_space(self.space)

                    # Add processed tractogram to final big tractogram
                    if final_sft is None:
                        final_sft = sft
                    else:
                        final_sft = concatenate_sft([final_sft, sft],
                                                    erase_metadata=False)

        if self.save_intermediate:
            output_fname = self.intermediate_folder.joinpath(
                subj_id + '_' + group + '.trk')
            logging.debug('      *Saving intermediate streamline group {} '
                          'into {}.'.format(group, output_fname))
            # Note. Do not remove the str below. Does not work well
            # with Path.
            save_tractogram(final_sft, str(output_fname))

        # Removing invalid streamlines
        logging.debug('      *Total: {:,.0f} streamlines. Now removing '
                      'invalid streamlines.'.format(len(final_sft)))
        final_sft.remove_invalid_streamlines()
        logging.debug("      *Remaining: {:,.0f} streamlines."
                      "".format(len(final_sft)))

        return final_sft, output_lengths
示例#8
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    inputs = [args.in_odf, args.in_seed, args.in_mask]
    assert_inputs_exist(parser, inputs)
    assert_outputs_exist(parser, args, args.out_tractogram)

    verify_streamline_length_options(parser, args)
    verify_compression_th(args.compress)
    verify_seed_options(parser, args)

    theta = gm.math.radians(get_theta(args.theta, args.algo))

    max_nbr_pts = int(args.max_length / args.step_size)
    min_nbr_pts = int(args.min_length / args.step_size) + 1
    max_invalid_dirs = int(math.ceil(args.max_invalid_length / args.step_size))

    logging.debug("Loading seeding mask.")
    seed_img = nib.load(args.in_seed)
    seed_data = seed_img.get_fdata(caching='unchanged', dtype=float)
    seed_res = seed_img.header.get_zooms()[:3]
    seed_generator = SeedGenerator(seed_data, seed_res)
    if args.npv:
        # toDo. This will not really produce n seeds per voxel, only true
        #  in average.
        nbr_seeds = len(seed_generator.seeds) * args.npv
    elif args.nt:
        nbr_seeds = args.nt
    else:
        # Setting npv = 1.
        nbr_seeds = len(seed_generator.seeds)
    if len(seed_generator.seeds) == 0:
        parser.error(
            'Seed mask "{}" does not have any voxel with value > 0.'.format(
                args.in_seed))

    logging.debug("Loading tracking mask.")
    mask_img = nib.load(args.in_mask)
    mask_data = mask_img.get_fdata(caching='unchanged', dtype=float)
    mask_res = mask_img.header.get_zooms()[:3]
    mask = DataVolume(mask_data, mask_res, args.mask_interp)

    logging.debug("Loading ODF SH data.")
    odf_sh_img = nib.load(args.in_odf)
    odf_sh_data = odf_sh_img.get_fdata(caching='unchanged', dtype=float)
    odf_sh_res = odf_sh_img.header.get_zooms()[:3]
    dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp)

    logging.debug("Instantiating propagator.")
    propagator = ODFPropagator(dataset, args.step_size, args.rk_order,
                               args.algo, args.sh_basis, args.sf_threshold,
                               args.sf_threshold_init, theta, args.sphere)

    logging.debug("Instantiating tracker.")
    tracker = Tracker(propagator,
                      mask,
                      seed_generator,
                      nbr_seeds,
                      min_nbr_pts,
                      max_nbr_pts,
                      max_invalid_dirs,
                      compression_th=args.compress,
                      nbr_processes=args.nbr_processes,
                      save_seeds=args.save_seeds,
                      mmap_mode='r+',
                      rng_seed=args.rng_seed,
                      track_forward_only=args.forward_only,
                      skip=args.skip)

    start = time.time()
    logging.debug("Tracking...")
    streamlines, seeds = tracker.track()

    str_time = "%.2f" % (time.time() - start)
    logging.debug("Tracked {} streamlines (out of {} seeds), in {} seconds.\n"
                  "Now saving...".format(len(streamlines), nbr_seeds,
                                         str_time))

    # save seeds if args.save_seeds is given
    data_per_streamline = {'seeds': seeds} if args.save_seeds else {}

    # Silencing SFT's logger if our logging is in DEBUG mode, because it
    # typically produces a lot of outputs!
    set_sft_logger_level('WARNING')

    # Compared with scil_compute_local_tracking, using sft rather than
    # LazyTractogram to deal with space.
    # Contrary to scilpy or dipy, where space after tracking is vox, here
    # space after tracking is voxmm.
    # Smallest possible streamline coordinate is (0,0,0), equivalent of
    # corner origin (TrackVis)
    sft = StatefulTractogram(streamlines,
                             mask_img,
                             Space.VOXMM,
                             Origin.TRACKVIS,
                             data_per_streamline=data_per_streamline)
    save_tractogram(sft, args.out_tractogram)
def main():
    parser = build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.labels])

    if os.path.abspath(args.output_dir) == os.getcwd():
        parser.error('Do not use the current path as output directory.')

    assert_output_dirs_exist_and_empty(parser, args, args.output_dir,
                                       create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.labels)
    data_labels = img_labels.get_fdata().astype(np.int16)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    if not np.issubdtype(img_labels.get_data_dtype().type, np.integer):
        parser.error("Label image should contain integers for labels.")

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.mean(vox_sizes) == vox_sizes[0]:
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    logging.info('*** Filtering streamlines ***')
    data_mask = np.zeros(data_labels.shape)
    data_mask[data_labels > 0] = 1

    original_len = len(sft)
    time1 = time.time()

    sft.to_vox()
    sft.to_corner()
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info(
        '    Discarded {} streamlines from filtering in {} sec.'.format(
            original_len - len(sft), round(time2 - time1, 2)))
    logging.info('    Number of streamlines to process: {}'.format(len(sft)))

    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices,
                                    data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    for in_label, out_label in comb_list:
        if iteration_counter > 0 and iteration_counter % 100 == 0:
            logging.info('Split {} nodes out of {}'.format(iteration_counter,
                                                           len(comb_list)))
        iteration_counter += 1

        pair_info = []
        if in_label not in con_info:
            continue
        elif out_label in con_info[in_label]:
            pair_info.extend(con_info[in_label][out_label])

        if out_label not in con_info:
            continue
        elif in_label in con_info[out_label]:
            pair_info.extend(con_info[out_label][in_label])

        if not len(pair_info):
            continue

        connecting_streamlines = []
        for connection in pair_info:
            strl_idx = connection['strl_idx']
            curr_streamlines = compute_streamline_segment(
                sft.streamlines[strl_idx],
                indices[strl_idx],
                connection['in_idx'],
                connection['out_idx'],
                points_to_idx[strl_idx])
            connecting_streamlines.append(curr_streamlines)

        _save_if_needed(connecting_streamlines, args, 'raw',
                        'raw', in_label, out_label)

        # Doing all post-processing
        if not args.no_pruning:
            valid_length, invalid_length = _prune_segments(
                connecting_streamlines,
                args.min_length,
                args.max_length,
                vox_sizes[0])

            _save_if_needed(invalid_length, args,
                            'discarded', 'invalid_length',
                            in_label, out_label)
        else:
            valid_length = connecting_streamlines

        if not len(valid_length):
            continue

        _save_if_needed(valid_length, args,
                        'intermediate', 'valid_length', in_label, out_label)

        if not args.no_remove_loops:
            no_loop_ids = remove_loops_and_sharp_turns(valid_length,
                                                       args.loop_max_angle)
            no_loops = [valid_length[i] for i in no_loop_ids]

            loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids)
            loops = [valid_length[i] for i in loop_ids]

            _save_if_needed(loops, args,
                            'discarded', 'loops', in_label, out_label)
        else:
            no_loops = valid_length

        if not len(no_loops):
            continue

        _save_if_needed(no_loops, args,
                        'intermediate', 'no_loops', in_label, out_label)

        if not args.no_remove_outliers:
            inliers, outliers = remove_outliers(no_loops,
                                                args.outlier_threshold)
            _save_if_needed(outliers, args,
                            'discarded', 'outliers', in_label, out_label)
        else:
            inliers = no_loops

        if not len(inliers):
            continue

        _save_if_needed(inliers, args,
                        'intermediate', 'inliers', in_label, out_label)

        if not args.no_remove_curv_dev:
            no_qb_curv_ids = remove_loops_and_sharp_turns(
                inliers,
                args.loop_max_angle,
                use_qb=True,
                qb_threshold=args.curv_qb_distance)
            no_qb_curv = [inliers[i] for i in no_qb_curv_ids]

            qb_curv_ids = np.setdiff1d(
                np.arange(len(inliers)), no_qb_curv_ids)
            qb_curv = [inliers[i] for i in qb_curv_ids]

            _save_if_needed(qb_curv, args,
                            'discarded', 'qb_curv', in_label, out_label)
        else:
            no_qb_curv = inliers

        _save_if_needed(no_qb_curv, args,
                        'final', 'final', in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
示例#10
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    logging.basicConfig(level=args.logging.upper())

    # ----- Checks
    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or '
                     'tck): {0}'.format(args.out_tractogram))

    assert_inputs_exist(parser, args.hdf5_file)
    assert_outputs_exist(parser, args, args.out_tractogram)

    verify_streamline_length_options(parser, args)
    verify_compression_th(args.compress)
    verify_seed_options(parser, args)

    # ----- Prepare values

    max_nbr_pts = int(args.max_length / args.step_size)
    min_nbr_pts = int(args.min_length / args.step_size) + 1
    max_invalid_dirs = int(math.ceil(args.max_invalid_len / args.step_size))

    # r+ is necessary for interpolation function in cython who need read/write
    # rights
    mmap_mode = None if args.set_mmap_to_none else 'r+'

    device = torch.device('cpu')
    if args.use_gpu:
        if args.nbr_processes > 1:
            logging.warning("Number of processes was set to {} but you "
                            "are using GPU. Parameter ignored.".format(
                                args.nbr_processes))
        if torch.cuda.is_available():
            device = torch.device('cuda')

    hdf_handle = h5py.File(args.hdf5_file, 'r')

    tracker, ref = prepare_tracker(parser, args, hdf_handle, device,
                                   min_nbr_pts, max_nbr_pts, max_invalid_dirs,
                                   mmap_mode)

    # ----- Track

    with Timer("\nTracking...", newline=True, color='blue'):
        streamlines, seeds = tracker.track()

        logging.debug(
            "Tracked {} streamlines (out of {} seeds). Now saving...".format(
                len(streamlines), tracker.nbr_seeds))

    # save seeds if args.save_seeds is given
    data_per_streamline = {'seed': lambda: seeds} if args.save_seeds else {}

    # Silencing SFT's logger if our logging is in DEBUG mode, because it
    # typically produces a lot of outputs!
    set_sft_logger_level('WARNING')

    sft = StatefulTractogram(streamlines,
                             ref,
                             Space.VOXMM,
                             data_per_streamline=data_per_streamline)
    save_tractogram(sft, args.out_tractogram, bbox_valid_check=False)