Exemplo n.º 1
0
 def streamlines(self):
     """ """
     streamlines = ArraySequence()
     streamlines._data = np.array(self._zpos)
     streamlines._offsets = np.array(self._zoff)
     streamlines._lengths = compute_lengths(streamlines._offsets,
                                            self.nb_points)
     return streamlines
    def __init__(self, nb_vertices=None, nb_streamlines=None, init_as=None,
                 reference=None):
        """ Initialize an empty TrxFile, support preallocation """
        if init_as is not None:
            affine = init_as.header['VOXEL_TO_RASMM']
            dimensions = init_as.header['DIMENSIONS']
        elif reference is not None:
            affine, dimensions, _, _ = get_reference_info(reference)
        else:
            logging.debug('No reference provided, using blank space '
                          'attributes, please update them later.')
            affine = np.eye(4).astype(np.float32)
            dimensions = np.array([1, 1, 1], dtype=np.uint16)

        if nb_vertices is None and nb_streamlines is None:
            if init_as is not None:
                raise ValueError('Cant use init_as without declaring '
                                 'nb_vertices AND nb_streamlines')
            logging.debug('Intializing empty TrxFile.')
            self.header = {}
            # Using the new format default type
            tmp_strs = ArraySequence()
            tmp_strs._data = tmp_strs._data.astype(np.float16)
            tmp_strs._offsets = tmp_strs._offsets.astype(np.uint64)
            tmp_strs._lengths = tmp_strs._lengths.astype(np.uint32)
            self.streamlines = tmp_strs
            self.groups = {}
            self.data_per_streamline = {}
            self.data_per_vertex = {}
            self.data_per_group = {}
            self._uncompressed_folder_handle = None

            nb_vertices = 0
            nb_streamlines = 0

        elif nb_vertices is not None and nb_streamlines is not None:
            logging.debug('Preallocating TrxFile with size {} streamlines'
                          'and {} vertices.'.format(nb_streamlines, nb_vertices))
            trx = self._initialize_empty_trx(nb_streamlines, nb_vertices,
                                             init_as=init_as)
            self.__dict__ = trx.__dict__
        else:
            raise ValueError('You must declare both nb_vertices AND '
                             'NB_STREAMLINES')

        self.header['VOXEL_TO_RASMM'] = affine
        self.header['DIMENSIONS'] = dimensions
        self.header['NB_VERTICES'] = nb_vertices
        self.header['NB_STREAMLINES'] = nb_streamlines
        self._copy_safe = True
Exemplo n.º 3
0
    def select(self, indices, keep_group=True):
        """ Get a subset of items, always points to the same memmaps """
        indices = np.array(indices, np.uint32)
        if len(indices) and (np.max(indices) > self.nb_streamlines - 1
                             or np.min(indices) < 0):
            raise ValueError('Invalid indices.')

        new_trx = TrxFile(init_as=self)
        if len(indices) == 0:
            new_trx.prune_metadata()
            return new_trx

        tmp_streamlines = self.streamlines
        new_trx._zpos.append(tmp_streamlines[indices].get_data())
        new_offsets = np.cumsum(tmp_streamlines[indices]._lengths[:-1])
        new_trx._zoff.append(np.concatenate(([0], new_offsets)))

        for dpp_key in self._zdpp.array_keys():
            tmp_dpp = ArraySequence()
            tmp_dpp._data = np.array(self._zdpp[dpp_key])
            tmp_dpp._offsets = tmp_streamlines._offsets
            tmp_dpp._lengths = tmp_streamlines._lengths

            new_trx._zdpp[dpp_key].append(tmp_dpp[indices].get_data())

        for dps_key in self._zdps.array_keys():
            new_trx._zdps[dps_key].append(
                np.array(self._zdps[dps_key])[indices])

        if keep_group:
            for grp_key in self._zgrp.array_keys():
                new_group = intersect_groups(self._zgrp[grp_key], indices)

                if len(new_group):
                    new_trx._zgrp[grp_key].append(new_group)
                else:
                    del new_trx._zcontainer['groups'][grp_key]

        for grp_key in self._zdpg.group_keys():
            if grp_key in new_trx._zgrp:
                for dpg_key in self._zdpg[grp_key].array_keys():
                    new_trx._zdpg[grp_key][dpg_key].append(
                        self._zdpg[grp_key][dpg_key])

        new_trx.nb_streamlines = len(new_trx._zoff)
        new_trx.nb_points = len(new_trx._zpos)
        new_trx.prune_metadata()

        return new_trx
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_transfo])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if args.verbose:
        log_level = logging.INFO
        logging.basicConfig(level=log_level)

    wb_file = load_tractogram_with_reference(parser, args, args.in_tractogram)
    wb_streamlines = wb_file.streamlines
    model_file = load_tractogram_with_reference(parser, args, args.in_model)

    transfo = load_matrix_in_any_format(args.in_transfo)
    if args.inverse:
        transfo = np.linalg.inv(load_matrix_in_any_format(args.in_transfo))

    before, after = compute_distance_barycenters(wb_file, model_file, transfo)
    if after > before:
        logging.warning('The distance between volumes barycenter should be '
                        'lower after registration. Maybe try using/removing '
                        '--inverse.')
        logging.info('Distance before: {}, Distance after: {}'.format(
            np.round(before, 3), np.round(after, 3)))
    model_streamlines = transform_streamlines(model_file.streamlines, transfo)

    rng = np.random.RandomState(args.seed)
    if args.in_pickle:
        with open(args.in_pickle, 'rb') as infile:
            cluster_map = pickle.load(infile)
        reco_obj = RecoBundles(wb_streamlines,
                               cluster_map=cluster_map,
                               rng=rng,
                               verbose=args.verbose)
    else:
        reco_obj = RecoBundles(wb_streamlines,
                               clust_thr=args.tractogram_clustering_thr,
                               rng=rng,
                               verbose=args.verbose)

    if args.out_pickle:
        with open(args.out_pickle, 'wb') as outfile:
            pickle.dump(reco_obj.cluster_map, outfile)
    _, indices = reco_obj.recognize(ArraySequence(model_streamlines),
                                    args.model_clustering_thr,
                                    pruning_thr=args.pruning_thr,
                                    slr_num_threads=args.slr_threads)
    new_streamlines = wb_streamlines[indices]
    new_data_per_streamlines = wb_file.data_per_streamline[indices]
    new_data_per_points = wb_file.data_per_point[indices]

    if not args.no_empty or new_streamlines:
        sft = StatefulTractogram(new_streamlines,
                                 wb_file.space_attributes,
                                 Space.RASMM,
                                 data_per_streamline=new_data_per_streamlines,
                                 data_per_point=new_data_per_points)
        save_tractogram(sft, args.out_tractogram)
Exemplo n.º 5
0
def reconstruct_streamlines(data, offsets, lengths, indices=None):
    """
    Function to reconstruct streamlines from its data, offsets and lengths
    (from the nibabel tractogram object).

    ----------
    data : np.ndarray
        Nx3 array representing all points of the streamlines.
    offsets : np.ndarray
        Nx1 array representing the cumsum of length array.
    lengths : np.ndarray
        Nx1 array representing the length of each streamline.
    indices : list
        List of int representing the indices to reconstruct.

    Returns
    -------
    streamlines : list of np.ndarray
        List of streamlines.
    """

    if data.ndim == 2:
        data = np.array(data).flatten()

    if indices is None:
        indices = np.arange(len(offsets))

    streamlines = []
    for i in indices:
        streamline = data[offsets[i] * 3:offsets[i] * 3 + lengths[i] * 3]
        streamlines.append(streamline.reshape((lengths[i], 3)))

    return ArraySequence(streamlines)
Exemplo n.º 6
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.transformation])
    assert_outputs_exist(parser, args, args.output_name)

    wb_file = load_tractogram_with_reference(parser, args, args.in_tractogram)
    wb_streamlines = wb_file.streamlines
    model_file = load_tractogram_with_reference(parser, args, args.in_model)

    # Default transformation source is expected to be ANTs
    transfo = np.loadtxt(args.transformation)
    if args.inverse:
        transfo = np.linalg.inv(np.loadtxt(args.transformation))

    model_streamlines = ArraySequence(
        transform_streamlines(model_file.streamlines, transfo))

    rng = np.random.RandomState(args.seed)
    if args.input_pickle:
        with open(args.input_pickle, 'rb') as infile:
            cluster_map = pickle.load(infile)
        reco_obj = RecoBundles(wb_streamlines,
                               cluster_map=cluster_map,
                               rng=rng,
                               verbose=args.verbose)
    else:
        reco_obj = RecoBundles(wb_streamlines,
                               clust_thr=args.wb_clustering_thr,
                               rng=rng,
                               verbose=args.verbose)

    if args.output_pickle:
        with open(args.output_pickle, 'wb') as outfile:
            pickle.dump(reco_obj.cluster_map, outfile)
    _, indices = reco_obj.recognize(model_streamlines,
                                    args.model_clustering_thr,
                                    pruning_thr=args.pruning_thr,
                                    slr_num_threads=args.slr_threads)
    new_streamlines = wb_streamlines[indices]
    new_data_per_streamlines = wb_file.data_per_streamline[indices]
    new_data_per_points = wb_file.data_per_point[indices]

    if not args.no_empty or new_streamlines:
        sft = StatefulTractogram(new_streamlines,
                                 wb_file,
                                 Space.RASMM,
                                 data_per_streamline=new_data_per_streamlines,
                                 data_per_point=new_data_per_points)
        save_tractogram(sft, args.output_name)
def register_tractogram(moving_tractogram, static_tractogram, only_rigid,
                        amount_to_load, matrix_filename, verbose):

    amount_to_load = max(250000, amount_to_load)

    moving_streamlines = next(
        ichunk(moving_tractogram.streamlines, amount_to_load))

    static_streamlines = next(
        ichunk(static_tractogram.streamlines, amount_to_load))

    if only_rigid:
        transformation_type = 'rigid'
    else:
        transformation_type = 'affine'

    ret = whole_brain_slr(ArraySequence(static_streamlines),
                          ArraySequence(moving_streamlines),
                          x0=transformation_type,
                          maxiter=150,
                          verbose=verbose)
    _, transfo, _, _ = ret
    np.savetxt(matrix_filename, transfo)
Exemplo n.º 8
0
    def consolidate_data_per_point(self):
        """ Convert the zarr representation of data_per_point to
        memory PerArraySequenceDict (nibabel)"""
        dpp_arr_seq_dict = PerArraySequenceDict()
        for dpp_key in self._zdpp.array_keys():
            arr_seq = ArraySequence()
            arr_seq._data = self._zdpp[dpp_key]
            arr_seq._offsets = self._zoff
            arr_seq._lengths = compute_lengths(arr_seq._offsets,
                                               self.nb_points)
            if arr_seq._data.ndim == 1:
                arr_seq._data = np.expand_dims(arr_seq._data, axis=-1)
            dpp_arr_seq_dict[dpp_key] = arr_seq

        return dpp_arr_seq_dict
