示例#1
0
文件: bundles.py 项目: grlee77/dipy
    def _cluster_streamlines(self, clust_thr, nb_pts):

        if self.verbose:
            t = time()
            print('# Cluster streamlines using QBx')
            print(' Tractogram has %d streamlines'
                  % (len(self.streamlines), ))
            print(' Size is %0.3f MB' % (nbytes(self.streamlines),))
            print(' Distance threshold %0.3f' % (clust_thr,))

        # TODO this needs to become a default parameter
        thresholds = self.start_thr + [clust_thr]

        merged_cluster_map = qbx_and_merge(self.streamlines, thresholds,
                                           nb_pts, None, self.rng,
                                           self.verbose)

        self.cluster_map = merged_cluster_map
        self.centroids = merged_cluster_map.centroids
        self.nb_centroids = len(self.centroids)
        self.indices = [cluster.indices for cluster in self.cluster_map]

        if self.verbose:
            print(' Streamlines have %d centroids'
                  % (self.nb_centroids,))
            print(' Total duration %0.3f sec. \n' % (time() - t,))
示例#2
0
    def _cluster_model_bundle(self,
                              model_bundle,
                              model_clust_thr,
                              nb_pts=20,
                              select_randomly=500000):

        if self.verbose:
            t = time()
            print('# Cluster model bundle using QBX')
            print(' Model bundle has %d streamlines' % (len(model_bundle), ))
            print(' Distance threshold %0.3f' % (model_clust_thr, ))
        thresholds = self.start_thr + [model_clust_thr]

        model_cluster_map = qbx_and_merge(model_bundle,
                                          thresholds,
                                          nb_pts=nb_pts,
                                          select_randomly=select_randomly,
                                          rng=self.rng,
                                          verbose=self.verbose)
        model_centroids = model_cluster_map.centroids
        nb_model_centroids = len(model_centroids)

        if self.verbose:
            print(' Model bundle has %d centroids' % (nb_model_centroids, ))
            print(' Duration %0.3f sec. \n' % (time() - t, ))
        return model_centroids
示例#3
0
def test_rb_clustermap():

    cluster_map = qbx_and_merge(f, thresholds=[40, 25, 20, 10])

    rb = RecoBundles(f, greater_than=0, less_than=1000000,
                     cluster_map=cluster_map, clust_thr=10)
    rec_trans, rec_labels = rb.recognize(model_bundle=f2,
                                         model_clust_thr=5.,
                                         reduction_thr=10)

    D = bundles_distances_mam(f2, f[rec_labels])

    # check if the bundle is recognized correctly
    if len(f2) == len(rec_labels):
        for row in D:
            assert_equal(row.min(), 0)

    refine_trans, refine_labels = rb.refine(model_bundle=f2,
                                            pruned_streamlines=rec_trans,
                                            model_clust_thr=5.,
                                            reduction_thr=10)

    D = bundles_distances_mam(f2, f[refine_labels])

    # check if the bundle is recognized correctly
    for row in D:
        assert_equal(row.min(), 0)
示例#4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, [], optional=args.output_centroids)
    if args.output_clusters_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.output_clusters_dir,
                                           create_dir=True)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    streamlines = sft.streamlines
    thresholds = [40, 30, 20, args.dist_thresh]
    clusters = qbx_and_merge(streamlines,
                             thresholds,
                             nb_pts=args.nb_points,
                             verbose=False)

    for i, cluster in enumerate(clusters):
        if len(cluster.indices) > 1:
            cluster_streamlines = itemgetter(*cluster.indices)(streamlines)
        else:
            cluster_streamlines = streamlines[cluster.indices]

        new_sft = StatefulTractogram(cluster_streamlines, sft, Space.RASMM)
        save_tractogram(
            new_sft,
            os.path.join(args.output_clusters_dir, 'cluster_{}.trk'.format(i)))

    if args.output_centroids:
        new_sft = StatefulTractogram(clusters.centroids, sft, Space.RASMM)
        save_tractogram(new_sft, args.output_centroids)
示例#5
0
def test_rb_clustermap():

    cluster_map = qbx_and_merge(f, thresholds=[40, 25, 20, 10])

    rb = RecoBundles(f, greater_than=0, less_than=1000000,
                     cluster_map=cluster_map, clust_thr=10)
    rec_trans, rec_labels = rb.recognize(model_bundle=f2,
                                         model_clust_thr=5.,
                                         reduction_thr=10)

    D = bundles_distances_mam(f2, f[rec_labels])

    # check if the bundle is recognized correctly
    if len(f2) == len(rec_labels):
        for row in D:
            assert_equal(row.min(), 0)

    refine_trans, refine_labels = rb.refine(model_bundle=f2,
                                            pruned_streamlines=rec_trans,
                                            model_clust_thr=5.,
                                            reduction_thr=10)

    D = bundles_distances_mam(f2, f[refine_labels])

    # check if the bundle is recognized correctly
    for row in D:
        assert_equal(row.min(), 0)
示例#6
0
def cluster_bundle(bundle, clust_thr, rng, nb_pts=20, select_randomly=500000):
    """ Clusters bundles

    Parameters
    ----------
    bundle : Streamlines
        White matter tract
    clust_thr : float
        clustering threshold used in quickbundlesX
    rng : RandomState
    nb_pts: integer (default 20)
        Discretizing streamlines to have nb_points number of points
    select_randomly: integer (default 500000)
        Randomly select streamlines from the input bundle

    Returns
    -------
    centroids : Streamlines
        clustered centroids of the input bundle

    References
    ----------
    .. [Garyfallidis12] Garyfallidis E. et al., QuickBundles a method for
                        tractography simplification, Frontiers in Neuroscience,
                        vol 6, no 175, 2012.
   """

    model_cluster_map = qbx_and_merge(bundle, clust_thr,
                                      nb_pts=nb_pts,
                                      select_randomly=select_randomly,
                                      rng=rng)
    centroids = model_cluster_map.centroids

    return centroids
示例#7
0
    def _cluster_streamlines(self, clust_thr, nb_pts):

        if self.verbose:
            t = time()
            logger.info('# Cluster streamlines using QBx')
            logger.info(' Tractogram has %d streamlines'
                        % (len(self.streamlines), ))
            logger.info(' Size is %0.3f MB' % (nbytes(self.streamlines),))
            logger.info(' Distance threshold %0.3f' % (clust_thr,))

        # TODO this needs to become a default parameter
        thresholds = self.start_thr + [clust_thr]

        merged_cluster_map = qbx_and_merge(self.streamlines, thresholds,
                                           nb_pts, None, self.rng,
                                           self.verbose)

        self.cluster_map = merged_cluster_map
        self.centroids = merged_cluster_map.centroids
        self.nb_centroids = len(self.centroids)
        self.indices = [cluster.indices for cluster in self.cluster_map]

        if self.verbose:
            logger.info(' Streamlines have %d centroids'
                        % (self.nb_centroids,))
            logger.info(' Total duration %0.3f sec. \n' % (time() - t,))
