Beispiel #1
0
def subsample_streamlines(streamlines,
                          clustering_threshold=6.,
                          removal_distance=2.):
    """ Subsample a group of streamlines (should be used on streamlines from a single bundle or similar structure).
    Streamlines are first clustered using `clustering_threshold`, then for each cluster, similar streamlines (closer than `removal_distance`) are removed.

    Parameters
    ----------
    streamlines : `ArraySequence` object
        Streamlines to subsample
    clustering_threshold : float
        distance threshold for clustering (in the space of the tracks)
    removal_distance : float
        distance threshold for removal (in the space of the tracks)
    Returns
    -------
    `ArraySequence` object
        Downsampled streamlines
    """

    output_streamlines = []

    qb = QuickBundles(streamlines, dist_thr=clustering_threshold, pts=20)
    for i in range(len(qb.centroids)):
        temp_streamlines = qb.label2tracks(streamlines, i)
        output_streamlines.extend(
            remove_similar_streamlines(temp_streamlines,
                                       removal_distance=removal_distance))

    return output_streamlines
Beispiel #2
0
def test_qbundles():
    streams, hdr = nib.trackvis.read(get_data('fornix'))
    T = [s[0] for s in streams]
    qb = QuickBundles(T, 10., 12)
    qb.virtuals()
    qb.exemplars()
    assert_equal(4, qb.total_clusters)
Beispiel #3
0
    def bundle_centroids(self,
                         streamlines=None,
                         cluster_thre=10,
                         dist_thre=10.0,
                         pts=12):
        """
        QuickBundles-based segmentation
        Parameters
        ----------
        streamlines: streamline data
        cluster_thre: remove small cluster
        dist_thre: clustering threshold (distance mm)
        pts: each streamlines are divided into sections

        Return
        ------
        centroids: cluster's centroids
        """
        if streamlines is None:
            streamlines = self._fasciculus.get_data()
        else:
            streamlines = streamlines
        bundles = QuickBundles(streamlines, dist_thre, pts)
        bundles.remove_small_clusters(cluster_thre)
        centroids = bundles.centroids

        return nibas.ArraySequence(centroids)
Beispiel #4
0
    def bundle_seg(self, streamlines=None, dist_thre=10.0, pts=12):
        """
        QuickBundles-based segmentation
        Parameters
        ----------
        streamlines: streamline data
        dist_thre: clustering threshold (distance mm)
        pts: each streamlines are divided into sections

        Return
        ------
        labels: label of each streamline
        data_cluster: cluster data
        N_list: size of each cluster
        """
        if streamlines is None:
            streamlines = self._fasciculus.get_data()
        else:
            streamlines = streamlines
        bundles = QuickBundles(streamlines, dist_thre, pts)
        clusters = bundles.clusters()
        labels = np.array(len(streamlines) * [None])
        N_list = []
        for i in range(len(clusters)):
            N_list.append(clusters[i]['N'])
        # show(N_list, title='N histogram', xlabel='N')
        data_clusters = []
        for i in range(len(clusters)):
            labels[clusters[i]['indices']] = i + 1
            data_clusters.append(streamlines[clusters[i]['indices']])

        return labels, data_clusters, N_list
Beispiel #5
0
    def bundle_thre_seg(self,
                        streamlines=None,
                        cluster_thre=10,
                        dist_thre=10.0,
                        pts=12):
        """
        QuickBundles-based segmentation
        Parameters
        ----------
        streamlines: streamline data
        cluster_thre: remove small cluster
        dist_thre: clustering threshold (distance mm)
        pts: each streamlines are divided into sections

        Return
        ------
        sort_index: sort of clusters according to y mean of cluster's centroids
        data_cluster: cluster data corresponding to sort_index
        """
        if streamlines is None:
            streamlines = self._fasciculus.get_data()
        else:
            streamlines = streamlines
        bundles = QuickBundles(streamlines, dist_thre, pts)
        bundles.remove_small_clusters(cluster_thre)
        clusters = bundles.clusters()
        data_clusters = []
        for key in clusters.keys():
            data_clusters.append(streamlines[clusters[key]['indices']])
        centroids = bundles.centroids
        clusters_y_mean = [clu[:, 1].mean() for clu in centroids]
        sort_index = np.argsort(clusters_y_mean)

        return sort_index, data_clusters
Beispiel #6
0
    def freeze(self):
        print(
            "Freezing current expanded real tracks, then doing QB on them, then restarting."
        )
        print("Selected virtuals: %s" % self.selected)
        tracks_frozen = []
        tracks_frozen_ids = []
        for tid in self.selected:
            print tid
            part_tracks = self.qb.label2tracks(self.tracks, tid)
            part_tracks_ids = self.qb.label2tracksids(tid)
            print("virtual %s represents %s tracks." % (tid, len(part_tracks)))
            tracks_frozen += part_tracks
            tracks_frozen_ids += part_tracks_ids
        print "frozen tracks size:", len(tracks_frozen)
        print "Computing quick bundles...",
        self.unselect_track('all')
        self.tracks = tracks_frozen
        self.tracks_ids = self.tracks_ids[
            tracks_frozen_ids]  # range(len(self.tracks))

        root = Tkinter.Tk()
        root.wm_title('QuickBundles threshold')
        ts = ThresholdSelector(root, default_value=self.qb.dist_thr / 2.0)
        root.wait_window()

        #print "Threshold value ",ts.value
        #self.qb = QuickBundles(self.tracks, dist_thr=qb.dist_thr/2., pts=self.qb.pts)
        self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts)
        #self.qb.dist_thr = qb.dist_thr/2.
        self.qb.dist_thr = ts.value
        if self.reps == 'virtuals':
            self.virtuals = qb.virtuals()
        if self.reps == 'exemplars':
            self.virtuals, self.ex_ids = self.qb.exemplars()
        print len(self.virtuals), 'virtuals'
        self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(
            self.virtuals, self.virtuals_alpha)
        #compute buffers
        self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(
            self.tracks, self.tracks_alpha)
        # self.unselect_track('all')
        self.selected = []
        self.old_color = {}
        self.expand = False
        self.history.append([
            self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer,
            self.virtuals_colors, self.virtuals_first, self.virtuals_count,
            self.tracks_buffer, self.tracks_colors, self.tracks_first,
            self.tracks_count
        ])
        if self.vol_shape is not None:
            print("Shifting!")
            self.virtuals_shifted = [
                downsample(t + np.array(self.vol_shape) / 2., 30)
                for t in self.virtuals
            ]
        else:
            self.virtuals_shifted = None
Beispiel #7
0
def test_qbundles():
    streams, hdr = nib.trackvis.read(get_data('fornix'))
    T = [s[0] for s in streams]
    Trk = np.array(T, dtype=np.object)
    qb = QuickBundles(T, 10., 12)
    Tqb = qb.virtuals()
    # Tqbe,Tqbei=qb.exemplars(T)
    Tqbe, Tqbei = qb.exemplars()
    assert_equal(4, qb.total_clusters)