Exemplo n.º 9
0
def transform_warp_streamlines(sft,
                               linear_transfo,
                               target,
                               inverse=False,
                               deformation_data=None,
                               remove_invalid=True,
                               cut_invalid=False):
    # TODO rename transform_warp_sft
    """ Transform tractogram using a affine Subsequently apply a warp from
    antsRegistration (optional).
    Remove/Cut invalid streamlines to preserve sft validity.

    Parameters
    ----------
    sft: StatefulTractogram
        Stateful tractogram object containing the streamlines to transform.
    linear_transfo: numpy.ndarray
        Linear transformation matrix to apply to the tractogram.
    target: Nifti filepath, image object, header
        Final reference for the tractogram after registration.
    inverse: boolean
        Apply the inverse linear transformation.
    deformation_data: np.ndarray
        4D array containing a 3D displacement vector in each voxel.

    remove_invalid: boolean
        Remove the streamlines landing out of the bounding box.
    cut_invalid: boolean
        Cut invalid streamlines rather than removing them. Keep the longest
        segment only.

    Return
    ----------
    new_sft : StatefulTractogram

    """
    sft.to_rasmm()
    sft.to_center()
    if inverse:
        linear_transfo = np.linalg.inv(linear_transfo)

    streamlines = transform_streamlines(sft.streamlines, linear_transfo)

    if deformation_data is not None:
        affine, _, _, _ = get_reference_info(target)

        # Because of duplication, an iteration over chunks of points is
        # necessary for a big dataset (especially if not compressed)
        streamlines = ArraySequence(streamlines)
        nb_points = len(streamlines._data)
        cur_position = 0
        chunk_size = 1000000
        nb_iteration = int(np.ceil(nb_points / chunk_size))
        inv_affine = np.linalg.inv(affine)

        while nb_iteration > 0:
            max_position = min(cur_position + chunk_size, nb_points)
            points = streamlines._data[cur_position:max_position]

            # To access the deformation information, we need to go in VOX space
            # No need for corner shift since we are doing interpolation
            cur_points_vox = np.array(transform_streamlines(
                points, inv_affine)).T

            x_def = map_coordinates(deformation_data[..., 0],
                                    cur_points_vox.tolist(),
                                    order=1)
            y_def = map_coordinates(deformation_data[..., 1],
                                    cur_points_vox.tolist(),
                                    order=1)
            z_def = map_coordinates(deformation_data[..., 2],
                                    cur_points_vox.tolist(),
                                    order=1)

            # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs
            final_points = np.array([-1 * x_def, -1 * y_def, z_def])
            final_points += np.array(points).T

            streamlines._data[cur_position:max_position] = final_points.T
            cur_position = max_position
            nb_iteration -= 1

    new_sft = StatefulTractogram(streamlines,
                                 target,
                                 Space.RASMM,
                                 data_per_point=sft.data_per_point,
                                 data_per_streamline=sft.data_per_streamline)
    if cut_invalid:
        new_sft, _ = cut_invalid_streamlines(new_sft)
    elif remove_invalid:
        new_sft.remove_invalid_streamlines()

    return new_sft
Exemplo n.º 10
0
def concatenate_sft(sft_list, erase_metadata=False, metadata_fake_init=False):
    """ Concatenate a list of StatefulTractogram together """
    if erase_metadata:
        sft_list[0].data_per_point = {}
        sft_list[0].data_per_streamline = {}

    for sft in sft_list[1:]:
        if erase_metadata:
            sft.data_per_point = {}
            sft.data_per_streamline = {}
        elif metadata_fake_init:
            for dps_key in list(sft.data_per_streamline.keys()):
                if dps_key not in sft_list[0].data_per_streamline.keys():
                    del sft.data_per_streamline[dps_key]
            for dpp_key in list(sft.data_per_point.keys()):
                if dpp_key not in sft_list[0].data_per_point.keys():
                    del sft.data_per_point[dpp_key]

            for dps_key in sft_list[0].data_per_streamline.keys():
                if dps_key not in sft.data_per_streamline:
                    arr_shape = sft_list[0].data_per_streamline[dps_key].shape
                    arr_shape[0] = len(sft)
                    sft.data_per_streamline[dps_key] = np.zeros(arr_shape)
            for dpp_key in sft_list[0].data_per_point.keys():
                if dpp_key not in sft.data_per_point:
                    arr_seq = ArraySequence()
                    arr_seq_shape = list(
                        sft_list[0].data_per_point[dpp_key]._data.shape)
                    arr_seq_shape[0] = len(sft.streamlines._data)
                    arr_seq._data = np.zeros(arr_seq_shape)
                    arr_seq._offsets = sft.streamlines._offsets
                    arr_seq._lengths = sft.streamlines._lengths
                    sft.data_per_point[dpp_key] = arr_seq

        if not metadata_fake_init and \
                not StatefulTractogram.are_compatible(sft, sft_list[0]):
            raise ValueError('Incompatible SFT, check space attributes and '
                             'data_per_point/streamlines.')
        elif not is_header_compatible(sft, sft_list[0]):
            raise ValueError('Incompatible SFT, check space attributes.')

    total_streamlines = 0
    total_points = 0
    lengths = []
    for sft in sft_list:
        total_streamlines += len(sft.streamlines._offsets)
        total_points += len(sft.streamlines._data)
        lengths.extend(sft.streamlines._lengths)
    lengths = np.array(lengths, dtype=np.uint32)
    offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))).astype(np.uint64)

    dpp = {}
    for dpp_key in sft_list[0].data_per_point.keys():
        arr_seq_shape = list(sft_list[0].data_per_point[dpp_key]._data.shape)
        arr_seq_shape[0] = total_points
        dpp[dpp_key] = ArraySequence()
        dpp[dpp_key]._data = np.zeros(arr_seq_shape)
        dpp[dpp_key]._lengths = lengths
        dpp[dpp_key]._offsets = offsets

    dps = {}
    for dps_key in sft_list[0].data_per_streamline.keys():
        arr_seq_shape = list(sft_list[0].data_per_streamline[dps_key].shape)
        arr_seq_shape[0] = total_streamlines
        dps[dps_key] = np.zeros(arr_seq_shape)

    streamlines = ArraySequence()
    streamlines._data = np.zeros((total_points, 3))
    streamlines._lengths = lengths
    streamlines._offsets = offsets

    pts_counter = 0
    strs_counter = 0
    for sft in sft_list:
        pts_curr_len = len(sft.streamlines._data)
        strs_curr_len = len(sft.streamlines._offsets)

        if strs_curr_len == 0 or pts_curr_len == 0:
            continue

        streamlines._data[pts_counter:pts_counter+pts_curr_len] = \
            sft.streamlines._data

        for dpp_key in sft_list[0].data_per_point.keys():
            dpp[dpp_key]._data[pts_counter:pts_counter+pts_curr_len] = \
                sft.data_per_point[dpp_key]._data
        for dps_key in sft_list[0].data_per_streamline.keys():
            dps[dps_key][strs_counter:strs_counter+strs_curr_len] = \
                sft.data_per_streamline[dps_key]
        pts_counter += pts_curr_len
        strs_counter += strs_curr_len

    fused_sft = StatefulTractogram.from_sft(streamlines,
                                            sft_list[0],
                                            data_per_point=dpp,
                                            data_per_streamline=dps)
    return fused_sft
Exemplo n.º 11
0
def find_identical_streamlines(streamlines_list,
                               epsilon=0.001,
                               union_mode=False,
                               difference_mode=False):
    """ Return the intersection/union/difference from a list of list of
    streamlines. Allows for a maximum distance for matching.

    Parameters:
    -----------
    streamlines_list: list
        List of lists of streamlines or list of ArraySequences
    epsilon: float
        Maximum allowed distance (should not go above 1.0)
    union_mode: bool
        Perform the union of streamlines
    difference_mode
        Perform the difference of streamlines (from the first element)
    Returns:
    --------
    Tuple, ArraySequence, np.ndarray
        Returns the concatenated streamlines and the indices to pick from it
    """
    streamlines = ArraySequence(itertools.chain(*streamlines_list))
    nb_streamlines = np.cumsum([len(sft) for sft in streamlines_list])
    nb_streamlines = np.insert(nb_streamlines, 0, 0)

    if union_mode and difference_mode:
        raise ValueError('Cannot use union_mode and difference_mode at the '
                         'same time.')

    all_tree = {}
    all_tree_mapping = {}
    first_points = np.array(streamlines.get_data()[streamlines._offsets])
    # Uses the number of point to speed up the search in the ckdtree
    for point_count in np.unique(streamlines._lengths):
        same_length_ind = np.where(streamlines._lengths == point_count)[0]
        all_tree[point_count] = cKDTree(first_points[same_length_ind])
        all_tree_mapping[point_count] = same_length_ind

    inversion_val = 1 if union_mode or difference_mode else 0
    streamlines_to_keep = np.ones((len(streamlines), )) * inversion_val
    average_match_distance = []

    # Difference by design will never select streamlines that are not from the
    # first set
    if difference_mode:
        streamlines_to_keep[nb_streamlines[1]:] = 0
    for i, streamline in enumerate(streamlines):
        # Unless do an union, there is no point at looking past the first set
        if not union_mode and i >= nb_streamlines[1]:
            break

        # Find the closest (first) points
        distance_ind = all_tree[len(streamline)].query_ball_point(
            streamline[0], r=2 * epsilon)
        actual_ind = np.sort(all_tree_mapping[len(streamline)][distance_ind])

        # Intersection requires finding matches is all sets
        if not union_mode or not difference_mode:
            intersect_test = np.zeros((len(nb_streamlines) - 1, ))

        for j in actual_ind:
            # Actual check of the whole streamline
            sub_vector = streamline - streamlines[j]
            norm = np.linalg.norm(sub_vector, axis=1)

            if union_mode:
                # 1) Yourself is not a match
                # 2) If the streamline hasn't been selected (by another match)
                # 3) The streamline is 'identical'
                if i != j and streamlines_to_keep[i] == inversion_val \
                        and (norm < 2*epsilon).all():
                    streamlines_to_keep[j] = not inversion_val
                    average_match_distance.append(
                        np.average(sub_vector, axis=0))
            elif difference_mode:
                # 1) Yourself is not a match
                # 2) The streamline is 'identical'
                if i != j and (norm < 2 * epsilon).all():
                    pos_in_list_j = np.max(np.where(nb_streamlines <= j)[0])

                    # If it is an identical streamline, but from the same set
                    # it needs to be removed, otherwise remove all instances
                    if pos_in_list_j == 0:
                        # If it is the first 'encounter' add it
                        if streamlines_to_keep[actual_ind].all():
                            streamlines_to_keep[j] = not inversion_val
                            average_match_distance.append(
                                np.average(sub_vector, axis=0))
                    else:
                        streamlines_to_keep[actual_ind] = not inversion_val
                        average_match_distance.append(
                            np.average(sub_vector, axis=0))
            else:
                # 1) The streamline is 'identical'
                if (norm < 2 * epsilon).all():
                    pos_in_list_i = np.max(np.where(nb_streamlines <= i)[0])
                    pos_in_list_j = np.max(np.where(nb_streamlines <= j)[0])
                    # If it is an identical streamline, but from the same set
                    # it needs to be removed
                    if i == j or pos_in_list_i != pos_in_list_j:
                        intersect_test[pos_in_list_j] = True
                    if i != j:
                        average_match_distance.append(
                            np.average(sub_vector, axis=0))

        # Verify that you actually found a match in each set
        if (not union_mode or not difference_mode) and intersect_test.all():
            streamlines_to_keep[i] = not inversion_val

    # To facilitate debugging and discovering shifts in data
    if average_match_distance:
        logging.info('Average matches distance: {}mm'.format(
            np.round(np.average(average_match_distance, axis=0), 5)))
    else:
        logging.info('No matches found.')

    return streamlines, np.where(streamlines_to_keep > 0)[0].astype(np.uint32)