示例#8
0
    def prune_far_from_model(self,
                             bundle_pruning_thr=10,
                             neighbors_cluster_thr=8):
        """
        Wrapper function to prune clusters from the tractogram too far from
        the model.
        :param neighbors_to_prune, list or arraySequence, streamlines to prune.
        :param bundle_pruning_thr, float, distance in mm for pruning.
        :param neighbors_cluster_thr, float, distance in mm for clustering.
        """
        # Neighbors can be refined since the search space is smaller
        thresholds = [32, 16, neighbors_cluster_thr]

        neighb_cluster_map = qbx_and_merge(self.neighb_streamlines,
                                           thresholds,
                                           nb_pts=self.nb_points,
                                           rng=self.rng,
                                           verbose=False)

        dist_matrix = bundles_distances_mdf(self.model_centroids,
                                            neighb_cluster_map.centroids)
        dist_matrix[np.isnan(dist_matrix)] = np.inf
        dist_matrix[dist_matrix > bundle_pruning_thr] = np.inf
        mins = np.min(dist_matrix, axis=0)

        pruned_indices = np.fromiter(chain(*[
            neighb_cluster_map[i].indices for i in np.where(mins != np.inf)[0]
        ]),
                                     dtype=np.int32)

        # Since the neighbors were clustered, a mapping of indices is neccesary
        self.final_pruned_indices = self.neighb_indices[pruned_indices]

        return self.final_pruned_indices
示例#9
0
def remove_loops_and_sharp_turns(streamlines,
                                 max_angle,
                                 use_qb=False,
                                 qb_threshold=15.,
                                 qb_seed=0):
    """
    Remove loops and sharp turns from a list of streamlines.
    Parameters
    ----------
    streamlines: list of ndarray
        The list of streamlines from which to remove loops and sharp turns.
    max_angle: float
        Maximal winding angle a streamline can have before
        being classified as a loop.
    use_qb: bool
        Set to True if the additional QuickBundles pass is done.
        This will help remove sharp turns. Should only be used on
        bundled streamlines, not on whole-brain tractograms.
    qb_threshold: float
        Quickbundles distance threshold, only used if use_qb is True.
    qb_seed: int
        Seed to initialize randomness in QuickBundles

    Returns
    -------
    list: the ids of clean streamlines
        Only the ids are returned so proper filtering can be done afterwards
    """

    streamlines_clean = []
    ids = []
    for i, s in enumerate(streamlines):
        if tm.winding(s) < max_angle:
            ids.append(i)
            streamlines_clean.append(s)

    if use_qb:
        ids = []
        if len(streamlines_clean) > 1:
            curvature = []

            rng = np.random.RandomState(qb_seed)
            clusters = qbx_and_merge(streamlines_clean,
                                     [40, 30, 20, qb_threshold],
                                     rng=rng,
                                     verbose=False)

            for cc in clusters.centroids:
                curvature.append(tm.mean_curvature(cc))
            mean_curvature = sum(curvature) / len(curvature)

            for i in range(len(clusters.centroids)):
                if tm.mean_curvature(clusters.centroids[i]) <= mean_curvature:
                    ids.extend(clusters[i].indices)
        else:
            logging.debug("Impossible to use the use_qb option because " +
                          "not more than one streamline left from the\n" +
                          "input file.")
    return ids
示例#10
0
def single_clusterize_and_rbx_init(args):
    """
    Wrapper function to multiprocess clustering executions and recobundles
    initialisation.

    Parameters
    ----------
    wb_streamlines : list or ArraySequence
        All streamlines of the tractogram to segment.
    tmp_memmap_filename: tuple (3)
        Temporary filename for the data, offsets and lengths.

    parameters_list : tuple (3)
        clustering_thr : int
            Distance in mm (for QBx) to cluster the input tractogram.
        seed : int
            Value to initialize the RandomState of numpy.
        nb_points : int
            Number of points used for all resampling of streamlines.

    Returns
    -------
    rbx : dict
        Initialisation of the recobundles class using specific parameters.
    """
    wb_streamlines = args[0]
    tmp_memmap_filename = args[1]
    clustering_thr = args[2][0]
    seed = args[2][1]
    nb_points = args[3]

    rbx = {}
    base_thresholds = [45, 35, 25]
    rng = np.random.RandomState(seed)
    cluster_timer = time()
    # If necessary, add an extra layer (more optimal)
    if clustering_thr < 15:
        current_thr_list = base_thresholds + [15, clustering_thr]
    else:
        current_thr_list = base_thresholds + [clustering_thr]

    cluster_map = qbx_and_merge(wb_streamlines,
                                current_thr_list,
                                nb_pts=nb_points, rng=rng,
                                verbose=False)
    clusters_indices = []
    for cluster in cluster_map.clusters:
        clusters_indices.append(cluster.indices)
    centroids = list(cluster_map.centroids)

    rbx[(seed, clustering_thr)] = RecobundlesX(tmp_memmap_filename,
                                               clusters_indices, centroids,
                                               nb_points=nb_points,
                                               rng=rng)
    logging.info('QBx with seed {0} at {1}mm took {2}sec. gave '
                 '{3} centroids'.format(seed, current_thr_list,
                                        round(time() - cluster_timer, 2),
                                        len(cluster_map.centroids)))
    return rbx
示例#11
0
def test_rb_clustermap():

    cluster_map = qbx_and_merge(f, thresholds=[40, 25, 20, 10])

    rb = RecoBundles(f, cluster_map=cluster_map, clust_thr=10)
    rec_trans, rec_labels, recognized = rb.recognize(model_bundle=f2,
                                                     model_clust_thr=5.,
                                                     reduction_thr=10)
    D = bundles_distances_mam(f2, recognized)

    # check if the bundle is recognized correctly
    for row in D:
        assert_equal(row.min(), 0)
示例#12
0
def run_rb(template, bucket, cluster_map=None, pruning_thr=10):
    # try pruning thresh 10 if not specific drop to 5
    if cluster_map is None:
        cluster_map = qbx_and_merge(bucket, thresholds=[40, 25, 20, 10])
    else:
        print("Loading provided cluster map")

    rb = RecoBundles(bucket, cluster_map=cluster_map, clust_thr=5)
    bundle_tsp, labels, bundle_bsp = rb.recognize(model_bundle=template,
                                                  model_clust_thr=5.,
                                                  reduction_thr=10,
                                                  pruning_thr=pruning_thr)
    return bundle_bsp, cluster_map
def multiprocess_subsampling(args):
    streamlines = args[0]
    min_distance = args[1]
    cluster_thr = args[2]
    min_cluster_size = args[3]
    average_streamlines = args[4]

    min_cluster_size = max(min_cluster_size, 1)
    thresholds = [40, 30, 20, cluster_thr]
    cluster_map = qbx_and_merge(ArraySequence(streamlines),
                                thresholds,
                                nb_pts=20,
                                verbose=False)

    return subsample_clusters(cluster_map, streamlines, min_distance,
                              min_cluster_size, average_streamlines)
示例#14
0
def test_qbx_and_merge():

    # Generate synthetic streamlines
    bundles = bearing_bundles(4, 2)
    bundles.append(straight_bundle(1))

    streamlines = Streamlines(list(itertools.chain(*bundles)))

    thresholds = [10, 2, 1]

    rng = np.random.RandomState(seed=42)
    qbxm_centroids = qbx_and_merge(streamlines, thresholds, rng=rng).centroids

    qbx = QuickBundlesX(thresholds)
    tree = qbx.cluster(streamlines)
    qbx_centroids = tree.get_clusters(3).centroids

    assert_equal(len(qbx_centroids) > len(qbxm_centroids), True)