Beispiel #8
0
def bench_quickbundles():
    dtype = "float32"
    repeat = 10
    nb_points = 18

    streams, hdr = nib.trackvis.read(get_data('fornix'))
    fornix = [s[0].astype(dtype) for s in streams]
    fornix = streamline_utils.set_number_of_points(fornix, nb_points)

    # Create eight copies of the fornix to be clustered (one in each octant).
    streamlines = []
    streamlines += [s + np.array([100, 100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([100, -100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([100, 100, -100], dtype) for s in fornix]
    streamlines += [s + np.array([100, -100, -100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, 100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, -100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, 100, -100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, -100, -100], dtype) for s in fornix]

    # The expected number of clusters of the fornix using threshold=10 is 4.
    threshold = 10.
    expected_nb_clusters = 4 * 8

    print("Timing QuickBundles 1.0 vs. 2.0")

    qb = QB_Old(streamlines, threshold, pts=None)
    qb1_time = measure("QB_Old(streamlines, threshold, nb_points)", repeat)
    print("QuickBundles time: {0:.4}sec".format(qb1_time))
    assert_equal(qb.total_clusters, expected_nb_clusters)
    sizes1 = [qb.partitions()[i]['N'] for i in range(qb.total_clusters)]
    indices1 = [
        qb.partitions()[i]['indices'] for i in range(qb.total_clusters)
    ]

    qb2 = QB_New(threshold)
    qb2_time = measure("clusters = qb2.cluster(streamlines)", repeat)
    print("QuickBundles2 time: {0:.4}sec".format(qb2_time))
    print("Speed up of {0}x".format(qb1_time / qb2_time))
    clusters = qb2.cluster(streamlines)
    sizes2 = map(len, clusters)
    indices2 = map(lambda c: c.indices, clusters)
    assert_equal(len(clusters), expected_nb_clusters)
    assert_array_equal(sizes2, sizes1)
    assert_arrays_equal(indices2, indices1)

    qb = QB_New(threshold, metric=MDFpy())
    qb3_time = measure("clusters = qb.cluster(streamlines)", repeat)
    print("QuickBundles2_python time: {0:.4}sec".format(qb3_time))
    print("Speed up of {0}x".format(qb1_time / qb3_time))
    clusters = qb.cluster(streamlines)
    sizes3 = map(len, clusters)
    indices3 = map(lambda c: c.indices, clusters)
    assert_equal(len(clusters), expected_nb_clusters)
    assert_array_equal(sizes3, sizes1)
    assert_arrays_equal(indices3, indices1)
def bench_quickbundles():
    dtype = "float32"
    repeat = 10
    nb_points = 12

    streams, hdr = nib.trackvis.read(get_fnames('fornix'))
    fornix = [s[0].astype(dtype) for s in streams]
    fornix = streamline_utils.set_number_of_points(fornix, nb_points)

    # Create eight copies of the fornix to be clustered (one in each octant).
    streamlines = []
    streamlines += [s + np.array([100, 100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([100, -100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([100, 100, -100], dtype) for s in fornix]
    streamlines += [s + np.array([100, -100, -100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, 100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, -100, 100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, 100, -100], dtype) for s in fornix]
    streamlines += [s + np.array([-100, -100, -100], dtype) for s in fornix]

    # The expected number of clusters of the fornix using threshold=10 is 4.
    threshold = 10.
    expected_nb_clusters = 4 * 8

    print("Timing QuickBundles 1.0 vs. 2.0")

    qb = QB_Old(streamlines, threshold, pts=None)
    qb1_time = measure("QB_Old(streamlines, threshold, nb_points)", repeat)
    print("QuickBundles time: {0:.4}sec".format(qb1_time))
    assert_equal(qb.total_clusters, expected_nb_clusters)
    sizes1 = [qb.partitions()[i]['N'] for i in range(qb.total_clusters)]
    indices1 = [qb.partitions()[i]['indices']
                for i in range(qb.total_clusters)]

    qb2 = QB_New(threshold)
    qb2_time = measure("clusters = qb2.cluster(streamlines)", repeat)
    print("QuickBundles2 time: {0:.4}sec".format(qb2_time))
    print("Speed up of {0}x".format(qb1_time / qb2_time))
    clusters = qb2.cluster(streamlines)
    sizes2 = map(len, clusters)
    indices2 = map(lambda c: c.indices, clusters)
    assert_equal(len(clusters), expected_nb_clusters)
    assert_array_equal(list(sizes2), sizes1)
    assert_arrays_equal(indices2, indices1)

    qb = QB_New(threshold, metric=MDFpy())
    qb3_time = measure("clusters = qb.cluster(streamlines)", repeat)
    print("QuickBundles2_python time: {0:.4}sec".format(qb3_time))
    print("Speed up of {0}x".format(qb1_time / qb3_time))
    clusters = qb.cluster(streamlines)
    sizes3 = map(len, clusters)
    indices3 = map(lambda c: c.indices, clusters)
    assert_equal(len(clusters), expected_nb_clusters)
    assert_array_equal(list(sizes3), sizes1)
    assert_arrays_equal(indices3, indices1)
Beispiel #10
0
def load_tracks(method="pmt"):
    from nibabel import trackvis as tv

    dname = "/home/eg309/Data/orbital_phantoms/dwi_dir/subject1/"

    if method == "pmt":
        fname = "/home/eg309/Data/orbital_phantoms/dwi_dir/workflow/tractography/_subject_id_subject1/cam2trk_pico_twoten/data_fit_pdfs_tracked.trk"
        streams, hdr = tv.read(fname, points_space="voxel")
        tracks = [s[0] for s in streams]
    if method == "dti":
        fname = dname + "dti_tracks.dpy"
    if method == "dsi":
        fname = dname + "dsi_tracks.dpy"
    if method == "gqs":
        fname = dname + "gqi_tracks.dpy"
    if method == "eit":
        fname = dname + "eit_tracks.dpy"
    if method in ["dti", "dsi", "gqs", "eit"]:
        dpr_linear = Dpy(fname, "r")
        tracks = dpr_linear.read_tracks()
        dpr_linear.close()

    if method != "pmt":
        tracks = [t - np.array([96 / 2.0, 96 / 2.0, 55 / 2.0]) for t in tracks if track_range(t, 100 / 2.5, 150 / 2.5)]
    tracks = [t for t in tracks if track_range(t, 100 / 2.5, 150 / 2.5)]

    print "final no of tracks ", len(tracks)
    qb = QuickBundles(tracks, 25.0 / 2.5, 18)
    # from dipy.viz import fvtk
    # r=fvtk.ren()
    # fvtk.add(r,fvtk.line(qb.virtuals(),fvtk.red))
    # fvtk.show(r)
    # show_tracks(tracks)#qb.exemplars()[0])
    # qb.remove_small_clusters(40)
    del tracks
    # load
    tl = TrackLabeler(qb, qb.downsampled_tracks(), vol_shape=None, tracks_line_width=3.0, tracks_alpha=1)

    # return tracks
    w = World()
    w.add(tl)
    # create window
    wi = Window(caption="Fos", bgcolor=(1.0, 1.0, 1.0, 1.0), width=1600, height=900)
    wi.attach(w)
    # create window manager
    wm = WindowManager()
    wm.add(wi)
    wm.run()
Beispiel #11
0
def half_split_comparisons():

    res = {}

    for id in range(len(tractography_sizes)):
        res[id] = {}
        first, second = split_halves(id)
        res[id]["lengths"] = [len(first), len(second)]
        print len(first), len(second)
        first_qb = QuickBundles(first, qb_threshold, downsampling)
        n_clus = first_qb.total_clusters
        res[id]["nclusters"] = n_clus
        print "QB for first half has", n_clus, "clusters"
        second_down = [downsample(s, downsampling) for s in second]
        matched_random = get_random_streamlines(first_qb.downsampled_tracks(), n_clus)
        neighbours_first = count_close_tracks(first_qb.virtuals(), first_qb.downsampled_tracks(), adjacency_threshold)
        neighbours_second = count_close_tracks(first_qb.virtuals(), second_down, adjacency_threshold)
        neighbours_random = count_close_tracks(matched_random, second_down, adjacency_threshold)

        maxclose = np.int(np.max(np.hstack((neighbours_first, neighbours_second, neighbours_random))))

        # The numbers of tracks 0, 1, 2, ... 'close' subset tracks
        counts = np.array(
            [
                (
                    np.int(n),
                    len(find(neighbours_first == n)),
                    len(find(neighbours_second == n)),
                    len(find(neighbours_random == n)),
                )
                for n in range(maxclose + 1)
            ],
            dtype="f",
        )
        totals = np.sum(counts[:, 1:], axis=0)
        res[id]["totals"] = totals
        res[id]["counts"] = counts
        # print totals
        # print counts
        missed_fractions = counts[0, 1:] / totals
        res[id]["missed_fractions"] = missed_fractions
        means = np.sum(counts[:, 1:] * counts[:, [0, 0, 0]], axis=0) / totals
        # print means
        res[id]["means"] = means
        # print res
    return res
Beispiel #12
0
 def freeze(self):
     print("Freezing current expanded real tracks, then doing QB on them, then restarting.")
     print("Selected virtuals: %s" % self.selected)
     tracks_frozen = []
     tracks_frozen_ids = []
     for tid in self.selected:
         print tid
         part_tracks = self.qb.label2tracks(self.tracks, tid)
         part_tracks_ids = self.qb.label2tracksids(tid)
         print("virtual %s represents %s tracks." % (tid, len(part_tracks)))
         tracks_frozen += part_tracks
         tracks_frozen_ids += part_tracks_ids
     print "frozen tracks size:", len(tracks_frozen)
     print "Computing quick bundles...",
     self.unselect_track('all')
     self.tracks = tracks_frozen
     self.tracks_ids = self.tracks_ids[tracks_frozen_ids] # range(len(self.tracks))
     
     root = Tkinter.Tk()
     root.wm_title('QuickBundles threshold')
     ts = ThresholdSelector(root, default_value=self.qb.dist_thr/2.0)
     root.wait_window()
     
     #print "Threshold value ",ts.value
     #self.qb = QuickBundles(self.tracks, dist_thr=qb.dist_thr/2., pts=self.qb.pts)
     self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts)
     #self.qb.dist_thr = qb.dist_thr/2.
     self.qb.dist_thr = ts.value
     if self.reps=='virtuals':
         self.virtuals=qb.virtuals()
     if self.reps=='exemplars':
         self.virtuals,self.ex_ids = self.qb.exemplars()
     print len(self.virtuals), 'virtuals'
     self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(self.virtuals, self.virtuals_alpha)
     #compute buffers
     self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(self.tracks, self.tracks_alpha)
     # self.unselect_track('all')
     self.selected = []
     self.old_color = {}
     self.expand = False
     self.history.append([self.qb, 
                         self.tracks, 
                         self.tracks_ids, 
                         self.virtuals_buffer, 
                         self.virtuals_colors, 
                         self.virtuals_first, 
                         self.virtuals_count, 
                         self.tracks_buffer, 
                         self.tracks_colors, 
                         self.tracks_first, 
                         self.tracks_count])
     if self.vol_shape is not None:
         print("Shifting!")
         self.virtuals_shifted = [downsample(t + np.array(self.vol_shape) / 2., 30) for t in self.virtuals]
     else:
         self.virtuals_shifted = None
Beispiel #13
0
def load_PX_tracks():

    roi = "LH_premotor"

    dn = "/home/hadron/from_John_mon12thmarch"
    dname = "/extra_probtrackX_analyses/_subject_id_subj05_101_32/particle2trackvis_" + roi + "_native/"
    fname = dn + dname + "tract_samples.trk"
    from nibabel import trackvis as tv

    points_space = [None, "voxel", "rasmm"]
    streamlines, hdr = tv.read(fname, as_generator=True, points_space="voxel")
    tracks = [s[0] for s in streamlines]
    del streamlines
    # return tracks

    qb = QuickBundles(tracks, 25.0 / 2.5, 18)
    # tl=Line(qb.exemplars()[0],line_width=1)
    del tracks
    qb.remove_small_clusters(20)

    tl = TrackLabeler(qb, qb.downsampled_tracks(), vol_shape=None, tracks_line_width=3.0, tracks_alpha=1)

    # put the seeds together
    # seeds=np.vstack((seeds,seeds2))
    # shif the seeds
    # seeds=np.dot(mat[:3,:3],seeds.T).T + mat[:3,3]
    # seeds=seeds-shift
    # seeds2=np.dot(mat[:3,:3],seeds2.T).T + mat[:3,3]
    # seeds2=seeds2-shift
    # msk = Point(seeds,colors=(1,0,0,1.),pointsize=2.)
    # msk2 = Point(seeds2,colors=(1,0,.ppppp2,1.),pointsize=2.)
    w = World()
    w.add(tl)
    # w.add(msk)
    # w.add(msk2)
    # w.add(sl)
    # create window
    wi = Window(caption="Fos", bgcolor=(0.3, 0.3, 0.6, 1.0), width=1600, height=900)
    wi.attach(w)
    # create window manager
    wm = WindowManager()
    wm.add(wi)
    wm.run()
def visualization(streamlines_file):
    # clustering of fibers into bundles and visualization thereof 
    streamlines = np.load(streamlines_file)['arr_0']
    qb = QuickBundles(streamlines, dist_thr=10., pts=18)
    centroids = qb.centroids
    #centroids = streamlines
    colors = line_colors(centroids).astype(np.float)
    mlab.figure(bgcolor=(0., 0., 0.))
    for streamline, color in zip(centroids, colors):
        mlab.plot3d(streamline.T[0], streamline.T[1], streamline.T[2],
                    line_width=1., tube_radius=.5, color=tuple(color))
Beispiel #15
0
def QB_reps(limits=[0, np.Inf], reps=1):

    ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"]

    sizes = []

    for id in range(len(test_tractography_sizes)):
        filename = "subj_" + ids[id] + "_QB_reps.pkl"
        f = open(filename, "w")
        ur_tracks = get_tracks(id, limits)
        res = {}
        # res['filtered'] = len(ur_tracks)
        res["qb_threshold"] = qb_threshold
        res["limits"] = limits
        # res['ur_tracks'] = ur_tracks
        print "Dataset", id, res["filtered"], "filtered tracks"
        res["shuffle"] = {}
        res["clusters"] = {}
        res["nclusters"] = {}
        res["centroids"] = {}
        res["cluster_sizes"] = {}
        for i in range(reps):
            print "Subject", ids[id], "shuffle", i
            shuffle = np.random.permutation(np.arange(len(ur_tracks)))
            res["shuffle"][i] = shuffle
            tracks = [ur_tracks[i] for i in shuffle]
            qb = QuickBundles(tracks, qb_threshold, downsampling)
            res["clusters"][i] = {}
            for k in qb.clusters().keys():
                # this would be improved if
                # we 'enumerated' QB's keys and used the enumerator as
                # as the key in the result
                res["clusters"][i][k] = qb.clusters()[k]["indices"]
            res["centroids"][i] = qb.centroids
            res["nclusters"][i] = qb.total_clusters
            res["cluster_sizes"][i] = qb.clusters_sizes()
            print "QB for has", qb.total_clusters, "clusters"
            sizes.append(qb.total_clusters)
        pickle.dump(res, f)
        f.close()
        print "Results written to", filename
Beispiel #16
0
def QB_reps_singly(limits=[0, np.Inf], reps=1):

    ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"]
    replabs = [str(i) for i in range(reps)]

    for id in range(len(test_tractography_sizes)):
        ur_tracks = get_tracks(id, limits)
        for i in range(reps):
            res = {}
            # res['filtered'] = len(ur_tracks)
            res["qb_threshold"] = qb_threshold
            res["limits"] = limits
            res["shuffle"] = {}
            res["clusters"] = {}
            res["nclusters"] = {}
            res["centroids"] = {}
            res["cluster_sizes"] = {}
            print "Subject", ids[id], "shuffle", i
            shuffle = np.random.permutation(np.arange(len(ur_tracks)))
            res["shuffle"] = shuffle
            tracks = [ur_tracks[j] for j in shuffle]
            print "... starting QB"
            qb = QuickBundles(tracks, qb_threshold, downsampling)
            print "... finished QB"
            res["clusters"] = {}
            for k in qb.clusters().keys():
                # this would be improved if
                # we 'enumerated' QB's keys and used the enumerator as
                # as the key in the result
                res["clusters"][k] = qb.clusters()[k]["indices"]
            res["centroids"] = qb.centroids
            res["nclusters"] = qb.total_clusters
            res["cluster_sizes"] = qb.clusters_sizes()
            print "QB for has", qb.total_clusters, "clusters"
            filename = "subj_" + ids[id] + "_QB_rep_" + replabs[i] + ".pkl"
            f = open(filename, "w")
            pickle.dump(res, f)
            f.close()
        print "Results written to", filename
    return sizes
Beispiel #17
0
def test_qbundles():
    streams, hdr = nib.trackvis.read(get_fnames('fornix'))
    T = [s[0] for s in streams]
    qb = QuickBundles(T, 10., 12)
    qb.virtuals()
    qb.exemplars()
    assert_equal(4, qb.total_clusters)
Beispiel #18
0
def qb_wrapper(data, qb_threshold, streamlines_ids=None, qb_n_points=None):
    """A wrapper for qb with the correct API for the Labeler.

    Note: qb_n_points = None means 'do not dowsample again data'.
    """
    print "qb_wrapper: starting"
    if streamlines_ids is None:
        streamlines_ids = np.arange(len(data), dtype=np.int)
    else:
        streamlines_ids = np.sort(list(streamlines_ids)).astype(np.int)

    print "streamlines_ids:", len(streamlines_ids)
    print "data:", data.shape
    print "Calling QuickBundles."
    qb = QuickBundles(data, qb_threshold, qb_n_points)
    tmpa, qb_internal_id = qb.exemplars() # this function call is necessary to let qb compute qb.exempsi
    clusters = {}
    print "Creating new clusters dictionary"
    for i, clusterid in enumerate(qb.clustering.keys()):
        indices = streamlines_ids[qb.clustering[clusterid]['indices']]
        tmp = indices[qb_internal_id[i]]
        clusters[tmp] = set(list(indices))
        print tmp, '->', len(clusters[tmp])
    return clusters
Beispiel #19
0
def bundle_tracks(in_file, dist_thr=40., pts=16, skip=80.):
    import subprocess
    import os.path as op
    from nibabel import trackvis as tv
    from dipy.segment.quickbundles import QuickBundles
    streams, hdr = tv.read(in_file)
    streamlines = [i[0] for i in streams]
    qb = QuickBundles(streamlines, float(dist_thr), int(pts))
    clusters = qb.clustering
    #scalars = [i[0] for i in streams]

    out_files = []
    name = "quickbundle_"
    n_clusters = clusters.keys()
    print("%d clusters found" % len(n_clusters))

    new_hdr = tv.empty_header()
    new_hdr['n_scalars'] = 1

    for cluster in clusters:
        cluster_trk = op.abspath(name + str(cluster) + ".trk")
        print("Writing cluster %d to %s" % (cluster, cluster_trk))
        out_files.append(cluster_trk)
        clust_idxs = clusters[cluster]['indices']
        new_streams = [streamlines[i] for i in clust_idxs]
        for_save = [(sl, None, None) for sl in new_streams]
        tv.write(cluster_trk, for_save, hdr)

    out_merged_file = "MergedBundles.trk"
    command_list = ["track_merge"]
    command_list.extend(out_files)
    command_list.append(out_merged_file)
    subprocess.call(command_list)
    out_scene_file = write_trackvis_scene(out_merged_file,
                                          n_clusters=len(clusters),
                                          skip=skip,
                                          names=None,
                                          out_file="NewScene.scene")
    print("Merged track file written to %s" % out_merged_file)
    print("Scene file written to %s" % out_scene_file)
    return out_files, out_merged_file, out_scene_file
class StreamlineLabeler(Actor):   
    
    def __init__(self, name,qb, tracks, reps='exemplars',
                 colors=None, vol_shape=None, 
                 virtuals_line_width=5.0, tracks_line_width=2.0, 
                 virtuals_alpha=1.0, tracks_alpha=0.6, 
                 affine=None, verbose=False):
        """TrackLabeler is meant to explore and select subsets of the
        tracks. The exploration occurs through QuickBundles (qb) in
        order to simplify the scene.
        """
        super(StreamlineLabeler, self).__init__(name)

        if affine is None: self.affine = np.eye(4, dtype = np.float32)
        else: self.affine = affine
        if vol_shape is not None:
            I, J, K = vol_shape
            centershift = img_to_ras_coords(np.array([[I/2., J/2., K/2.]]), affine)
            centeraffine = from_matvec(np.eye(3), centershift.squeeze())
            affine[:3,3] = affine[:3, 3] - centeraffine[:3, 3]
        self.glaffine = (GLfloat * 16)(*tuple(affine.T.ravel()))
        self.glaff = affine
        self.mouse_x=None
        self.mouse_y=None
        self.cache = {}
        self.qb = qb
        self.reps = reps
        #virtual tracks
        if self.reps=='virtuals':
            self.virtuals=qb.virtuals()
        if self.reps=='exemplars':
            self.virtuals,self.ex_ids = qb.exemplars()
        self.virtuals_alpha = virtuals_alpha
        self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(self.virtuals, colors, self.virtuals_alpha)
        #full tractography (downsampled at 12 pts per track)
        self.tracks = tracks
        self.tracks_alpha = tracks_alpha
        self.tracks_ids = np.arange(len(self.tracks), dtype=np.int)
        self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(self.tracks, colors, self.tracks_alpha)
        #calculate boundary box for entire tractography
        self.min = np.min(self.tracks_buffer,axis=0)
        self.max = np.max(self.tracks_buffer,axis=0)      
        self.vertices=self.tracks_buffer
        #coord1 = np.array([self.tracks_buffer[:,0].min(),self.tracks_buffer[:,1].min(),self.tracks_buffer[:,2].min()], dtype = 'f4')        
        #coord2 = np.array([self.tracks_buffer[:,0].max(),self.tracks_buffer[:,1].max(),self.tracks_buffer[:,2].max()], dtype = 'f4')
        #self.make_aabb((coord1,coord2),0)
        #show size of tractography buffer
        print('MBytes %f' % (self.tracks_buffer.nbytes/2.**20,))
        self.position = (0,0,0)
        #buffer for selected virtual tracks
        self.selected = []
        self.virtuals_line_width = virtuals_line_width
        self.tracks_line_width = tracks_line_width
        self.old_color = {}
        self.hide_virtuals = False
        self.expand = False
        self.verbose = verbose
        self.tracks_visualized_first = np.array([], dtype='i4')
        self.tracks_visualized_count = np.array([], dtype='i4')
        self.history = [[self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count]]
        #shifting of track is necessary for dipy.tracking.vox2track.track_counts
        #we also upsample using 30 points in order to increase the accuracy of track counts
        self.vol_shape = vol_shape
        if self.vol_shape !=None:
            #self.tracks_shifted =[t+np.array(vol_shape)/2. for t in self.tracks]
            self.virtuals_shifted =[downsample(t+np.array(self.vol_shape)/2.,30) for t in self.virtuals]

        else:
            #self.tracks_shifted=None
            self.virtuals_shifted=None

    def compute_buffers(self, tracks, colors, alpha):
        """Compute buffers for GL compilation.
        """
        tracks_buffer = np.ascontiguousarray(np.concatenate(tracks).astype('f4'))        
        tracks_colors = np.ascontiguousarray(self.compute_colors(tracks, colors, alpha))
        tracks_count = np.ascontiguousarray(np.array([len(v) for v in tracks],dtype='i4'))
        tracks_first = np.ascontiguousarray(np.r_[0,np.cumsum(tracks_count)[:-1]].astype('i4'))
        
        if isinstance(tracks_count,tuple): print '== count'
        if isinstance(tracks_first,tuple): print '== first'

        return tracks_buffer, tracks_colors, tracks_first, tracks_count

    def compute_colors(self, tracks, colors, alpha):
        """Compute colors for a list of tracks.
        """

        assert(type(tracks)==type([]))
        tot_vertices = np.sum([len(curve) for curve in tracks])
        color = np.empty((tot_vertices,4), dtype='f4')
        counter = 0
        j = 0
        for curve in  tracks:
            if (colors==None):
                color[counter:counter+len(curve),:3] = track2rgb(curve).astype('f4')
            else:
                color[counter:counter+len(curve),:3] = colors[j].astype('f4')                               
            j = j + 1
            counter += len(curve)
        color[:,3] = alpha
        return color  

    def draw(self):
        """Draw virtual and real tracks.

        This is done at every frame and therefore must be real fast.
        """
        glDisable(GL_LIGHTING)
        # virtuals
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_BLEND)
        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
        glEnable(GL_LINE_SMOOTH)        
        glHint(GL_LINE_SMOOTH_HINT, GL_NICEST)
        glEnableClientState(GL_VERTEX_ARRAY)
        glEnableClientState(GL_COLOR_ARRAY)
        if not self.hide_virtuals:
            glVertexPointer(3,GL_FLOAT,0,self.virtuals_buffer.ctypes.data)
            glColorPointer(4,GL_FLOAT,0,self.virtuals_colors.ctypes.data)
            glLineWidth(self.virtuals_line_width)
            glPushMatrix()
            glMultMatrixf(self.glaffine)
            if isinstance(self.virtuals_first, tuple): print '>> first Tuple'
            if isinstance(self.virtuals_count, tuple): print '>> count Tuple'
            glib.glMultiDrawArrays(GL_LINE_STRIP, 
                                   self.virtuals_first.ctypes.data, 
                                   self.virtuals_count.ctypes.data, 
                                   len(self.virtuals))
            glPopMatrix()
        # reals:
        if self.expand and self.tracks_visualized_first.size > 0:
            glVertexPointer(3,GL_FLOAT,0,self.tracks_buffer.ctypes.data)
            glColorPointer(4,GL_FLOAT,0,self.tracks_colors.ctypes.data)
            glLineWidth(self.tracks_line_width)
            glPushMatrix()
            glMultMatrixf(self.glaffine)
            glib.glMultiDrawArrays(GL_LINE_STRIP, 
                                    self.tracks_visualized_first.ctypes.data, 
                                    self.tracks_visualized_count.ctypes.data, 
                                    len(self.tracks_visualized_count))
            glPopMatrix()
        glDisableClientState(GL_COLOR_ARRAY)
        glDisableClientState(GL_VERTEX_ARRAY)      
        glLineWidth(1.)
        glDisable(GL_DEPTH_TEST)
        glDisable(GL_BLEND)
        glDisable(GL_LINE_SMOOTH)

    def process_mouse_position(self,x,y):
        self.mouse_x=x
        self.mouse_y=y

    def process_pickray(self,near,far):
        pass

    def update(self,dt):
        pass

    def select_track(self, ids):
        """Do visual selection of given virtuals.
        """
        # WARNING: we assume that no tracks can ever have color_selected as original color
        color_selected = np.array([1.0, 1.0, 1.0, 1.0], dtype='f4')
        if ids == 'all':
            ids = range(len(self.virtuals))
        elif np.isscalar(ids):
            ids = [ids]
        for id in ids:
            if not id in self.old_color:
                self.old_color[id] = self.virtuals_colors[self.virtuals_first[id]:self.virtuals_first[id]+self.virtuals_count[id],:].copy()
                new_color = np.ones(self.old_color[id].shape, dtype='f4') * color_selected
                if self.verbose: print("Storing old color: %s" % self.old_color[id][0])
                self.virtuals_colors[self.virtuals_first[id]:self.virtuals_first[id]+self.virtuals_count[id],:] = new_color
                self.selected.append(id)

    def unselect_track(self, ids):
        """Do visual un-selection of given virtuals.
        """
        if ids == 'all':
            ids = range(len(self.virtuals))
        elif np.isscalar(ids):
            ids = [ids]
        for id in ids:
            if id in self.old_color:
                self.virtuals_colors[self.virtuals_first[id]:self.virtuals_first[id]+self.virtuals_count[id],:] = self.old_color[id]
                if self.verbose: print("Setting old color: %s" % self.old_color[id][0])
                self.old_color.pop(id)
                if id in self.selected:
                    self.selected.remove(id)
                else:
                    print('WARNING: unselecting id %s but not in %s' % (id, self.selected))
                    
    def invert_tracks(self):
        """ invert selected tracks to unselected
        """        
        tmp_selected=list(set(range(len(self.virtuals))).difference(set(self.selected)))
        self.unselect_track('all')
        #print tmp_selected
        self.selected=[]
        self.select_track(tmp_selected)

    def process_messages(self,messages):
        msg=messages['key_pressed']
        #print 'Processing messages in actor', self.name, 
        #' key_press message ', msg
        if msg!=None:
            self.process_keys(msg,None)
        msg=messages['mouse_position']            
        #print 'Processing messages in actor', self.name, 
        #' mouse_pos message ', msg
        if msg!=None:
            self.process_mouse_position(*msg)

    def process_keys(self,symbol,modifiers):
        """Bind actions to key press.
        """
        prev_selected = copy.copy(self.selected)
        if symbol == Qt.Key_P:     
            print 'P'
            id = self.picking_virtuals(symbol, modifiers)
            print('Track id %d' % id)
            if prev_selected.count(id) == 0:
                self.select_track(id)
            else:
                self.unselect_track(id)
            if self.verbose: 
                print 'Selected:'
                print self.selected

        if symbol==Qt.Key_E:
            print 'E'
            if self.verbose: print("Expand/collapse selected clusters.")
            if not self.expand and len(self.selected)>0:
                tracks_selected = []
                for tid in self.selected: tracks_selected += self.qb.label2tracksids(tid)
                self.tracks_visualized_first = np.ascontiguousarray(self.tracks_first[tracks_selected, :])
                self.tracks_visualized_count = np.ascontiguousarray(self.tracks_count[tracks_selected, :])
                self.expand = True
            else:
                self.expand = False
        
        # Freeze and restart:
        elif symbol == Qt.Key_F and len(self.selected) > 0:
            print 'F'
            self.freeze()

        elif symbol == Qt.Key_A:
            print 'A'        
            print('Select/unselect all virtuals')
            if len(self.selected) < len(self.virtuals):
                self.select_track('all')
            else:
                self.unselect_track('all')
        
        elif symbol == Qt.Key_I:
            print 'I'
            print('Invert selection')
            print self.selected
            self.invert_tracks()
            
        elif symbol == Qt.Key_H:
            print 'H'
            print('Hide/show virtuals.')
            self.hide_virtuals = not self.hide_virtuals       
            
        elif symbol == Qt.Key_S:
            print 'S'
            print('Save selected tracks_ids as pickle file.')
            self.tracks_ids_to_be_saved = self.tracks_ids
            if len(self.selected)>0:
                self.tracks_ids_to_be_saved = self.tracks_ids[np.concatenate([self.qb.label2tracksids(tid) for tid in self.selected])]
            print("Saving %s tracks." % len(self.tracks_ids_to_be_saved))
            root = Tkinter.Tk()
            root.withdraw()
            pickle.dump(self.tracks_ids_to_be_saved, 
                    tkFileDialog.asksaveasfile(), 
                    protocol=pickle.HIGHEST_PROTOCOL)

        elif symbol == Qt.Key_Question:
            print question_message
        elif symbol == Qt.Key_B:
            print 'B'
            print('Go back in the freezing history.')
            if len(self.history) > 1:
                self.history.pop()
                self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.history[-1]
                if self.reps=='virtuals':
                    self.virtuals=qb.virtuals()
                if self.reps=='exemplars':
                    self.virtuals, self.ex_ids = self.qb.exemplars()#virtuals()
                print len(self.virtuals), 'virtuals'
                self.selected = []
                self.old_color = {}
                self.expand = False
                self.hide_virtuals = False

        elif symbol == Qt.Key_G:
            print 'G'
            print('Get tracks from mask.')
            ids = self.maskout_tracks()
            self.select_track(ids)

    def freeze(self):
        print("Freezing current expanded real tracks, then doing QB on them, then restarting.")
        print("Selected virtuals: %s" % self.selected)
        tracks_frozen = []
        tracks_frozen_ids = []
        for tid in self.selected:
            print tid
            part_tracks = self.qb.label2tracks(self.tracks, tid)
            part_tracks_ids = self.qb.label2tracksids(tid)
            print("virtual %s represents %s tracks." % (tid, len(part_tracks)))
            tracks_frozen += part_tracks
            tracks_frozen_ids += part_tracks_ids
        print "frozen tracks size:", len(tracks_frozen)
        print "Computing quick bundles...",
        self.unselect_track('all')
        self.tracks = tracks_frozen
        self.tracks_ids = self.tracks_ids[tracks_frozen_ids] 
        
        root = Tkinter.Tk()
        root.wm_title('QuickBundles threshold')
        ts = ThresholdSelector(root, default_value=self.qb.dist_thr/2.0)
        root.wait_window()
        
        self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts)
        #self.qb.dist_thr = qb.dist_thr/2.
        self.qb.dist_thr = ts.value
        if self.reps=='virtuals':
            self.virtuals=qb.virtuals()
        if self.reps=='exemplars':
            self.virtuals,self.ex_ids = self.qb.exemplars()
        print len(self.virtuals), 'virtuals'
        self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(self.virtuals, self.virtuals_alpha)
        #compute buffers
        self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(self.tracks, self.tracks_alpha)
        # self.unselect_track('all')
        self.selected = []
        self.old_color = {}
        self.expand = False
        self.history.append([self.qb, 
                            self.tracks, 
                            self.tracks_ids, 
                            self.virtuals_buffer, 
                            self.virtuals_colors, 
                            self.virtuals_first, 
                            self.virtuals_count, 
                            self.tracks_buffer, 
                            self.tracks_colors, 
                            self.tracks_first, 
                            self.tracks_count])
        if self.vol_shape is not None:
            print("Shifting!")
            self.virtuals_shifted = [downsample(t + np.array(self.vol_shape) / 2., 30) for t in self.virtuals]
        else:
            self.virtuals_shifted = None

    def picking_virtuals(self, symbol,modifiers, min_dist=1e-3):
        """Compute the id of the closest track to the mouse pointer.
        """
        x, y = self.mouse_x, self.mouse_y
        # Define two points in model space from mouse+screen(=0) position and mouse+horizon(=1) position
        near = screen_to_model(x, y, 0)
        far = screen_to_model(x, y, 1)

        #print 'peak virtuals ', near, far, x, y
        # Compute distance of virtuals from screen and from the line defined by the two points above
        tmp = np.array([cll.mindistance_segment2track_info(near, far,
            apply_transformation(xyz, self.glaff)) for xyz in self.virtuals])
        line_distance, screen_distance = tmp[:,0], tmp[:,1]
        if False: # basic algoritm:
            # Among the virtuals within a range to the line (i.e. < min_dist) return the closest to the screen:
            closest_to_line_idx = np.argsort(line_distance)
            closest_to_line_thresholded_bool = line_distance[closest_to_line_idx] < min_dist
            if (closest_to_line_thresholded_bool).any():
                return closest_to_line_idx[np.argmin(screen_distance[closest_to_line_thresholded_bool])]
            else:
                return closest_to_line_idx[0]
        else: # simpler and apparently more effective algorithm:
            return np.argmin(line_distance + screen_distance)

    def maskout_tracks(self):
        """ retrieve ids of virtuals which go through the mask
        """
        mask = self.slicer.mask        
        #tracks = self.tracks_shifted
        tracks = self.virtuals_shifted
        #tcs,self.tes = track_counts(tracks,mask.shape,(1,1,1),True)
        tcs,tes = track_counts(tracks,mask.shape,(1,1,1),True)
        # print 'tcs:',tcs
        # print 'tes:',len(self.tes.keys())
        #find volume indices of mask's voxels
        roiinds=np.where(mask==1)
        #make it a nice 2d numpy array (Nx3)
        roiinds=np.array(roiinds).T
        #get tracks going through the roi
        # print "roiinds:", len(roiinds)
        # mask_tracks,mask_tracks_inds=bring_roi_tracks(tracks,roiinds,self.tes)
        mask_tracks_inds = []
        for voxel in roiinds:
            try:
                #mask_tracks_inds+=self.tes[tuple(voxel)]
                mask_tracks_inds+=tes[tuple(voxel)]
            except KeyError:
                pass
        mask_tracks_inds = list(set(mask_tracks_inds))
        print("Masked tracks %d" % len(mask_tracks_inds))
        print("mask_tracks_inds: %s" % mask_tracks_inds)
        return mask_tracks_inds
Beispiel #21
0
    #    T = [downsample(t, 18) - np.array(data.shape[:3]) / 2. for t in T]
    #    axis = np.array([1, 0, 0])
    #    theta = - 90.
    #    T = np.dot(T,rotation_matrix(axis, theta))
    #    axis = np.array([0, 1, 0])
    #    theta = 180.
    #    T = np.dot(T, rotation_matrix(axis, theta))
    #
    # load initial QuickBundles with threshold 30mm
    # fpkl = dname+'data/subj_05/101_32/DTI/qb_gqi_1M_linear_30.pkl'
    # qb=QuickBundles(T, 10., 18)
    # save_pickle(fpkl,qb)
    # qb=load_pickle(fpkl)

    qb = QuickBundles(T, 20.0, 12)
    # save_pickle(fpkl,qb)
    # qb=load_pickle(fpkl)
    Init()

    title = "[F]oS Streamline Interaction and Segmentation"
    w = Window(caption=title, width=1200, height=800, bgcolor=(0.5, 0.5, 0.9), right_panel=True)

    scene = Scene(scenename="Main Scene", activate_aabb=False)

    # create the interaction system for tracks
    tl = StreamlineLabeler(
        "Bundle Picker", qb, qb.downsampled_tracks(), vol_shape=data.shape[:3], tracks_alpha=1, affine=affine
    )

    guil = Guillotine("Volume Slicer", data, affine)
Beispiel #22
0
        T2=[T[i] for i in seg_inds]
        T=T1
    
    if reduce_length:
        T=[t for t in T if track_range(100,200)]
        #iT=np.random.randint(0,len(T),5000)
        #T=[T[i] for i in iT]
    #stop
    
    #center
    shift=(np.array(data.shape)-1)/2.    
    T=[t-shift for t in T]
    
    #load initial QuickBundles with threshold 30mm
    #fpkl = 'data/subj_05/101_32/DTI/qb_gqi_1M_linear_30.pkl'
    qb=QuickBundles(T,25.,30)    
    print len(qb.clustering)
    #qb=load_pickle(fpkl)
    qb.remove_small_clusters(1000)
    print len(qb.clustering)

        
    #create the interaction system for tracks 
    tl = TrackLabeler(qb,qb.downsampled_tracks(),vol_shape=data.shape,tracks_line_width=3.,tracks_alpha=1.)   
    #add a interactive slicing/masking tool
    sl = Slicer(affine,data)    
    #add one way communication between tl and sl
    tl.slicer=sl
    #OpenGL coordinate system axes    
    ax = Axes(100)
    x,y,z=data.shape
Beispiel #23
0
def main():
    parser = OptionParser(usage="Usage: %prog [options] <tract.vtp> <output.csv>")
    parser.add_option("-d", "--dist", dest="dist", default=20, type='float', help="Quickbundle distance threshold")
    parser.add_option("-n", "--num", dest="num", default=50, type='int', help="Number of subdivisions along centroids")
    parser.add_option("-s", "--scalar", dest="scalar", default="FA", help="Scalar to measure")
    parser.add_option("--curvepoints", dest="curvepoints_file", help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.")
    parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False, help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines")
    parser.add_option('--reverse', dest='is_reverse', action='store_true', default=False, help='Reverse the centroid measure stepping order')
    (options, args) = parser.parse_args()

    if len(args) == 0:
        parser.print_help()
        sys.exit(2)
        
    QB_DIST = options.dist
    QB_NPOINTS = options.num
    SCALAR_NAME = options.scalar
    LOCAL_POINT_ASSIGN = options.is_local

    filename= args[0]
    filebase = path.basename(filename).split('.')[0]

    reader = vtk.vtkXMLPolyDataReader()
    reader.SetFileName(filename)
    reader.Update()

    polydata = reader.GetOutput()


    tract_ids = []
    for i in range(polydata.GetNumberOfCells()):
        # get point ids in [[ids][ids...]...] format
        pids =  polydata.GetCell(i).GetPointIds()
        ids = [ pids.GetId(p) for p in range(pids.GetNumberOfIds())]
        tract_ids.append(ids) 
    print 'tracks:',len(tract_ids)

    verts = vtk_to_numpy(polydata.GetPoints().GetData())
    print 'verts:',len(verts)

    scalars = []
    groups = []
    subjects = []
    pointdata = polydata.GetPointData()
    for si in range(pointdata.GetNumberOfArrays()):
        sname =  pointdata.GetArrayName(si)
        print sname
        if sname==SCALAR_NAME:
            scalars = vtk_to_numpy(pointdata.GetArray(si))
        if sname=='group':
            groups = vtk_to_numpy(pointdata.GetArray(si))
            groups = groups.astype(int)
        if sname=='tid':
            subjects = vtk_to_numpy(pointdata.GetArray(si))
            subjects = subjects.astype(int)


    streamlines = []
    stream_scalars = []
    stream_groups = []
    stream_pids = []
    stream_sids = []

    for i in tract_ids:
        # index np.array by a list will get all the respective indices
        streamlines.append(verts[i])
        stream_scalars.append(scalars[i])
        stream_groups.append(groups[i])
        stream_pids.append(i)
        stream_sids.append(subjects[i])

    streamlines = np.array(streamlines)
    stream_scalars = np.array(stream_scalars)
    stream_groups = np.array(stream_groups)
    stream_pids = np.array(stream_pids)
    stream_sids = np.array(stream_sids)

    # get total average direction (where majority point towards)
    avg_d = np.zeros(3)
    # for line in streams:
    #     d = np.array(line[-1]) - np.array(line[0])
    #     d = d / la.norm(d)
    #     avg_d += d
    #     avg_d /= la.norm(avg_d)

    avg_com = np.zeros(3)
    avg_mid = np.zeros(3)

    strl_len = [len(l) for l in streamlines]
    stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines])

    centroids = []
    if options.curvepoints_file: 
        LOCAL_POINT_ASSIGN = False
        cpoints = []
        ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',')
        # have a separate vtkCardinalSpline interpreter for x,y,z
        curve = [vtk.vtkCardinalSpline() for i in range(3)]
        for c in curve:
            c.ClosedOff()

        for pi, point in enumerate(ctrlpoints):
            for i,val in enumerate(point):
                curve[i].AddPoint(pi,point[i])

        param_range = [0.0,0.0]
        curve[0].GetParametricRange(param_range)

        t = param_range[0]
        step = (param_range[1]-param_range[0])/(QB_NPOINTS)

        while t < param_range[1]:
            cp = [c.Evaluate(t) for c in curve]
            cpoints.append(cp)
            t = t + step

        centroids.append(cpoints)
        centroids = np.array(centroids)


    else:
        """
            Use quickbundles to find centroids
        """
        # streamlines = newlines
        qb = QuickBundles(streamlines, dist_thr=QB_DIST,pts=QB_NPOINTS)
        # bundle_distance_mam

        centroids = qb.centroids
        clusters = qb.clusters()

        avg_d = np.zeros(3)
        avg_com = np.zeros(3)
        avg_mid = np.zeros(3)

        #unify centroid list orders to point in the same general direction
        for i, line in enumerate(centroids):
            ori = np.array(tm.mean_orientation(line))
            #d = np.array(line[-1]) - np.array(line[0])
            #print line[-1],line[0],d
            # get the unit vector of the mean orientation
            if i==0:
                avg_d = ori

            #d = d / la.norm(d) 
            dotprod = ori.dot(avg_d) 
            print 'dotprod',dotprod
            if dotprod < 0:
                print 'reverse',dotprod      
                centroids[i] = line[::-1]
                line = centroids[i]
                ori*=-1
            avg_d += ori

        if options.is_reverse:
            for i,c in enumerate(centroids):
                centroids[i] = c[::-1]



    DATADF = None

    """
        CENTROIDS
    """
    for ci, cent in enumerate(centroids):
        print '---- centroid:'

        if LOCAL_POINT_ASSIGN:
            """
                apply centroid to only their point assignments
                through quickbundles
            """
            ind = clusters[ci]['indices']
            cent_streams = streamlines[ind]
            cent_scalars = stream_scalars[ind]
            cent_groups = stream_groups[ind]
            cent_pids = stream_pids[ind]
            cent_sids = stream_sids[ind]
        else:
            # apply each centriod to all the points
            # instead of only their centroid assignments
            cent_streams = streamlines
            cent_scalars = stream_scalars
            cent_groups = stream_groups
            cent_pids = stream_pids
            cent_sids = stream_sids


        cent_verts = np.vstack(cent_streams)
        cent_scalars = np.concatenate(cent_scalars)
        cent_groups = np.concatenate(cent_groups)
        cent_pids = np.concatenate(cent_pids)
        cent_sids = np.concatenate(cent_sids)
        # cent_color = np.array(pal[ci])

        c, labels = kmeans2(cent_verts, cent, iter=1)

        cid = np.ones(len(labels))
        d = {'value':cent_scalars, 'position':labels, 'group':cent_groups, 'pid':cent_pids, 'sid':cent_sids, 'centroid':ci}

        df = pd.DataFrame(data=d)
        if DATADF is None:
            DATADF = df
        else:
            pd.concat([DATADF, df])


    outfilename = '_'.join([filebase,SCALAR_NAME,'rawdata.csv'])
    if len(args) > 1:
        outfilename = args[1]

    if outfilename.endswith('.csv'):
        DATADF.to_csv(outfilename, index=False)

    if outfilename.endswith('.xls'):
        DATADF.to_excel(outfilename, index=False)
T = [i[0] for i in streams]
"""
Downsample tracks to 12 points:
"""

tracks = [tm.downsample(t, 12) for t in T]
"""
Delete unnecessary data:
"""

del streams, hdr
"""
Perform QuickBundles clustering with a 10mm threshold:
"""

qb = QuickBundles(tracks, dist_thr=10., pts=None)
"""
Show the initial *Fornix* dataset:
"""

r = fvtk.ren()
fvtk.add(r, fvtk.line(T, fvtk.white, opacity=1, linewidth=3))
#fvtk.show(r)
fvtk.record(r, n_frames=1, out_path='fornix_initial', size=(600, 600))
fvtk.clear(r)
"""
.. figure:: fornix_initial1000000.png
   :align: center

   **Initial Fornix dataset**.
"""
Beispiel #25
0
def terminus2hemi_surface_density_map(streamlines, geo_path, hemi):
    """
    Streamline endpoints areas mapping to hemisphere surface
    streamlines > 1000
    Parameters
    ----------
    streamlines: streamline data
    geo_path: surface data path

    Return
    ------
    endpoints areas map on surface
    """
    streamlines = _sort_streamlines(streamlines)
    bundles = QuickBundles(streamlines, 10, 12)
    # bundles.remove_small_clusters(10)
    clusters = bundles.clusters()
    data_clusters = []
    for key in clusters.keys():
        data_clusters.append(streamlines[clusters[key]['indices']])

    data0 = data_clusters[0]
    stream_terminus_lh0 = np.array([s[0] for s in data0])
    stream_terminus_rh0 = np.array([s[-1] for s in data0])

    suffix = os.path.split(geo_path)[1].split('.')[-1]
    if suffix in ('white', 'inflated', 'pial'):
        coords, faces = nib.freesurfer.read_geometry(geo_path)
    elif suffix == 'gii':
        gii_data = nib.load(geo_path).darrays
        coords, faces = gii_data[0].data, gii_data[1].data
    else:
        raise ImageFileError(
            'This file format-{} is not supported at present.'.format(suffix))
    if hemi == 'lh':
        dist_lh0 = cdist(coords, stream_terminus_lh0)
        vert_value = np.array([
            float(np.array(dist_lh0[m] < 5).sum(axis=0))
            for m in range(len(dist_lh0[:]))
        ])
    else:
        dist_rh0 = cdist(coords, stream_terminus_rh0)
        vert_value = np.array([
            float(np.array(dist_rh0[n] < 5).sum(axis=0))
            for n in range(len(dist_rh0[:]))
        ])

    for i in range(len(data_clusters)):
        data = data_clusters[i]
        if hemi == 'lh':
            stream_terminus = np.array([s[0] for s in data])
        else:
            stream_terminus = np.array([s[-1] for s in data])

        dist = cdist(coords, stream_terminus)
        vert_value_i = np.array(
            [np.array(dist[j] < 5).sum(axis=0) for j in range(len(dist[:]))])

        if i != 0:
            vert_value += vert_value_i

    return vert_value
Beispiel #26
0
class TrackLabeler(Actor):
    def __init__(self,
                 name,
                 qb,
                 tracks,
                 reps='exemplars',
                 colors=None,
                 vol_shape=None,
                 virtuals_line_width=5.0,
                 tracks_line_width=2.0,
                 virtuals_alpha=1.0,
                 tracks_alpha=0.6,
                 affine=None,
                 verbose=False):
        """TrackLabeler is meant to explore and select subsets of the
        tracks. The exploration occurs through QuickBundles (qb) in
        order to simplify the scene.
        """
        super(TrackLabeler, self).__init__(name)

        if affine is None: self.affine = np.eye(4, dtype=np.float32)
        else: self.affine = affine

        self.mouse_x = None
        self.mouse_y = None
        self.cache = {}
        self.qb = qb
        self.reps = reps
        #virtual tracks
        if self.reps == 'virtuals':
            self.virtuals = qb.virtuals()
        if self.reps == 'exemplars':
            self.virtuals, self.ex_ids = qb.exemplars()
        self.virtuals_alpha = virtuals_alpha
        self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(
            self.virtuals, self.virtuals_alpha)
        #full tractography (downsampled at 12 pts per track)
        self.tracks = tracks
        self.tracks_alpha = tracks_alpha
        self.tracks_ids = np.arange(len(self.tracks), dtype=np.int)
        self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(
            self.tracks, self.tracks_alpha)
        #calculate boundary box for entire tractography
        self.min = np.min(self.tracks_buffer, axis=0)
        self.max = np.max(self.tracks_buffer, axis=0)
        self.vertices = self.tracks_buffer
        #coord1 = np.array([self.tracks_buffer[:,0].min(),self.tracks_buffer[:,1].min(),self.tracks_buffer[:,2].min()], dtype = 'f4')
        #coord2 = np.array([self.tracks_buffer[:,0].max(),self.tracks_buffer[:,1].max(),self.tracks_buffer[:,2].max()], dtype = 'f4')
        #self.make_aabb((coord1,coord2),0)
        #show size of tractography buffer
        print('MBytes %f' % (self.tracks_buffer.nbytes / 2.**20, ))
        self.position = (0, 0, 0)
        #buffer for selected virtual tracks
        self.selected = []
        self.virtuals_line_width = virtuals_line_width
        self.tracks_line_width = tracks_line_width
        self.old_color = {}
        self.hide_virtuals = False
        self.expand = False
        self.verbose = verbose
        self.tracks_visualized_first = np.array([], dtype='i4')
        self.tracks_visualized_count = np.array([], dtype='i4')
        self.history = [[
            self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer,
            self.virtuals_colors, self.virtuals_first, self.virtuals_count,
            self.tracks_buffer, self.tracks_colors, self.tracks_first,
            self.tracks_count
        ]]
        #shifting of track is necessary for dipy.tracking.vox2track.track_counts
        #we also upsample using 30 points in order to increase the accuracy of track counts
        self.vol_shape = vol_shape
        if self.vol_shape != None:
            #self.tracks_shifted =[t+np.array(vol_shape)/2. for t in self.tracks]
            self.virtuals_shifted = [
                downsample(t + np.array(self.vol_shape) / 2., 30)
                for t in self.virtuals
            ]
        else:
            #self.tracks_shifted=None
            self.virtuals_shifted = None

    def compute_buffers(self, tracks, alpha):
        """Compute buffers for GL compilation.
        """
        tracks_buffer = np.ascontiguousarray(
            np.concatenate(tracks).astype('f4'))
        tracks_colors = np.ascontiguousarray(self.compute_colors(
            tracks, alpha))
        tracks_count = np.ascontiguousarray(
            np.array([len(v) for v in tracks], dtype='i4'))
        tracks_first = np.ascontiguousarray(
            np.r_[0, np.cumsum(tracks_count)[:-1]].astype('i4'))

        if isinstance(tracks_count, tuple): print '== count'
        if isinstance(tracks_first, tuple): print '== first'

        return tracks_buffer, tracks_colors, tracks_first, tracks_count

    def compute_colors(self, tracks, alpha):
        """Compute colors for a list of tracks.
        """
        assert (type(tracks) == type([]))
        tot_vertices = np.sum([len(curve) for curve in tracks])
        color = np.empty((tot_vertices, 4), dtype='f4')
        counter = 0
        for curve in tracks:
            color[counter:counter +
                  len(curve), :3] = track2rgb(curve).astype('f4')
            counter += len(curve)

        color[:, 3] = alpha
        return color

    def draw(self):
        """Draw virtual and real tracks.

        This is done at every frame and therefore must be real fast.
        """
        # virtuals
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_BLEND)
        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
        glEnable(GL_LINE_SMOOTH)
        glHint(GL_LINE_SMOOTH_HINT, GL_NICEST)
        glEnableClientState(GL_VERTEX_ARRAY)
        glEnableClientState(GL_COLOR_ARRAY)
        if not self.hide_virtuals:
            glVertexPointer(3, GL_FLOAT, 0, self.virtuals_buffer.ctypes.data)
            glColorPointer(4, GL_FLOAT, 0, self.virtuals_colors.ctypes.data)
            glLineWidth(self.virtuals_line_width)
            glPushMatrix()
            if isinstance(self.virtuals_first, tuple): print '>> first Tuple'
            if isinstance(self.virtuals_count, tuple): print '>> count Tuple'
            glib.glMultiDrawArrays(GL_LINE_STRIP,
                                   self.virtuals_first.ctypes.data,
                                   self.virtuals_count.ctypes.data,
                                   len(self.virtuals))
            glPopMatrix()
        # reals:
        if self.expand and self.tracks_visualized_first.size > 0:
            glVertexPointer(3, GL_FLOAT, 0, self.tracks_buffer.ctypes.data)
            glColorPointer(4, GL_FLOAT, 0, self.tracks_colors.ctypes.data)
            glLineWidth(self.tracks_line_width)
            glPushMatrix()
            glib.glMultiDrawArrays(GL_LINE_STRIP,
                                   self.tracks_visualized_first.ctypes.data,
                                   self.tracks_visualized_count.ctypes.data,
                                   len(self.tracks_visualized_count))
            glPopMatrix()
        glDisableClientState(GL_COLOR_ARRAY)
        glDisableClientState(GL_VERTEX_ARRAY)
        glLineWidth(1.)
        glDisable(GL_DEPTH_TEST)
        glDisable(GL_BLEND)
        glDisable(GL_LINE_SMOOTH)

    def process_mouse_position(self, x, y):
        self.mouse_x = x
        self.mouse_y = y

    def process_pickray(self, near, far):
        pass

    def update(self, dt):
        pass

    def select_track(self, ids):
        """Do visual selection of given virtuals.
        """
        # WARNING: we assume that no tracks can ever have color_selected as original color
        color_selected = np.array([1.0, 1.0, 1.0, 1.0], dtype='f4')
        if ids == 'all':
            ids = range(len(self.virtuals))
        elif np.isscalar(ids):
            ids = [ids]
        for id in ids:
            if not id in self.old_color:
                self.old_color[id] = self.virtuals_colors[
                    self.virtuals_first[id]:self.virtuals_first[id] +
                    self.virtuals_count[id], :].copy()
                new_color = np.ones(self.old_color[id].shape,
                                    dtype='f4') * color_selected
                if self.verbose:
                    print("Storing old color: %s" % self.old_color[id][0])
                self.virtuals_colors[
                    self.virtuals_first[id]:self.virtuals_first[id] +
                    self.virtuals_count[id], :] = new_color
                self.selected.append(id)

    def unselect_track(self, ids):
        """Do visual un-selection of given virtuals.
        """
        if ids == 'all':
            ids = range(len(self.virtuals))
        elif np.isscalar(ids):
            ids = [ids]
        for id in ids:
            if id in self.old_color:
                self.virtuals_colors[
                    self.virtuals_first[id]:self.virtuals_first[id] +
                    self.virtuals_count[id], :] = self.old_color[id]
                if self.verbose:
                    print("Setting old color: %s" % self.old_color[id][0])
                self.old_color.pop(id)
                if id in self.selected:
                    self.selected.remove(id)
                else:
                    print('WARNING: unselecting id %s but not in %s' %
                          (id, self.selected))

    def invert_tracks(self):
        """ invert selected tracks to unselected
        """
        tmp_selected = list(
            set(range(len(self.virtuals))).difference(set(self.selected)))
        self.unselect_track('all')
        #print tmp_selected
        self.selected = []
        self.select_track(tmp_selected)

    def process_messages(self, messages):
        msg = messages['key_pressed']
        #print 'Processing messages in actor', self.name,
        #' key_press message ', msg
        if msg != None:
            self.process_keys(msg, None)
        msg = messages['mouse_position']
        #print 'Processing messages in actor', self.name,
        #' mouse_pos message ', msg
        if msg != None:
            self.process_mouse_position(*msg)

    def process_keys(self, symbol, modifiers):
        """Bind actions to key press.
        """
        prev_selected = copy.copy(self.selected)
        if symbol == Qt.Key_P:
            print 'P'
            id = self.picking_virtuals(symbol, modifiers)
            print('Track id %d' % id)
            if prev_selected.count(id) == 0:
                self.select_track(id)
            else:
                self.unselect_track(id)
            if self.verbose:
                print 'Selected:'
                print self.selected

        if symbol == Qt.Key_E:
            print 'E'
            if self.verbose: print("Expand/collapse selected clusters.")
            if not self.expand and len(self.selected) > 0:
                tracks_selected = []
                for tid in self.selected:
                    tracks_selected += self.qb.label2tracksids(tid)
                self.tracks_visualized_first = np.ascontiguousarray(
                    self.tracks_first[tracks_selected, :])
                self.tracks_visualized_count = np.ascontiguousarray(
                    self.tracks_count[tracks_selected, :])
                self.expand = True
            else:
                self.expand = False

        # Freeze and restart:
        elif symbol == Qt.Key_F and len(self.selected) > 0:
            print 'F'
            self.freeze()

        elif symbol == Qt.Key_A:
            print 'A'
            print('Select/unselect all virtuals')
            if len(self.selected) < len(self.virtuals):
                self.select_track('all')
            else:
                self.unselect_track('all')

        elif symbol == Qt.Key_I:
            print 'I'
            print('Invert selection')
            print self.selected
            self.invert_tracks()

        elif symbol == Qt.Key_H:
            print 'H'
            print('Hide/show virtuals.')
            self.hide_virtuals = not self.hide_virtuals

        elif symbol == Qt.Key_S:
            print 'S'
            print('Save selected tracks_ids as pickle file.')
            self.tracks_ids_to_be_saved = self.tracks_ids
            if len(self.selected) > 0:
                self.tracks_ids_to_be_saved = self.tracks_ids[np.concatenate(
                    [self.qb.label2tracksids(tid) for tid in self.selected])]
            print("Saving %s tracks." % len(self.tracks_ids_to_be_saved))
            root = Tkinter.Tk()
            root.withdraw()
            pickle.dump(self.tracks_ids_to_be_saved,
                        tkFileDialog.asksaveasfile(),
                        protocol=pickle.HIGHEST_PROTOCOL)

        elif symbol == Qt.Key_Question:
            print question_message
        elif symbol == Qt.Key_B:
            print 'B'
            print('Go back in the freezing history.')
            if len(self.history) > 1:
                self.history.pop()
                self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.history[
                    -1]
                if self.reps == 'virtuals':
                    self.virtuals = qb.virtuals()
                if self.reps == 'exemplars':
                    self.virtuals, self.ex_ids = self.qb.exemplars(
                    )  #virtuals()
                print len(self.virtuals), 'virtuals'
                self.selected = []
                self.old_color = {}
                self.expand = False
                self.hide_virtuals = False

        elif symbol == Qt.Key_G:
            print 'G'
            print('Get tracks from mask.')
            ids = self.maskout_tracks()
            self.select_track(ids)

    def freeze(self):
        print(
            "Freezing current expanded real tracks, then doing QB on them, then restarting."
        )
        print("Selected virtuals: %s" % self.selected)
        tracks_frozen = []
        tracks_frozen_ids = []
        for tid in self.selected:
            print tid
            part_tracks = self.qb.label2tracks(self.tracks, tid)
            part_tracks_ids = self.qb.label2tracksids(tid)
            print("virtual %s represents %s tracks." % (tid, len(part_tracks)))
            tracks_frozen += part_tracks
            tracks_frozen_ids += part_tracks_ids
        print "frozen tracks size:", len(tracks_frozen)
        print "Computing quick bundles...",
        self.unselect_track('all')
        self.tracks = tracks_frozen
        self.tracks_ids = self.tracks_ids[
            tracks_frozen_ids]  # range(len(self.tracks))

        root = Tkinter.Tk()
        root.wm_title('QuickBundles threshold')
        ts = ThresholdSelector(root, default_value=self.qb.dist_thr / 2.0)
        root.wait_window()

        #print "Threshold value ",ts.value
        #self.qb = QuickBundles(self.tracks, dist_thr=qb.dist_thr/2., pts=self.qb.pts)
        self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts)
        #self.qb.dist_thr = qb.dist_thr/2.
        self.qb.dist_thr = ts.value
        if self.reps == 'virtuals':
            self.virtuals = qb.virtuals()
        if self.reps == 'exemplars':
            self.virtuals, self.ex_ids = self.qb.exemplars()
        print len(self.virtuals), 'virtuals'
        self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(
            self.virtuals, self.virtuals_alpha)
        #compute buffers
        self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(
            self.tracks, self.tracks_alpha)
        # self.unselect_track('all')
        self.selected = []
        self.old_color = {}
        self.expand = False
        self.history.append([
            self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer,
            self.virtuals_colors, self.virtuals_first, self.virtuals_count,
            self.tracks_buffer, self.tracks_colors, self.tracks_first,
            self.tracks_count
        ])
        if self.vol_shape is not None:
            print("Shifting!")
            self.virtuals_shifted = [
                downsample(t + np.array(self.vol_shape) / 2., 30)
                for t in self.virtuals
            ]
        else:
            self.virtuals_shifted = None

    def picking_virtuals(self, symbol, modifiers, min_dist=1e-3):
        """Compute the id of the closest track to the mouse pointer.
        """
        x, y = self.mouse_x, self.mouse_y
        # Define two points in model space from mouse+screen(=0) position and mouse+horizon(=1) position
        near = screen_to_model(x, y, 0)
        far = screen_to_model(x, y, 1)

        #print 'peak virtuals ', near, far, x, y
        # Compute distance of virtuals from screen and from the line defined by the two points above
        tmp = np.array([cll.mindistance_segment2track_info(near, far, xyz) \
                        for xyz in self.virtuals])
        line_distance, screen_distance = tmp[:, 0], tmp[:, 1]
        if False:  # basic algoritm:
            # Among the virtuals within a range to the line (i.e. < min_dist) return the closest to the screen:
            closest_to_line_idx = np.argsort(line_distance)
            closest_to_line_thresholded_bool = line_distance[
                closest_to_line_idx] < min_dist
            if (closest_to_line_thresholded_bool).any():
                return closest_to_line_idx[np.argmin(
                    screen_distance[closest_to_line_thresholded_bool])]
            else:
                return closest_to_line_idx[0]
        else:  # simpler and apparently more effective algorithm:
            return np.argmin(line_distance + screen_distance)

    def set_state(self):  # , line_width):
        """Tell hardware what to do with the scene.
        """
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_BLEND)
        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
        glEnable(GL_LINE_SMOOTH)
        glHint(GL_LINE_SMOOTH_HINT, GL_NICEST)
        # glLineWidth(line_width)

    def unset_state(self):
        """Close communication with hardware.

        Disable what was enabled during set_state().
        """
        glDisable(GL_DEPTH_TEST)
        glDisable(GL_BLEND)
        glDisable(GL_LINE_SMOOTH)
        # glLineWidth(1.)

    def delete(self):
        pass

    def maskout_tracks(self):
        """ retrieve ids of virtuals which go through the mask
        """
        mask = self.slicer.mask
        #tracks = self.tracks_shifted
        tracks = self.virtuals_shifted
        #tcs,self.tes = track_counts(tracks,mask.shape,(1,1,1),True)
        tcs, tes = track_counts(tracks, mask.shape, (1, 1, 1), True)
        # print 'tcs:',tcs
        # print 'tes:',len(self.tes.keys())
        #find volume indices of mask's voxels
        roiinds = np.where(mask == 1)
        #make it a nice 2d numpy array (Nx3)
        roiinds = np.array(roiinds).T
        #get tracks going through the roi
        # print "roiinds:", len(roiinds)
        # mask_tracks,mask_tracks_inds=bring_roi_tracks(tracks,roiinds,self.tes)
        mask_tracks_inds = []
        for voxel in roiinds:
            try:
                #mask_tracks_inds+=self.tes[tuple(voxel)]
                mask_tracks_inds += tes[tuple(voxel)]
            except KeyError:
                pass
        mask_tracks_inds = list(set(mask_tracks_inds))
        print("Masked tracks %d" % len(mask_tracks_inds))
        print("mask_tracks_inds: %s" % mask_tracks_inds)
        return mask_tracks_inds
