def subsample_streamlines(streamlines, clustering_threshold=6., removal_distance=2.): """ Subsample a group of streamlines (should be used on streamlines from a single bundle or similar structure). Streamlines are first clustered using `clustering_threshold`, then for each cluster, similar streamlines (closer than `removal_distance`) are removed. Parameters ---------- streamlines : `ArraySequence` object Streamlines to subsample clustering_threshold : float distance threshold for clustering (in the space of the tracks) removal_distance : float distance threshold for removal (in the space of the tracks) Returns ------- `ArraySequence` object Downsampled streamlines """ output_streamlines = [] qb = QuickBundles(streamlines, dist_thr=clustering_threshold, pts=20) for i in range(len(qb.centroids)): temp_streamlines = qb.label2tracks(streamlines, i) output_streamlines.extend( remove_similar_streamlines(temp_streamlines, removal_distance=removal_distance)) return output_streamlines
def test_qbundles(): streams, hdr = nib.trackvis.read(get_data('fornix')) T = [s[0] for s in streams] qb = QuickBundles(T, 10., 12) qb.virtuals() qb.exemplars() assert_equal(4, qb.total_clusters)
def bundle_centroids(self, streamlines=None, cluster_thre=10, dist_thre=10.0, pts=12): """ QuickBundles-based segmentation Parameters ---------- streamlines: streamline data cluster_thre: remove small cluster dist_thre: clustering threshold (distance mm) pts: each streamlines are divided into sections Return ------ centroids: cluster's centroids """ if streamlines is None: streamlines = self._fasciculus.get_data() else: streamlines = streamlines bundles = QuickBundles(streamlines, dist_thre, pts) bundles.remove_small_clusters(cluster_thre) centroids = bundles.centroids return nibas.ArraySequence(centroids)
def bundle_seg(self, streamlines=None, dist_thre=10.0, pts=12): """ QuickBundles-based segmentation Parameters ---------- streamlines: streamline data dist_thre: clustering threshold (distance mm) pts: each streamlines are divided into sections Return ------ labels: label of each streamline data_cluster: cluster data N_list: size of each cluster """ if streamlines is None: streamlines = self._fasciculus.get_data() else: streamlines = streamlines bundles = QuickBundles(streamlines, dist_thre, pts) clusters = bundles.clusters() labels = np.array(len(streamlines) * [None]) N_list = [] for i in range(len(clusters)): N_list.append(clusters[i]['N']) # show(N_list, title='N histogram', xlabel='N') data_clusters = [] for i in range(len(clusters)): labels[clusters[i]['indices']] = i + 1 data_clusters.append(streamlines[clusters[i]['indices']]) return labels, data_clusters, N_list
def bundle_thre_seg(self, streamlines=None, cluster_thre=10, dist_thre=10.0, pts=12): """ QuickBundles-based segmentation Parameters ---------- streamlines: streamline data cluster_thre: remove small cluster dist_thre: clustering threshold (distance mm) pts: each streamlines are divided into sections Return ------ sort_index: sort of clusters according to y mean of cluster's centroids data_cluster: cluster data corresponding to sort_index """ if streamlines is None: streamlines = self._fasciculus.get_data() else: streamlines = streamlines bundles = QuickBundles(streamlines, dist_thre, pts) bundles.remove_small_clusters(cluster_thre) clusters = bundles.clusters() data_clusters = [] for key in clusters.keys(): data_clusters.append(streamlines[clusters[key]['indices']]) centroids = bundles.centroids clusters_y_mean = [clu[:, 1].mean() for clu in centroids] sort_index = np.argsort(clusters_y_mean) return sort_index, data_clusters
def freeze(self): print( "Freezing current expanded real tracks, then doing QB on them, then restarting." ) print("Selected virtuals: %s" % self.selected) tracks_frozen = [] tracks_frozen_ids = [] for tid in self.selected: print tid part_tracks = self.qb.label2tracks(self.tracks, tid) part_tracks_ids = self.qb.label2tracksids(tid) print("virtual %s represents %s tracks." % (tid, len(part_tracks))) tracks_frozen += part_tracks tracks_frozen_ids += part_tracks_ids print "frozen tracks size:", len(tracks_frozen) print "Computing quick bundles...", self.unselect_track('all') self.tracks = tracks_frozen self.tracks_ids = self.tracks_ids[ tracks_frozen_ids] # range(len(self.tracks)) root = Tkinter.Tk() root.wm_title('QuickBundles threshold') ts = ThresholdSelector(root, default_value=self.qb.dist_thr / 2.0) root.wait_window() #print "Threshold value ",ts.value #self.qb = QuickBundles(self.tracks, dist_thr=qb.dist_thr/2., pts=self.qb.pts) self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts) #self.qb.dist_thr = qb.dist_thr/2. self.qb.dist_thr = ts.value if self.reps == 'virtuals': self.virtuals = qb.virtuals() if self.reps == 'exemplars': self.virtuals, self.ex_ids = self.qb.exemplars() print len(self.virtuals), 'virtuals' self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers( self.virtuals, self.virtuals_alpha) #compute buffers self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers( self.tracks, self.tracks_alpha) # self.unselect_track('all') self.selected = [] self.old_color = {} self.expand = False self.history.append([ self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count ]) if self.vol_shape is not None: print("Shifting!") self.virtuals_shifted = [ downsample(t + np.array(self.vol_shape) / 2., 30) for t in self.virtuals ] else: self.virtuals_shifted = None
def test_qbundles(): streams, hdr = nib.trackvis.read(get_data('fornix')) T = [s[0] for s in streams] Trk = np.array(T, dtype=np.object) qb = QuickBundles(T, 10., 12) Tqb = qb.virtuals() # Tqbe,Tqbei=qb.exemplars(T) Tqbe, Tqbei = qb.exemplars() assert_equal(4, qb.total_clusters)
def bench_quickbundles(): dtype = "float32" repeat = 10 nb_points = 18 streams, hdr = nib.trackvis.read(get_data('fornix')) fornix = [s[0].astype(dtype) for s in streams] fornix = streamline_utils.set_number_of_points(fornix, nb_points) # Create eight copies of the fornix to be clustered (one in each octant). streamlines = [] streamlines += [s + np.array([100, 100, 100], dtype) for s in fornix] streamlines += [s + np.array([100, -100, 100], dtype) for s in fornix] streamlines += [s + np.array([100, 100, -100], dtype) for s in fornix] streamlines += [s + np.array([100, -100, -100], dtype) for s in fornix] streamlines += [s + np.array([-100, 100, 100], dtype) for s in fornix] streamlines += [s + np.array([-100, -100, 100], dtype) for s in fornix] streamlines += [s + np.array([-100, 100, -100], dtype) for s in fornix] streamlines += [s + np.array([-100, -100, -100], dtype) for s in fornix] # The expected number of clusters of the fornix using threshold=10 is 4. threshold = 10. expected_nb_clusters = 4 * 8 print("Timing QuickBundles 1.0 vs. 2.0") qb = QB_Old(streamlines, threshold, pts=None) qb1_time = measure("QB_Old(streamlines, threshold, nb_points)", repeat) print("QuickBundles time: {0:.4}sec".format(qb1_time)) assert_equal(qb.total_clusters, expected_nb_clusters) sizes1 = [qb.partitions()[i]['N'] for i in range(qb.total_clusters)] indices1 = [ qb.partitions()[i]['indices'] for i in range(qb.total_clusters) ] qb2 = QB_New(threshold) qb2_time = measure("clusters = qb2.cluster(streamlines)", repeat) print("QuickBundles2 time: {0:.4}sec".format(qb2_time)) print("Speed up of {0}x".format(qb1_time / qb2_time)) clusters = qb2.cluster(streamlines) sizes2 = map(len, clusters) indices2 = map(lambda c: c.indices, clusters) assert_equal(len(clusters), expected_nb_clusters) assert_array_equal(sizes2, sizes1) assert_arrays_equal(indices2, indices1) qb = QB_New(threshold, metric=MDFpy()) qb3_time = measure("clusters = qb.cluster(streamlines)", repeat) print("QuickBundles2_python time: {0:.4}sec".format(qb3_time)) print("Speed up of {0}x".format(qb1_time / qb3_time)) clusters = qb.cluster(streamlines) sizes3 = map(len, clusters) indices3 = map(lambda c: c.indices, clusters) assert_equal(len(clusters), expected_nb_clusters) assert_array_equal(sizes3, sizes1) assert_arrays_equal(indices3, indices1)
def bench_quickbundles(): dtype = "float32" repeat = 10 nb_points = 12 streams, hdr = nib.trackvis.read(get_fnames('fornix')) fornix = [s[0].astype(dtype) for s in streams] fornix = streamline_utils.set_number_of_points(fornix, nb_points) # Create eight copies of the fornix to be clustered (one in each octant). streamlines = [] streamlines += [s + np.array([100, 100, 100], dtype) for s in fornix] streamlines += [s + np.array([100, -100, 100], dtype) for s in fornix] streamlines += [s + np.array([100, 100, -100], dtype) for s in fornix] streamlines += [s + np.array([100, -100, -100], dtype) for s in fornix] streamlines += [s + np.array([-100, 100, 100], dtype) for s in fornix] streamlines += [s + np.array([-100, -100, 100], dtype) for s in fornix] streamlines += [s + np.array([-100, 100, -100], dtype) for s in fornix] streamlines += [s + np.array([-100, -100, -100], dtype) for s in fornix] # The expected number of clusters of the fornix using threshold=10 is 4. threshold = 10. expected_nb_clusters = 4 * 8 print("Timing QuickBundles 1.0 vs. 2.0") qb = QB_Old(streamlines, threshold, pts=None) qb1_time = measure("QB_Old(streamlines, threshold, nb_points)", repeat) print("QuickBundles time: {0:.4}sec".format(qb1_time)) assert_equal(qb.total_clusters, expected_nb_clusters) sizes1 = [qb.partitions()[i]['N'] for i in range(qb.total_clusters)] indices1 = [qb.partitions()[i]['indices'] for i in range(qb.total_clusters)] qb2 = QB_New(threshold) qb2_time = measure("clusters = qb2.cluster(streamlines)", repeat) print("QuickBundles2 time: {0:.4}sec".format(qb2_time)) print("Speed up of {0}x".format(qb1_time / qb2_time)) clusters = qb2.cluster(streamlines) sizes2 = map(len, clusters) indices2 = map(lambda c: c.indices, clusters) assert_equal(len(clusters), expected_nb_clusters) assert_array_equal(list(sizes2), sizes1) assert_arrays_equal(indices2, indices1) qb = QB_New(threshold, metric=MDFpy()) qb3_time = measure("clusters = qb.cluster(streamlines)", repeat) print("QuickBundles2_python time: {0:.4}sec".format(qb3_time)) print("Speed up of {0}x".format(qb1_time / qb3_time)) clusters = qb.cluster(streamlines) sizes3 = map(len, clusters) indices3 = map(lambda c: c.indices, clusters) assert_equal(len(clusters), expected_nb_clusters) assert_array_equal(list(sizes3), sizes1) assert_arrays_equal(indices3, indices1)
def load_tracks(method="pmt"): from nibabel import trackvis as tv dname = "/home/eg309/Data/orbital_phantoms/dwi_dir/subject1/" if method == "pmt": fname = "/home/eg309/Data/orbital_phantoms/dwi_dir/workflow/tractography/_subject_id_subject1/cam2trk_pico_twoten/data_fit_pdfs_tracked.trk" streams, hdr = tv.read(fname, points_space="voxel") tracks = [s[0] for s in streams] if method == "dti": fname = dname + "dti_tracks.dpy" if method == "dsi": fname = dname + "dsi_tracks.dpy" if method == "gqs": fname = dname + "gqi_tracks.dpy" if method == "eit": fname = dname + "eit_tracks.dpy" if method in ["dti", "dsi", "gqs", "eit"]: dpr_linear = Dpy(fname, "r") tracks = dpr_linear.read_tracks() dpr_linear.close() if method != "pmt": tracks = [t - np.array([96 / 2.0, 96 / 2.0, 55 / 2.0]) for t in tracks if track_range(t, 100 / 2.5, 150 / 2.5)] tracks = [t for t in tracks if track_range(t, 100 / 2.5, 150 / 2.5)] print "final no of tracks ", len(tracks) qb = QuickBundles(tracks, 25.0 / 2.5, 18) # from dipy.viz import fvtk # r=fvtk.ren() # fvtk.add(r,fvtk.line(qb.virtuals(),fvtk.red)) # fvtk.show(r) # show_tracks(tracks)#qb.exemplars()[0]) # qb.remove_small_clusters(40) del tracks # load tl = TrackLabeler(qb, qb.downsampled_tracks(), vol_shape=None, tracks_line_width=3.0, tracks_alpha=1) # return tracks w = World() w.add(tl) # create window wi = Window(caption="Fos", bgcolor=(1.0, 1.0, 1.0, 1.0), width=1600, height=900) wi.attach(w) # create window manager wm = WindowManager() wm.add(wi) wm.run()
def half_split_comparisons(): res = {} for id in range(len(tractography_sizes)): res[id] = {} first, second = split_halves(id) res[id]["lengths"] = [len(first), len(second)] print len(first), len(second) first_qb = QuickBundles(first, qb_threshold, downsampling) n_clus = first_qb.total_clusters res[id]["nclusters"] = n_clus print "QB for first half has", n_clus, "clusters" second_down = [downsample(s, downsampling) for s in second] matched_random = get_random_streamlines(first_qb.downsampled_tracks(), n_clus) neighbours_first = count_close_tracks(first_qb.virtuals(), first_qb.downsampled_tracks(), adjacency_threshold) neighbours_second = count_close_tracks(first_qb.virtuals(), second_down, adjacency_threshold) neighbours_random = count_close_tracks(matched_random, second_down, adjacency_threshold) maxclose = np.int(np.max(np.hstack((neighbours_first, neighbours_second, neighbours_random)))) # The numbers of tracks 0, 1, 2, ... 'close' subset tracks counts = np.array( [ ( np.int(n), len(find(neighbours_first == n)), len(find(neighbours_second == n)), len(find(neighbours_random == n)), ) for n in range(maxclose + 1) ], dtype="f", ) totals = np.sum(counts[:, 1:], axis=0) res[id]["totals"] = totals res[id]["counts"] = counts # print totals # print counts missed_fractions = counts[0, 1:] / totals res[id]["missed_fractions"] = missed_fractions means = np.sum(counts[:, 1:] * counts[:, [0, 0, 0]], axis=0) / totals # print means res[id]["means"] = means # print res return res
def freeze(self): print("Freezing current expanded real tracks, then doing QB on them, then restarting.") print("Selected virtuals: %s" % self.selected) tracks_frozen = [] tracks_frozen_ids = [] for tid in self.selected: print tid part_tracks = self.qb.label2tracks(self.tracks, tid) part_tracks_ids = self.qb.label2tracksids(tid) print("virtual %s represents %s tracks." % (tid, len(part_tracks))) tracks_frozen += part_tracks tracks_frozen_ids += part_tracks_ids print "frozen tracks size:", len(tracks_frozen) print "Computing quick bundles...", self.unselect_track('all') self.tracks = tracks_frozen self.tracks_ids = self.tracks_ids[tracks_frozen_ids] # range(len(self.tracks)) root = Tkinter.Tk() root.wm_title('QuickBundles threshold') ts = ThresholdSelector(root, default_value=self.qb.dist_thr/2.0) root.wait_window() #print "Threshold value ",ts.value #self.qb = QuickBundles(self.tracks, dist_thr=qb.dist_thr/2., pts=self.qb.pts) self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts) #self.qb.dist_thr = qb.dist_thr/2. self.qb.dist_thr = ts.value if self.reps=='virtuals': self.virtuals=qb.virtuals() if self.reps=='exemplars': self.virtuals,self.ex_ids = self.qb.exemplars() print len(self.virtuals), 'virtuals' self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(self.virtuals, self.virtuals_alpha) #compute buffers self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(self.tracks, self.tracks_alpha) # self.unselect_track('all') self.selected = [] self.old_color = {} self.expand = False self.history.append([self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count]) if self.vol_shape is not None: print("Shifting!") self.virtuals_shifted = [downsample(t + np.array(self.vol_shape) / 2., 30) for t in self.virtuals] else: self.virtuals_shifted = None
def load_PX_tracks(): roi = "LH_premotor" dn = "/home/hadron/from_John_mon12thmarch" dname = "/extra_probtrackX_analyses/_subject_id_subj05_101_32/particle2trackvis_" + roi + "_native/" fname = dn + dname + "tract_samples.trk" from nibabel import trackvis as tv points_space = [None, "voxel", "rasmm"] streamlines, hdr = tv.read(fname, as_generator=True, points_space="voxel") tracks = [s[0] for s in streamlines] del streamlines # return tracks qb = QuickBundles(tracks, 25.0 / 2.5, 18) # tl=Line(qb.exemplars()[0],line_width=1) del tracks qb.remove_small_clusters(20) tl = TrackLabeler(qb, qb.downsampled_tracks(), vol_shape=None, tracks_line_width=3.0, tracks_alpha=1) # put the seeds together # seeds=np.vstack((seeds,seeds2)) # shif the seeds # seeds=np.dot(mat[:3,:3],seeds.T).T + mat[:3,3] # seeds=seeds-shift # seeds2=np.dot(mat[:3,:3],seeds2.T).T + mat[:3,3] # seeds2=seeds2-shift # msk = Point(seeds,colors=(1,0,0,1.),pointsize=2.) # msk2 = Point(seeds2,colors=(1,0,.ppppp2,1.),pointsize=2.) w = World() w.add(tl) # w.add(msk) # w.add(msk2) # w.add(sl) # create window wi = Window(caption="Fos", bgcolor=(0.3, 0.3, 0.6, 1.0), width=1600, height=900) wi.attach(w) # create window manager wm = WindowManager() wm.add(wi) wm.run()
def visualization(streamlines_file): # clustering of fibers into bundles and visualization thereof streamlines = np.load(streamlines_file)['arr_0'] qb = QuickBundles(streamlines, dist_thr=10., pts=18) centroids = qb.centroids #centroids = streamlines colors = line_colors(centroids).astype(np.float) mlab.figure(bgcolor=(0., 0., 0.)) for streamline, color in zip(centroids, colors): mlab.plot3d(streamline.T[0], streamline.T[1], streamline.T[2], line_width=1., tube_radius=.5, color=tuple(color))
def QB_reps(limits=[0, np.Inf], reps=1): ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"] sizes = [] for id in range(len(test_tractography_sizes)): filename = "subj_" + ids[id] + "_QB_reps.pkl" f = open(filename, "w") ur_tracks = get_tracks(id, limits) res = {} # res['filtered'] = len(ur_tracks) res["qb_threshold"] = qb_threshold res["limits"] = limits # res['ur_tracks'] = ur_tracks print "Dataset", id, res["filtered"], "filtered tracks" res["shuffle"] = {} res["clusters"] = {} res["nclusters"] = {} res["centroids"] = {} res["cluster_sizes"] = {} for i in range(reps): print "Subject", ids[id], "shuffle", i shuffle = np.random.permutation(np.arange(len(ur_tracks))) res["shuffle"][i] = shuffle tracks = [ur_tracks[i] for i in shuffle] qb = QuickBundles(tracks, qb_threshold, downsampling) res["clusters"][i] = {} for k in qb.clusters().keys(): # this would be improved if # we 'enumerated' QB's keys and used the enumerator as # as the key in the result res["clusters"][i][k] = qb.clusters()[k]["indices"] res["centroids"][i] = qb.centroids res["nclusters"][i] = qb.total_clusters res["cluster_sizes"][i] = qb.clusters_sizes() print "QB for has", qb.total_clusters, "clusters" sizes.append(qb.total_clusters) pickle.dump(res, f) f.close() print "Results written to", filename
def QB_reps_singly(limits=[0, np.Inf], reps=1): ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"] replabs = [str(i) for i in range(reps)] for id in range(len(test_tractography_sizes)): ur_tracks = get_tracks(id, limits) for i in range(reps): res = {} # res['filtered'] = len(ur_tracks) res["qb_threshold"] = qb_threshold res["limits"] = limits res["shuffle"] = {} res["clusters"] = {} res["nclusters"] = {} res["centroids"] = {} res["cluster_sizes"] = {} print "Subject", ids[id], "shuffle", i shuffle = np.random.permutation(np.arange(len(ur_tracks))) res["shuffle"] = shuffle tracks = [ur_tracks[j] for j in shuffle] print "... starting QB" qb = QuickBundles(tracks, qb_threshold, downsampling) print "... finished QB" res["clusters"] = {} for k in qb.clusters().keys(): # this would be improved if # we 'enumerated' QB's keys and used the enumerator as # as the key in the result res["clusters"][k] = qb.clusters()[k]["indices"] res["centroids"] = qb.centroids res["nclusters"] = qb.total_clusters res["cluster_sizes"] = qb.clusters_sizes() print "QB for has", qb.total_clusters, "clusters" filename = "subj_" + ids[id] + "_QB_rep_" + replabs[i] + ".pkl" f = open(filename, "w") pickle.dump(res, f) f.close() print "Results written to", filename return sizes
def test_qbundles(): streams, hdr = nib.trackvis.read(get_fnames('fornix')) T = [s[0] for s in streams] qb = QuickBundles(T, 10., 12) qb.virtuals() qb.exemplars() assert_equal(4, qb.total_clusters)
def qb_wrapper(data, qb_threshold, streamlines_ids=None, qb_n_points=None): """A wrapper for qb with the correct API for the Labeler. Note: qb_n_points = None means 'do not dowsample again data'. """ print "qb_wrapper: starting" if streamlines_ids is None: streamlines_ids = np.arange(len(data), dtype=np.int) else: streamlines_ids = np.sort(list(streamlines_ids)).astype(np.int) print "streamlines_ids:", len(streamlines_ids) print "data:", data.shape print "Calling QuickBundles." qb = QuickBundles(data, qb_threshold, qb_n_points) tmpa, qb_internal_id = qb.exemplars() # this function call is necessary to let qb compute qb.exempsi clusters = {} print "Creating new clusters dictionary" for i, clusterid in enumerate(qb.clustering.keys()): indices = streamlines_ids[qb.clustering[clusterid]['indices']] tmp = indices[qb_internal_id[i]] clusters[tmp] = set(list(indices)) print tmp, '->', len(clusters[tmp]) return clusters
def bundle_tracks(in_file, dist_thr=40., pts=16, skip=80.): import subprocess import os.path as op from nibabel import trackvis as tv from dipy.segment.quickbundles import QuickBundles streams, hdr = tv.read(in_file) streamlines = [i[0] for i in streams] qb = QuickBundles(streamlines, float(dist_thr), int(pts)) clusters = qb.clustering #scalars = [i[0] for i in streams] out_files = [] name = "quickbundle_" n_clusters = clusters.keys() print("%d clusters found" % len(n_clusters)) new_hdr = tv.empty_header() new_hdr['n_scalars'] = 1 for cluster in clusters: cluster_trk = op.abspath(name + str(cluster) + ".trk") print("Writing cluster %d to %s" % (cluster, cluster_trk)) out_files.append(cluster_trk) clust_idxs = clusters[cluster]['indices'] new_streams = [streamlines[i] for i in clust_idxs] for_save = [(sl, None, None) for sl in new_streams] tv.write(cluster_trk, for_save, hdr) out_merged_file = "MergedBundles.trk" command_list = ["track_merge"] command_list.extend(out_files) command_list.append(out_merged_file) subprocess.call(command_list) out_scene_file = write_trackvis_scene(out_merged_file, n_clusters=len(clusters), skip=skip, names=None, out_file="NewScene.scene") print("Merged track file written to %s" % out_merged_file) print("Scene file written to %s" % out_scene_file) return out_files, out_merged_file, out_scene_file
class StreamlineLabeler(Actor): def __init__(self, name,qb, tracks, reps='exemplars', colors=None, vol_shape=None, virtuals_line_width=5.0, tracks_line_width=2.0, virtuals_alpha=1.0, tracks_alpha=0.6, affine=None, verbose=False): """TrackLabeler is meant to explore and select subsets of the tracks. The exploration occurs through QuickBundles (qb) in order to simplify the scene. """ super(StreamlineLabeler, self).__init__(name) if affine is None: self.affine = np.eye(4, dtype = np.float32) else: self.affine = affine if vol_shape is not None: I, J, K = vol_shape centershift = img_to_ras_coords(np.array([[I/2., J/2., K/2.]]), affine) centeraffine = from_matvec(np.eye(3), centershift.squeeze()) affine[:3,3] = affine[:3, 3] - centeraffine[:3, 3] self.glaffine = (GLfloat * 16)(*tuple(affine.T.ravel())) self.glaff = affine self.mouse_x=None self.mouse_y=None self.cache = {} self.qb = qb self.reps = reps #virtual tracks if self.reps=='virtuals': self.virtuals=qb.virtuals() if self.reps=='exemplars': self.virtuals,self.ex_ids = qb.exemplars() self.virtuals_alpha = virtuals_alpha self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(self.virtuals, colors, self.virtuals_alpha) #full tractography (downsampled at 12 pts per track) self.tracks = tracks self.tracks_alpha = tracks_alpha self.tracks_ids = np.arange(len(self.tracks), dtype=np.int) self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(self.tracks, colors, self.tracks_alpha) #calculate boundary box for entire tractography self.min = np.min(self.tracks_buffer,axis=0) self.max = np.max(self.tracks_buffer,axis=0) self.vertices=self.tracks_buffer #coord1 = np.array([self.tracks_buffer[:,0].min(),self.tracks_buffer[:,1].min(),self.tracks_buffer[:,2].min()], dtype = 'f4') #coord2 = np.array([self.tracks_buffer[:,0].max(),self.tracks_buffer[:,1].max(),self.tracks_buffer[:,2].max()], dtype = 'f4') #self.make_aabb((coord1,coord2),0) #show size of tractography buffer print('MBytes %f' % (self.tracks_buffer.nbytes/2.**20,)) self.position = (0,0,0) #buffer for selected virtual tracks self.selected = [] self.virtuals_line_width = virtuals_line_width self.tracks_line_width = tracks_line_width self.old_color = {} self.hide_virtuals = False self.expand = False self.verbose = verbose self.tracks_visualized_first = np.array([], dtype='i4') self.tracks_visualized_count = np.array([], dtype='i4') self.history = [[self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count]] #shifting of track is necessary for dipy.tracking.vox2track.track_counts #we also upsample using 30 points in order to increase the accuracy of track counts self.vol_shape = vol_shape if self.vol_shape !=None: #self.tracks_shifted =[t+np.array(vol_shape)/2. for t in self.tracks] self.virtuals_shifted =[downsample(t+np.array(self.vol_shape)/2.,30) for t in self.virtuals] else: #self.tracks_shifted=None self.virtuals_shifted=None def compute_buffers(self, tracks, colors, alpha): """Compute buffers for GL compilation. """ tracks_buffer = np.ascontiguousarray(np.concatenate(tracks).astype('f4')) tracks_colors = np.ascontiguousarray(self.compute_colors(tracks, colors, alpha)) tracks_count = np.ascontiguousarray(np.array([len(v) for v in tracks],dtype='i4')) tracks_first = np.ascontiguousarray(np.r_[0,np.cumsum(tracks_count)[:-1]].astype('i4')) if isinstance(tracks_count,tuple): print '== count' if isinstance(tracks_first,tuple): print '== first' return tracks_buffer, tracks_colors, tracks_first, tracks_count def compute_colors(self, tracks, colors, alpha): """Compute colors for a list of tracks. """ assert(type(tracks)==type([])) tot_vertices = np.sum([len(curve) for curve in tracks]) color = np.empty((tot_vertices,4), dtype='f4') counter = 0 j = 0 for curve in tracks: if (colors==None): color[counter:counter+len(curve),:3] = track2rgb(curve).astype('f4') else: color[counter:counter+len(curve),:3] = colors[j].astype('f4') j = j + 1 counter += len(curve) color[:,3] = alpha return color def draw(self): """Draw virtual and real tracks. This is done at every frame and therefore must be real fast. """ glDisable(GL_LIGHTING) # virtuals glEnable(GL_DEPTH_TEST) glEnable(GL_BLEND) glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) glEnable(GL_LINE_SMOOTH) glHint(GL_LINE_SMOOTH_HINT, GL_NICEST) glEnableClientState(GL_VERTEX_ARRAY) glEnableClientState(GL_COLOR_ARRAY) if not self.hide_virtuals: glVertexPointer(3,GL_FLOAT,0,self.virtuals_buffer.ctypes.data) glColorPointer(4,GL_FLOAT,0,self.virtuals_colors.ctypes.data) glLineWidth(self.virtuals_line_width) glPushMatrix() glMultMatrixf(self.glaffine) if isinstance(self.virtuals_first, tuple): print '>> first Tuple' if isinstance(self.virtuals_count, tuple): print '>> count Tuple' glib.glMultiDrawArrays(GL_LINE_STRIP, self.virtuals_first.ctypes.data, self.virtuals_count.ctypes.data, len(self.virtuals)) glPopMatrix() # reals: if self.expand and self.tracks_visualized_first.size > 0: glVertexPointer(3,GL_FLOAT,0,self.tracks_buffer.ctypes.data) glColorPointer(4,GL_FLOAT,0,self.tracks_colors.ctypes.data) glLineWidth(self.tracks_line_width) glPushMatrix() glMultMatrixf(self.glaffine) glib.glMultiDrawArrays(GL_LINE_STRIP, self.tracks_visualized_first.ctypes.data, self.tracks_visualized_count.ctypes.data, len(self.tracks_visualized_count)) glPopMatrix() glDisableClientState(GL_COLOR_ARRAY) glDisableClientState(GL_VERTEX_ARRAY) glLineWidth(1.) glDisable(GL_DEPTH_TEST) glDisable(GL_BLEND) glDisable(GL_LINE_SMOOTH) def process_mouse_position(self,x,y): self.mouse_x=x self.mouse_y=y def process_pickray(self,near,far): pass def update(self,dt): pass def select_track(self, ids): """Do visual selection of given virtuals. """ # WARNING: we assume that no tracks can ever have color_selected as original color color_selected = np.array([1.0, 1.0, 1.0, 1.0], dtype='f4') if ids == 'all': ids = range(len(self.virtuals)) elif np.isscalar(ids): ids = [ids] for id in ids: if not id in self.old_color: self.old_color[id] = self.virtuals_colors[self.virtuals_first[id]:self.virtuals_first[id]+self.virtuals_count[id],:].copy() new_color = np.ones(self.old_color[id].shape, dtype='f4') * color_selected if self.verbose: print("Storing old color: %s" % self.old_color[id][0]) self.virtuals_colors[self.virtuals_first[id]:self.virtuals_first[id]+self.virtuals_count[id],:] = new_color self.selected.append(id) def unselect_track(self, ids): """Do visual un-selection of given virtuals. """ if ids == 'all': ids = range(len(self.virtuals)) elif np.isscalar(ids): ids = [ids] for id in ids: if id in self.old_color: self.virtuals_colors[self.virtuals_first[id]:self.virtuals_first[id]+self.virtuals_count[id],:] = self.old_color[id] if self.verbose: print("Setting old color: %s" % self.old_color[id][0]) self.old_color.pop(id) if id in self.selected: self.selected.remove(id) else: print('WARNING: unselecting id %s but not in %s' % (id, self.selected)) def invert_tracks(self): """ invert selected tracks to unselected """ tmp_selected=list(set(range(len(self.virtuals))).difference(set(self.selected))) self.unselect_track('all') #print tmp_selected self.selected=[] self.select_track(tmp_selected) def process_messages(self,messages): msg=messages['key_pressed'] #print 'Processing messages in actor', self.name, #' key_press message ', msg if msg!=None: self.process_keys(msg,None) msg=messages['mouse_position'] #print 'Processing messages in actor', self.name, #' mouse_pos message ', msg if msg!=None: self.process_mouse_position(*msg) def process_keys(self,symbol,modifiers): """Bind actions to key press. """ prev_selected = copy.copy(self.selected) if symbol == Qt.Key_P: print 'P' id = self.picking_virtuals(symbol, modifiers) print('Track id %d' % id) if prev_selected.count(id) == 0: self.select_track(id) else: self.unselect_track(id) if self.verbose: print 'Selected:' print self.selected if symbol==Qt.Key_E: print 'E' if self.verbose: print("Expand/collapse selected clusters.") if not self.expand and len(self.selected)>0: tracks_selected = [] for tid in self.selected: tracks_selected += self.qb.label2tracksids(tid) self.tracks_visualized_first = np.ascontiguousarray(self.tracks_first[tracks_selected, :]) self.tracks_visualized_count = np.ascontiguousarray(self.tracks_count[tracks_selected, :]) self.expand = True else: self.expand = False # Freeze and restart: elif symbol == Qt.Key_F and len(self.selected) > 0: print 'F' self.freeze() elif symbol == Qt.Key_A: print 'A' print('Select/unselect all virtuals') if len(self.selected) < len(self.virtuals): self.select_track('all') else: self.unselect_track('all') elif symbol == Qt.Key_I: print 'I' print('Invert selection') print self.selected self.invert_tracks() elif symbol == Qt.Key_H: print 'H' print('Hide/show virtuals.') self.hide_virtuals = not self.hide_virtuals elif symbol == Qt.Key_S: print 'S' print('Save selected tracks_ids as pickle file.') self.tracks_ids_to_be_saved = self.tracks_ids if len(self.selected)>0: self.tracks_ids_to_be_saved = self.tracks_ids[np.concatenate([self.qb.label2tracksids(tid) for tid in self.selected])] print("Saving %s tracks." % len(self.tracks_ids_to_be_saved)) root = Tkinter.Tk() root.withdraw() pickle.dump(self.tracks_ids_to_be_saved, tkFileDialog.asksaveasfile(), protocol=pickle.HIGHEST_PROTOCOL) elif symbol == Qt.Key_Question: print question_message elif symbol == Qt.Key_B: print 'B' print('Go back in the freezing history.') if len(self.history) > 1: self.history.pop() self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.history[-1] if self.reps=='virtuals': self.virtuals=qb.virtuals() if self.reps=='exemplars': self.virtuals, self.ex_ids = self.qb.exemplars()#virtuals() print len(self.virtuals), 'virtuals' self.selected = [] self.old_color = {} self.expand = False self.hide_virtuals = False elif symbol == Qt.Key_G: print 'G' print('Get tracks from mask.') ids = self.maskout_tracks() self.select_track(ids) def freeze(self): print("Freezing current expanded real tracks, then doing QB on them, then restarting.") print("Selected virtuals: %s" % self.selected) tracks_frozen = [] tracks_frozen_ids = [] for tid in self.selected: print tid part_tracks = self.qb.label2tracks(self.tracks, tid) part_tracks_ids = self.qb.label2tracksids(tid) print("virtual %s represents %s tracks." % (tid, len(part_tracks))) tracks_frozen += part_tracks tracks_frozen_ids += part_tracks_ids print "frozen tracks size:", len(tracks_frozen) print "Computing quick bundles...", self.unselect_track('all') self.tracks = tracks_frozen self.tracks_ids = self.tracks_ids[tracks_frozen_ids] root = Tkinter.Tk() root.wm_title('QuickBundles threshold') ts = ThresholdSelector(root, default_value=self.qb.dist_thr/2.0) root.wait_window() self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts) #self.qb.dist_thr = qb.dist_thr/2. self.qb.dist_thr = ts.value if self.reps=='virtuals': self.virtuals=qb.virtuals() if self.reps=='exemplars': self.virtuals,self.ex_ids = self.qb.exemplars() print len(self.virtuals), 'virtuals' self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers(self.virtuals, self.virtuals_alpha) #compute buffers self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers(self.tracks, self.tracks_alpha) # self.unselect_track('all') self.selected = [] self.old_color = {} self.expand = False self.history.append([self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count]) if self.vol_shape is not None: print("Shifting!") self.virtuals_shifted = [downsample(t + np.array(self.vol_shape) / 2., 30) for t in self.virtuals] else: self.virtuals_shifted = None def picking_virtuals(self, symbol,modifiers, min_dist=1e-3): """Compute the id of the closest track to the mouse pointer. """ x, y = self.mouse_x, self.mouse_y # Define two points in model space from mouse+screen(=0) position and mouse+horizon(=1) position near = screen_to_model(x, y, 0) far = screen_to_model(x, y, 1) #print 'peak virtuals ', near, far, x, y # Compute distance of virtuals from screen and from the line defined by the two points above tmp = np.array([cll.mindistance_segment2track_info(near, far, apply_transformation(xyz, self.glaff)) for xyz in self.virtuals]) line_distance, screen_distance = tmp[:,0], tmp[:,1] if False: # basic algoritm: # Among the virtuals within a range to the line (i.e. < min_dist) return the closest to the screen: closest_to_line_idx = np.argsort(line_distance) closest_to_line_thresholded_bool = line_distance[closest_to_line_idx] < min_dist if (closest_to_line_thresholded_bool).any(): return closest_to_line_idx[np.argmin(screen_distance[closest_to_line_thresholded_bool])] else: return closest_to_line_idx[0] else: # simpler and apparently more effective algorithm: return np.argmin(line_distance + screen_distance) def maskout_tracks(self): """ retrieve ids of virtuals which go through the mask """ mask = self.slicer.mask #tracks = self.tracks_shifted tracks = self.virtuals_shifted #tcs,self.tes = track_counts(tracks,mask.shape,(1,1,1),True) tcs,tes = track_counts(tracks,mask.shape,(1,1,1),True) # print 'tcs:',tcs # print 'tes:',len(self.tes.keys()) #find volume indices of mask's voxels roiinds=np.where(mask==1) #make it a nice 2d numpy array (Nx3) roiinds=np.array(roiinds).T #get tracks going through the roi # print "roiinds:", len(roiinds) # mask_tracks,mask_tracks_inds=bring_roi_tracks(tracks,roiinds,self.tes) mask_tracks_inds = [] for voxel in roiinds: try: #mask_tracks_inds+=self.tes[tuple(voxel)] mask_tracks_inds+=tes[tuple(voxel)] except KeyError: pass mask_tracks_inds = list(set(mask_tracks_inds)) print("Masked tracks %d" % len(mask_tracks_inds)) print("mask_tracks_inds: %s" % mask_tracks_inds) return mask_tracks_inds
# T = [downsample(t, 18) - np.array(data.shape[:3]) / 2. for t in T] # axis = np.array([1, 0, 0]) # theta = - 90. # T = np.dot(T,rotation_matrix(axis, theta)) # axis = np.array([0, 1, 0]) # theta = 180. # T = np.dot(T, rotation_matrix(axis, theta)) # # load initial QuickBundles with threshold 30mm # fpkl = dname+'data/subj_05/101_32/DTI/qb_gqi_1M_linear_30.pkl' # qb=QuickBundles(T, 10., 18) # save_pickle(fpkl,qb) # qb=load_pickle(fpkl) qb = QuickBundles(T, 20.0, 12) # save_pickle(fpkl,qb) # qb=load_pickle(fpkl) Init() title = "[F]oS Streamline Interaction and Segmentation" w = Window(caption=title, width=1200, height=800, bgcolor=(0.5, 0.5, 0.9), right_panel=True) scene = Scene(scenename="Main Scene", activate_aabb=False) # create the interaction system for tracks tl = StreamlineLabeler( "Bundle Picker", qb, qb.downsampled_tracks(), vol_shape=data.shape[:3], tracks_alpha=1, affine=affine ) guil = Guillotine("Volume Slicer", data, affine)
T2=[T[i] for i in seg_inds] T=T1 if reduce_length: T=[t for t in T if track_range(100,200)] #iT=np.random.randint(0,len(T),5000) #T=[T[i] for i in iT] #stop #center shift=(np.array(data.shape)-1)/2. T=[t-shift for t in T] #load initial QuickBundles with threshold 30mm #fpkl = 'data/subj_05/101_32/DTI/qb_gqi_1M_linear_30.pkl' qb=QuickBundles(T,25.,30) print len(qb.clustering) #qb=load_pickle(fpkl) qb.remove_small_clusters(1000) print len(qb.clustering) #create the interaction system for tracks tl = TrackLabeler(qb,qb.downsampled_tracks(),vol_shape=data.shape,tracks_line_width=3.,tracks_alpha=1.) #add a interactive slicing/masking tool sl = Slicer(affine,data) #add one way communication between tl and sl tl.slicer=sl #OpenGL coordinate system axes ax = Axes(100) x,y,z=data.shape
def main(): parser = OptionParser(usage="Usage: %prog [options] <tract.vtp> <output.csv>") parser.add_option("-d", "--dist", dest="dist", default=20, type='float', help="Quickbundle distance threshold") parser.add_option("-n", "--num", dest="num", default=50, type='int', help="Number of subdivisions along centroids") parser.add_option("-s", "--scalar", dest="scalar", default="FA", help="Scalar to measure") parser.add_option("--curvepoints", dest="curvepoints_file", help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.") parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False, help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines") parser.add_option('--reverse', dest='is_reverse', action='store_true', default=False, help='Reverse the centroid measure stepping order') (options, args) = parser.parse_args() if len(args) == 0: parser.print_help() sys.exit(2) QB_DIST = options.dist QB_NPOINTS = options.num SCALAR_NAME = options.scalar LOCAL_POINT_ASSIGN = options.is_local filename= args[0] filebase = path.basename(filename).split('.')[0] reader = vtk.vtkXMLPolyDataReader() reader.SetFileName(filename) reader.Update() polydata = reader.GetOutput() tract_ids = [] for i in range(polydata.GetNumberOfCells()): # get point ids in [[ids][ids...]...] format pids = polydata.GetCell(i).GetPointIds() ids = [ pids.GetId(p) for p in range(pids.GetNumberOfIds())] tract_ids.append(ids) print 'tracks:',len(tract_ids) verts = vtk_to_numpy(polydata.GetPoints().GetData()) print 'verts:',len(verts) scalars = [] groups = [] subjects = [] pointdata = polydata.GetPointData() for si in range(pointdata.GetNumberOfArrays()): sname = pointdata.GetArrayName(si) print sname if sname==SCALAR_NAME: scalars = vtk_to_numpy(pointdata.GetArray(si)) if sname=='group': groups = vtk_to_numpy(pointdata.GetArray(si)) groups = groups.astype(int) if sname=='tid': subjects = vtk_to_numpy(pointdata.GetArray(si)) subjects = subjects.astype(int) streamlines = [] stream_scalars = [] stream_groups = [] stream_pids = [] stream_sids = [] for i in tract_ids: # index np.array by a list will get all the respective indices streamlines.append(verts[i]) stream_scalars.append(scalars[i]) stream_groups.append(groups[i]) stream_pids.append(i) stream_sids.append(subjects[i]) streamlines = np.array(streamlines) stream_scalars = np.array(stream_scalars) stream_groups = np.array(stream_groups) stream_pids = np.array(stream_pids) stream_sids = np.array(stream_sids) # get total average direction (where majority point towards) avg_d = np.zeros(3) # for line in streams: # d = np.array(line[-1]) - np.array(line[0]) # d = d / la.norm(d) # avg_d += d # avg_d /= la.norm(avg_d) avg_com = np.zeros(3) avg_mid = np.zeros(3) strl_len = [len(l) for l in streamlines] stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines]) centroids = [] if options.curvepoints_file: LOCAL_POINT_ASSIGN = False cpoints = [] ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',') # have a separate vtkCardinalSpline interpreter for x,y,z curve = [vtk.vtkCardinalSpline() for i in range(3)] for c in curve: c.ClosedOff() for pi, point in enumerate(ctrlpoints): for i,val in enumerate(point): curve[i].AddPoint(pi,point[i]) param_range = [0.0,0.0] curve[0].GetParametricRange(param_range) t = param_range[0] step = (param_range[1]-param_range[0])/(QB_NPOINTS) while t < param_range[1]: cp = [c.Evaluate(t) for c in curve] cpoints.append(cp) t = t + step centroids.append(cpoints) centroids = np.array(centroids) else: """ Use quickbundles to find centroids """ # streamlines = newlines qb = QuickBundles(streamlines, dist_thr=QB_DIST,pts=QB_NPOINTS) # bundle_distance_mam centroids = qb.centroids clusters = qb.clusters() avg_d = np.zeros(3) avg_com = np.zeros(3) avg_mid = np.zeros(3) #unify centroid list orders to point in the same general direction for i, line in enumerate(centroids): ori = np.array(tm.mean_orientation(line)) #d = np.array(line[-1]) - np.array(line[0]) #print line[-1],line[0],d # get the unit vector of the mean orientation if i==0: avg_d = ori #d = d / la.norm(d) dotprod = ori.dot(avg_d) print 'dotprod',dotprod if dotprod < 0: print 'reverse',dotprod centroids[i] = line[::-1] line = centroids[i] ori*=-1 avg_d += ori if options.is_reverse: for i,c in enumerate(centroids): centroids[i] = c[::-1] DATADF = None """ CENTROIDS """ for ci, cent in enumerate(centroids): print '---- centroid:' if LOCAL_POINT_ASSIGN: """ apply centroid to only their point assignments through quickbundles """ ind = clusters[ci]['indices'] cent_streams = streamlines[ind] cent_scalars = stream_scalars[ind] cent_groups = stream_groups[ind] cent_pids = stream_pids[ind] cent_sids = stream_sids[ind] else: # apply each centriod to all the points # instead of only their centroid assignments cent_streams = streamlines cent_scalars = stream_scalars cent_groups = stream_groups cent_pids = stream_pids cent_sids = stream_sids cent_verts = np.vstack(cent_streams) cent_scalars = np.concatenate(cent_scalars) cent_groups = np.concatenate(cent_groups) cent_pids = np.concatenate(cent_pids) cent_sids = np.concatenate(cent_sids) # cent_color = np.array(pal[ci]) c, labels = kmeans2(cent_verts, cent, iter=1) cid = np.ones(len(labels)) d = {'value':cent_scalars, 'position':labels, 'group':cent_groups, 'pid':cent_pids, 'sid':cent_sids, 'centroid':ci} df = pd.DataFrame(data=d) if DATADF is None: DATADF = df else: pd.concat([DATADF, df]) outfilename = '_'.join([filebase,SCALAR_NAME,'rawdata.csv']) if len(args) > 1: outfilename = args[1] if outfilename.endswith('.csv'): DATADF.to_csv(outfilename, index=False) if outfilename.endswith('.xls'): DATADF.to_excel(outfilename, index=False)
T = [i[0] for i in streams] """ Downsample tracks to 12 points: """ tracks = [tm.downsample(t, 12) for t in T] """ Delete unnecessary data: """ del streams, hdr """ Perform QuickBundles clustering with a 10mm threshold: """ qb = QuickBundles(tracks, dist_thr=10., pts=None) """ Show the initial *Fornix* dataset: """ r = fvtk.ren() fvtk.add(r, fvtk.line(T, fvtk.white, opacity=1, linewidth=3)) #fvtk.show(r) fvtk.record(r, n_frames=1, out_path='fornix_initial', size=(600, 600)) fvtk.clear(r) """ .. figure:: fornix_initial1000000.png :align: center **Initial Fornix dataset**. """
def terminus2hemi_surface_density_map(streamlines, geo_path, hemi): """ Streamline endpoints areas mapping to hemisphere surface streamlines > 1000 Parameters ---------- streamlines: streamline data geo_path: surface data path Return ------ endpoints areas map on surface """ streamlines = _sort_streamlines(streamlines) bundles = QuickBundles(streamlines, 10, 12) # bundles.remove_small_clusters(10) clusters = bundles.clusters() data_clusters = [] for key in clusters.keys(): data_clusters.append(streamlines[clusters[key]['indices']]) data0 = data_clusters[0] stream_terminus_lh0 = np.array([s[0] for s in data0]) stream_terminus_rh0 = np.array([s[-1] for s in data0]) suffix = os.path.split(geo_path)[1].split('.')[-1] if suffix in ('white', 'inflated', 'pial'): coords, faces = nib.freesurfer.read_geometry(geo_path) elif suffix == 'gii': gii_data = nib.load(geo_path).darrays coords, faces = gii_data[0].data, gii_data[1].data else: raise ImageFileError( 'This file format-{} is not supported at present.'.format(suffix)) if hemi == 'lh': dist_lh0 = cdist(coords, stream_terminus_lh0) vert_value = np.array([ float(np.array(dist_lh0[m] < 5).sum(axis=0)) for m in range(len(dist_lh0[:])) ]) else: dist_rh0 = cdist(coords, stream_terminus_rh0) vert_value = np.array([ float(np.array(dist_rh0[n] < 5).sum(axis=0)) for n in range(len(dist_rh0[:])) ]) for i in range(len(data_clusters)): data = data_clusters[i] if hemi == 'lh': stream_terminus = np.array([s[0] for s in data]) else: stream_terminus = np.array([s[-1] for s in data]) dist = cdist(coords, stream_terminus) vert_value_i = np.array( [np.array(dist[j] < 5).sum(axis=0) for j in range(len(dist[:]))]) if i != 0: vert_value += vert_value_i return vert_value
class TrackLabeler(Actor): def __init__(self, name, qb, tracks, reps='exemplars', colors=None, vol_shape=None, virtuals_line_width=5.0, tracks_line_width=2.0, virtuals_alpha=1.0, tracks_alpha=0.6, affine=None, verbose=False): """TrackLabeler is meant to explore and select subsets of the tracks. The exploration occurs through QuickBundles (qb) in order to simplify the scene. """ super(TrackLabeler, self).__init__(name) if affine is None: self.affine = np.eye(4, dtype=np.float32) else: self.affine = affine self.mouse_x = None self.mouse_y = None self.cache = {} self.qb = qb self.reps = reps #virtual tracks if self.reps == 'virtuals': self.virtuals = qb.virtuals() if self.reps == 'exemplars': self.virtuals, self.ex_ids = qb.exemplars() self.virtuals_alpha = virtuals_alpha self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers( self.virtuals, self.virtuals_alpha) #full tractography (downsampled at 12 pts per track) self.tracks = tracks self.tracks_alpha = tracks_alpha self.tracks_ids = np.arange(len(self.tracks), dtype=np.int) self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers( self.tracks, self.tracks_alpha) #calculate boundary box for entire tractography self.min = np.min(self.tracks_buffer, axis=0) self.max = np.max(self.tracks_buffer, axis=0) self.vertices = self.tracks_buffer #coord1 = np.array([self.tracks_buffer[:,0].min(),self.tracks_buffer[:,1].min(),self.tracks_buffer[:,2].min()], dtype = 'f4') #coord2 = np.array([self.tracks_buffer[:,0].max(),self.tracks_buffer[:,1].max(),self.tracks_buffer[:,2].max()], dtype = 'f4') #self.make_aabb((coord1,coord2),0) #show size of tractography buffer print('MBytes %f' % (self.tracks_buffer.nbytes / 2.**20, )) self.position = (0, 0, 0) #buffer for selected virtual tracks self.selected = [] self.virtuals_line_width = virtuals_line_width self.tracks_line_width = tracks_line_width self.old_color = {} self.hide_virtuals = False self.expand = False self.verbose = verbose self.tracks_visualized_first = np.array([], dtype='i4') self.tracks_visualized_count = np.array([], dtype='i4') self.history = [[ self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count ]] #shifting of track is necessary for dipy.tracking.vox2track.track_counts #we also upsample using 30 points in order to increase the accuracy of track counts self.vol_shape = vol_shape if self.vol_shape != None: #self.tracks_shifted =[t+np.array(vol_shape)/2. for t in self.tracks] self.virtuals_shifted = [ downsample(t + np.array(self.vol_shape) / 2., 30) for t in self.virtuals ] else: #self.tracks_shifted=None self.virtuals_shifted = None def compute_buffers(self, tracks, alpha): """Compute buffers for GL compilation. """ tracks_buffer = np.ascontiguousarray( np.concatenate(tracks).astype('f4')) tracks_colors = np.ascontiguousarray(self.compute_colors( tracks, alpha)) tracks_count = np.ascontiguousarray( np.array([len(v) for v in tracks], dtype='i4')) tracks_first = np.ascontiguousarray( np.r_[0, np.cumsum(tracks_count)[:-1]].astype('i4')) if isinstance(tracks_count, tuple): print '== count' if isinstance(tracks_first, tuple): print '== first' return tracks_buffer, tracks_colors, tracks_first, tracks_count def compute_colors(self, tracks, alpha): """Compute colors for a list of tracks. """ assert (type(tracks) == type([])) tot_vertices = np.sum([len(curve) for curve in tracks]) color = np.empty((tot_vertices, 4), dtype='f4') counter = 0 for curve in tracks: color[counter:counter + len(curve), :3] = track2rgb(curve).astype('f4') counter += len(curve) color[:, 3] = alpha return color def draw(self): """Draw virtual and real tracks. This is done at every frame and therefore must be real fast. """ # virtuals glEnable(GL_DEPTH_TEST) glEnable(GL_BLEND) glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) glEnable(GL_LINE_SMOOTH) glHint(GL_LINE_SMOOTH_HINT, GL_NICEST) glEnableClientState(GL_VERTEX_ARRAY) glEnableClientState(GL_COLOR_ARRAY) if not self.hide_virtuals: glVertexPointer(3, GL_FLOAT, 0, self.virtuals_buffer.ctypes.data) glColorPointer(4, GL_FLOAT, 0, self.virtuals_colors.ctypes.data) glLineWidth(self.virtuals_line_width) glPushMatrix() if isinstance(self.virtuals_first, tuple): print '>> first Tuple' if isinstance(self.virtuals_count, tuple): print '>> count Tuple' glib.glMultiDrawArrays(GL_LINE_STRIP, self.virtuals_first.ctypes.data, self.virtuals_count.ctypes.data, len(self.virtuals)) glPopMatrix() # reals: if self.expand and self.tracks_visualized_first.size > 0: glVertexPointer(3, GL_FLOAT, 0, self.tracks_buffer.ctypes.data) glColorPointer(4, GL_FLOAT, 0, self.tracks_colors.ctypes.data) glLineWidth(self.tracks_line_width) glPushMatrix() glib.glMultiDrawArrays(GL_LINE_STRIP, self.tracks_visualized_first.ctypes.data, self.tracks_visualized_count.ctypes.data, len(self.tracks_visualized_count)) glPopMatrix() glDisableClientState(GL_COLOR_ARRAY) glDisableClientState(GL_VERTEX_ARRAY) glLineWidth(1.) glDisable(GL_DEPTH_TEST) glDisable(GL_BLEND) glDisable(GL_LINE_SMOOTH) def process_mouse_position(self, x, y): self.mouse_x = x self.mouse_y = y def process_pickray(self, near, far): pass def update(self, dt): pass def select_track(self, ids): """Do visual selection of given virtuals. """ # WARNING: we assume that no tracks can ever have color_selected as original color color_selected = np.array([1.0, 1.0, 1.0, 1.0], dtype='f4') if ids == 'all': ids = range(len(self.virtuals)) elif np.isscalar(ids): ids = [ids] for id in ids: if not id in self.old_color: self.old_color[id] = self.virtuals_colors[ self.virtuals_first[id]:self.virtuals_first[id] + self.virtuals_count[id], :].copy() new_color = np.ones(self.old_color[id].shape, dtype='f4') * color_selected if self.verbose: print("Storing old color: %s" % self.old_color[id][0]) self.virtuals_colors[ self.virtuals_first[id]:self.virtuals_first[id] + self.virtuals_count[id], :] = new_color self.selected.append(id) def unselect_track(self, ids): """Do visual un-selection of given virtuals. """ if ids == 'all': ids = range(len(self.virtuals)) elif np.isscalar(ids): ids = [ids] for id in ids: if id in self.old_color: self.virtuals_colors[ self.virtuals_first[id]:self.virtuals_first[id] + self.virtuals_count[id], :] = self.old_color[id] if self.verbose: print("Setting old color: %s" % self.old_color[id][0]) self.old_color.pop(id) if id in self.selected: self.selected.remove(id) else: print('WARNING: unselecting id %s but not in %s' % (id, self.selected)) def invert_tracks(self): """ invert selected tracks to unselected """ tmp_selected = list( set(range(len(self.virtuals))).difference(set(self.selected))) self.unselect_track('all') #print tmp_selected self.selected = [] self.select_track(tmp_selected) def process_messages(self, messages): msg = messages['key_pressed'] #print 'Processing messages in actor', self.name, #' key_press message ', msg if msg != None: self.process_keys(msg, None) msg = messages['mouse_position'] #print 'Processing messages in actor', self.name, #' mouse_pos message ', msg if msg != None: self.process_mouse_position(*msg) def process_keys(self, symbol, modifiers): """Bind actions to key press. """ prev_selected = copy.copy(self.selected) if symbol == Qt.Key_P: print 'P' id = self.picking_virtuals(symbol, modifiers) print('Track id %d' % id) if prev_selected.count(id) == 0: self.select_track(id) else: self.unselect_track(id) if self.verbose: print 'Selected:' print self.selected if symbol == Qt.Key_E: print 'E' if self.verbose: print("Expand/collapse selected clusters.") if not self.expand and len(self.selected) > 0: tracks_selected = [] for tid in self.selected: tracks_selected += self.qb.label2tracksids(tid) self.tracks_visualized_first = np.ascontiguousarray( self.tracks_first[tracks_selected, :]) self.tracks_visualized_count = np.ascontiguousarray( self.tracks_count[tracks_selected, :]) self.expand = True else: self.expand = False # Freeze and restart: elif symbol == Qt.Key_F and len(self.selected) > 0: print 'F' self.freeze() elif symbol == Qt.Key_A: print 'A' print('Select/unselect all virtuals') if len(self.selected) < len(self.virtuals): self.select_track('all') else: self.unselect_track('all') elif symbol == Qt.Key_I: print 'I' print('Invert selection') print self.selected self.invert_tracks() elif symbol == Qt.Key_H: print 'H' print('Hide/show virtuals.') self.hide_virtuals = not self.hide_virtuals elif symbol == Qt.Key_S: print 'S' print('Save selected tracks_ids as pickle file.') self.tracks_ids_to_be_saved = self.tracks_ids if len(self.selected) > 0: self.tracks_ids_to_be_saved = self.tracks_ids[np.concatenate( [self.qb.label2tracksids(tid) for tid in self.selected])] print("Saving %s tracks." % len(self.tracks_ids_to_be_saved)) root = Tkinter.Tk() root.withdraw() pickle.dump(self.tracks_ids_to_be_saved, tkFileDialog.asksaveasfile(), protocol=pickle.HIGHEST_PROTOCOL) elif symbol == Qt.Key_Question: print question_message elif symbol == Qt.Key_B: print 'B' print('Go back in the freezing history.') if len(self.history) > 1: self.history.pop() self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.history[ -1] if self.reps == 'virtuals': self.virtuals = qb.virtuals() if self.reps == 'exemplars': self.virtuals, self.ex_ids = self.qb.exemplars( ) #virtuals() print len(self.virtuals), 'virtuals' self.selected = [] self.old_color = {} self.expand = False self.hide_virtuals = False elif symbol == Qt.Key_G: print 'G' print('Get tracks from mask.') ids = self.maskout_tracks() self.select_track(ids) def freeze(self): print( "Freezing current expanded real tracks, then doing QB on them, then restarting." ) print("Selected virtuals: %s" % self.selected) tracks_frozen = [] tracks_frozen_ids = [] for tid in self.selected: print tid part_tracks = self.qb.label2tracks(self.tracks, tid) part_tracks_ids = self.qb.label2tracksids(tid) print("virtual %s represents %s tracks." % (tid, len(part_tracks))) tracks_frozen += part_tracks tracks_frozen_ids += part_tracks_ids print "frozen tracks size:", len(tracks_frozen) print "Computing quick bundles...", self.unselect_track('all') self.tracks = tracks_frozen self.tracks_ids = self.tracks_ids[ tracks_frozen_ids] # range(len(self.tracks)) root = Tkinter.Tk() root.wm_title('QuickBundles threshold') ts = ThresholdSelector(root, default_value=self.qb.dist_thr / 2.0) root.wait_window() #print "Threshold value ",ts.value #self.qb = QuickBundles(self.tracks, dist_thr=qb.dist_thr/2., pts=self.qb.pts) self.qb = QuickBundles(self.tracks, dist_thr=ts.value, pts=self.qb.pts) #self.qb.dist_thr = qb.dist_thr/2. self.qb.dist_thr = ts.value if self.reps == 'virtuals': self.virtuals = qb.virtuals() if self.reps == 'exemplars': self.virtuals, self.ex_ids = self.qb.exemplars() print len(self.virtuals), 'virtuals' self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count = self.compute_buffers( self.virtuals, self.virtuals_alpha) #compute buffers self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count = self.compute_buffers( self.tracks, self.tracks_alpha) # self.unselect_track('all') self.selected = [] self.old_color = {} self.expand = False self.history.append([ self.qb, self.tracks, self.tracks_ids, self.virtuals_buffer, self.virtuals_colors, self.virtuals_first, self.virtuals_count, self.tracks_buffer, self.tracks_colors, self.tracks_first, self.tracks_count ]) if self.vol_shape is not None: print("Shifting!") self.virtuals_shifted = [ downsample(t + np.array(self.vol_shape) / 2., 30) for t in self.virtuals ] else: self.virtuals_shifted = None def picking_virtuals(self, symbol, modifiers, min_dist=1e-3): """Compute the id of the closest track to the mouse pointer. """ x, y = self.mouse_x, self.mouse_y # Define two points in model space from mouse+screen(=0) position and mouse+horizon(=1) position near = screen_to_model(x, y, 0) far = screen_to_model(x, y, 1) #print 'peak virtuals ', near, far, x, y # Compute distance of virtuals from screen and from the line defined by the two points above tmp = np.array([cll.mindistance_segment2track_info(near, far, xyz) \ for xyz in self.virtuals]) line_distance, screen_distance = tmp[:, 0], tmp[:, 1] if False: # basic algoritm: # Among the virtuals within a range to the line (i.e. < min_dist) return the closest to the screen: closest_to_line_idx = np.argsort(line_distance) closest_to_line_thresholded_bool = line_distance[ closest_to_line_idx] < min_dist if (closest_to_line_thresholded_bool).any(): return closest_to_line_idx[np.argmin( screen_distance[closest_to_line_thresholded_bool])] else: return closest_to_line_idx[0] else: # simpler and apparently more effective algorithm: return np.argmin(line_distance + screen_distance) def set_state(self): # , line_width): """Tell hardware what to do with the scene. """ glEnable(GL_DEPTH_TEST) glEnable(GL_BLEND) glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) glEnable(GL_LINE_SMOOTH) glHint(GL_LINE_SMOOTH_HINT, GL_NICEST) # glLineWidth(line_width) def unset_state(self): """Close communication with hardware. Disable what was enabled during set_state(). """ glDisable(GL_DEPTH_TEST) glDisable(GL_BLEND) glDisable(GL_LINE_SMOOTH) # glLineWidth(1.) def delete(self): pass def maskout_tracks(self): """ retrieve ids of virtuals which go through the mask """ mask = self.slicer.mask #tracks = self.tracks_shifted tracks = self.virtuals_shifted #tcs,self.tes = track_counts(tracks,mask.shape,(1,1,1),True) tcs, tes = track_counts(tracks, mask.shape, (1, 1, 1), True) # print 'tcs:',tcs # print 'tes:',len(self.tes.keys()) #find volume indices of mask's voxels roiinds = np.where(mask == 1) #make it a nice 2d numpy array (Nx3) roiinds = np.array(roiinds).T #get tracks going through the roi # print "roiinds:", len(roiinds) # mask_tracks,mask_tracks_inds=bring_roi_tracks(tracks,roiinds,self.tes) mask_tracks_inds = [] for voxel in roiinds: try: #mask_tracks_inds+=self.tes[tuple(voxel)] mask_tracks_inds += tes[tuple(voxel)] except KeyError: pass mask_tracks_inds = list(set(mask_tracks_inds)) print("Masked tracks %d" % len(mask_tracks_inds)) print("mask_tracks_inds: %s" % mask_tracks_inds) return mask_tracks_inds
from dipy.io.pickles import save_pickle from dipy.data import get_data from dipy.viz import fvtk import nipype.interfaces.mrtrix as mrt import os import sys streams, hdr = tv.read(sys.argv[1]) streamlines = [i[0] for i in streams] """ Perform QuickBundles clustering with a 10mm distance threshold after having downsampled the streamlines to have only 12 points. """ print("Computing bundles") qb = QuickBundles(streamlines, dist_thr=int(sys.argv[3]), pts=int(sys.argv[4])) print("Completed") """ qb has attributes like `centroids` (cluster representatives), `total_clusters` (total number of clusters) and methods like `partitions` (complete description of all clusters) and `label2tracksids` (provides the indices of the streamlines which belong in a specific cluster). """ centroids = qb.centroids print(len(centroids)) streamlines = [i[0] for i in streams] for i, centroid in enumerate(centroids): print(i) inds = qb.label2tracksids(i) list1 = [] for number in range(1, len(inds)):
from fos.data import get_track_filename from pyglet.window import key from fos.core.utils import screen_to_model import fos.core.collision as cll from pyglet.gl import * #dipy modules from dipy.segment.quickbundles import QuickBundles streams,hdr = nib.trackvis.read(get_track_filename()) #center the data T=[s[0] for s in streams] mean_T=np.mean(np.concatenate(T),axis=0) T=[t-mean_T for t in T] qb=QuickBundles(T,10.,12) Tqb=qb.virtuals() Tqbe,Tqbei=qb.exemplars() class TrackLabeler(Actor): def __init__(self,qb,tracks,colors=None,line_width=2.,affine=None): self.virtuals=qb.virtuals() self.tracks=tracks if affine is None: self.affine = np.eye(4, dtype = np.float32) else: self.affine = affine #aabb - bounding box things ccurves=np.concatenate(tracks)
# T = [downsample(t, 18) - np.array(data.shape[:3]) / 2. for t in T] # axis = np.array([1, 0, 0]) # theta = - 90. # T = np.dot(T,rotation_matrix(axis, theta)) # axis = np.array([0, 1, 0]) # theta = 180. # T = np.dot(T, rotation_matrix(axis, theta)) # #load initial QuickBundles with threshold 30mm #fpkl = dname+'data/subj_05/101_32/DTI/qb_gqi_1M_linear_30.pkl' #qb=QuickBundles(T, 10., 18) #save_pickle(fpkl,qb) #qb=load_pickle(fpkl) qb=QuickBundles(T, 20., 12) #save_pickle(fpkl,qb) #qb=load_pickle(fpkl) Init() title = '[F]oS Streamline Interaction and Segmentation' w = Window(caption = title, width = 1200, height = 800, bgcolor = (.5, .5, 0.9), right_panel=True) scene = Scene(scenename = 'Main Scene', activate_aabb = False) #create the interaction system for tracks tl = StreamlineLabeler('Bundle Picker', qb,qb.downsampled_tracks(),
def main(): parser = OptionParser(usage="Usage: %prog [options] <tract.vtp>") parser.add_option("-s", "--scalar", dest="scalar", default="FA", help="Scalar to measure") parser.add_option("-n", "--num", dest="num", default=50, type='int', help="Number of subdivisions along centroids") parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False, help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines") parser.add_option("-d", "--dist", dest="dist", default=20, type='float', help="Quickbundle distance threshold") parser.add_option("--curvepoints", dest="curvepoints_file", help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.") parser.add_option('--yrange', dest='yrange') parser.add_option('--xrange', dest='xrange') parser.add_option('--reverse', dest='is_reverse', action='store_true', default=False, help='Reverse the centroid measure stepping order') parser.add_option('--pairplot', dest='pairplot',) parser.add_option('--noviz', dest='is_viz', action='store_false', default=True) parser.add_option('--hide-centroid', dest='show_centroid', action='store_false', default=True) parser.add_option('--config', dest='config') parser.add_option('--background', dest='bg_file', help='Background NIFTI image') parser.add_option('--annot', dest='annot') (options, args) = parser.parse_args() if len(args) == 0: parser.print_help() sys.exit(2) name_mapping = None if options.config: config = GtsConfig(options.config, configure=False) name_mapping = config.group_labels annotations = None if options.annot: with open(options.annot, 'r') as fp: annotations = yaml.load(fp) QB_DIST = options.dist QB_NPOINTS = options.num SCALAR_NAME = options.scalar LOCAL_POINT_ASSIGN = options.is_local filename = args[0] filebase = path.basename(filename).split('.')[0] reader = vtk.vtkXMLPolyDataReader() reader.SetFileName(filename) reader.Update() polydata = reader.GetOutput() tract_ids = [] for i in range(polydata.GetNumberOfCells()): # get point ids in [[ids][ids...]...] format pids = polydata.GetCell(i).GetPointIds() ids = [pids.GetId(p) for p in range(pids.GetNumberOfIds())] tract_ids.append(ids) print 'tracks:', len(tract_ids) verts = vtk_to_numpy(polydata.GetPoints().GetData()) print 'verts:', len(verts) scalars = [] groups = [] subjects = [] pointdata = polydata.GetPointData() for si in range(pointdata.GetNumberOfArrays()): sname = pointdata.GetArrayName(si) print sname if sname == SCALAR_NAME: scalars = vtk_to_numpy(pointdata.GetArray(si)) if sname == 'group': groups = vtk_to_numpy(pointdata.GetArray(si)) groups = groups.astype(int) if sname == 'tid': subjects = vtk_to_numpy(pointdata.GetArray(si)) subjects = subjects.astype(int) streamlines = [] stream_scalars = [] stream_groups = [] stream_pids = [] stream_sids = [] for i in tract_ids: # index np.array by a list will get all the respective indices streamlines.append(verts[i]) stream_scalars.append(scalars[i]) stream_pids.append(i) stream_sids.append(subjects[i]) try: stream_groups.append(groups[i]) except Exception: # group might not exist pass streamlines = np.array(streamlines) stream_scalars = np.array(stream_scalars) stream_groups = np.array(stream_groups) stream_pids = np.array(stream_pids) stream_sids = np.array(stream_sids) # get total average direction (where majority point towards) avg_d = np.zeros(3) # for line in streams: # d = np.array(line[-1]) - np.array(line[0]) # d = d / la.norm(d) # avg_d += d # avg_d /= la.norm(avg_d) avg_com = np.zeros(3) avg_mid = np.zeros(3) strl_len = [len(l) for l in streamlines] stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines]) centroids = [] if options.curvepoints_file: LOCAL_POINT_ASSIGN = False cpoints = [] ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',') # have a separate vtkCardinalSpline interpreter for x,y,z curve = [vtk.vtkCardinalSpline() for i in range(3)] for c in curve: c.ClosedOff() for pi, point in enumerate(ctrlpoints): for i, val in enumerate(point): curve[i].AddPoint(pi, point[i]) param_range = [0.0, 0.0] curve[0].GetParametricRange(param_range) t = param_range[0] step = (param_range[1] - param_range[0]) / (QB_NPOINTS - 1.0) while t < param_range[1]: cp = [c.Evaluate(t) for c in curve] cpoints.append(cp) t = t + step centroids.append(cpoints) centroids = np.array(centroids) else: """ Use quickbundles to find centroids """ # streamlines = newlines qb = QuickBundles(streamlines, dist_thr=QB_DIST, pts=QB_NPOINTS) # bundle_distance_mam centroids = qb.centroids clusters = qb.clusters() avg_d = np.zeros(3) avg_com = np.zeros(3) avg_mid = np.zeros(3) #unify centroid list orders to point in the same general direction for i, line in enumerate(centroids): ori = np.array(tm.mean_orientation(line)) #d = np.array(line[-1]) - np.array(line[0]) #print line[-1],line[0],d # get the unit vector of the mean orientation if i == 0: avg_d = ori #d = d / la.norm(d) dotprod = ori.dot(avg_d) print 'dotprod', dotprod if dotprod < 0: print 'reverse', dotprod centroids[i] = line[::-1] line = centroids[i] ori *= -1 avg_d += ori if options.is_reverse: for i, c in enumerate(centroids): centroids[i] = c[::-1] # prepare mayavi 3d viz if options.is_viz: bg_val = 0. fig = mlab.figure(bgcolor=(bg_val, bg_val, bg_val)) scene = mlab.gcf().scene fig.scene.render_window.aa_frames = 4 mlab.draw() if options.bg_file: mrsrc, bgdata = getNiftiAsScalarField(options.bg_file) orie = 'z_axes' opacity = 0.5 slice_index = 0 mlab.pipeline.image_plane_widget(mrsrc, opacity=opacity, plane_orientation=orie, slice_index=int( slice_index), colormap='black-white', line_width=0, reset_zoom=False) # prepare the plt plot len_cent = len(centroids) pal = sns.color_palette("bright", len_cent) DATADF = None """ CENTROIDS """ for ci, cent in enumerate(centroids): print '---- centroid:' if LOCAL_POINT_ASSIGN: """ apply centroid to only their point assignments through quickbundles """ ind = clusters[ci]['indices'] cent_streams = streamlines[ind] cent_scalars = stream_scalars[ind] cent_groups = stream_groups[ind] cent_pids = stream_pids[ind] cent_sids = stream_sids[ind] else: # apply each centriod to all the points # instead of only their centroid assignments cent_streams = streamlines cent_scalars = stream_scalars cent_groups = stream_groups cent_pids = stream_pids cent_sids = stream_sids cent_verts = np.vstack(cent_streams) cent_scalars = np.concatenate(cent_scalars) cent_groups = np.concatenate(cent_groups) cent_pids = np.concatenate(cent_pids) cent_sids = np.concatenate(cent_sids) cent_color = np.array(pal[ci]) c, labels = kmeans2(cent_verts, cent, iter=1) cid = np.ones(len(labels)) d = {'value': cent_scalars, 'position': labels, 'group': cent_groups, 'pid': cent_pids, 'sid': cent_sids} df = pd.DataFrame(data=d) if DATADF is None: DATADF = df else: pd.concat([DATADF, df]) UNIQ_GROUPS = df.group.unique() UNIQ_GROUPS.sort() # UNIQ_GROUPS = [0,1] grppal = sns.color_palette("Set2", len(UNIQ_GROUPS)) print '# UNIQ GROUPS', UNIQ_GROUPS # print df # df = df[df['sid'] != 15] # df = df[df['sid'] != 16] # df = df[df['sid'] != 17] # df = df[df['sid'] != 18] """ plot each group by their position """ fig = plt.figure(figsize=(14, 7)) ax1 = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3) ax2 = plt.subplot2grid((4, 3), (3, 0), colspan=3, sharex=ax1) axes = [ax1, ax2] plt.xlabel('Position Index') if len(centroids) > 1: cent_patch = mpatches.Patch( color=cent_color, label='Centroid {}'.format(ci + 1)) cent_legend = axes[0].legend(handles=[cent_patch], loc=9) axes[0].add_artist(cent_legend) """ Perform stats """ if len(UNIQ_GROUPS) > 1: # df = resample_data(df, num_sample_per_pos=120) # print df pvalsDf = position_stats(df, name_mapping=name_mapping) logpvals = np.log(pvalsDf) * -1 # print logpvals pvals = logpvals.mask(pvalsDf >= 0.05) import matplotlib.ticker as mticker print pvals cmap = mcmap.Reds cmap.set_bad('w', 1.) axes[1].pcolormesh(pvals.values.T, cmap=cmap, vmin=0, vmax=10, edgecolors='face', alpha=0.8) #axes[1].yaxis.set_major_locator(mticker.MultipleLocator(base=1.0)) axes[1].set_yticks( np.arange(pvals.values.shape[1]) + 0.5, minor=False) axes[1].set_yticklabels( pvalsDf.columns.values.tolist(), minor=False) legend_handles = [] for gi, GRP in enumerate(UNIQ_GROUPS): print '-------------------- GROUP ', gi, '----------------------' subgrp = df[df['group'] == GRP] print len(subgrp) if options.xrange: x0, x1 = options.xrange.split(',') x0 = int(x0) x1 = int(x1) subgrp = subgrp[(subgrp['position'] >= x0) & (subgrp['position'] < x1)] posGrp = subgrp.groupby('position', sort=True) cent_stats = posGrp.apply(lambda x: stats_per_group(x)) if len(cent_stats) == 0: continue cent_stats = cent_stats.unstack() cent_median_scalar = cent_stats['median'].tolist() x = np.array([i for i in posGrp.groups]) # print x # print cent_stats['median'].tolist() mcolor = np.array(grppal[gi]) # if gi>0: # mcolor*= 1./(1+gi) cent_color = tuple(cent_color) mcolor = tuple(mcolor) if type(axes) is list: cur_axe = axes[0] else: cur_axe = axes cur_axe.set_ylabel(SCALAR_NAME) # cur_axe.yaxis.label.set_color(cent_color) # cur_axe.tick_params(axis='y', colors=cent_color) #cur_axe.fill_between(x, [s[0] for s in cent_ci], [t[1] for t in cent_ci], alpha=0.3, color=mcolor) # cur_axe.fill_between(x, [s[0] for s in cent_stats['whisk'].tolist()], # [t[1] for t in cent_stats['whisk'].tolist()], alpha=0.1, color=mcolor) qtile_top = np.array([s[0] for s in cent_stats['ci'].tolist()]) qtile_bottom = np.array([t[1] for t in cent_stats['ci'].tolist()]) x_new, qtop_sm = smooth(x, qtile_top) x_new, qbottom_sm = smooth(x, qtile_bottom) cur_axe.fill_between(x_new, qtop_sm, qbottom_sm, alpha=0.25, color=mcolor) # cur_axe.errorbar(x, cent_stats['median'].tolist(), yerr=[[s[0] for s in cent_stats['err'].tolist()], # [t[1] for t in cent_stats['err'].tolist()]], color=mcolor, alpha=0.1) x_new, median_sm = smooth(x, cent_stats['median']) hnd, = cur_axe.plot(x_new, median_sm, c=mcolor) legend_handles.append(hnd) # cur_axe.scatter(x,cent_stats['median'].tolist(), c=mcolor) if options.yrange: plotrange = options.yrange.split(',') cur_axe.set_ylim([float(plotrange[0]), float(plotrange[1])]) legend_labels = UNIQ_GROUPS if name_mapping is not None: legend_labels = [name_mapping[str(i)] for i in UNIQ_GROUPS] cur_axe.legend(legend_handles, legend_labels) if annotations: for key, val in annotations.iteritems(): # print key cur_axe.axvspan(val[0], val[1], fill=False, linestyle='dashed') axis_to_data = cur_axe.transAxes + cur_axe.transData.inverted() data_to_axis = axis_to_data.inverted() axpoint = data_to_axis.transform((val[0], 0)) # print axpoint cur_axe.text(axpoint[0], 1.02, key, transform=cur_axe.transAxes) """ Plot 3D Viz """ if options.is_viz: scene.disable_render = True # scene.renderer.render_window.set(alpha_bit_planes=1,multi_samples=0) # scene.renderer.set(use_depth_peeling=True,maximum_number_of_peels=4,occlusion_ratio=0.1) # ran_colors = np.random.random_integers(255, size=(len(cent),4)) # ran_colors[:,-1] = 255 mypts = mlab.points3d(cent_verts[:, 0], cent_verts[:, 1], cent_verts[:, 2], labels, opacity=0.3, scale_mode='none', scale_factor=2, line_width=2, colormap='blue-red', mode='point') # print mypts.module_manager.scalar_lut_manager.lut.table.to_array() # mypts.module_manager.scalar_lut_manager.lut.table = ran_colors # mypts.module_manager.scalar_lut_manager.lut.number_of_colors = len(ran_colors) delta = len(cent) - len(cent_median_scalar) if delta > 0: cent_median_scalar = np.pad( cent_median_scalar, (0, delta), mode='constant', constant_values=0) # calculate the displacement vector for all pairs uvw = cent - np.roll(cent, 1, axis=0) uvw[0] *= 0 uvw = np.roll(uvw, -1, axis=0) arrow_plot = mlab.quiver3d( cent[:, 0], cent[:, 1], cent[:, 2], uvw[:, 0], uvw[:, 1], uvw[:, 2], scalars=cent_median_scalar, scale_factor=1, #color=mcolor, mode='arrow') gsource = arrow_plot.glyph.glyph_source.glyph_source # for name, thing in inspect.getmembers(gsource): # print name arrow_plot.glyph.color_mode = 'color_by_scalar' #arrow_plot.glyph.scale_mode = 'scale_by_scalar' #arrow_plot.glyph.glyph.clamping = True #arrow_plot.glyph.glyph.scale_factor = 5 #print arrow_plot.glyph.glyph.glyph_source gsource.tip_length = 0.4 gsource.shaft_radius = 0.2 gsource.tip_radius = 0.3 if options.show_centroid: tube_plot = mlab.plot3d(cent[:, 0], cent[:, 1], cent[ :, 2], cent_median_scalar, color=cent_color, tube_radius=0.2, opacity=0.25) tube_filter = tube_plot.parent.parent.filter tube_filter.vary_radius = 'vary_radius_by_scalar' tube_filter.radius_factor = 10 # plot first and last def plot_pos_index(p): pos = cent[p] mlab.text3d(pos[0], pos[1], pos[2], str(p), scale=0.8) for p in xrange(0, len(cent - 1), 10): plot_pos_index(p) plot_pos_index(len(cent) - 1) scene.disable_render = False DATADF.to_csv( '_'.join([filebase, SCALAR_NAME, 'rawdata.csv']), index=False) outfile = '_'.join([filebase, SCALAR_NAME]) print 'save to {}'.format(outfile) plt.savefig('{}.pdf'.format(outfile), dpi=300) if options.is_viz: plt.show(block=False) mlab.show()
Downsample tracks to 12 points: """ tracks=[tm.downsample(t, 12) for t in T] """ Delete unnecessary data: """ del streams,hdr """ Perform QuickBundles clustering with a 10mm threshold: """ qb=QuickBundles(tracks, dist_thr=10., pts=None) """ Show the initial *Fornix* dataset: """ r=fvtk.ren() fvtk.add(r,fvtk.line(T, fvtk.white, opacity=1, linewidth=3)) #fvtk.show(r) fvtk.record(r,n_frames=1,out_path='fornix_initial',size=(600,600)) fvtk.clear(r) """ .. figure:: fornix_initial1000000.png :align: center **Initial Fornix dataset**.
streamlines = [] for i in range(polydata.GetNumberOfCells()): pts = polydata.GetCell(i).GetPoints() npts = np.array([pts.GetPoint(i) for i in range(pts.GetNumberOfPoints())]) streamlines.append(npts) import scipy scipy.io.savemat("1.mat",{'streamlines':streamlines}) #need to transpose each stream array for AFQ in malab # run with cmd # @MATLAB: ts = cellfun(@transpose,streamlines,'UniformOutput',false) #print streamlines qb = QuickBundles(streamlines, dist_thr=10.,pts=20) centroids = qb.centroids clusters = qb.clusters() colormap = np.random.rand(len(centroids),3) #print npts #ren = vtk.vtkRenderer() #renwin = vtk.vtkRenderWindow() #renwin.AddRenderer(ren) #c1 = clusters[0] #print c1['hidden'] ren = fvtk.ren()
# 'response_dhollander/101107/Diffusion/1M_20_01_20dynamic250_SD_Stream_occipital8_lr5.tck' img_path = '/home/brain/workingdir/data/dwi/hcp/preprocessed/' \ 'response_dhollander/101107/Structure/T1w_acpc_dc_restore_brain1.25.nii.gz' img = nib.load(img_path) fa = Fasciculus(fib) streamlines = fa.get_data() length_t = fa.get_lengths() ind = length_t > 10 streamlines = streamlines[ind] fa.set_data(streamlines) fibcluster = FibClustering(fa) print len(streamlines) # 1 qb = QuickBundles(streamlines, 2) clusters = qb.clusters() print qb.clusters_sizes() indexs = [] for i in range(len(clusters)): if clusters[i]['N'] >= 400: indexs += clusters[i]['indices'] # 2 streamlines = streamlines[indexs] qb = QuickBundles(streamlines, 2) clusters = qb.clusters() centroids = qb.centroids centroids_lengths = np.array(list(length(centroids))) print centroids_lengths
""" fname = get_data('fornix') """ Load fornix streamlines. """ streams, hdr = tv.read(fname) streamlines = [i[0] for i in streams] """ Perform QuickBundles clustering with a 10mm distance threshold after having downsampled the streamlines to have only 12 points. """ qb = QuickBundles(streamlines, dist_thr=10., pts=18) """ qb has attributes like `centroids` (cluster representatives), `total_clusters` (total number of clusters) and methods like `partitions` (complete description of all clusters) and `label2tracksids` (provides the indices of the streamlines which belong in a specific cluster). Lets first show the initial dataset. """ ren = fvtk.ren() ren.SetBackground(1, 1, 1) fvtk.add(ren, fvtk.streamtube(streamlines, fvtk.colors.white)) fvtk.record(ren, n_frames=1, out_path='fornix_initial.png', size=(600, 600)) """ .. figure:: fornix_initial.png
print('ind.shape (%d, %d, %d)' % ind.shape) # Compute Eular Delta Crossing with FA eu = EuDX(a=fa, ind=ind, seeds=100000, odf_vertices=sphere.vertices, a_low=0.2) # FA uses a_low = 0.2 streamlines = [line for line in eu] print('Number of streamlines %i' % len(streamlines)) ''' for line in streamlines: print(line.shape) ''' # Do steamline clustering using QuickBundles (QB) using Eular's Method # dist_thr (distance threshold) which affects number of clusters and their size # pts (number of points in each streamline) which will be used for downsampling before clustering # Default values : dist_thr = 4 & pts = 12 qb = QuickBundles(streamlines, dist_thr=20, pts=20) clusters = qb.clusters() print('Number of clusters %i' % qb.total_clusters) print('Cluster size', qb.clusters_sizes()) # Display streamlines ren = window.Renderer() ren.add(actor.streamtube(streamlines, window.colors.white)) window.show(ren) window.record(ren, out_path=filename + '_stream_lines_eu.png', size=(600, 600)) # Display centroids window.clear(ren) colormap = actor.create_colormap(np.arange(qb.total_clusters)) ren.add(actor.streamtube(streamlines, window.colors.white, opacity=0.1)) ren.add(actor.streamtube(qb.centroids, colormap, linewidth=0.5))
# T=[track for track in euler if track_range(track,100/2.5,200/2.5)] # T=tracks_double_mask(seeds,seeds2,gqs,mask.shape,mask,mask2,) T = transform_tracks(T, mat) print len(T) T = [track for track in T if is_close(track, lT, 5)] shift = (np.array(ref_shape) - 1) / 2.0 T = [t - shift for t in T] print len(T) # save tracks dpr_linear = Dpy(ftracks, "w") dpr_linear.write_tracks(T) dpr_linear.close() # cluster tracks qb = QuickBundles(T, 25.0, 25) # qb.remove_small_clusters(40) del T # load tl = TrackLabeler(qb, qb.downsampled_tracks(), vol_shape=ref_shape, tracks_line_width=3.0, tracks_alpha=1) fT1 = "data/subj_" + subject + "/MPRAGE_32/T1_flirt_out.nii.gz" # fT1_ref = '/usr/share/fsl/data/standard/MNI152_T1_1mm_brain.nii.gz' img = nib.load(fT1) # img = nib.load(fT1) sl = Slicer(img.get_affine(), img.get_data()) tl.slicer = sl luigi = Line([t - shift for t in lT], line_width=2) # put the seeds together seeds = np.vstack((seeds, seeds2))