示例#15
0
 def _cluster_model_bundle(self, model, model_clust_thr, identifier=None):
     """
     Wrapper function to compute QBx for the model and logging informations.
     :param model, list or arraySequence, streamlines to be used as model.
     :param model_clust_thr, float, distance in mm for clustering.
     :param identifier, str, name of the bundle for logging.
     """
     thresholds = [30, 20, 15, model_clust_thr]
     model_cluster_map = qbx_and_merge(model, thresholds,
                                       nb_pts=self.nb_points,
                                       rng=self.rng,
                                       verbose=False)
     self.model_centroids = model_cluster_map.centroids
     len_centroids = len(self.model_centroids)
     if len_centroids > 1000:
         logging.warning('Model {0} simplified at threshold '
                         '{1}mm with {2} centroids'.format(identifier,
                                                           str(model_clust_thr),
                                                           str(len_centroids)))
示例#16
0
def run_rb(templatesls, bucketosls, cluster_map=None, pruning_thr=10):
    # try pruning thresh 10 if not specific drop to 5
    if cluster_map is None:
        cluster_map = qbx_and_merge(bucketosls, thresholds=[40, 25, 20, 10])
    else:
        print("Loading provided cluster map")

    rb = RecoBundles(bucketosls, cluster_map=cluster_map, clust_thr=5)
    recognized_atlassp, rec_labels, recognized_ptsp = rb.recognize(
        model_bundle=templatesls,
        model_clust_thr=5.,
        reduction_thr=10,
        pruning_thr=pruning_thr)
    '''rb = RecoBundles(bucketosls, cluster_map=cluster_map, clust_thr=10)
    recognized, rec_labels, rec_trans = rb.recognize(model_bundle=templatesls,
                                                         model_clust_thr=1.)'''
    #D = bundles_distances_mam(templatesls, recognized)

    return recognized_ptsp, cluster_map
示例#17
0
    def _prune_what_not_in_model(self,
                                 neighbors_to_prune,
                                 bundle_pruning_thr=10,
                                 neighbors_cluster_thr=8):
        """
        Wrapper function to prune clusters from the tractogram too far from
        the model
        :param neighbors_to_prune, list or arraySequence, streamlines to prune
        :param bundle_pruning_thr, float, distance in mm for pruning
        :param neighbors_cluster_thr, float, distance in mm for clustering
        """
        # Neighbors can be refined since the search space is smaller
        thresholds = [40, 30, 20, neighbors_cluster_thr]
        self.rtransf_cluster_map = qbx_and_merge(neighbors_to_prune,
                                                 thresholds,
                                                 nb_pts=self.nb_points,
                                                 rng=self.rng,
                                                 verbose=False)

        dist_matrix = bundles_distances_mdf(self.model_centroids,
                                            self.rtransf_cluster_map.centroids)
        dist_matrix[np.isnan(dist_matrix)] = np.inf
        dist_matrix[dist_matrix > bundle_pruning_thr] = np.inf
        mins = np.min(dist_matrix, axis=0)

        pruned_clusters = [
            self.rtransf_cluster_map[i].indices
            for i in np.where(mins != np.inf)[0]
        ]
        pruned_indices = list(chain(*pruned_clusters))
        pruned_streamlines = [neighbors_to_prune[i] for i in pruned_indices]

        self.pruned_streamlines = pruned_streamlines
        initial_indices = list(chain(*self.neighb_indices))

        # Since the neighbors were clustered, a mapping of indices is neccesary
        final_indices = []
        for i in range(len(pruned_clusters)):
            final_indices.extend(
                [initial_indices[i] for i in pruned_clusters[i]])

        return final_indices
示例#18
0
文件: bundles.py 项目: grlee77/dipy
    def _cluster_model_bundle(self, model_bundle, model_clust_thr, nb_pts=20,
                              select_randomly=500000):

        if self.verbose:
            t = time()
            print('# Cluster model bundle using QBX')
            print(' Model bundle has %d streamlines'
                  % (len(model_bundle), ))
            print(' Distance threshold %0.3f' % (model_clust_thr,))
        thresholds = self.start_thr + [model_clust_thr]

        model_cluster_map = qbx_and_merge(model_bundle, thresholds,
                                          nb_pts=nb_pts,
                                          select_randomly=select_randomly,
                                          rng=self.rng,
                                          verbose=self.verbose)
        model_centroids = model_cluster_map.centroids
        nb_model_centroids = len(model_centroids)

        if self.verbose:
            print(' Model bundle has %d centroids'
                  % (nb_model_centroids,))
            print(' Duration %0.3f sec. \n' % (time() - t, ))
        return model_centroids
示例#19
0
def load_data_tmp_saving(args):
    filename = args[0]
    reference = args[1]
    init_only = args[2]
    disable_centroids = args[3]

    # Since data is often re-use when comparing multiple bundles, anything
    # that can be computed once is saved temporarily and simply loaded on demand
    hash_tmp = hashlib.md5(filename.encode()).hexdigest()
    tmp_density_filename = os.path.join('tmp_measures/',
                                        '{}_density.nii.gz'.format(hash_tmp))
    tmp_endpoints_filename = os.path.join('tmp_measures/',
                                          '{}_endpoints.nii.gz'.format(hash_tmp))
    tmp_centroids_filename = os.path.join('tmp_measures/',
                                          '{}_centroids.trk'.format(hash_tmp))

    sft = load_tractogram(filename, reference)
    sft.to_vox()
    sft.to_corner()
    streamlines = sft.get_streamlines_copy()
    if not streamlines:
        if init_only:
            logging.warning('{} is empty'.format(filename))
        return None

    if os.path.isfile(tmp_density_filename) \
            and os.path.isfile(tmp_endpoints_filename) \
            and os.path.isfile(tmp_centroids_filename):
        # If initilization, loading the data is useless
        if init_only:
            return None
        density = nib.load(tmp_density_filename).get_fdata(dtype=np.float32)
        endpoints_density = nib.load(tmp_endpoints_filename).get_fdata(dtype=np.float32)
        sft_centroids = load_tractogram(tmp_centroids_filename, reference)
        sft_centroids.to_vox()
        sft_centroids.to_corner()
        centroids = sft_centroids.get_streamlines_copy()
    else:
        transformation, dimensions, _, _ = sft.space_attributes
        density = compute_tract_counts_map(streamlines, dimensions)
        endpoints_density = get_endpoints_density_map(streamlines, dimensions,
                                                      point_to_select=3)
        thresholds = [32, 24, 12, 6]
        if disable_centroids:
            centroids = []
        else:
            centroids = qbx_and_merge(streamlines, thresholds,
                                      rng=RandomState(0),
                                      verbose=False).centroids

        # Saving tmp files to save on future computation
        nib.save(nib.Nifti1Image(density.astype(np.float32), transformation),
                 tmp_density_filename)
        nib.save(nib.Nifti1Image(endpoints_density.astype(np.int16),
                                 transformation),
                 tmp_endpoints_filename)

        # Saving in vox space and corner.
        centroids_sft = StatefulTractogram.from_sft(centroids, sft)
        save_tractogram(centroids_sft, tmp_centroids_filename)

    return density, endpoints_density, streamlines, centroids