Beispiel #27
0
from dipy.io.pickles import save_pickle
from dipy.data import get_data
from dipy.viz import fvtk
import nipype.interfaces.mrtrix as mrt
import os
import sys

streams, hdr = tv.read(sys.argv[1])
streamlines = [i[0] for i in streams]
"""
Perform QuickBundles clustering with a 10mm distance threshold after having
downsampled the streamlines to have only 12 points.
"""

print("Computing bundles")
qb = QuickBundles(streamlines, dist_thr=int(sys.argv[3]), pts=int(sys.argv[4]))
print("Completed")
"""
qb has attributes like `centroids` (cluster representatives), `total_clusters`
(total number of clusters) and methods like `partitions` (complete description
of all clusters) and `label2tracksids` (provides the indices of the streamlines
which belong in a specific cluster).
"""
centroids = qb.centroids
print(len(centroids))
streamlines = [i[0] for i in streams]
for i, centroid in enumerate(centroids):
    print(i)
    inds = qb.label2tracksids(i)
    list1 = []
    for number in range(1, len(inds)):
Beispiel #28
0
from fos.data import get_track_filename
from pyglet.window import key
from fos.core.utils import screen_to_model
import fos.core.collision as cll
from pyglet.gl import *
#dipy modules
from dipy.segment.quickbundles import QuickBundles