Exemplo n.º 12
0
Arquivo: track.py Projeto: dPys/PyNets
def run_tracking(step_curv_combinations,
                 recon_shelved,
                 n_seeds_per_iter,
                 traversal,
                 maxcrossing,
                 max_length,
                 pft_back_tracking_dist,
                 pft_front_tracking_dist,
                 particle_count,
                 roi_neighborhood_tol,
                 min_length,
                 track_type,
                 min_separation_angle,
                 sphere,
                 tiss_class,
                 tissue_shelved,
                 verbose=False):
    """
    Create a density map of the list of streamlines.

    Parameters
    ----------
    step_curv_combinations : list
        List of tuples representing all pair combinations of step sizes and
        curvature thresholds from which to sample streamlines.
    recon_path : str
        File path to diffusion reconstruction model.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique
        ensemble combination. By default this is set to 250.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel
        while tracking.
    max_length : int
        Maximum number of steps to restrict tracking.
    pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is
        set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    waymask_data : ndarray
        Tractography constraint mask array in native diffusion space.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    min_separation_angle : float
        The minimum angle between directions [0, 90].
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    tiss_class : str
        Tissue classification method.
    tissue_shelved : str
        File path to joblib-shelved 4D T1w tissue segmentations in native
        diffusion space.

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.
    """
    import gc
    import time
    import numpy as np
    from dipy.tracking import utils
    from dipy.tracking.streamline import select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, \
        ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter,
                                ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)
    from nilearn.image import index_img, math_img
    from pynets.dmri.utils import generate_seeds, random_seeds_from_mask
    from nibabel.streamlines.array_sequence import ArraySequence

    start_time = time.time()

    if verbose is True:
        print("%s%s%s" % ('Preparing tissue constraints:',
                          np.round(time.time() - start_time, 1), 's'))
        start_time = time.time()

    tissue_img = tissue_shelved.get()

    # Order:
    B0_mask = index_img(tissue_img, 0)
    atlas_img = index_img(tissue_img, 1)
    t1w2dwi = index_img(tissue_img, 3)
    gm_in_dwi = index_img(tissue_img, 4)
    vent_csf_in_dwi = index_img(tissue_img, 5)
    wm_in_dwi = index_img(tissue_img, 6)
    tissue_img.uncache()

    tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                                   wm_in_dwi, tiss_class, B0_mask)

    # if verbose is True:
    #     print("%s%s%s" % (
    #     'Fitting tissue classifier:',
    #     np.round(time.time() - start_time, 1), 's'))
    #     start_time = time.time()

    if verbose is True:
        print("%s%s%s" % ('Loading reconstruction:',
                          np.round(time.time() - start_time, 1), 's'))
        start_time = time.time()

        print("%s%s" % ("Curvature: ", step_curv_combinations[1]))

    # Instantiate DirectionGetter
    if traversal.lower() in ["probabilistic", "prob"]:
        dg = ProbabilisticDirectionGetter.from_shcoeff(
            recon_shelved.get(),
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif traversal.lower() in ["closestpeaks", "cp"]:
        dg = ClosestPeakDirectionGetter.from_shcoeff(
            recon_shelved.get(),
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif traversal.lower() in ["deterministic", "det"]:
        maxcrossing = 1
        dg = DeterministicMaximumDirectionGetter.from_shcoeff(
            recon_shelved.get(),
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    else:
        raise ValueError("ERROR: No valid direction getter(s) specified.")

    if verbose is True:
        print("%s%s%s" % ('Extracting directions:',
                          np.round(time.time() - start_time, 1), 's'))
        start_time = time.time()
        print("%s%s" % ("Step: ", step_curv_combinations[0]))

    # Perform wm-gm interface seeding, using n_seeds at a time
    seeds = generate_seeds(
        random_seeds_from_mask(np.asarray(
            math_img("img > 0.01", img=index_img(
                tissue_img, 2)).dataobj).astype("bool").astype("int16") > 0,
                               seeds_count=n_seeds_per_iter,
                               random_seed=42))

    if verbose is True:
        print("%s%s%s" % ('Drawing random seeds:',
                          np.round(time.time() - start_time, 1), 's'))
        start_time = time.time()
        # print(seeds)

    # Perform tracking
    if track_type == "local":
        streamline_generator = LocalTracking(dg,
                                             tiss_classifier,
                                             np.stack([i for i in seeds]),
                                             np.eye(4),
                                             max_cross=int(maxcrossing),
                                             maxlen=int(max_length),
                                             step_size=float(
                                                 step_curv_combinations[0]),
                                             fixedstep=False,
                                             return_all=True,
                                             random_seed=42)
    elif track_type == "particle":
        streamline_generator = ParticleFilteringTracking(
            dg,
            tiss_classifier,
            np.stack([i for i in seeds]),
            np.eye(4),
            max_cross=int(maxcrossing),
            step_size=float(step_curv_combinations[0]),
            maxlen=int(max_length),
            pft_back_tracking_dist=pft_back_tracking_dist,
            pft_front_tracking_dist=pft_front_tracking_dist,
            pft_max_trial=20,
            particle_count=particle_count,
            return_all=True,
            random_seed=42)
    else:
        raise ValueError("ERROR: No valid tracking method(s) specified.")

    if verbose is True:
        print("%s%s%s" % ('Instantiating tracking:',
                          np.round(time.time() - start_time, 1), 's'))
        start_time = time.time()
        # print(seeds)

    del dg

    # Filter resulting streamlines by those that stay entirely
    # inside the brain
    try:
        roi_proximal_streamlines = utils.target(
            streamline_generator,
            np.eye(4),
            np.asarray(B0_mask.dataobj).astype('bool'),
            include=True)
    except BaseException:
        print('No streamlines found inside the brain! ' 'Check registrations.')
        #return None

    if verbose is True:
        print("%s%s%s" % ('Drawing streamlines:',
                          np.round(time.time() - start_time, 1), 's'))
        start_time = time.time()

    del seeds, tiss_classifier, streamline_generator

    B0_mask.uncache()
    atlas_img.uncache()
    t1w2dwi.uncache()
    gm_in_dwi.uncache()
    vent_csf_in_dwi.uncache()
    wm_in_dwi.uncache()
    gc.collect()

    # Filter resulting streamlines by roi-intersection
    # characteristics
    atlas_data = np.array(atlas_img.dataobj).astype("uint16")

    # Build mask vector from atlas for later roi filtering
    parcels = [
        atlas_data == roi_val
        for roi_val in [i for i in np.unique(atlas_data) if i != 0]
    ]

    try:
        roi_proximal_streamlines = \
                select_by_rois(
                    roi_proximal_streamlines,
                    affine=np.eye(4),
                    rois=parcels,
                    include=list(np.ones(len(parcels)).astype("bool")),
                    mode="any",
                    tol=roi_neighborhood_tol,
                )
    except BaseException:
        print('No streamlines found to connect any parcels! '
              'Check registrations.')
        #return None

    del atlas_data

    if verbose is True:
        print("%s%s%s" % ('Selecting by parcellation:',
                          np.round(time.time() - start_time, 1), 's'))
        start_time = time.time()

    del parcels

    gc.collect()

    if verbose is True:
        print("%s%s%s" % ('Selecting by minimum length criterion:',
                          np.round(time.time() - start_time, 1), 's'))

    gc.collect()

    return ArraySequence([
        s.astype("float32") for s in roi_proximal_streamlines
        if len(s) > float(min_length)
    ])
Exemplo n.º 13
0
Arquivo: track.py Projeto: dPys/PyNets
def track_ensemble(target_samples,
                   atlas_data_wm_gm_int,
                   labels_im_file,
                   recon_path,
                   sphere,
                   traversal,
                   curv_thr_list,
                   step_list,
                   track_type,
                   maxcrossing,
                   roi_neighborhood_tol,
                   min_length,
                   waymask,
                   B0_mask,
                   t1w2dwi,
                   gm_in_dwi,
                   vent_csf_in_dwi,
                   wm_in_dwi,
                   tiss_class,
                   BACKEND='threading'):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI
    masks.

    Parameters
    ----------
    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : str
        File path to Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a
        Nifti1Image in T1w-warped native diffusion space.
    recon_path : str
        File path to diffusion reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    traversal : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel
        while tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask_data : ndarray
        Tractography constraint mask array in native diffusion space.
    B0_mask_data : ndarray
        B0 brain mask data.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique
        ensemble combination. By default this is set to 250.
    max_length : int
        Maximum number of steps to restrict tracking.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is
        set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    min_separation_angle : float
        The minimum angle between directions [0, 90].

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.

    References
    ----------
    .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016).
      Ensemble Tractography. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.1004692
    """
    import os
    import gc
    import time
    import warnings
    import time
    import tempfile
    from joblib import Parallel, delayed, Memory
    import itertools
    import pickle5 as pickle
    from pynets.dmri.track import run_tracking
    from colorama import Fore, Style
    from pynets.dmri.utils import generate_sl
    from nibabel.streamlines.array_sequence import concatenate, ArraySequence
    from pynets.core.utils import save_3d_to_4d
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img
    from pynets.core.utils import load_runconfig
    from dipy.tracking import utils

    warnings.filterwarnings("ignore")

    pickle.HIGHEST_PROTOCOL = 5
    joblib_dir = tempfile.mkdtemp()
    os.makedirs(joblib_dir, exist_ok=True)

    hardcoded_params = load_runconfig()
    nthreads = hardcoded_params["omp_threads"][0]
    os.environ['MKL_NUM_THREADS'] = str(nthreads)
    os.environ['OPENBLAS_NUM_THREADS'] = str(nthreads)
    n_seeds_per_iter = \
        hardcoded_params['tracking']["n_seeds_per_iter"][0]
    max_length = \
        hardcoded_params['tracking']["max_length"][0]
    pft_back_tracking_dist = \
        hardcoded_params['tracking']["pft_back_tracking_dist"][0]
    pft_front_tracking_dist = \
        hardcoded_params['tracking']["pft_front_tracking_dist"][0]
    particle_count = \
        hardcoded_params['tracking']["particle_count"][0]
    min_separation_angle = \
        hardcoded_params['tracking']["min_separation_angle"][0]
    min_streams = \
        hardcoded_params['tracking']["min_streams"][0]
    seeding_mask_thr = hardcoded_params['tracking']["seeding_mask_thr"][0]
    timeout = hardcoded_params['tracking']["track_timeout"][0]

    all_combs = list(itertools.product(step_list, curv_thr_list))

    # Construct seeding mask
    seeding_mask = f"{os.path.dirname(labels_im_file)}/seeding_mask.nii.gz"
    if waymask is not None and os.path.isfile(waymask):
        waymask_img = math_img(f"img > {seeding_mask_thr}",
                               img=nib.load(waymask))
        waymask_img.to_filename(waymask)
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                waymask_img,
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)
    else:
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)

    tissues4d = save_3d_to_4d([
        B0_mask, labels_im_file, seeding_mask, t1w2dwi, gm_in_dwi,
        vent_csf_in_dwi, wm_in_dwi
    ])

    # Commence Ensemble Tractography
    start = time.time()
    stream_counter = 0

    all_streams = []
    ix = 0

    memory = Memory(location=joblib_dir, mmap_mode='r+', verbose=0)
    os.chdir(f"{memory.location}/joblib")

    @memory.cache
    def load_recon_data(recon_path):
        import h5py
        with h5py.File(recon_path, 'r') as hf:
            recon_data = hf['reconstruction'][:].astype('float32')
        hf.close()
        return recon_data

    recon_shelved = load_recon_data.call_and_shelve(recon_path)

    @memory.cache
    def load_tissue_data(tissues4d):
        return nib.load(tissues4d)

    tissue_shelved = load_tissue_data.call_and_shelve(tissues4d)

    try:
        while float(stream_counter) < float(target_samples) and \
                float(ix) < 0.50*float(len(all_combs)):
            with Parallel(n_jobs=nthreads,
                          backend=BACKEND,
                          mmap_mode='r+',
                          verbose=0) as parallel:

                out_streams = parallel(
                    delayed(run_tracking)
                    (i, recon_shelved, n_seeds_per_iter, traversal,
                     maxcrossing, max_length, pft_back_tracking_dist,
                     pft_front_tracking_dist, particle_count,
                     roi_neighborhood_tol, min_length, track_type,
                     min_separation_angle, sphere, tiss_class, tissue_shelved)
                    for i in all_combs)

                out_streams = list(filter(None, out_streams))

                if len(out_streams) > 1:
                    out_streams = concatenate(out_streams, axis=0)
                else:
                    continue

                if waymask is not None and os.path.isfile(waymask):
                    try:
                        out_streams = out_streams[utils.near_roi(
                            out_streams,
                            np.eye(4),
                            np.asarray(
                                nib.load(waymask).dataobj).astype("bool"),
                            tol=int(round(roi_neighborhood_tol * 0.50, 1)),
                            mode="all")]
                    except BaseException:
                        print(f"\n{Fore.RED}No streamlines generated in "
                              f"waymask vacinity\n")
                        print(Style.RESET_ALL)
                        return None

                if len(out_streams) < min_streams:
                    ix += 1
                    print(f"\n{Fore.YELLOW}Fewer than {min_streams} "
                          f"streamlines tracked "
                          f"on last iteration...\n")
                    print(Style.RESET_ALL)
                    if ix > 5:
                        print(f"\n{Fore.RED}No streamlines generated\n")
                        print(Style.RESET_ALL)
                        return None
                    continue
                else:
                    ix -= 1

                stream_counter += len(out_streams)
                all_streams.extend([generate_sl(i) for i in out_streams])
                del out_streams

                print("%s%s%s%s" % (
                    "\nCumulative Streamline Count: ",
                    Fore.CYAN,
                    stream_counter,
                    "\n",
                ))
                gc.collect()
                print(Style.RESET_ALL)

                if time.time() - start > timeout:
                    print(f"\n{Fore.RED}Warning: Tractography timed "
                          f"out: {time.time() - start}")
                    print(Style.RESET_ALL)
                    memory.clear(warn=False)
                    return None

    except RuntimeError as e:
        print(f"\n{Fore.RED}Error: Tracking failed due to:\n{e}\n")
        print(Style.RESET_ALL)
        memory.clear(warn=False)
        return None

    print("Tracking Complete: ", str(time.time() - start))

    memory.clear(warn=False)

    del parallel, all_combs
    gc.collect()

    if stream_counter != 0:
        print('Generating final ...')
        return ArraySequence([ArraySequence(i) for i in all_streams])
    else:
        print(f"\n{Fore.RED}No streamlines generated!")
        print(Style.RESET_ALL)
        return None
Exemplo n.º 14
0
def run_tracking(step_curv_combinations, recon_path, n_seeds_per_iter,
                 directget, maxcrossing, max_length, pft_back_tracking_dist,
                 pft_front_tracking_dist, particle_count, roi_neighborhood_tol,
                 waymask, min_length, track_type, min_separation_angle, sphere,
                 tiss_class, tissues4d, cache_dir):

    import gc
    import os
    import h5py
    from dipy.tracking import utils
    from dipy.tracking.streamline import select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, \
        ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter,
                                ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)
    from nilearn.image import index_img
    from pynets.dmri.track import prep_tissues
    from nibabel.streamlines.array_sequence import ArraySequence
    from nipype.utils.filemanip import copyfile, fname_presuffix

    recon_path_tmp_path = fname_presuffix(recon_path,
                                          suffix=f"_{step_curv_combinations}",
                                          newpath=cache_dir)
    copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False)

    if waymask is not None:
        waymask_tmp_path = fname_presuffix(waymask,
                                           suffix=f"_{step_curv_combinations}",
                                           newpath=cache_dir)
        copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False)
    else:
        waymask_tmp_path = None

    tissue_img = nib.load(tissues4d)

    # Order:
    B0_mask = index_img(tissue_img, 0)
    atlas_img = index_img(tissue_img, 1)
    atlas_data_wm_gm_int = index_img(tissue_img, 2)
    t1w2dwi = index_img(tissue_img, 3)
    gm_in_dwi = index_img(tissue_img, 4)
    vent_csf_in_dwi = index_img(tissue_img, 5)
    wm_in_dwi = index_img(tissue_img, 6)

    tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                                   wm_in_dwi, tiss_class, B0_mask)

    B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool")
    atlas_data = np.array(atlas_img.dataobj).astype("uint16")
    atlas_data_wm_gm_int_data = np.asarray(
        atlas_data_wm_gm_int.dataobj).astype("bool").astype("int16")

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    intensities = [i for i in np.unique(atlas_data) if i != 0]
    for roi_val in intensities:
        parcels.append(atlas_data == roi_val)
        i += 1

    del atlas_data

    parcel_vec = list(np.ones(len(parcels)).astype("bool"))

    with h5py.File(recon_path_tmp_path, 'r+') as hf:
        mod_fit = hf['reconstruction'][:].astype('float32')
    hf.close()

    print("%s%s" % ("Curvature: ", step_curv_combinations[1]))

    # Instantiate DirectionGetter
    if directget == "prob" or directget == "probabilistic":
        dg = ProbabilisticDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget == "clos" or directget == "closest":
        dg = ClosestPeakDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget == "det" or directget == "deterministic":
        maxcrossing = 1
        dg = DeterministicMaximumDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    else:
        raise ValueError("ERROR: No valid direction getter(s) specified.")

    print("%s%s" % ("Step: ", step_curv_combinations[0]))

    # Perform wm-gm interface seeding, using n_seeds at a time
    seeds = utils.random_seeds_from_mask(
        atlas_data_wm_gm_int_data > 0,
        seeds_count=n_seeds_per_iter,
        seed_count_per_voxel=False,
        affine=np.eye(4),
    )
    if len(seeds) == 0:
        print(
            UserWarning("No valid seed points found in wm-gm "
                        "interface..."))
        return None

    # print(seeds)

    # Perform tracking
    if track_type == "local":
        streamline_generator = LocalTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            maxlen=int(max_length),
            step_size=float(step_curv_combinations[0]),
            fixedstep=False,
            return_all=True,
        )
    elif track_type == "particle":
        streamline_generator = ParticleFilteringTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            step_size=float(step_curv_combinations[0]),
            maxlen=int(max_length),
            pft_back_tracking_dist=pft_back_tracking_dist,
            pft_front_tracking_dist=pft_front_tracking_dist,
            particle_count=particle_count,
            return_all=True,
        )
    else:
        try:
            raise ValueError("ERROR: No valid tracking method(s) specified.")
        except ValueError:
            import sys
            sys.exit(0)

    # Filter resulting streamlines by those that stay entirely
    # inside the brain
    try:
        roi_proximal_streamlines = utils.target(streamline_generator,
                                                np.eye(4),
                                                B0_mask_data,
                                                include=True)
    except BaseException:
        print('No streamlines found inside the brain! ' 'Check registrations.')
        return None

    # Filter resulting streamlines by roi-intersection
    # characteristics

    try:
        roi_proximal_streamlines = \
            nib.streamlines.array_sequence.ArraySequence(
                select_by_rois(
                    roi_proximal_streamlines,
                    affine=np.eye(4),
                    rois=parcels,
                    include=parcel_vec,
                    mode="%s" % ("any" if waymask is not None else
                                 "both_end"),
                    tol=roi_neighborhood_tol,
                )
            )
        print("%s%s" % ("Filtering by: \nNode intersection: ",
                        len(roi_proximal_streamlines)))
    except BaseException:
        print('No streamlines found to connect any parcels! '
              'Check registrations.')
        return None

    try:
        roi_proximal_streamlines = nib.streamlines. \
            array_sequence.ArraySequence(
            [
                s for s in roi_proximal_streamlines
                if len(s) >= float(min_length)
            ]
        )
        print(f"Minimum fiber length >{min_length}mm: "
              f"{len(roi_proximal_streamlines)}")
    except BaseException:
        print('No streamlines remaining after minimal length criterion.')
        return None

    if waymask is not None and os.path.isfile(waymask_tmp_path):
        from nilearn.image import math_img
        mask = math_img("img > 0.0075", img=nib.load(waymask_tmp_path))
        waymask_data = np.asarray(mask.dataobj).astype("bool")
        try:
            roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(
                roi_proximal_streamlines,
                np.eye(4),
                waymask_data,
                tol=roi_neighborhood_tol,
                mode="all")]
            print("%s%s" %
                  ("Waymask proximity: ", len(roi_proximal_streamlines)))
        except BaseException:
            print('No streamlines remaining in waymask\'s vacinity.')
            return None

    out_streams = [s.astype("float32") for s in roi_proximal_streamlines]

    del dg, seeds, roi_proximal_streamlines, streamline_generator, \
        atlas_data_wm_gm_int_data, mod_fit, B0_mask_data

    os.remove(recon_path_tmp_path)
    gc.collect()

    try:
        return ArraySequence(out_streams)
    except BaseException:
        return None
Exemplo n.º 15
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = ArraySequence([])
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]
    for name in args.in_bundles:
        tmp_sft = load_tractogram_with_reference(parser, args, name)
        if not is_header_compatible(reference_file, tmp_sft):
            raise ValueError('Headers are not compatible.')
        fusion_streamlines.extend(tmp_sft.streamlines)

    fusion_streamlines, _ = union_robust([fusion_streamlines])

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i, name in enumerate(args.in_bundles):
        sft = load_tractogram_with_reference(parser, args, name)

        # Needed for streamline-wise representation
        bundle = sft.get_streamlines_copy()
        sft.to_vox()
        sft.to_corner()

        binary = compute_tract_counts_map(sft.streamlines, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = intersection_robust([fusion_streamlines, bundle])
            streamlines_vote[list(indices), [i]] += 1

    if args.same_tractogram:
        real_indices = []
        for i in range(len(fusion_streamlines)):
            ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
            if np.sum(streamlines_vote[i]) >= ratio_value:
                real_indices.append(i)

        new_streamlines = fusion_streamlines[real_indices]

        new_sft = StatefulTractogram(list(new_streamlines), reference_file,
                                     Space.RASMM)
        save_tractogram(new_sft, output_streamlines_filename)

    volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Exemplo n.º 16
0
def track_ensemble(target_samples, atlas_data_wm_gm_int, labels_im_file,
                   recon_path, sphere, directget, curv_thr_list, step_list,
                   track_type, maxcrossing, roi_neighborhood_tol, min_length,
                   waymask, B0_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                   wm_in_dwi, tiss_class, cache_dir):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI
    masks.

    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : str
        File path to Nifti1Image in T1w-warped native diffusion space,
        restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a
        Nifti1Image in T1w-warped native diffusion space.
    recon_path : str
        File path to diffusion reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel
        while tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask_data : ndarray
        Tractography constraint mask array in native diffusion space.
    B0_mask_data : ndarray
        B0 brain mask data.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique
        ensemble combination. By default this is set to 250.
    max_length : int
        Maximum number of steps to restrict tracking.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is
        set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    min_separation_angle : float
        The minimum angle between directions [0, 90].

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.

    References
    ----------
    .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016).
      Ensemble Tractography. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.1004692

    """
    import os
    import gc
    import time
    import warnings
    from joblib import Parallel, delayed
    import itertools
    from pynets.dmri.track import run_tracking
    from colorama import Fore, Style
    from pynets.dmri.utils import generate_sl
    from nibabel.streamlines.array_sequence import concatenate, ArraySequence
    from pynets.core.utils import save_3d_to_4d
    from nilearn.masking import intersect_masks
    from nilearn.image import math_img
    from pynets.core.utils import load_runconfig
    warnings.filterwarnings("ignore")

    tmp_files_dir = f"{cache_dir}/tmp_files"
    joblib_dir = f"{cache_dir}/joblib_tracking"
    os.makedirs(tmp_files_dir, exist_ok=True)
    os.makedirs(joblib_dir, exist_ok=True)

    hardcoded_params = load_runconfig()
    nthreads = hardcoded_params["nthreads"][0]
    n_seeds_per_iter = \
        hardcoded_params['tracking']["n_seeds_per_iter"][0]
    max_length = \
        hardcoded_params['tracking']["max_length"][0]
    pft_back_tracking_dist = \
        hardcoded_params['tracking']["pft_back_tracking_dist"][0]
    pft_front_tracking_dist = \
        hardcoded_params['tracking']["pft_front_tracking_dist"][0]
    particle_count = \
        hardcoded_params['tracking']["particle_count"][0]
    min_separation_angle = \
        hardcoded_params['tracking']["min_separation_angle"][0]
    min_streams = \
        hardcoded_params['tracking']["min_streams"][0]
    timeout = hardcoded_params['tracking']["track_timeout"][0]

    all_combs = list(itertools.product(step_list, curv_thr_list))

    # Construct seeding mask
    seeding_mask = f"{tmp_files_dir}/seeding_mask.nii.gz"
    if waymask is not None and os.path.isfile(waymask):
        waymask_img = math_img("img > 0.0075", img=nib.load(waymask))
        waymask_img.to_filename(waymask)
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                waymask_img,
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)
    else:
        atlas_data_wm_gm_int_img = intersect_masks(
            [
                math_img("img > 0.001", img=nib.load(atlas_data_wm_gm_int)),
                math_img("img > 0.001", img=nib.load(labels_im_file))
            ],
            threshold=1,
            connected=False,
        )
        nib.save(atlas_data_wm_gm_int_img, seeding_mask)

    tissues4d = save_3d_to_4d([
        B0_mask, labels_im_file, seeding_mask, t1w2dwi, gm_in_dwi,
        vent_csf_in_dwi, wm_in_dwi
    ])

    # Commence Ensemble Tractography
    start = time.time()
    stream_counter = 0

    all_streams = []
    ix = 0

    try:
        while float(stream_counter) < float(target_samples) and \
                float(ix) < 0.50*float(len(all_combs)):
            with Parallel(n_jobs=nthreads,
                          backend='loky',
                          mmap_mode='r+',
                          temp_folder=joblib_dir,
                          verbose=0,
                          timeout=timeout) as parallel:
                out_streams = parallel(
                    delayed(run_tracking)
                    (i, recon_path, n_seeds_per_iter, directget, maxcrossing,
                     max_length, pft_back_tracking_dist,
                     pft_front_tracking_dist, particle_count,
                     roi_neighborhood_tol, waymask, min_length, track_type,
                     min_separation_angle, sphere, tiss_class, tissues4d,
                     tmp_files_dir) for i in all_combs)

                out_streams = [
                    i for i in out_streams if i is not None
                    and i is not ArraySequence() and len(i) > 0
                ]

                if len(out_streams) > 1:
                    out_streams = concatenate(out_streams, axis=0)

                if len(out_streams) < min_streams:
                    ix += 2
                    print(f"Fewer than {min_streams} streamlines tracked "
                          f"on last iteration with cache directory: "
                          f"{cache_dir}. Loosening tolerance and "
                          f"anatomical constraints. Check {tissues4d} or "
                          f"{recon_path} for errors...")
                    # if track_type != 'particle':
                    #     tiss_class = 'wb'
                    roi_neighborhood_tol = float(roi_neighborhood_tol) * 1.25
                    # min_length = float(min_length) * 0.9875
                    continue
                else:
                    ix -= 1

                # Append streamline generators to prevent exponential growth
                # in memory consumption
                all_streams.extend([generate_sl(i) for i in out_streams])
                stream_counter += len(out_streams)
                del out_streams

                print("%s%s%s%s" % (
                    "\nCumulative Streamline Count: ",
                    Fore.CYAN,
                    stream_counter,
                    "\n",
                ))
                gc.collect()
                print(Style.RESET_ALL)
        os.system(f"rm -rf {joblib_dir}/*")
    except BaseException:
        os.system(f"rm -rf {tmp_files_dir} &")
        return None

    if ix >= 0.75*len(all_combs) and \
            float(stream_counter) < float(target_samples):
        print(f"Tractography failed. >{len(all_combs)} consecutive sampling "
              f"iterations with few streamlines.")
        os.system(f"rm -rf {tmp_files_dir} &")
        return None
    else:
        os.system(f"rm -rf {tmp_files_dir} &")
        print("Tracking Complete: ", str(time.time() - start))

    del parallel, all_combs
    gc.collect()

    if stream_counter != 0:
        print('Generating final ArraySequence...')
        return ArraySequence([ArraySequence(i) for i in all_streams])
    else:
        print('No streamlines generated!')
        return None
Exemplo n.º 17
0
def warp_streamlines(sft, deformation_data, source='ants'):
    """ Warp tractogram using a deformation map. Apply warp in-place.
    Support Ants and Dipy deformation map.

    Parameters
    ----------
    streamlines: list or ArraySequence
        Streamlines as loaded by the nibabel API (RASMM)
    transfo: numpy.ndarray
        Transformation matrix to bring streamlines from RASMM to Voxel space
    deformation_data: numpy.ndarray
        4D numpy array containing a 3D displacement vector in each voxel
    source: str
        Source of the deformation map [ants, dipy]
    """
    sft.to_rasmm()
    sft.to_center()
    streamlines = sft.streamlines
    transfo = sft.affine
    if source == 'ants':
        flip = [-1, -1, 1]
    elif source == 'dipy':
        flip = [1, 1, 1]

    # Because of duplication, an iteration over chunks of points is necessary
    # for a big dataset (especially if not compressed)
    streamlines = ArraySequence(streamlines)
    nb_points = len(streamlines._data)
    cur_position = 0
    chunk_size = 1000000
    nb_iteration = int(np.ceil(nb_points / chunk_size))
    inv_transfo = np.linalg.inv(transfo)

    while nb_iteration > 0:
        max_position = min(cur_position + chunk_size, nb_points)
        points = streamlines._data[cur_position:max_position]

        # To access the deformation information, we need to go in voxel space
        # No need for corner shift since we are doing interpolation
        cur_points_vox = np.array(transform_streamlines(points, inv_transfo)).T

        x_def = map_coordinates(deformation_data[..., 0],
                                cur_points_vox.tolist(),
                                order=1)
        y_def = map_coordinates(deformation_data[..., 1],
                                cur_points_vox.tolist(),
                                order=1)
        z_def = map_coordinates(deformation_data[..., 2],
                                cur_points_vox.tolist(),
                                order=1)

        # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs
        final_points = np.array(
            [flip[0] * x_def, flip[1] * y_def, flip[2] * z_def])

        # The Ants deformation is relative to world space
        if source == 'ants':
            final_points += np.array(points).T
        # Dipy transformation is relative to vox space
        elif source == 'dipy':
            final_points += cur_points_vox
            transform_streamlines(final_points, transfo, in_place=True)
        streamlines._data[cur_position:max_position] = final_points.T
        cur_position = max_position
        nb_iteration -= 1

        return streamlines
Exemplo n.º 18
0
def single_clusterize_and_rbx_init(args):
    """
    Wrapper function to multiprocess clustering executions and recobundles
    initialisation.

    Parameters
    ----------
    tmp_memmap_filename: tuple (3)
        Temporary filename for the data, offsets and lengths.

    parameters_list : tuple (3)
        clustering_thr : int
            Distance in mm (for QBx) to cluster the input tractogram.
        seed : int
            Value to initialize the RandomState of numpy.
        nb_points : int
            Number of points used for all resampling of streamlines.

    Returns
    -------
    rbx : dict
        Initialisation of the recobundles class using specific parameters.
    """
    tmp_memmap_filename = args[0]
    wb_streamlines = reconstruct_streamlines_from_memmap(tmp_memmap_filename)
    clustering_thr = args[1][0]
    seed = args[1][1]
    nb_points = args[2]

    rbx = {}
    base_thresholds = [45, 35, 25]
    rng = np.random.RandomState(seed)
    cluster_timer = time()
    # If necessary, add an extra layer (more optimal)
    if clustering_thr < 15:
        current_thr_list = base_thresholds + [15, clustering_thr]
    else:
        current_thr_list = base_thresholds + [clustering_thr]

    cluster_map = qbx_and_merge(wb_streamlines,
                                current_thr_list,
                                nb_pts=nb_points,
                                rng=rng,
                                verbose=False)
    clusters_indices = []
    for cluster in cluster_map.clusters:
        clusters_indices.append(cluster.indices)
    centroids = ArraySequence(cluster_map.centroids)
    clusters_indices = ArraySequence(clusters_indices)
    clusters_indices._data = clusters_indices._data.astype(np.int32)

    rbx[(seed, clustering_thr)] = RecobundlesX(tmp_memmap_filename,
                                               clusters_indices,
                                               centroids,
                                               nb_points=nb_points,
                                               rng=rng)
    logging.info('QBx with seed {0} at {1}mm took {2}sec. gave '
                 '{3} centroids'.format(seed, current_thr_list,
                                        round(time() - cluster_timer, 2),
                                        len(cluster_map.centroids)))
    return rbx
    def _create_trx_from_pointer(header, dict_pointer_size,
                                 root_zip=None, root=None):
        """ After reading the structure of a zip/folder, create a TrxFile """
        # TODO support empty positions, using optional tag?
        trx = TrxFile()
        trx.header = header
        positions, offsets = None, None
        for elem_filename in dict_pointer_size.keys():
            if root_zip:
                filename = root_zip
            else:
                filename = elem_filename

            folder = os.path.dirname(elem_filename)
            base, dim, ext = _split_ext_with_dimensionality(elem_filename)
            if ext == '.bit':
                ext = '.bool'
            mem_adress, size = dict_pointer_size[elem_filename]

            if root is not None and folder.startswith(root.rstrip('/')):
                folder = folder.replace(root, '').lstrip('/')

            # Parse/walk the directory tree
            if base == 'positions' and folder == '':
                if size != trx.header['NB_VERTICES']*3 or dim != 3:
                    raise ValueError('Wrong data size/dimensionality.')
                positions = _create_memmap(filename, mode='r+',
                                           offset=mem_adress,
                                           shape=(
                                               trx.header['NB_VERTICES'], 3),
                                           dtype=ext[1:])
            elif base == 'offsets' and folder == '':
                if size != trx.header['NB_STREAMLINES'] or dim != 1:
                    raise ValueError('Wrong offsets size/dimensionality.')
                offsets = _create_memmap(filename, mode='r+',
                                         offset=mem_adress,
                                         shape=(trx.header['NB_STREAMLINES'],),
                                         dtype=ext[1:])
                lengths = _compute_lengths(offsets, trx.header['NB_VERTICES'])
            elif folder == 'dps':
                nb_scalar = size / trx.header['NB_STREAMLINES']
                if not nb_scalar.is_integer() or nb_scalar != dim:
                    raise ValueError('Wrong dps size/dimensionality.')
                else:
                    shape = (trx.header['NB_STREAMLINES'], int(nb_scalar))

                trx.data_per_streamline[base] = _create_memmap(
                    filename, mode='r+', offset=mem_adress,
                    shape=shape, dtype=ext[1:])
            elif folder == 'dpv':
                nb_scalar = size / trx.header['NB_VERTICES']
                if not nb_scalar.is_integer() or nb_scalar != dim:
                    raise ValueError('Wrong dpv size/dimensionality.')
                else:
                    shape = (trx.header['NB_VERTICES'], int(nb_scalar))

                trx.data_per_vertex[base] = _create_memmap(
                    filename, mode='r+', offset=mem_adress,
                    shape=shape, dtype=ext[1:])
            elif folder.startswith('dpg'):
                if int(size) != dim:
                    raise ValueError('Wrong dpg size/dimensionality.')
                else:
                    shape = (1, int(size))

                # Handle the two-layers architecture
                data_name = os.path.basename(base)
                sub_folder = os.path.basename(folder)
                if sub_folder not in trx.data_per_group:
                    trx.data_per_group[sub_folder] = {}
                trx.data_per_group[sub_folder][data_name] = _create_memmap(
                    filename, mode='r+', offset=mem_adress,
                    shape=shape, dtype=ext[1:])
            elif folder == 'groups':
                # Groups are simply indices, nothing else
                # TODO Crash if not uint?
                if dim != 1:
                    raise ValueError('Wrong group dimensionality.')
                else:
                    shape = (int(size),)
                trx.groups[base] = _create_memmap(filename, mode='r+',
                                                  offset=mem_adress,
                                                  shape=shape,
                                                  dtype=ext[1:])
            else:
                logging.error('{} is not part of a valid structure.'.format(
                    elem_filename))

        # All essential array must be declared
        if positions is not None and offsets is not None:
            trx.streamlines._data = positions
            trx.streamlines._offsets = offsets
            trx.streamlines._lengths = lengths
        else:
            raise ValueError('Missing essential data.')

        for dpv_key in trx.data_per_vertex:
            tmp = trx.data_per_vertex[dpv_key]
            trx.data_per_vertex[dpv_key] = ArraySequence()
            trx.data_per_vertex[dpv_key]._data = tmp
            trx.data_per_vertex[dpv_key]._offsets = offsets
            trx.data_per_vertex[dpv_key]._lengths = lengths
        return trx
    def _initialize_empty_trx(nb_streamlines, nb_vertices, init_as=None):
        """ Create on-disk memmaps of a certain size (preallocation) """
        trx = TrxFile()
        tmp_dir = tempfile.TemporaryDirectory()
        logging.info('Temporary folder for memmaps: {}'.format(tmp_dir.name))

        trx.header['NB_VERTICES'] = nb_vertices
        trx.header['NB_STREAMLINES'] = nb_streamlines

        if init_as is not None:
            trx.header['VOXEL_TO_RASMM'] = init_as.header['VOXEL_TO_RASMM']
            trx.header['DIMENSIONS'] = init_as.header['DIMENSIONS']
            positions_dtype = init_as.streamlines._data.dtype
            offsets_dtype = init_as.streamlines._offsets.dtype
            lengths_dtype = init_as.streamlines._lengths.dtype
        else:
            positions_dtype = np.dtype(np.float16)
            offsets_dtype = np.dtype(np.uint64)
            lengths_dtype = np.dtype(np.uint32)

        logging.debug('Initializing positions with dtype:    {}'.format(
            positions_dtype.name))
        logging.debug('Initializing offsets with dtype: {}'.format(
            offsets_dtype.name))
        logging.debug('Initializing lengths with dtype: {}'.format(
            lengths_dtype.name))

        # A TrxFile without init_as only contain the essential arrays
        positions_filename = os.path.join(tmp_dir.name,
                                          'positions.3.{}'.format(positions_dtype.name))
        trx.streamlines._data = _create_memmap(positions_filename, mode='w+',
                                               shape=(nb_vertices, 3),
                                               dtype=positions_dtype)

        offsets_filename = os.path.join(tmp_dir.name,
                                        'offsets.{}'.format(offsets_dtype.name))
        trx.streamlines._offsets = _create_memmap(offsets_filename, mode='w+',
                                                  shape=(nb_streamlines,),
                                                  dtype=offsets_dtype)
        trx.streamlines._lengths = np.zeros(shape=(nb_streamlines,),
                                            dtype=lengths_dtype)

        # Only the structure of fixed-size arrays is copied
        if init_as is not None:
            if len(init_as.data_per_vertex.keys()) > 0:
                os.mkdir(os.path.join(tmp_dir.name, 'dpv/'))
            if len(init_as.data_per_streamline.keys()) > 0:
                os.mkdir(os.path.join(tmp_dir.name, 'dps/'))

            for dpv_key in init_as.data_per_vertex.keys():
                dtype = init_as.data_per_vertex[dpv_key]._data.dtype
                tmp_as = init_as.data_per_vertex[dpv_key]._data
                if tmp_as.ndim == 1:
                    dpv_filename = os.path.join(tmp_dir.name, 'dpv/'
                                                '{}.{}'.format(dpv_key,
                                                               dtype.name))
                    shape = (nb_vertices, 1)
                elif tmp_as.ndim == 2:
                    dim = tmp_as.shape[-1]
                    shape = (nb_vertices, dim)
                    dpv_filename = os.path.join(tmp_dir.name, 'dpv/'
                                                '{}.{}.{}'.format(dpv_key,
                                                                  dim,
                                                                  dtype.name))
                else:
                    raise ValueError('Invalid dimensionality.')

                logging.debug('Initializing {} (dpv) with dtype: '
                              '{}'.format(dpv_key, dtype.name))
                trx.data_per_vertex[dpv_key] = ArraySequence()
                trx.data_per_vertex[dpv_key]._data = _create_memmap(dpv_filename,
                                                                    mode='w+',
                                                                    shape=shape,
                                                                    dtype=dtype)
                trx.data_per_vertex[dpv_key]._offsets = trx.streamlines._offsets
                trx.data_per_vertex[dpv_key]._lengths = trx.streamlines._lengths

            for dps_key in init_as.data_per_streamline.keys():
                dtype = init_as.data_per_streamline[dps_key].dtype
                tmp_as = init_as.data_per_streamline[dps_key]
                if tmp_as.ndim == 1:
                    dps_filename = os.path.join(tmp_dir.name, 'dps/'
                                                '{}.{}'.format(dps_key,
                                                               dtype.name))
                    shape = (nb_streamlines,)
                elif tmp_as.ndim == 2:
                    dim = tmp_as.shape[-1]
                    shape = (nb_streamlines, dim)
                    dps_filename = os.path.join(tmp_dir.name, 'dps/'
                                                '{}.{}.{}'.format(dps_key,
                                                                  dim,
                                                                  dtype.name))
                else:
                    raise ValueError('Invalid dimensionality.')

                logging.debug('Initializing {} (dps) with and dtype: '
                              '{}'.format(dps_key, dtype.name))
                trx.data_per_streamline[dps_key] = _create_memmap(dps_filename,
                                                                  mode='w+',
                                                                  shape=shape,
                                                                  dtype=dtype)

        trx._uncompressed_folder_handle = tmp_dir

        return trx
Exemplo n.º 21
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    set_sft_logger_level('ERROR')
    assert_inputs_exist(parser, [args.in_bundle, args.in_centroid],
                        optional=args.reference)
    assert_outputs_exist(parser,
                         args,
                         args.out_labels_map,
                         optional=[
                             args.out_labels_npz, args.out_distances_npz,
                             args.labels_color_dpp, args.distances_color_dpp
                         ])

    sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft_centroid = load_tractogram_with_reference(parser, args,
                                                  args.in_centroid)

    if not len(sft_bundle.streamlines):
        logging.error('Empty bundle file {}. '
                      'Skipping'.format(args.in_bundle))
        raise ValueError

    if len(sft_centroid.streamlines) < 1 \
            or len(sft_centroid.streamlines) > 1:
        logging.error('Centroid file {} should contain one streamline. '
                      'Skipping'.format(args.in_centroid))
        raise ValueError

    if not is_header_compatible(sft_centroid, sft_bundle):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_centroid, args.in_bundle))

    sft_bundle.to_vox()
    sft_bundle.to_corner()

    # Slightly cut the bundle at the edgge to clean up single streamline voxels
    # with no neighbor. Remove isolated voxels to keep a single 'blob'
    binary_bundle = compute_tract_counts_map(
        sft_bundle.streamlines, sft_bundle.dimensions).astype(bool)

    structure = ndi.generate_binary_structure(3, 1)
    if np.count_nonzero(binary_bundle) > args.min_voxel_count \
            and len(sft_bundle) > args.min_streamline_count:
        binary_bundle = ndi.binary_dilation(binary_bundle,
                                            structure=np.ones((3, 3, 3)))
        binary_bundle = ndi.binary_erosion(binary_bundle,
                                           structure=structure,
                                           iterations=2)

        bundle_disjoint, _ = ndi.label(binary_bundle)
        unique, count = np.unique(bundle_disjoint, return_counts=True)
        val = unique[np.argmax(count[1:]) + 1]
        binary_bundle[bundle_disjoint != val] = 0

        # Chop off some streamlines
        cut_sft = cut_outside_of_mask_streamlines(sft_bundle, binary_bundle)
    else:
        cut_sft = sft_bundle

    if args.nb_pts is not None:
        sft_centroid = resample_streamlines_num_points(sft_centroid,
                                                       args.nb_pts)
    else:
        args.nb_pts = len(sft_centroid.streamlines[0])

    # Generate a centroids labels mask for the centroid alone
    sft_centroid.to_vox()
    sft_centroid.to_corner()
    sft_centroid = _affine_slr(sft_bundle, sft_centroid)

    # Map every streamlines points to the centroids
    binary_centroid = compute_tract_counts_map(
        sft_centroid.streamlines, sft_centroid.dimensions).astype(bool)
    # TODO N^2 growth in RAM, should split it if we want to do nb_pts = 100
    min_dist_label, min_dist = min_dist_to_centroid(
        cut_sft.streamlines._data, sft_centroid.streamlines._data)
    min_dist_label += 1  # 0 means no labels

    # It is not allowed that labels jumps labels for consistency
    # Streamlines should have continous labels
    curr_ind = 0
    final_streamlines = []
    final_label = []
    final_dist = []
    for i, streamline in enumerate(cut_sft.streamlines):
        next_ind = curr_ind + len(streamline)
        curr_labels = min_dist_label[curr_ind:next_ind]
        curr_dist = min_dist[curr_ind:next_ind]
        curr_ind = next_ind

        # Flip streamlines so the labels increase (facilitate if/else)
        # Should always be ordered in nextflow pipeline
        gradient = np.gradient(curr_labels)
        if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)):
            streamline = streamline[::-1]
            curr_labels = curr_labels[::-1]
            curr_dist = curr_dist[::-1]

        # Find jumps, cut them and find the longest
        gradient = np.ediff1d(curr_labels)
        max_jump = max(args.nb_pts // 5, 1)
        if len(np.argwhere(np.abs(gradient) > max_jump)) > 0:
            pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1
            split_chunk = np.split(curr_labels, pos_jump)
            max_len = 0
            max_pos = 0
            for j, chunk in enumerate(split_chunk):
                if len(chunk) > max_len:
                    max_len = len(chunk)
                    max_pos = j

            curr_labels = split_chunk[max_pos]
            gradient_chunk = np.ediff1d(chunk)
            if len(np.unique(np.sign(gradient_chunk))) > 1:
                continue
            streamline = np.split(streamline, pos_jump)[max_pos]
            curr_dist = np.split(curr_dist, pos_jump)[max_pos]

        final_streamlines.append(streamline)
        final_label.append(curr_labels)
        final_dist.append(curr_dist)

    # Re-arrange the new cut streamlines and their metadata
    # Compute the voxels equivalent of the labels maps
    new_sft = StatefulTractogram.from_sft(final_streamlines, sft_bundle)

    tdi_mask_nzr = np.nonzero(binary_bundle)
    tdi_mask_nzr_ind = np.transpose(tdi_mask_nzr)
    min_dist_ind, _ = min_dist_to_centroid(tdi_mask_nzr_ind,
                                           sft_centroid.streamlines[0])
    img_labels = np.zeros(binary_centroid.shape, dtype=np.int16)
    img_labels[tdi_mask_nzr] = min_dist_ind + 1  # 0 is background value

    nib.save(nib.Nifti1Image(img_labels, sft_bundle.affine),
             args.out_labels_map)

    if args.labels_color_dpp or args.distances_color_dpp \
            or args.out_labels_npz or args.out_distances_npz:
        labels_array = ArraySequence(final_label)
        dist_array = ArraySequence(final_dist)
        # WARNING: WILL NOT WORK WITH THE INPUT TRK !
        # These will fit only with the TRK saved below.
        if args.out_labels_npz:
            np.savez_compressed(args.out_labels_npz, labels_array._data)
        if args.out_distances_npz:
            np.savez_compressed(args.out_distances_npz, dist_array._data)

        cmap = plt.get_cmap(args.colormap)
        new_sft.data_per_point['color'] = ArraySequence(new_sft.streamlines)

        # Nicer visualisation for MI-Brain
        if args.labels_color_dpp:
            new_sft.data_per_point['color']._data = cmap(
                labels_array._data / np.max(labels_array._data))[:, 0:3] * 255
            save_tractogram(new_sft, args.labels_color_dpp)

        if args.distances_color_dpp:
            new_sft.data_per_point['color']._data = cmap(
                dist_array._data / np.max(dist_array._data))[:, 0:3] * 255
            save_tractogram(new_sft, args.distances_color_dpp)
Exemplo n.º 22
0
def run_tracking(step_curv_combinations,
                 recon_path,
                 n_seeds_per_iter,
                 directget,
                 maxcrossing,
                 max_length,
                 pft_back_tracking_dist,
                 pft_front_tracking_dist,
                 particle_count,
                 roi_neighborhood_tol,
                 waymask,
                 min_length,
                 track_type,
                 min_separation_angle,
                 sphere,
                 tiss_class,
                 tissues4d,
                 cache_dir,
                 min_seeds=100):

    import gc
    import os
    import h5py
    from dipy.tracking import utils
    from dipy.tracking.streamline import select_by_rois
    from dipy.tracking.local_tracking import LocalTracking, \
        ParticleFilteringTracking
    from dipy.direction import (ProbabilisticDirectionGetter,
                                ClosestPeakDirectionGetter,
                                DeterministicMaximumDirectionGetter)
    from nilearn.image import index_img
    from pynets.dmri.track import prep_tissues
    from nibabel.streamlines.array_sequence import ArraySequence
    from nipype.utils.filemanip import copyfile, fname_presuffix
    import uuid
    from time import strftime

    run_uuid = f"{strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4()}"

    recon_path_tmp_path = fname_presuffix(
        recon_path,
        suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
        f"{run_uuid}",
        newpath=cache_dir)
    copyfile(recon_path, recon_path_tmp_path, copy=True, use_hardlink=False)

    tissues4d_tmp_path = fname_presuffix(
        tissues4d,
        suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
        f"{run_uuid}",
        newpath=cache_dir)
    copyfile(tissues4d, tissues4d_tmp_path, copy=True, use_hardlink=False)

    if waymask is not None:
        waymask_tmp_path = fname_presuffix(
            waymask,
            suffix=f"_{'_'.join([str(i) for i in step_curv_combinations])}_"
            f"{run_uuid}",
            newpath=cache_dir)
        copyfile(waymask, waymask_tmp_path, copy=True, use_hardlink=False)
    else:
        waymask_tmp_path = None

    tissue_img = nib.load(tissues4d_tmp_path)

    # Order:
    B0_mask = index_img(tissue_img, 0)
    atlas_img = index_img(tissue_img, 1)
    seeding_mask = index_img(tissue_img, 2)
    t1w2dwi = index_img(tissue_img, 3)
    gm_in_dwi = index_img(tissue_img, 4)
    vent_csf_in_dwi = index_img(tissue_img, 5)
    wm_in_dwi = index_img(tissue_img, 6)

    tiss_classifier = prep_tissues(t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                                   wm_in_dwi, tiss_class, B0_mask)

    B0_mask_data = np.asarray(B0_mask.dataobj).astype("bool")

    seeding_mask = np.asarray(
        seeding_mask.dataobj).astype("bool").astype("int16")

    with h5py.File(recon_path_tmp_path, 'r+') as hf:
        mod_fit = hf['reconstruction'][:].astype('float32')

    print("%s%s" % ("Curvature: ", step_curv_combinations[1]))

    # Instantiate DirectionGetter
    if directget.lower() in ["probabilistic", "prob"]:
        dg = ProbabilisticDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget.lower() in ["closestpeaks", "cp"]:
        dg = ClosestPeakDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    elif directget.lower() in ["deterministic", "det"]:
        maxcrossing = 1
        dg = DeterministicMaximumDirectionGetter.from_shcoeff(
            mod_fit,
            max_angle=float(step_curv_combinations[1]),
            sphere=sphere,
            min_separation_angle=min_separation_angle,
        )
    else:
        raise ValueError("ERROR: No valid direction getter(s) specified.")

    print("%s%s" % ("Step: ", step_curv_combinations[0]))

    # Perform wm-gm interface seeding, using n_seeds at a time
    seeds = utils.random_seeds_from_mask(
        seeding_mask > 0,
        seeds_count=n_seeds_per_iter,
        seed_count_per_voxel=False,
        affine=np.eye(4),
    )
    if len(seeds) < min_seeds:
        print(
            UserWarning(
                f"<{min_seeds} valid seed points found in wm-gm interface..."))
        return None

    # print(seeds)

    # Perform tracking
    if track_type == "local":
        streamline_generator = LocalTracking(dg,
                                             tiss_classifier,
                                             seeds,
                                             np.eye(4),
                                             max_cross=int(maxcrossing),
                                             maxlen=int(max_length),
                                             step_size=float(
                                                 step_curv_combinations[0]),
                                             fixedstep=False,
                                             return_all=True,
                                             random_seed=42)
    elif track_type == "particle":
        streamline_generator = ParticleFilteringTracking(
            dg,
            tiss_classifier,
            seeds,
            np.eye(4),
            max_cross=int(maxcrossing),
            step_size=float(step_curv_combinations[0]),
            maxlen=int(max_length),
            pft_back_tracking_dist=pft_back_tracking_dist,
            pft_front_tracking_dist=pft_front_tracking_dist,
            pft_max_trial=20,
            particle_count=particle_count,
            return_all=True,
            random_seed=42)
    else:
        raise ValueError("ERROR: No valid tracking method(s) specified.")

    # Filter resulting streamlines by those that stay entirely
    # inside the brain
    try:
        roi_proximal_streamlines = utils.target(streamline_generator,
                                                np.eye(4),
                                                B0_mask_data.astype('bool'),
                                                include=True)
    except BaseException:
        print('No streamlines found inside the brain! ' 'Check registrations.')
        return None

    del mod_fit, seeds, tiss_classifier, streamline_generator, \
        B0_mask_data, seeding_mask, dg

    B0_mask.uncache()
    atlas_img.uncache()
    t1w2dwi.uncache()
    gm_in_dwi.uncache()
    vent_csf_in_dwi.uncache()
    wm_in_dwi.uncache()
    atlas_img.uncache()
    tissue_img.uncache()
    gc.collect()

    # Filter resulting streamlines by roi-intersection
    # characteristics
    atlas_data = np.array(atlas_img.dataobj).astype("uint16")

    # Build mask vector from atlas for later roi filtering
    parcels = []
    i = 0
    intensities = [i for i in np.unique(atlas_data) if i != 0]
    for roi_val in intensities:
        parcels.append(atlas_data == roi_val)
        i += 1

    parcel_vec = list(np.ones(len(parcels)).astype("bool"))

    try:
        roi_proximal_streamlines = \
            nib.streamlines.array_sequence.ArraySequence(
                select_by_rois(
                    roi_proximal_streamlines,
                    affine=np.eye(4),
                    rois=parcels,
                    include=parcel_vec,
                    mode="any",
                    tol=roi_neighborhood_tol,
                )
            )
        print("%s%s" % ("Filtering by: \nNode intersection: ",
                        len(roi_proximal_streamlines)))
    except BaseException:
        print('No streamlines found to connect any parcels! '
              'Check registrations.')
        return None

    try:
        roi_proximal_streamlines = nib.streamlines. \
            array_sequence.ArraySequence(
                [
                    s for s in roi_proximal_streamlines
                    if len(s) >= float(min_length)
                ]
            )
        print(f"Minimum fiber length >{min_length}mm: "
              f"{len(roi_proximal_streamlines)}")
    except BaseException:
        print('No streamlines remaining after minimal length criterion.')
        return None

    if waymask is not None and os.path.isfile(waymask_tmp_path):
        waymask_data = np.asarray(
            nib.load(waymask_tmp_path).dataobj).astype("bool")
        try:
            roi_proximal_streamlines = roi_proximal_streamlines[utils.near_roi(
                roi_proximal_streamlines,
                np.eye(4),
                waymask_data,
                tol=int(round(roi_neighborhood_tol * 0.50, 1)),
                mode="all")]
            print("%s%s" %
                  ("Waymask proximity: ", len(roi_proximal_streamlines)))
            del waymask_data
        except BaseException:
            print('No streamlines remaining in waymask\'s vacinity.')
            return None

    hf.close()
    del parcels, atlas_data

    tmp_files = [tissues4d_tmp_path, waymask_tmp_path, recon_path_tmp_path]
    for j in tmp_files:
        if j is not None:
            if os.path.isfile(j):
                os.system(f"rm -f {j} &")

    if len(roi_proximal_streamlines) > 0:
        return ArraySequence(
            [s.astype("float32") for s in roi_proximal_streamlines])
    else:
        return None
Exemplo n.º 23
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Exemplo n.º 24
0
def run_rf_inference(config=None, gpu_queue=None):
    """"""
    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))

    print(
        "Loading DWI...")  ####################################################

    dwi_img = nib.load(config['dwi_path'])
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    dwi_aff = dwi_img.affine
    dwi_affi = np.linalg.inv(dwi_aff)
    dwi = dwi_img.get_data()

    def xyz2ijk(coords, snap=False):
        ijk = (coords.T).copy()
        dwi_affi.dot(ijk, out=ijk)
        if snap:
            return np.round(ijk, out=ijk).astype(int, copy=False).T
        else:
            return ijk.T

    with open(os.path.join(config['model_dir'], 'model'), 'rb') as f:
        model = pickle.load(f)

    train_config_file = os.path.join(config['model_dir'], 'config.yml')
    bvec_path = configs.load(train_config_file, 'bvecs')
    _, bvecs = read_bvals_bvecs(None, bvec_path)

    terminator = Terminator(config['term_path'], config['thresh'])

    prior = Prior(config['prior_path'])

    print(
        "Initializing Fibers...")  ############################################

    seed_file = nib.streamlines.load(config['seed_path'])
    xyz = seed_file.tractogram.streamlines.data
    n_seeds = 2 * len(xyz)
    xyz = np.vstack([xyz, xyz])  # Duplicate seeds for both directions
    xyz = np.hstack([xyz, np.ones([n_seeds, 1])])  # add affine dimension
    xyz = xyz.reshape(-1, 1, 4)  # (fiber, segment, coord)

    fiber_idx = np.hstack([
        np.arange(n_seeds // 2, dtype="int32"),
        np.arange(n_seeds // 2, dtype="int32")
    ])
    fibers = [[] for _ in range(n_seeds // 2)]

    print(
        "Start Iteration...")  ################################################

    input_shape = model.n_features_
    block_size = int(np.cbrt(input_shape / dwi.shape[-1]))

    d = np.zeros([n_seeds, dwi.shape[-1] * block_size**3])
    dnorm = np.zeros([n_seeds, 1])
    vout = np.zeros([n_seeds, 3])
    for i in range(config['max_steps']):
        t0 = time()

        # Get coords of latest segement for each fiber
        ijk = xyz2ijk(xyz[:, -1, :], snap=True)

        n_ongoing = len(ijk)

        for ii, idx in enumerate(ijk):
            d[ii] = dwi[idx[0] - (block_size // 2):idx[0] + (block_size // 2) +
                        1, idx[1] - (block_size // 2):idx[1] +
                        (block_size // 2) + 1,
                        idx[2] - (block_size // 2):idx[2] + (block_size // 2) +
                        1, :].flatten()  # returns copy
            dnorm[ii] = np.linalg.norm(d[ii])
            d[ii] /= dnorm[ii]

        if i == 0:
            inputs = np.hstack(
                [prior(xyz[:, 0, :]), d[:n_ongoing], dnorm[:n_ongoing]])
        else:
            inputs = np.hstack(
                [vout[:n_ongoing], d[:n_ongoing], dnorm[:n_ongoing]])

        chunk = 2**15  # 32768
        n_chunks = np.ceil(n_ongoing / chunk).astype(int)
        for c in range(n_chunks):

            outputs = model.predict(inputs[c * chunk:(c + 1) * chunk])
            v = bvecs[outputs, ...]
            vout[c * chunk:(c + 1) * chunk] = v

        rout = xyz[:, -1, :3] + config['step_size'] * vout
        rout = np.hstack([rout, np.ones((n_ongoing, 1))]).reshape(-1, 1, 4)

        xyz = np.concatenate([xyz, rout], axis=1)

        terminal_indices = terminator(xyz[:, -1, :])

        for idx in terminal_indices:
            gidx = fiber_idx[idx]
            # Other end not yet added
            if not fibers[gidx]:
                fibers[gidx].append(np.copy(xyz[idx, :, :3]))
            # Other end already added
            else:
                this_end = xyz[idx, :, :3]
                other_end = fibers[gidx][0]
                merged_fiber = np.vstack(
                    [np.flip(this_end[1:], axis=0),
                     other_end])  # stitch ends together
                fibers[gidx] = [merged_fiber]

        xyz = np.delete(xyz, terminal_indices, axis=0)
        vout = np.delete(vout, terminal_indices, axis=0)
        fiber_idx = np.delete(fiber_idx, terminal_indices)

        print(
            "Iter {:4d}/{}, finished {:5d}/{:5d} ({:3.0f}%) of all seeds with"
            " {:6.0f} steps/sec".format(
                (i + 1), config['max_steps'], n_seeds - n_ongoing, n_seeds,
                100 * (1 - n_ongoing / n_seeds), n_ongoing / (time() - t0)),
            end="\r")

        if n_ongoing == 0:
            break

        gc.collect()

    # Include unfinished fibers:

    fibers = [
        fibers[gidx] for gidx in range(len(fibers)) if gidx not in fiber_idx
    ]
    # Save Result

    fibers = [f[0] for f in fibers]

    tractogram = Tractogram(streamlines=ArraySequence(fibers),
                            affine_to_rasmm=np.eye(4))

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    out_dir = os.path.join(os.path.dirname(config["dwi_path"]),
                           "predicted_fibers", timestamp)

    configs.deep_update(config, {"out_dir": out_dir})

    os.makedirs(out_dir, exist_ok=True)

    fiber_path = os.path.join(out_dir, timestamp + ".trk")
    print("\nSaving {}".format(fiber_path))
    TrkFile(tractogram, seed_file.header).save(fiber_path)

    config_path = os.path.join(out_dir, "config.yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False)

    if config["score"]:
        score_on_tm(fiber_path)

    return tractogram
Exemplo n.º 25
0
def resample_tractogram(tractogram,
                        npts,
                        smoothing,
                        min_length=0,
                        max_length=1000):

    streamlines = tractogram.streamlines

    position = ArraySequence()
    tangent = ArraySequence()
    rows = 0

    def max_dist_from_mean(path):
        return np.linalg.norm(path - np.mean(path, axis=0, keepdims=True),
                              axis=1).max()

    n_fails = 0
    n_length = 0
    for i, f in enumerate(streamlines):

        flen = np.linalg.norm(f[1:] - f[:-1], axis=1).sum()
        if (flen < min_length) or (flen > max_length):
            n_length += 1
            continue

        r, t, cnt = fiber_geometry(f, npts=npts, smoothing=smoothing)

        if max_dist_from_mean(r) > 1.2 * max_dist_from_mean(f):
            n_fails += 1
            continue

        position.append(r, cache_build=True)
        tangent.append(t, cache_build=True)
        rows += cnt

        print("Finished {:3.0f}%".format(100 * (i + 1) / len(streamlines)),
              end="\r")

    if n_fails > 0:
        print("Failed to resample {} out of {} ".format(
            n_fails, len(streamlines)) + "fibers, they were not included.")

    if n_length > 0:
        print("{} out of {} ".format(n_length, len(streamlines)) +
              "fibers excluded by length.")

    position.finalize_append()
    tangent.finalize_append()

    other_data = {}
    if npts == "same":
        for key in list(tractogram.data_per_point.keys()):
            if key != "t":
                other_data[key] = tractogram.data_per_point[key]

    data_per_point = PerArraySequenceDict(n_rows=rows, t=tangent, **other_data)

    return Tractogram(
        streamlines=position,
        data_per_point=data_per_point,
        affine_to_rasmm=np.eye(
            4)  # Fiber coordinates are already in rasmm space!
    )
Exemplo n.º 26
0
def track_ensemble(target_samples, atlas_data_wm_gm_int, labels_im_file,
                   recon_path, sphere, directget, curv_thr_list, step_list,
                   track_type, maxcrossing, roi_neighborhood_tol, min_length,
                   waymask, B0_mask, t1w2dwi, gm_in_dwi, vent_csf_in_dwi,
                   wm_in_dwi, tiss_class, cache_dir):
    """
    Perform native-space ensemble tractography, restricted to a vector of ROI
    masks.

    target_samples : int
        Total number of streamline samples specified to generate streams.
    atlas_data_wm_gm_int : array
        3D int32 numpy array of atlas parcellation intensities from Nifti1Image
        in T1w-warped native diffusion space, restricted to wm-gm interface.
    parcels : list
        List of 3D boolean numpy arrays of atlas parcellation ROI masks from a
        Nifti1Image in T1w-warped native diffusion space.
    recon_path : str
        File path to diffusion reconstruction model.
    tiss_classifier : str
        Tissue classification method.
    sphere : obj
        DiPy object for modeling diffusion directions on a sphere.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), and prob (probabilistic).
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    maxcrossing : int
        Maximum number if diffusion directions that can be assumed per voxel
        while tracking.
    roi_neighborhood_tol : float
        Distance (in the units of the streamlines, usually mm). If any
        coordinate in the streamline is within this distance from the center
        of any voxel in the ROI, the filtering criterion is set to True for
        this streamline, otherwise False. Defaults to the distance between
        the center of each voxel and the corner of the voxel.
    min_length : int
        Minimum fiber length threshold in mm.
    waymask_data : ndarray
        Tractography constraint mask array in native diffusion space.
    B0_mask_data : ndarray
        B0 brain mask data.
    n_seeds_per_iter : int
        Number of seeds from which to initiate tracking for each unique
        ensemble combination. By default this is set to 250.
    max_length : int
        Maximum number of steps to restrict tracking.
    particle_count
        pft_back_tracking_dist : float
        Distance in mm to back track before starting the particle filtering
        tractography. The total particle filtering tractography distance is
        equal to back_tracking_dist + front_tracking_dist. By default this is
        set to 2 mm.
    pft_front_tracking_dist : float
        Distance in mm to run the particle filtering tractography after the
        the back track distance. The total particle filtering tractography
        distance is equal to back_tracking_dist + front_tracking_dist. By
        default this is set to 1 mm.
    particle_count : int
        Number of particles to use in the particle filter.
    min_separation_angle : float
        The minimum angle between directions [0, 90].

    Returns
    -------
    streamlines : ArraySequence
        DiPy list/array-like object of streamline points from tractography.

    References
    ----------
    .. [1] Takemura, H., Caiafa, C. F., Wandell, B. A., & Pestilli, F. (2016).
      Ensemble Tractography. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.1004692

    """
    import os
    import gc
    import time
    import pkg_resources
    import yaml
    import shutil
    from joblib import Parallel, delayed
    import itertools
    from pynets.dmri.track import run_tracking
    from colorama import Fore, Style
    from pynets.dmri.dmri_utils import generate_sl
    from nibabel.streamlines.array_sequence import concatenate, ArraySequence
    from pynets.core.utils import save_3d_to_4d

    cache_dir = f"{cache_dir}/joblib_tracking"
    os.makedirs(cache_dir, exist_ok=True)

    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
              "r") as stream:
        hardcoded_params = yaml.load(stream)
        nthreads = hardcoded_params["nthreads"][0]
        n_seeds_per_iter = \
            hardcoded_params['tracking']["n_seeds_per_iter"][0]
        max_length = \
            hardcoded_params['tracking']["max_length"][0]
        pft_back_tracking_dist = \
            hardcoded_params['tracking']["pft_back_tracking_dist"][0]
        pft_front_tracking_dist = \
            hardcoded_params['tracking']["pft_front_tracking_dist"][0]
        particle_count = \
            hardcoded_params['tracking']["particle_count"][0]
        min_separation_angle = \
            hardcoded_params['tracking']["min_separation_angle"][0]
    stream.close()

    all_combs = list(itertools.product(step_list, curv_thr_list))

    tissues4d = save_3d_to_4d([
        B0_mask, labels_im_file, atlas_data_wm_gm_int, t1w2dwi, gm_in_dwi,
        vent_csf_in_dwi, wm_in_dwi
    ])

    # Commence Ensemble Tractography
    start = time.time()
    stream_counter = 0

    all_streams = []
    ix = 0
    while float(stream_counter) < float(target_samples) and \
        float(ix) < 0.75*float(len(all_combs)):
        with Parallel(n_jobs=nthreads,
                      backend='loky',
                      mmap_mode='r+',
                      temp_folder=cache_dir,
                      verbose=10) as parallel:
            out_streams = parallel(
                delayed(run_tracking)
                (i, recon_path, n_seeds_per_iter, directget, maxcrossing,
                 max_length, pft_back_tracking_dist, pft_front_tracking_dist,
                 particle_count, roi_neighborhood_tol, waymask, min_length,
                 track_type, min_separation_angle, sphere, tiss_class,
                 tissues4d, cache_dir) for i in all_combs)

            out_streams = [
                i for i in out_streams
                if i is not None and i is not ArraySequence() and len(i) > 0
            ]

            if len(out_streams) > 1:
                out_streams = concatenate(out_streams, axis=0)

            if len(out_streams) < 50:
                ix += 1
                print("Fewer than 100 streamlines tracked on last iteration."
                      " loosening tolerance and anatomical constraints...")
                if track_type != 'particle':
                    tiss_class = 'wb'
                roi_neighborhood_tol = float(roi_neighborhood_tol) * 1.05
                min_length = float(min_length) * 0.95
                continue
            else:
                ix -= 1

            # Append streamline generators to prevent exponential growth
            # in memory consumption
            all_streams.extend([generate_sl(i) for i in out_streams])
            stream_counter += len(out_streams)
            del out_streams

            print("%s%s%s%s" % (
                "\nCumulative Streamline Count: ",
                Fore.CYAN,
                stream_counter,
                "\n",
            ))
            gc.collect()
            print(Style.RESET_ALL)

    if ix >= 0.75*len(all_combs) and \
        float(stream_counter) < float(target_samples):
        print(f"Tractography failed. >{len(all_combs)} consecutive sampling "
              f"iterations with <50 streamlines. Are you using a waymask? "
              f"If so, it may be too restrictive.")
        return ArraySequence()
    else:
        print("Tracking Complete: ", str(time.time() - start))

    del parallel, all_combs
    shutil.rmtree(cache_dir, ignore_errors=True)

    if stream_counter != 0:
        print('Generating final ArraySequence...')
        return ArraySequence([ArraySequence(i) for i in all_streams])
    else:
        print('No streamlines generated!')
        return ArraySequence()