示例#20
0
def compute_bundle_adjacency_streamlines(bundle_1,
                                         bundle_2,
                                         non_overlap=False,
                                         centroids_1=None,
                                         centroids_2=None):
    """
    Compute the distance in millimeters between two bundles. Uses centroids
    to limit computation time. Each centroid of the first bundle is matched
    to the nearest centroid of the second bundle and vice-versa.
    Distance between matched paired is averaged for the final results.
    References
    ----------
    .. [Garyfallidis15] Garyfallidis et al. Robust and efficient linear
        registration of white-matter fascicles in the space of streamlines,
        Neuroimage, 2015.
    Parameters
    ----------
    bundle_1: list of ndarray
        First set of streamlines.
    bundle_2: list of ndarray
        Second set of streamlines.
    non_overlap: bool
        Exclude overlapping streamlines from the computation.
    centroids_1: list of ndarray
        Pre-computed centroids for the first bundle.
    centroids_2: list of ndarray
        Pre-computed centroids for the second bundle.
    Returns
    -------
    float: Distance in millimeters between both bundles.
    """
    if not bundle_1 or not bundle_2:
        return -1
    thresholds = [32, 24, 12, 6]
    # Intialize the clusters
    if centroids_1 is None:
        centroids_1 = qbx_and_merge(bundle_1,
                                    thresholds,
                                    rng=RandomState(0),
                                    verbose=False).centroids
    if centroids_2 is None:
        centroids_2 = qbx_and_merge(bundle_2,
                                    thresholds,
                                    rng=RandomState(0),
                                    verbose=False).centroids
    if non_overlap:
        non_overlap_1, _ = difference_robust([bundle_1, bundle_2])
        non_overlap_2, _ = difference_robust([bundle_2, bundle_1])

        if non_overlap_1:
            non_overlap_centroids_1 = qbx_and_merge(non_overlap_1,
                                                    thresholds,
                                                    rng=RandomState(0),
                                                    verbose=False).centroids
            distance_matrix_1 = bundles_distances_mdf(non_overlap_centroids_1,
                                                      centroids_2)

            min_b1 = np.min(distance_matrix_1, axis=0)
            distance_b1 = np.average(min_b1)
        else:
            distance_b1 = 0

        if non_overlap_2:
            non_overlap_centroids_2 = qbx_and_merge(non_overlap_2,
                                                    thresholds,
                                                    rng=RandomState(0),
                                                    verbose=False).centroids
            distance_matrix_2 = bundles_distances_mdf(centroids_1,
                                                      non_overlap_centroids_2)
            min_b2 = np.min(distance_matrix_2, axis=1)
            distance_b2 = np.average(min_b2)
        else:
            distance_b2 = 0

    else:
        distance_matrix = bundles_distances_mdf(centroids_1, centroids_2)
        min_b1 = np.min(distance_matrix, axis=0)
        min_b2 = np.min(distance_matrix, axis=1)
        distance_b1 = np.average(min_b1)
        distance_b2 = np.average(min_b2)

    return (distance_b1 + distance_b2) / 2.0
示例#21
0
文件: app.py 项目: theNaavik/dipy
    def add_cluster_actors(self, scene, tractograms,
                           threshold, enable_callbacks=True):
        """ Add streamline actors to the scene

        Parameters
        ----------
        scene : Scene
        tractograms : list
            list of tractograms
        threshold : float
            Cluster threshold
        enable_callbacks : bool
            Enable callbacks for selecting clusters
        """
        color_gen = distinguishable_colormap()
        for (t, sft) in enumerate(tractograms):
            streamlines = sft.streamlines

            if self.random_colors:
                colors = next(color_gen)
            else:
                colors = None

            if not self.world_coords:
                # TODO we need to read the affine of a tractogram
                # from a StatefullTractogram
                msg = 'Currently native coordinates are not supported'
                msg += ' for streamlines'
                raise ValueError(msg)

            if self.cluster:

                print(' Clustering threshold {} \n'.format(threshold))
                clusters = qbx_and_merge(streamlines,
                                         [40, 30, 25, 20, threshold])
                self.tractogram_clusters[t] = clusters
                centroids = clusters.centroids
                print(' Number of centroids is {}'.format(len(centroids)))
                sizes = np.array([len(c) for c in clusters])
                linewidths = np.interp(sizes,
                                       [sizes.min(), sizes.max()], [0.1, 2.])
                centroid_lengths = np.array([length(c) for c in centroids])

                print(' Minimum number of streamlines in cluster {}'
                      .format(sizes.min()))

                print(' Maximum number of streamlines in cluster {}'
                      .format(sizes.max()))

                print(' Construct cluster actors')
                for (i, c) in enumerate(centroids):

                    centroid_actor = actor.streamtube([c], colors,
                                                      linewidth=linewidths[i],
                                                      lod=False)
                    scene.add(centroid_actor)
                    self.mem.centroid_actors.append(centroid_actor)

                    cluster_actor = actor.line(clusters[i],
                                               lod=False)
                    cluster_actor.GetProperty().SetRenderLinesAsTubes(1)
                    cluster_actor.GetProperty().SetLineWidth(6)
                    cluster_actor.GetProperty().SetOpacity(1)
                    cluster_actor.VisibilityOff()

                    scene.add(cluster_actor)
                    self.mem.cluster_actors.append(cluster_actor)

                    # Every centroid actor (cea) is paired to a cluster actor
                    # (cla).

                    self.cea[centroid_actor] = {
                        'cluster_actor': cluster_actor,
                        'cluster': i, 'tractogram': t,
                        'size': sizes[i], 'length': centroid_lengths[i],
                        'selected': 0, 'expanded': 0}

                    self.cla[cluster_actor] = {
                        'centroid_actor': centroid_actor,
                        'cluster': i, 'tractogram': t,
                        'size': sizes[i], 'length': centroid_lengths[i],
                        'selected': 0, 'highlighted': 0}
                    apply_shader(self, cluster_actor)
                    apply_shader(self, centroid_actor)

            else:

                streamline_actor = actor.line(streamlines, colors=colors)
                streamline_actor.GetProperty().SetEdgeVisibility(1)
                streamline_actor.GetProperty().SetRenderLinesAsTubes(1)
                streamline_actor.GetProperty().SetLineWidth(6)
                streamline_actor.GetProperty().SetOpacity(1)
                scene.add(streamline_actor)
                self.mem.streamline_actors.append(streamline_actor)

        if not enable_callbacks:
            return

        def left_click_centroid_callback(obj, event):

            self.cea[obj]['selected'] = not self.cea[obj]['selected']
            self.cla[self.cea[obj]['cluster_actor']]['selected'] = \
                self.cea[obj]['selected']
            self.show_m.render()

        def left_click_cluster_callback(obj, event):

            if self.cla[obj]['selected']:
                self.cla[obj]['centroid_actor'].VisibilityOn()
                ca = self.cla[obj]['centroid_actor']
                self.cea[ca]['selected'] = 0
                obj.VisibilityOff()
                self.cea[ca]['expanded'] = 0

            self.show_m.render()

        for cl in self.cla:
            cl.AddObserver('LeftButtonPressEvent', left_click_cluster_callback,
                           1.0)
            self.cla[cl]['centroid_actor'].AddObserver(
                'LeftButtonPressEvent', left_click_centroid_callback, 1.0)