streams,hdr = nib.trackvis.read(get_track_filename())
#center the data
T=[s[0] for s in streams]
mean_T=np.mean(np.concatenate(T),axis=0)
T=[t-mean_T for t in T]

qb=QuickBundles(T,10.,12)
Tqb=qb.virtuals()
Tqbe,Tqbei=qb.exemplars()

class TrackLabeler(Actor):   
    
    def __init__(self,qb,tracks,colors=None,line_width=2.,affine=None):
        
        self.virtuals=qb.virtuals()
        self.tracks=tracks           
        if affine is None:
            self.affine = np.eye(4, dtype = np.float32)
        else:
            self.affine = affine            
        #aabb - bounding box things
        ccurves=np.concatenate(tracks)
#    T = [downsample(t, 18) - np.array(data.shape[:3]) / 2. for t in T]
#    axis = np.array([1, 0, 0])
#    theta = - 90. 
#    T = np.dot(T,rotation_matrix(axis, theta))
#    axis = np.array([0, 1, 0])
#    theta = 180. 
#    T = np.dot(T, rotation_matrix(axis, theta))
#    
    #load initial QuickBundles with threshold 30mm
    #fpkl = dname+'data/subj_05/101_32/DTI/qb_gqi_1M_linear_30.pkl'
    #qb=QuickBundles(T, 10., 18)
    #save_pickle(fpkl,qb)
    #qb=load_pickle(fpkl)

    qb=QuickBundles(T, 20., 12)
    #save_pickle(fpkl,qb)
    #qb=load_pickle(fpkl)
    Init()

    title = '[F]oS Streamline Interaction and Segmentation'
    w = Window(caption = title, 
                width = 1200, 
                height = 800, 
                bgcolor = (.5, .5, 0.9), right_panel=True)

    scene = Scene(scenename = 'Main Scene', activate_aabb = False)
    
    #create the interaction system for tracks 
    tl = StreamlineLabeler('Bundle Picker', 
                        qb,qb.downsampled_tracks(), 
Beispiel #30
0
def main():
    parser = OptionParser(usage="Usage: %prog [options] <tract.vtp>")
    parser.add_option("-s", "--scalar", dest="scalar",
                      default="FA", help="Scalar to measure")
    parser.add_option("-n", "--num", dest="num", default=50,
                      type='int', help="Number of subdivisions along centroids")
    parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False,
                      help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines")
    parser.add_option("-d", "--dist", dest="dist", default=20,
                      type='float', help="Quickbundle distance threshold")
    parser.add_option("--curvepoints", dest="curvepoints_file",
                      help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.")
    parser.add_option('--yrange', dest='yrange')
    parser.add_option('--xrange', dest='xrange')
    parser.add_option('--reverse', dest='is_reverse', action='store_true',
                      default=False, help='Reverse the centroid measure stepping order')
    parser.add_option('--pairplot', dest='pairplot',)
    parser.add_option('--noviz', dest='is_viz',
                      action='store_false', default=True)
    parser.add_option('--hide-centroid', dest='show_centroid',
                      action='store_false', default=True)
    parser.add_option('--config', dest='config')
    parser.add_option('--background', dest='bg_file',
                      help='Background NIFTI image')
    parser.add_option('--annot', dest='annot')

    (options, args) = parser.parse_args()

    if len(args) == 0:
        parser.print_help()
        sys.exit(2)

    name_mapping = None
    if options.config:
        config = GtsConfig(options.config, configure=False)
        name_mapping = config.group_labels

    annotations = None
    if options.annot:
        with open(options.annot, 'r') as fp:
            annotations = yaml.load(fp)

    QB_DIST = options.dist
    QB_NPOINTS = options.num
    SCALAR_NAME = options.scalar
    LOCAL_POINT_ASSIGN = options.is_local

    filename = args[0]
    filebase = path.basename(filename).split('.')[0]

    reader = vtk.vtkXMLPolyDataReader()
    reader.SetFileName(filename)
    reader.Update()

    polydata = reader.GetOutput()

    tract_ids = []
    for i in range(polydata.GetNumberOfCells()):
        # get point ids in [[ids][ids...]...] format
        pids = polydata.GetCell(i).GetPointIds()
        ids = [pids.GetId(p) for p in range(pids.GetNumberOfIds())]
        tract_ids.append(ids)
    print 'tracks:', len(tract_ids)

    verts = vtk_to_numpy(polydata.GetPoints().GetData())
    print 'verts:', len(verts)

    scalars = []
    groups = []
    subjects = []
    pointdata = polydata.GetPointData()
    for si in range(pointdata.GetNumberOfArrays()):
        sname = pointdata.GetArrayName(si)
        print sname
        if sname == SCALAR_NAME:
            scalars = vtk_to_numpy(pointdata.GetArray(si))
        if sname == 'group':
            groups = vtk_to_numpy(pointdata.GetArray(si))
            groups = groups.astype(int)
        if sname == 'tid':
            subjects = vtk_to_numpy(pointdata.GetArray(si))
            subjects = subjects.astype(int)

    streamlines = []
    stream_scalars = []
    stream_groups = []
    stream_pids = []
    stream_sids = []

    for i in tract_ids:
        # index np.array by a list will get all the respective indices
        streamlines.append(verts[i])
        stream_scalars.append(scalars[i])
        stream_pids.append(i)
        stream_sids.append(subjects[i])
        try:
            stream_groups.append(groups[i])
        except Exception:
            # group might not exist
            pass

    streamlines = np.array(streamlines)
    stream_scalars = np.array(stream_scalars)
    stream_groups = np.array(stream_groups)
    stream_pids = np.array(stream_pids)
    stream_sids = np.array(stream_sids)

    # get total average direction (where majority point towards)
    avg_d = np.zeros(3)
    # for line in streams:
    #     d = np.array(line[-1]) - np.array(line[0])
    #     d = d / la.norm(d)
    #     avg_d += d
    #     avg_d /= la.norm(avg_d)

    avg_com = np.zeros(3)
    avg_mid = np.zeros(3)

    strl_len = [len(l) for l in streamlines]
    stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines])

    centroids = []
    if options.curvepoints_file:
        LOCAL_POINT_ASSIGN = False
        cpoints = []
        ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',')
        # have a separate vtkCardinalSpline interpreter for x,y,z
        curve = [vtk.vtkCardinalSpline() for i in range(3)]
        for c in curve:
            c.ClosedOff()

        for pi, point in enumerate(ctrlpoints):
            for i, val in enumerate(point):
                curve[i].AddPoint(pi, point[i])

        param_range = [0.0, 0.0]
        curve[0].GetParametricRange(param_range)

        t = param_range[0]
        step = (param_range[1] - param_range[0]) / (QB_NPOINTS - 1.0)

        while t < param_range[1]:
            cp = [c.Evaluate(t) for c in curve]
            cpoints.append(cp)
            t = t + step

        centroids.append(cpoints)
        centroids = np.array(centroids)

    else:
        """
            Use quickbundles to find centroids
        """
        # streamlines = newlines
        qb = QuickBundles(streamlines, dist_thr=QB_DIST, pts=QB_NPOINTS)
        # bundle_distance_mam

        centroids = qb.centroids
        clusters = qb.clusters()

        avg_d = np.zeros(3)
        avg_com = np.zeros(3)
        avg_mid = np.zeros(3)

        #unify centroid list orders to point in the same general direction
        for i, line in enumerate(centroids):
            ori = np.array(tm.mean_orientation(line))
            #d = np.array(line[-1]) - np.array(line[0])
            #print line[-1],line[0],d
            # get the unit vector of the mean orientation
            if i == 0:
                avg_d = ori

            #d = d / la.norm(d)
            dotprod = ori.dot(avg_d)
            print 'dotprod', dotprod
            if dotprod < 0:
                print 'reverse', dotprod
                centroids[i] = line[::-1]
                line = centroids[i]
                ori *= -1
            avg_d += ori

        if options.is_reverse:
            for i, c in enumerate(centroids):
                centroids[i] = c[::-1]

    # prepare mayavi 3d viz

    if options.is_viz:
        bg_val = 0.
        fig = mlab.figure(bgcolor=(bg_val, bg_val, bg_val))
        scene = mlab.gcf().scene
        fig.scene.render_window.aa_frames = 4
        mlab.draw()

        if options.bg_file:
            mrsrc, bgdata = getNiftiAsScalarField(options.bg_file)
            orie = 'z_axes'

            opacity = 0.5
            slice_index = 0

            mlab.pipeline.image_plane_widget(mrsrc, opacity=opacity, plane_orientation=orie, slice_index=int(
                slice_index), colormap='black-white', line_width=0, reset_zoom=False)

    # prepare the plt plot
    len_cent = len(centroids)
    pal = sns.color_palette("bright", len_cent)

    DATADF = None

    """
        CENTROIDS
    """
    for ci, cent in enumerate(centroids):
        print '---- centroid:'

        if LOCAL_POINT_ASSIGN:
            """
                apply centroid to only their point assignments
                through quickbundles
            """
            ind = clusters[ci]['indices']
            cent_streams = streamlines[ind]
            cent_scalars = stream_scalars[ind]
            cent_groups = stream_groups[ind]
            cent_pids = stream_pids[ind]
            cent_sids = stream_sids[ind]
        else:
            # apply each centriod to all the points
            # instead of only their centroid assignments
            cent_streams = streamlines
            cent_scalars = stream_scalars
            cent_groups = stream_groups
            cent_pids = stream_pids
            cent_sids = stream_sids

        cent_verts = np.vstack(cent_streams)
        cent_scalars = np.concatenate(cent_scalars)
        cent_groups = np.concatenate(cent_groups)
        cent_pids = np.concatenate(cent_pids)
        cent_sids = np.concatenate(cent_sids)
        cent_color = np.array(pal[ci])

        c, labels = kmeans2(cent_verts, cent, iter=1)

        cid = np.ones(len(labels))
        d = {'value': cent_scalars, 'position': labels,
             'group': cent_groups, 'pid': cent_pids, 'sid': cent_sids}

        df = pd.DataFrame(data=d)
        if DATADF is None:
            DATADF = df
        else:
            pd.concat([DATADF, df])

        UNIQ_GROUPS = df.group.unique()
        UNIQ_GROUPS.sort()

        # UNIQ_GROUPS = [0,1]

        grppal = sns.color_palette("Set2", len(UNIQ_GROUPS))

        print '# UNIQ GROUPS', UNIQ_GROUPS

        # print df
        # df = df[df['sid'] != 15]
        # df = df[df['sid'] != 16]
        # df = df[df['sid'] != 17]
        # df = df[df['sid'] != 18]
        """ 
            plot each group by their position 
        """

        fig = plt.figure(figsize=(14, 7))
        ax1 = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3)
        ax2 = plt.subplot2grid((4, 3), (3, 0), colspan=3, sharex=ax1)
        axes = [ax1, ax2]

        plt.xlabel('Position Index')

        if len(centroids) > 1:
            cent_patch = mpatches.Patch(
                color=cent_color, label='Centroid {}'.format(ci + 1))
            cent_legend = axes[0].legend(handles=[cent_patch], loc=9)
            axes[0].add_artist(cent_legend)

        """
            Perform stats
        """

        if len(UNIQ_GROUPS) > 1:
            # df = resample_data(df, num_sample_per_pos=120)
            # print df
            pvalsDf = position_stats(df, name_mapping=name_mapping)
            logpvals = np.log(pvalsDf) * -1
            # print logpvals

            pvals = logpvals.mask(pvalsDf >= 0.05)

            import matplotlib.ticker as mticker
            print pvals
            cmap = mcmap.Reds
            cmap.set_bad('w', 1.)
            axes[1].pcolormesh(pvals.values.T, cmap=cmap,
                               vmin=0, vmax=10, edgecolors='face', alpha=0.8)
            #axes[1].yaxis.set_major_locator(mticker.MultipleLocator(base=1.0))
            axes[1].set_yticks(
                np.arange(pvals.values.shape[1]) + 0.5, minor=False)
            axes[1].set_yticklabels(
                pvalsDf.columns.values.tolist(), minor=False)

        legend_handles = []
        for gi, GRP in enumerate(UNIQ_GROUPS):
            print '-------------------- GROUP ', gi, '----------------------'
            subgrp = df[df['group'] == GRP]
            print len(subgrp)

            if options.xrange:
                x0, x1 = options.xrange.split(',')
                x0 = int(x0)
                x1 = int(x1)
                subgrp = subgrp[(subgrp['position'] >= x0) &
                                (subgrp['position'] < x1)]

            posGrp = subgrp.groupby('position', sort=True)

            cent_stats = posGrp.apply(lambda x: stats_per_group(x))

            if len(cent_stats) == 0:
                continue

            cent_stats = cent_stats.unstack()
            cent_median_scalar = cent_stats['median'].tolist()

            x = np.array([i for i in posGrp.groups])
            # print x

            # print cent_stats['median'].tolist()
            mcolor = np.array(grppal[gi])
            # if gi>0:
            #     mcolor*= 1./(1+gi)

            cent_color = tuple(cent_color)
            mcolor = tuple(mcolor)

            if type(axes) is list:
                cur_axe = axes[0]
            else:
                cur_axe = axes

            cur_axe.set_ylabel(SCALAR_NAME)
            # cur_axe.yaxis.label.set_color(cent_color)
            # cur_axe.tick_params(axis='y', colors=cent_color)

            #cur_axe.fill_between(x, [s[0] for s in cent_ci], [t[1] for t in cent_ci], alpha=0.3, color=mcolor)

            # cur_axe.fill_between(x, [s[0] for s in cent_stats['whisk'].tolist()],
            #     [t[1] for t in cent_stats['whisk'].tolist()], alpha=0.1, color=mcolor)

            qtile_top = np.array([s[0] for s in cent_stats['ci'].tolist()])
            qtile_bottom = np.array([t[1] for t in cent_stats['ci'].tolist()])

            x_new, qtop_sm = smooth(x, qtile_top)
            x_new, qbottom_sm = smooth(x, qtile_bottom)
            cur_axe.fill_between(x_new, qtop_sm, qbottom_sm,
                                 alpha=0.25, color=mcolor)

            # cur_axe.errorbar(x, cent_stats['median'].tolist(), yerr=[[s[0] for s in cent_stats['err'].tolist()],
            #     [t[1] for t in cent_stats['err'].tolist()]], color=mcolor, alpha=0.1)

            x_new, median_sm = smooth(x, cent_stats['median'])
            hnd, = cur_axe.plot(x_new, median_sm, c=mcolor)
            legend_handles.append(hnd)

            # cur_axe.scatter(x,cent_stats['median'].tolist(), c=mcolor)

            if options.yrange:
                plotrange = options.yrange.split(',')
                cur_axe.set_ylim([float(plotrange[0]), float(plotrange[1])])

        legend_labels = UNIQ_GROUPS
        if name_mapping is not None:
            legend_labels = [name_mapping[str(i)] for i in UNIQ_GROUPS]
        cur_axe.legend(legend_handles, legend_labels)

        if annotations:
            for key, val in annotations.iteritems():
                # print key
                cur_axe.axvspan(val[0], val[1], fill=False, linestyle='dashed')
                axis_to_data = cur_axe.transAxes + cur_axe.transData.inverted()
                data_to_axis = axis_to_data.inverted()
                axpoint = data_to_axis.transform((val[0], 0))
                # print axpoint
                cur_axe.text(axpoint[0], 1.02, key,
                             transform=cur_axe.transAxes)

        """
            Plot 3D Viz 
        """

        if options.is_viz:
            scene.disable_render = True
            # scene.renderer.render_window.set(alpha_bit_planes=1,multi_samples=0)
            # scene.renderer.set(use_depth_peeling=True,maximum_number_of_peels=4,occlusion_ratio=0.1)
            # ran_colors = np.random.random_integers(255, size=(len(cent),4))
            # ran_colors[:,-1] = 255
            mypts = mlab.points3d(cent_verts[:, 0], cent_verts[:, 1], cent_verts[:, 2], labels,
                                  opacity=0.3,
                                  scale_mode='none',
                                  scale_factor=2,
                                  line_width=2,
                                  colormap='blue-red',
                                  mode='point')

            # print mypts.module_manager.scalar_lut_manager.lut.table.to_array()
            # mypts.module_manager.scalar_lut_manager.lut.table = ran_colors
            # mypts.module_manager.scalar_lut_manager.lut.number_of_colors = len(ran_colors)

            delta = len(cent) - len(cent_median_scalar)
            if delta > 0:
                cent_median_scalar = np.pad(
                    cent_median_scalar, (0, delta), mode='constant', constant_values=0)

            # calculate the displacement vector for all pairs
            uvw = cent - np.roll(cent, 1, axis=0)
            uvw[0] *= 0
            uvw = np.roll(uvw, -1, axis=0)
            arrow_plot = mlab.quiver3d(
                cent[:, 0], cent[:, 1], cent[:, 2],
                uvw[:, 0], uvw[:, 1], uvw[:, 2],
                scalars=cent_median_scalar,
                scale_factor=1,
                #color=mcolor,
                mode='arrow')

            gsource = arrow_plot.glyph.glyph_source.glyph_source

            # for name, thing in inspect.getmembers(gsource):
            #      print name

            arrow_plot.glyph.color_mode = 'color_by_scalar'
            #arrow_plot.glyph.scale_mode = 'scale_by_scalar'
            #arrow_plot.glyph.glyph.clamping = True
            #arrow_plot.glyph.glyph.scale_factor = 5
            #print arrow_plot.glyph.glyph.glyph_source
            gsource.tip_length = 0.4
            gsource.shaft_radius = 0.2
            gsource.tip_radius = 0.3

            if options.show_centroid:
                tube_plot = mlab.plot3d(cent[:, 0], cent[:, 1], cent[
                                        :, 2], cent_median_scalar, color=cent_color, tube_radius=0.2, opacity=0.25)
                tube_filter = tube_plot.parent.parent.filter
                tube_filter.vary_radius = 'vary_radius_by_scalar'
                tube_filter.radius_factor = 10

            # plot first and last
            def plot_pos_index(p):
                pos = cent[p]
                mlab.text3d(pos[0], pos[1], pos[2], str(p), scale=0.8)

            for p in xrange(0, len(cent - 1), 10):
                plot_pos_index(p)
            plot_pos_index(len(cent) - 1)

            scene.disable_render = False

    DATADF.to_csv(
        '_'.join([filebase, SCALAR_NAME, 'rawdata.csv']), index=False)
    outfile = '_'.join([filebase, SCALAR_NAME])
    print 'save to {}'.format(outfile)
    plt.savefig('{}.pdf'.format(outfile), dpi=300)

    if options.is_viz:
        plt.show(block=False)
        mlab.show()
Downsample tracks to 12 points:
"""

tracks=[tm.downsample(t, 12) for t in T]

"""
Delete unnecessary data:
"""

del streams,hdr

"""
Perform QuickBundles clustering with a 10mm threshold:
"""

qb=QuickBundles(tracks, dist_thr=10., pts=None)

"""
Show the initial *Fornix* dataset:
"""

r=fvtk.ren()
fvtk.add(r,fvtk.line(T, fvtk.white, opacity=1, linewidth=3))
#fvtk.show(r)
fvtk.record(r,n_frames=1,out_path='fornix_initial',size=(600,600))
fvtk.clear(r)
"""
.. figure:: fornix_initial1000000.png
   :align: center

   **Initial Fornix dataset**.
Beispiel #32
0
streamlines = []
for i in range(polydata.GetNumberOfCells()):
    pts = polydata.GetCell(i).GetPoints()
    npts = np.array([pts.GetPoint(i) for i in range(pts.GetNumberOfPoints())])
    streamlines.append(npts)

import scipy
scipy.io.savemat("1.mat",{'streamlines':streamlines})
#need to transpose each stream array for AFQ in malab
# run with cmd
# @MATLAB: ts = cellfun(@transpose,streamlines,'UniformOutput',false)

#print streamlines

qb = QuickBundles(streamlines, dist_thr=10.,pts=20)