示例#22
0
    def build_scene(self):

        scene = window.Renderer()
        for (t, streamlines) in enumerate(self.tractograms):
            if self.random_colors:
                colors = self.prng.random_sample(3)
            else:
                colors = None

            if self.cluster:

                print(' Clustering threshold {} \n'.format(self.cluster_thr))
                clusters = qbx_and_merge(streamlines,
                                         [40, 30, 25, 20, self.cluster_thr])
                self.tractogram_clusters[t] = clusters
                centroids = clusters.centroids
                print(' Number of centroids is {}'.format(len(centroids)))
                sizes = np.array([len(c) for c in clusters])
                linewidths = np.interp(sizes,
                                       [sizes.min(), sizes.max()], [0.1, 2.])
                centroid_lengths = np.array([length(c) for c in centroids])

                print(' Minimum number of streamlines in cluster {}'.format(
                    sizes.min()))

                print(' Maximum number of streamlines in cluster {}'.format(
                    sizes.max()))

                print(' Construct cluster actors')
                for (i, c) in enumerate(centroids):

                    centroid_actor = actor.streamtube([c],
                                                      colors,
                                                      linewidth=linewidths[i],
                                                      lod=False)
                    scene.add(centroid_actor)

                    cluster_actor = actor.line(clusters[i], lod=False)
                    cluster_actor.GetProperty().SetRenderLinesAsTubes(1)
                    cluster_actor.GetProperty().SetLineWidth(6)
                    cluster_actor.GetProperty().SetOpacity(1)
                    cluster_actor.VisibilityOff()

                    scene.add(cluster_actor)

                    # Every centroid actor (cea) is paired to a cluster actor
                    # (cla).

                    self.cea[centroid_actor] = {
                        'cluster_actor': cluster_actor,
                        'cluster': i,
                        'tractogram': t,
                        'size': sizes[i],
                        'length': centroid_lengths[i],
                        'selected': 0,
                        'expanded': 0
                    }

                    self.cla[cluster_actor] = {
                        'centroid_actor': centroid_actor,
                        'cluster': i,
                        'tractogram': t,
                        'size': sizes[i],
                        'length': centroid_lengths[i],
                        'selected': 0
                    }
                    apply_shader(self, cluster_actor)
                    apply_shader(self, centroid_actor)

            else:

                streamline_actor = actor.line(streamlines, colors=colors)
                streamline_actor.GetProperty().SetEdgeVisibility(1)
                streamline_actor.GetProperty().SetRenderLinesAsTubes(1)
                streamline_actor.GetProperty().SetLineWidth(6)
                streamline_actor.GetProperty().SetOpacity(1)
                scene.add(streamline_actor)
        return scene
示例#23
0
文件: bundles.py 项目: grlee77/dipy
    def _prune_what_not_in_model(self, model_centroids,
                                 transf_streamlines,
                                 neighb_indices,
                                 mdf_thr=5,
                                 pruning_thr=10,
                                 pruning_distance='mdf'):

        if pruning_thr < 0:
            print('Pruning_thr has to be greater or equal to 0')

        if self.verbose:
            print('# Prune streamlines using the MDF distance')
            print(' Pruning threshold %0.3f' % (pruning_thr,))
            print(' Pruning distance {}'.format(pruning_distance))
            t = time()

        thresholds = [40, 30, 20, 10, mdf_thr]
        rtransf_cluster_map = qbx_and_merge(transf_streamlines,
                                            thresholds, nb_pts=20,
                                            select_randomly=500000,
                                            rng=self.rng,
                                            verbose=self.verbose)

        if self.verbose:
            print(' QB Duration %0.3f sec. \n' % (time() - t, ))

        rtransf_centroids = rtransf_cluster_map.centroids

        if pruning_distance.lower() == 'mdf':
            if self.verbose:
                print(' Using MDF')
            dist_matrix = bundles_distances_mdf(model_centroids,
                                                rtransf_centroids)
        elif pruning_distance.lower() == 'mam':
            if self.verbose:
                print(' Using MAM')
            dist_matrix = bundles_distances_mam(model_centroids,
                                                rtransf_centroids)
        else:
            raise ValueError('Given pruning distance is not available')
        dist_matrix[np.isnan(dist_matrix)] = np.inf
        dist_matrix[dist_matrix > pruning_thr] = np.inf

        pruning_matrix = dist_matrix.copy()

        if self.verbose:
            print(' Pruning matrix size is (%d, %d)'
                  % pruning_matrix.shape)

        mins = np.min(pruning_matrix, axis=0)
        pruned_indices = [rtransf_cluster_map[i].indices
                          for i in np.where(mins != np.inf)[0]]
        pruned_indices = list(chain(*pruned_indices))
        pruned_streamlines = transf_streamlines[np.array(pruned_indices)]

        initial_indices = list(chain(*neighb_indices))
        final_indices = [initial_indices[i] for i in pruned_indices]
        labels = final_indices

        if self.verbose:
            msg = ' Number of centroids: %d'
            print(msg % (len(rtransf_centroids),))
            msg = ' Number of streamlines after pruning: %d'
            print(msg % (len(pruned_streamlines),))

        if len(pruned_streamlines) == 0:
            print(' You have removed all streamlines')
            return Streamlines([]), []

        if self.verbose:
            print(' Duration %0.3f sec. \n' % (time() - t, ))

        return pruned_streamlines, labels
示例#24
0
文件: app.py 项目: grlee77/dipy
    def build_scene(self):

        scene = window.Renderer()
        for (t, streamlines) in enumerate(self.tractograms):
            if self.random_colors:
                colors = self.prng.random_sample(3)
            else:
                colors = None

            if self.cluster:

                print(' Clustering threshold {} \n'.format(self.cluster_thr))
                clusters = qbx_and_merge(streamlines,
                                         [40, 30, 25, 20, self.cluster_thr])
                self.tractogram_clusters[t] = clusters
                centroids = clusters.centroids
                print(' Number of centroids is {}'.format(len(centroids)))
                sizes = np.array([len(c) for c in clusters])
                linewidths = np.interp(sizes,
                                       [sizes.min(), sizes.max()], [0.1, 2.])
                centroid_lengths = np.array([length(c) for c in centroids])

                print(' Minimum number of streamlines in cluster {}'
                      .format(sizes.min()))

                print(' Maximum number of streamlines in cluster {}'
                      .format(sizes.max()))

                print(' Construct cluster actors')
                for (i, c) in enumerate(centroids):

                    centroid_actor = actor.streamtube([c], colors,
                                                      linewidth=linewidths[i],
                                                      lod=False)
                    scene.add(centroid_actor)

                    cluster_actor = actor.line(clusters[i],
                                               lod=False)
                    cluster_actor.GetProperty().SetRenderLinesAsTubes(1)
                    cluster_actor.GetProperty().SetLineWidth(6)
                    cluster_actor.GetProperty().SetOpacity(1)
                    cluster_actor.VisibilityOff()

                    scene.add(cluster_actor)

                    # Every centroid actor (cea) is paired to a cluster actor
                    # (cla).

                    self.cea[centroid_actor] = {
                        'cluster_actor': cluster_actor,
                        'cluster': i, 'tractogram': t,
                        'size': sizes[i], 'length': centroid_lengths[i],
                        'selected': 0, 'expanded': 0}

                    self.cla[cluster_actor] = {
                        'centroid_actor': centroid_actor,
                        'cluster': i, 'tractogram': t,
                        'size': sizes[i], 'length': centroid_lengths[i],
                        'selected': 0}
                    apply_shader(self, cluster_actor)
                    apply_shader(self, centroid_actor)

            else:

                streamline_actor = actor.line(streamlines, colors=colors)
                streamline_actor.GetProperty().SetEdgeVisibility(1)
                streamline_actor.GetProperty().SetRenderLinesAsTubes(1)
                streamline_actor.GetProperty().SetLineWidth(6)
                streamline_actor.GetProperty().SetOpacity(1)
                scene.add(streamline_actor)
        return scene
示例#25
0
    def _prune_what_not_in_model(self,
                                 model_centroids,
                                 transf_streamlines,
                                 neighb_indices,
                                 mdf_thr=5,
                                 pruning_thr=10,
                                 pruning_distance='mdf'):

        if pruning_thr < 0:
            print('Pruning_thr has to be greater or equal to 0')

        if self.verbose:
            print('# Prune streamlines using the MDF distance')
            print(' Pruning threshold %0.3f' % (pruning_thr, ))
            print(' Pruning distance {}'.format(pruning_distance))
            t = time()

        thresholds = [40, 30, 20, 10, mdf_thr]
        rtransf_cluster_map = qbx_and_merge(transf_streamlines,
                                            thresholds,
                                            nb_pts=20,
                                            select_randomly=500000,
                                            rng=self.rng,
                                            verbose=self.verbose)

        if self.verbose:
            print(' QB Duration %0.3f sec. \n' % (time() - t, ))

        rtransf_centroids = rtransf_cluster_map.centroids

        if pruning_distance.lower() == 'mdf':
            if self.verbose:
                print(' Using MDF')
            dist_matrix = bundles_distances_mdf(model_centroids,
                                                rtransf_centroids)
        elif pruning_distance.lower() == 'mam':
            if self.verbose:
                print(' Using MAM')
            dist_matrix = bundles_distances_mam(model_centroids,
                                                rtransf_centroids)
        else:
            raise ValueError('Given pruning distance is not available')
        dist_matrix[np.isnan(dist_matrix)] = np.inf
        dist_matrix[dist_matrix > pruning_thr] = np.inf

        pruning_matrix = dist_matrix.copy()

        if self.verbose:
            print(' Pruning matrix size is (%d, %d)' % pruning_matrix.shape)

        mins = np.min(pruning_matrix, axis=0)
        pruned_indices = [
            rtransf_cluster_map[i].indices for i in np.where(mins != np.inf)[0]
        ]
        pruned_indices = list(chain(*pruned_indices))
        idx = np.array(pruned_indices)
        if len(idx) == 0:
            print(' You have removed all streamlines')
            return Streamlines([]), []

        pruned_streamlines = transf_streamlines[idx]

        initial_indices = list(chain(*neighb_indices))
        final_indices = [initial_indices[i] for i in pruned_indices]
        labels = final_indices

        if self.verbose:
            msg = ' Number of centroids: %d'
            print(msg % (len(rtransf_centroids), ))
            msg = ' Number of streamlines after pruning: %d'
            print(msg % (len(pruned_streamlines), ))

        if self.verbose:
            print(' Duration %0.3f sec. \n' % (time() - t, ))

        return pruned_streamlines, labels
示例#26
0
def remove_loops_and_sharp_turns(streamlines,
                                 max_angle,
                                 use_qb=False,
                                 qb_threshold=15.,
                                 qb_seed=0):
    """
    Remove loops and sharp turns from a list of streamlines.
    Parameters
    ----------
    streamlines: list of ndarray
        The list of streamlines from which to remove loops and sharp turns.
    use_qb: bool
        Set to True if the additional QuickBundles pass is done.
        This will help remove sharp turns. Should only be used on
        bundled streamlines, not on whole-brain tractograms.
    max_angle: float
        Maximal winding angle a streamline can have before
        being classified as a loop.
    qb_threshold: float
        Quickbundles distance threshold, only used if use_qb is True.
    Returns
    -------
    A tuple containing
        list of ndarray: the clean streamlines
        list of ndarray: the list of removed streamlines, if any
    """

    loops = []
    streamlines_clean = []
    for s in streamlines:
        if tm.winding(s) >= max_angle:
            loops.append(s)
        else:
            streamlines_clean.append(s)

    if use_qb:
        if len(streamlines_clean) > 1:
            streamlines = streamlines_clean
            curvature = []
            streamlines_clean = []

            rng = np.random.RandomState(qb_seed)
            clusters = qbx_and_merge(streamlines, [40, 30, 20, qb_threshold],
                                     rng=rng, verbose=False)

            for cc in clusters.centroids:
                curvature.append(tm.mean_curvature(cc))
            mean_curvature = sum(curvature)/len(curvature)

            for i in range(len(clusters.centroids)):
                if tm.mean_curvature(clusters.centroids[i]) > mean_curvature:
                    for indice in clusters[i].indices:
                        loops.append(streamlines[indice])
                else:
                    for indice in clusters[i].indices:
                        streamlines_clean.append(streamlines[indice])
        else:
            logging.debug("Impossible to use the use_qb option because " +
                          "not more than one streamline left from the\n" +
                          "input file.")

    return streamlines_clean, loops