centroids = qb.centroids
clusters = qb.clusters()
colormap = np.random.rand(len(centroids),3)


#print npts

#ren = vtk.vtkRenderer()
#renwin = vtk.vtkRenderWindow()
#renwin.AddRenderer(ren)

#c1 = clusters[0]
#print c1['hidden']
ren = fvtk.ren()
Beispiel #33
0
#            'response_dhollander/101107/Diffusion/1M_20_01_20dynamic250_SD_Stream_occipital8_lr5.tck'
img_path = '/home/brain/workingdir/data/dwi/hcp/preprocessed/' \
           'response_dhollander/101107/Structure/T1w_acpc_dc_restore_brain1.25.nii.gz'

img = nib.load(img_path)
fa = Fasciculus(fib)
streamlines = fa.get_data()
length_t = fa.get_lengths()
ind = length_t > 10
streamlines = streamlines[ind]
fa.set_data(streamlines)
fibcluster = FibClustering(fa)
print len(streamlines)

# 1
qb = QuickBundles(streamlines, 2)
clusters = qb.clusters()
print qb.clusters_sizes()
indexs = []
for i in range(len(clusters)):
    if clusters[i]['N'] >= 400:
        indexs += clusters[i]['indices']

# 2
streamlines = streamlines[indexs]
qb = QuickBundles(streamlines, 2)
clusters = qb.clusters()

centroids = qb.centroids
centroids_lengths = np.array(list(length(centroids)))
print centroids_lengths
Beispiel #34
0
"""

fname = get_data('fornix')
"""
Load fornix streamlines.
"""

streams, hdr = tv.read(fname)

streamlines = [i[0] for i in streams]
"""
Perform QuickBundles clustering with a 10mm distance threshold after having
downsampled the streamlines to have only 12 points.
"""

qb = QuickBundles(streamlines, dist_thr=10., pts=18)
"""
qb has attributes like `centroids` (cluster representatives), `total_clusters`
(total number of clusters) and methods like `partitions` (complete description
of all clusters) and `label2tracksids` (provides the indices of the streamlines
which belong in a specific cluster).

Lets first show the initial dataset.
"""

ren = fvtk.ren()
ren.SetBackground(1, 1, 1)
fvtk.add(ren, fvtk.streamtube(streamlines, fvtk.colors.white))
fvtk.record(ren, n_frames=1, out_path='fornix_initial.png', size=(600, 600))
"""
.. figure:: fornix_initial.png
print('ind.shape (%d, %d, %d)' % ind.shape)

# Compute Eular Delta Crossing with FA
eu = EuDX(a=fa, ind=ind, seeds=100000, odf_vertices=sphere.vertices,
          a_low=0.2)  # FA uses a_low = 0.2