示例#27
0
def slr_with_qbx(static, moving,
                 x0='affine',
                 rm_small_clusters=50,
                 maxiter=100,
                 select_random=None,
                 verbose=False,
                 greater_than=50,
                 less_than=250,
                 qbx_thr=[40, 30, 20, 15],
                 nb_pts=20,
                 progressive=True, rng=None, num_threads=None):
    """ Utility function for registering large tractograms.

    For efficiency we apply the registration on cluster centroids and remove
    small clusters.

    Parameters
    ----------
    static : Streamlines
    moving : Streamlines

    x0 : str
        rigid, similarity or affine transformation model (default affine)

    rm_small_clusters : int
        Remove clusters that have less than `rm_small_clusters` (default 50)

    select_random : int
        If not None select a random number of streamlines to apply clustering
        Default None.

    verbose : bool,
        If True then information about the optimization is shown.

    greater_than : int, optional
            Keep streamlines that have length greater than
            this value (default 50)

    less_than : int, optional
            Keep streamlines have length less than this value (default 250)

    qbx_thr : variable int
            Thresholds for QuickBundlesX (default [40, 30, 20, 15])

    np_pts : int, optional
            Number of points for discretizing each streamline (default 20)

    progressive : boolean, optional
            (default True)

    rng : RandomState
        If None creates RandomState in function.

    num_threads : int
        Number of threads. If None (default) then all available threads
        will be used. Only metrics using OpenMP will use this variable.

    Notes
    -----
    The order of operations is the following. First short or long streamlines
    are removed. Second the tractogram or a random selection of the tractogram
    is clustered with QuickBundles. Then SLR [Garyfallidis15]_ is applied.

    References
    ----------
    .. [Garyfallidis15] Garyfallidis et al. "Robust and efficient linear
    registration of white-matter fascicles in the space of streamlines",
    NeuroImage, 117, 124--140, 2015
    .. [Garyfallidis14] Garyfallidis et al., "Direct native-space fiber
            bundle alignment for group comparisons", ISMRM, 2014.
    .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
    bundles using local and global streamline-based registration and
    clustering, Neuroimage, 2017.
    """
    if rng is None:
        rng = np.random.RandomState()

    if verbose:
        print('Static streamlines size {}'.format(len(static)))
        print('Moving streamlines size {}'.format(len(moving)))

    def check_range(streamline, gt=greater_than, lt=less_than):

        if (length(streamline) > gt) & (length(streamline) < lt):
            return True
        else:
            return False

    streamlines1 = Streamlines(static[np.array([check_range(s)
                                                for s in static])])
    streamlines2 = Streamlines(moving[np.array([check_range(s)
                                                for s in moving])])

    if verbose:

        print('Static streamlines after length reduction {}'
              .format(len(streamlines1)))
        print('Moving streamlines after length reduction {}'
              .format(len(streamlines2)))

    if select_random is not None:
        rstreamlines1 = select_random_set_of_streamlines(streamlines1,
                                                         select_random,
                                                         rng=rng)
    else:
        rstreamlines1 = streamlines1

    rstreamlines1 = set_number_of_points(rstreamlines1, nb_pts)

    rstreamlines1._data.astype('f4')

    cluster_map1 = qbx_and_merge(rstreamlines1, thresholds=qbx_thr, rng=rng)
    qb_centroids1 = remove_clusters_by_size(cluster_map1, rm_small_clusters)

    if select_random is not None:
        rstreamlines2 = select_random_set_of_streamlines(streamlines2,
                                                         select_random,
                                                         rng=rng)
    else:
        rstreamlines2 = streamlines2

    rstreamlines2 = set_number_of_points(rstreamlines2, nb_pts)
    rstreamlines2._data.astype('f4')

    cluster_map2 = qbx_and_merge(rstreamlines2, thresholds=qbx_thr, rng=rng)

    qb_centroids2 = remove_clusters_by_size(cluster_map2, rm_small_clusters)

    if verbose:
        t = time()

    if not progressive:
        slr = StreamlineLinearRegistration(x0=x0,
                                           options={'maxiter': maxiter},
                                           num_threads=num_threads)
        slm = slr.optimize(qb_centroids1, qb_centroids2)
    else:
        bounds = DEFAULT_BOUNDS

        slm = progressive_slr(qb_centroids1, qb_centroids2,
                              x0=x0, metric=None,
                              bounds=bounds, num_threads=num_threads)

    if verbose:
        print('QB static centroids size %d' % len(qb_centroids1,))
        print('QB moving centroids size %d' % len(qb_centroids2,))
        duration = time() - t
        print('SLR finished in  %0.3f seconds.' % (duration,))
        if slm.iterations is not None:
            print('SLR iterations: %d ' % (slm.iterations,))

    moved = slm.transform(moving)

    return moved, slm.matrix, qb_centroids1, qb_centroids2
示例#28
0
def slr_with_qbx(static,
                 moving,
                 x0='affine',
                 rm_small_clusters=50,
                 maxiter=100,
                 select_random=None,
                 verbose=False,
                 greater_than=50,
                 less_than=250,
                 qbx_thr=[40, 30, 20, 15],
                 nb_pts=20,
                 progressive=True,
                 rng=None,
                 num_threads=None):
    """ Utility function for registering large tractograms.

    For efficiency, we apply the registration on cluster centroids and remove
    small clusters.

    Parameters
    ----------
    static : Streamlines
    moving : Streamlines

    x0 : str, optional.
        rigid, similarity or affine transformation model (default affine)

    rm_small_clusters : int, optional
        Remove clusters that have less than `rm_small_clusters` (default 50)

    select_random : int, optional.
        If not, None selects a random number of streamlines to apply clustering
        Default None.

    verbose : bool, optional
        If True, logs information about optimization. Default: False

    greater_than : int, optional
            Keep streamlines that have length greater than
            this value (default 50)

    less_than : int, optional
            Keep streamlines have length less than this value (default 250)

    qbx_thr : variable int
            Thresholds for QuickBundlesX (default [40, 30, 20, 15])

    np_pts : int, optional
            Number of points for discretizing each streamline (default 20)

    progressive : boolean, optional
            (default True)

    rng : RandomState
        If None creates RandomState in function.

    num_threads : int
        Number of threads. If None (default) then all available threads
        will be used. Only metrics using OpenMP will use this variable.

    Notes
    -----
    The order of operations is the following. First short or long streamlines
    are removed. Second, the tractogram or a random selection of the tractogram
    is clustered with QuickBundles. Then SLR [Garyfallidis15]_ is applied.

    References
    ----------
    .. [Garyfallidis15] Garyfallidis et al. "Robust and efficient linear
    registration of white-matter fascicles in the space of streamlines",
    NeuroImage, 117, 124--140, 2015
    .. [Garyfallidis14] Garyfallidis et al., "Direct native-space fiber
            bundle alignment for group comparisons", ISMRM, 2014.
    .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
    bundles using local and global streamline-based registration and
    clustering, Neuroimage, 2017.
    """
    if rng is None:
        rng = np.random.RandomState()

    if verbose:
        logger.info('Static streamlines size {}'.format(len(static)))
        logger.info('Moving streamlines size {}'.format(len(moving)))

    def check_range(streamline, gt=greater_than, lt=less_than):

        if (length(streamline) > gt) & (length(streamline) < lt):
            return True
        else:
            return False

    streamlines1 = Streamlines(static[np.array(
        [check_range(s) for s in static])])
    streamlines2 = Streamlines(moving[np.array(
        [check_range(s) for s in moving])])
    if verbose:
        logger.info('Static streamlines after length reduction {}'.format(
            len(streamlines1)))
        logger.info('Moving streamlines after length reduction {}'.format(
            len(streamlines2)))

    if select_random is not None:
        rstreamlines1 = select_random_set_of_streamlines(streamlines1,
                                                         select_random,
                                                         rng=rng)
    else:
        rstreamlines1 = streamlines1

    rstreamlines1 = set_number_of_points(rstreamlines1, nb_pts)

    rstreamlines1._data.astype('f4')

    cluster_map1 = qbx_and_merge(rstreamlines1, thresholds=qbx_thr, rng=rng)
    qb_centroids1 = remove_clusters_by_size(cluster_map1, rm_small_clusters)

    if select_random is not None:
        rstreamlines2 = select_random_set_of_streamlines(streamlines2,
                                                         select_random,
                                                         rng=rng)
    else:
        rstreamlines2 = streamlines2

    rstreamlines2 = set_number_of_points(rstreamlines2, nb_pts)
    rstreamlines2._data.astype('f4')

    cluster_map2 = qbx_and_merge(rstreamlines2, thresholds=qbx_thr, rng=rng)

    qb_centroids2 = remove_clusters_by_size(cluster_map2, rm_small_clusters)

    if verbose:
        t = time()

    if not progressive:
        slr = StreamlineLinearRegistration(x0=x0,
                                           options={'maxiter': maxiter},
                                           num_threads=num_threads)
        slm = slr.optimize(qb_centroids1, qb_centroids2)
    else:
        bounds = DEFAULT_BOUNDS

        slm = progressive_slr(qb_centroids1,
                              qb_centroids2,
                              x0=x0,
                              metric=None,
                              bounds=bounds,
                              num_threads=num_threads)

    if verbose:
        logger.info('QB static centroids size %d' % len(qb_centroids1, ))
        logger.info('QB moving centroids size %d' % len(qb_centroids2, ))
        duration = time() - t
        logger.info('SLR finished in  %0.3f seconds.' % (duration, ))
        if slm.iterations is not None:
            logger.info('SLR iterations: %d ' % (slm.iterations, ))

    moved = slm.transform(moving)

    return moved, slm.matrix, qb_centroids1, qb_centroids2