streamlines = [line for line in eu]
print('Number of streamlines %i' % len(streamlines))
'''
for line in streamlines:
  print(line.shape)
'''
# Do steamline clustering using QuickBundles (QB) using Eular's Method
# dist_thr (distance threshold) which affects number of clusters and their size
# pts (number of points in each streamline) which will be used for downsampling before clustering
# Default values : dist_thr = 4 & pts = 12
qb = QuickBundles(streamlines, dist_thr=20, pts=20)
clusters = qb.clusters()
print('Number of clusters %i' % qb.total_clusters)
print('Cluster size', qb.clusters_sizes())

# Display streamlines
ren = window.Renderer()
ren.add(actor.streamtube(streamlines, window.colors.white))
window.show(ren)
window.record(ren, out_path=filename + '_stream_lines_eu.png', size=(600, 600))

# Display centroids
window.clear(ren)
colormap = actor.create_colormap(np.arange(qb.total_clusters))
ren.add(actor.streamtube(streamlines, window.colors.white, opacity=0.1))
ren.add(actor.streamtube(qb.centroids, colormap, linewidth=0.5))
Beispiel #36
0
    # T=[track for track in euler if track_range(track,100/2.5,200/2.5)]
    # T=tracks_double_mask(seeds,seeds2,gqs,mask.shape,mask,mask2,)
    T = transform_tracks(T, mat)
    print len(T)
    T = [track for track in T if is_close(track, lT, 5)]

    shift = (np.array(ref_shape) - 1) / 2.0
    T = [t - shift for t in T]
    print len(T)

    # save tracks
    dpr_linear = Dpy(ftracks, "w")
    dpr_linear.write_tracks(T)
    dpr_linear.close()
    # cluster tracks
    qb = QuickBundles(T, 25.0, 25)
    # qb.remove_small_clusters(40)
    del T
    # load
    tl = TrackLabeler(qb, qb.downsampled_tracks(), vol_shape=ref_shape, tracks_line_width=3.0, tracks_alpha=1)
    fT1 = "data/subj_" + subject + "/MPRAGE_32/T1_flirt_out.nii.gz"
    # fT1_ref = '/usr/share/fsl/data/standard/MNI152_T1_1mm_brain.nii.gz'
    img = nib.load(fT1)
    # img = nib.load(fT1)
    sl = Slicer(img.get_affine(), img.get_data())
    tl.slicer = sl

    luigi = Line([t - shift for t in lT], line_width=2)

    # put the seeds together
    seeds = np.vstack((seeds, seeds2))