示例#29
0
    def multi_recognize(self,
                        input_tractogram_path,
                        tractogram_clustering_thr,
                        nb_points=20,
                        nbr_processes=1,
                        seeds=None):
        """
        Parameters
        ----------
        input_tractogram_path : str
            Filepath of the whole brain tractogram to segment
        tractogram_clustering_thr : int
            Distance in mm (for QBx) to cluster the input tractogram
        nb_points : str
            Number of points used for all resampling of streamlines
        nbr_processes : int
            Number of processes used for the parallel bundle recognition
        seeds : list
            List of seed for the RandomState
        """

        # Load the subject tractogram
        timer = time()
        tractogram = nib.streamlines.load(input_tractogram_path)
        wb_streamlines = tractogram.streamlines
        logging.debug('Tractogram {0} with {1} streamlines '
                      'is loaded in {2} seconds'.format(
                          input_tractogram_path, len(wb_streamlines),
                          round(time() - timer, 2)))

        # Prepare all tags to read the atlas properly
        bundle_names, bundles_filepath = self._init_bundles_tag()

        total_timer = time()
        processing_dict = {}
        processing_dict['bundle_id'] = []
        processing_dict['tag'] = []
        processing_dict['model_bundle'] = []
        processing_dict['tct'] = []
        processing_dict['mct'] = []
        processing_dict['bpt'] = []
        processing_dict['slr_transform_type'] = []
        processing_dict['seed'] = []

        # Each type of bundle is processed separately
        for seed in seeds:
            for bundle_id in range(len(bundle_names)):
                random.seed(seed)
                bundle_parameters = self.config[bundle_names[bundle_id]]
                model_cluster_thr = bundle_parameters['model_clustering_thr']
                bundle_pruning_thr = bundle_parameters['bundle_pruning_thr']
                slr_transform_type = bundle_parameters['slr_transform_type']
                potential_parameters = list(
                    product(tractogram_clustering_thr, model_cluster_thr,
                            bundle_pruning_thr))
                random.shuffle(potential_parameters)

                if self.multi_parameters > len(potential_parameters):
                    logging.error('More multi-parameters executions than '
                                  'potential parameters, not enough parameter '
                                  'choices for bundle {0}'.format(
                                      bundle_names[bundle_id]))
                    raise ValueError('Multi-parameters option is too high')

                # Generate a set of parameters for each run
                picked_parameters = potential_parameters[0:self.
                                                         multi_parameters]

                logging.debug('Parameters choice for {0}, for the {1}'
                              ' executions are {2}'.format(
                                  bundle_names[bundle_id],
                                  self.multi_parameters, picked_parameters))

                # Using the tag previously generated, load the appropriate
                # model bundles
                model_bundles_dict = self._load_bundles_dictionary(
                    bundles_filepath[bundle_id])

                # Each run (can) have their unique set of parameters
                for parameters in picked_parameters:
                    tct, mct, bpt = parameters

                    # Each bundle (can) have multiple models
                    for tag in bundles_filepath[bundle_id]:
                        model_bundle = model_bundles_dict[tag]
                        processing_dict['bundle_id'] += [bundle_id]
                        processing_dict['tag'] += [tag]
                        processing_dict['model_bundle'] += [model_bundle]
                        processing_dict['tct'] += [tct]
                        processing_dict['mct'] += [mct]
                        processing_dict['bpt'] += [bpt]
                        processing_dict['slr_transform_type'] += [
                            slr_transform_type
                        ]
                        processing_dict['seed'] += [seed]

        # Cluster the whole tractogram only once per possible clustering threshold
        rbx_all = {}
        base_thresholds = [45, 35, 25]
        for seed in seeds:
            rng = np.random.RandomState(seed)
            for clustering_thr in tractogram_clustering_thr:
                timer = time()
                # If necessary, add an extra layer (more optimal)
                if clustering_thr < 15:
                    current_thr_list = base_thresholds + [15, clustering_thr]
                else:
                    current_thr_list = base_thresholds + [clustering_thr]

                cluster_map = qbx_and_merge(wb_streamlines,
                                            current_thr_list,
                                            nb_pts=nb_points,
                                            rng=rng,
                                            verbose=False)

                rbx_all[(seed,
                         clustering_thr)] = RecobundlesX(wb_streamlines,
                                                         cluster_map,
                                                         nb_points=nb_points,
                                                         rng=rng)

                logging.info('QBx with seed {0} at {1}mm took {2}sec. gave '
                             '{3} centroids'.format(
                                 seed, current_thr_list,
                                 round(time() - timer, 2),
                                 len(cluster_map.centroids)))

        pool = multiprocessing.Pool(nbr_processes)
        all_measures_dict = pool.map(
            single_recognize,
            zip(repeat(rbx_all), processing_dict['bundle_id'],
                processing_dict['tag'], processing_dict['model_bundle'],
                processing_dict['tct'], processing_dict['mct'],
                processing_dict['bpt'], processing_dict['slr_transform_type'],
                processing_dict['seed']))
        pool.close()
        pool.join()

        streamlines_wise_vote = dok_matrix(
            (len(wb_streamlines), len(bundle_names)), dtype=np.int16)
        bundles_wise_vote = dok_matrix(
            (len(bundle_names), len(wb_streamlines)), dtype=np.int16)

        for bundle_id, recognized_indices in all_measures_dict:
            if recognized_indices is not None:
                streamlines_wise_vote[recognized_indices.T, bundle_id] += 1
                bundles_wise_vote[bundle_id, recognized_indices.T] += 1

        nb_exec = len(self.atlas_dir) * self.multi_parameters * len(seeds) * \
            len(bundle_names)
        logging.info('RBx took {0} sec. for a total of '
                     '{1} exectutions'.format(round(time() - total_timer, 2),
                                              nb_exec))
        logging.debug('{0} tractogram clustering, {1} seeds, '
                      '{2} multi-parameters, {3} sub-model directory, '
                      '{4} bundles'.format(len(tractogram_clustering_thr),
                                           len(seeds), self.multi_parameters,
                                           len(self.atlas_dir),
                                           len(bundle_names)))

        # Once everything was run, save the results using a voting system
        minimum_vote = round(
            len(self.atlas_dir) * self.multi_parameters * len(seeds) *
            self.minimal_vote_ratio)
        minimum_vote = max(minimum_vote, 1)

        extension = os.path.splitext(input_tractogram_path)[1]
        self._save_recognized_bundles(tractogram, bundle_names,
                                      streamlines_wise_vote, bundles_wise_vote,
                                      minimum_vote, extension)