def bundle_thre_seg(self, streamlines=None, cluster_thre=10, dist_thre=10.0, pts=12): """ QuickBundles-based segmentation Parameters ---------- streamlines: streamline data cluster_thre: remove small cluster dist_thre: clustering threshold (distance mm) pts: each streamlines are divided into sections Return ------ sort_index: sort of clusters according to y mean of cluster's centroids data_cluster: cluster data corresponding to sort_index """ if streamlines is None: streamlines = self._fasciculus.get_data() else: streamlines = streamlines bundles = QuickBundles(streamlines, dist_thre, pts) bundles.remove_small_clusters(cluster_thre) clusters = bundles.clusters() data_clusters = [] for key in clusters.keys(): data_clusters.append(streamlines[clusters[key]['indices']]) centroids = bundles.centroids clusters_y_mean = [clu[:, 1].mean() for clu in centroids] sort_index = np.argsort(clusters_y_mean) return sort_index, data_clusters
def bundle_seg(self, streamlines=None, dist_thre=10.0, pts=12): """ QuickBundles-based segmentation Parameters ---------- streamlines: streamline data dist_thre: clustering threshold (distance mm) pts: each streamlines are divided into sections Return ------ labels: label of each streamline data_cluster: cluster data N_list: size of each cluster """ if streamlines is None: streamlines = self._fasciculus.get_data() else: streamlines = streamlines bundles = QuickBundles(streamlines, dist_thre, pts) clusters = bundles.clusters() labels = np.array(len(streamlines) * [None]) N_list = [] for i in range(len(clusters)): N_list.append(clusters[i]['N']) # show(N_list, title='N histogram', xlabel='N') data_clusters = [] for i in range(len(clusters)): labels[clusters[i]['indices']] = i + 1 data_clusters.append(streamlines[clusters[i]['indices']]) return labels, data_clusters, N_list
def QB_reps(limits=[0, np.Inf], reps=1): ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"] sizes = [] for id in range(len(test_tractography_sizes)): filename = "subj_" + ids[id] + "_QB_reps.pkl" f = open(filename, "w") ur_tracks = get_tracks(id, limits) res = {} # res['filtered'] = len(ur_tracks) res["qb_threshold"] = qb_threshold res["limits"] = limits # res['ur_tracks'] = ur_tracks print "Dataset", id, res["filtered"], "filtered tracks" res["shuffle"] = {} res["clusters"] = {} res["nclusters"] = {} res["centroids"] = {} res["cluster_sizes"] = {} for i in range(reps): print "Subject", ids[id], "shuffle", i shuffle = np.random.permutation(np.arange(len(ur_tracks))) res["shuffle"][i] = shuffle tracks = [ur_tracks[i] for i in shuffle] qb = QuickBundles(tracks, qb_threshold, downsampling) res["clusters"][i] = {} for k in qb.clusters().keys(): # this would be improved if # we 'enumerated' QB's keys and used the enumerator as # as the key in the result res["clusters"][i][k] = qb.clusters()[k]["indices"] res["centroids"][i] = qb.centroids res["nclusters"][i] = qb.total_clusters res["cluster_sizes"][i] = qb.clusters_sizes() print "QB for has", qb.total_clusters, "clusters" sizes.append(qb.total_clusters) pickle.dump(res, f) f.close() print "Results written to", filename
def QB_reps_singly(limits=[0, np.Inf], reps=1): ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"] replabs = [str(i) for i in range(reps)] for id in range(len(test_tractography_sizes)): ur_tracks = get_tracks(id, limits) for i in range(reps): res = {} # res['filtered'] = len(ur_tracks) res["qb_threshold"] = qb_threshold res["limits"] = limits res["shuffle"] = {} res["clusters"] = {} res["nclusters"] = {} res["centroids"] = {} res["cluster_sizes"] = {} print "Subject", ids[id], "shuffle", i shuffle = np.random.permutation(np.arange(len(ur_tracks))) res["shuffle"] = shuffle tracks = [ur_tracks[j] for j in shuffle] print "... starting QB" qb = QuickBundles(tracks, qb_threshold, downsampling) print "... finished QB" res["clusters"] = {} for k in qb.clusters().keys(): # this would be improved if # we 'enumerated' QB's keys and used the enumerator as # as the key in the result res["clusters"][k] = qb.clusters()[k]["indices"] res["centroids"] = qb.centroids res["nclusters"] = qb.total_clusters res["cluster_sizes"] = qb.clusters_sizes() print "QB for has", qb.total_clusters, "clusters" filename = "subj_" + ids[id] + "_QB_rep_" + replabs[i] + ".pkl" f = open(filename, "w") pickle.dump(res, f) f.close() print "Results written to", filename return sizes
# Compute Eular Delta Crossing with FA eu = EuDX(a=fa, ind=ind, seeds=100000, odf_vertices=sphere.vertices, a_low=0.2) # FA uses a_low = 0.2 streamlines = [line for line in eu] print('Number of streamlines %i' % len(streamlines)) ''' for line in streamlines: print(line.shape) ''' # Do steamline clustering using QuickBundles (QB) using Eular's Method # dist_thr (distance threshold) which affects number of clusters and their size # pts (number of points in each streamline) which will be used for downsampling before clustering # Default values : dist_thr = 4 & pts = 12 qb = QuickBundles(streamlines, dist_thr=20, pts=20) clusters = qb.clusters() print('Number of clusters %i' % qb.total_clusters) print('Cluster size', qb.clusters_sizes()) # Display streamlines ren = window.Renderer() ren.add(actor.streamtube(streamlines, window.colors.white)) window.show(ren) window.record(ren, out_path=filename + '_stream_lines_eu.png', size=(600, 600)) # Display centroids window.clear(ren) colormap = actor.create_colormap(np.arange(qb.total_clusters)) ren.add(actor.streamtube(streamlines, window.colors.white, opacity=0.1)) ren.add(actor.streamtube(qb.centroids, colormap, linewidth=0.5)) window.show(ren)
def main(): parser = OptionParser(usage="Usage: %prog [options] <tract.vtp>") parser.add_option("-s", "--scalar", dest="scalar", default="FA", help="Scalar to measure") parser.add_option("-n", "--num", dest="num", default=50, type='int', help="Number of subdivisions along centroids") parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False, help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines") parser.add_option("-d", "--dist", dest="dist", default=20, type='float', help="Quickbundle distance threshold") parser.add_option("--curvepoints", dest="curvepoints_file", help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.") parser.add_option('--yrange', dest='yrange') parser.add_option('--xrange', dest='xrange') parser.add_option('--reverse', dest='is_reverse', action='store_true', default=False, help='Reverse the centroid measure stepping order') parser.add_option('--pairplot', dest='pairplot',) parser.add_option('--noviz', dest='is_viz', action='store_false', default=True) parser.add_option('--hide-centroid', dest='show_centroid', action='store_false', default=True) parser.add_option('--config', dest='config') parser.add_option('--background', dest='bg_file', help='Background NIFTI image') parser.add_option('--annot', dest='annot') (options, args) = parser.parse_args() if len(args) == 0: parser.print_help() sys.exit(2) name_mapping = None if options.config: config = GtsConfig(options.config, configure=False) name_mapping = config.group_labels annotations = None if options.annot: with open(options.annot, 'r') as fp: annotations = yaml.load(fp) QB_DIST = options.dist QB_NPOINTS = options.num SCALAR_NAME = options.scalar LOCAL_POINT_ASSIGN = options.is_local filename = args[0] filebase = path.basename(filename).split('.')[0] reader = vtk.vtkXMLPolyDataReader() reader.SetFileName(filename) reader.Update() polydata = reader.GetOutput() tract_ids = [] for i in range(polydata.GetNumberOfCells()): # get point ids in [[ids][ids...]...] format pids = polydata.GetCell(i).GetPointIds() ids = [pids.GetId(p) for p in range(pids.GetNumberOfIds())] tract_ids.append(ids) print 'tracks:', len(tract_ids) verts = vtk_to_numpy(polydata.GetPoints().GetData()) print 'verts:', len(verts) scalars = [] groups = [] subjects = [] pointdata = polydata.GetPointData() for si in range(pointdata.GetNumberOfArrays()): sname = pointdata.GetArrayName(si) print sname if sname == SCALAR_NAME: scalars = vtk_to_numpy(pointdata.GetArray(si)) if sname == 'group': groups = vtk_to_numpy(pointdata.GetArray(si)) groups = groups.astype(int) if sname == 'tid': subjects = vtk_to_numpy(pointdata.GetArray(si)) subjects = subjects.astype(int) streamlines = [] stream_scalars = [] stream_groups = [] stream_pids = [] stream_sids = [] for i in tract_ids: # index np.array by a list will get all the respective indices streamlines.append(verts[i]) stream_scalars.append(scalars[i]) stream_pids.append(i) stream_sids.append(subjects[i]) try: stream_groups.append(groups[i]) except Exception: # group might not exist pass streamlines = np.array(streamlines) stream_scalars = np.array(stream_scalars) stream_groups = np.array(stream_groups) stream_pids = np.array(stream_pids) stream_sids = np.array(stream_sids) # get total average direction (where majority point towards) avg_d = np.zeros(3) # for line in streams: # d = np.array(line[-1]) - np.array(line[0]) # d = d / la.norm(d) # avg_d += d # avg_d /= la.norm(avg_d) avg_com = np.zeros(3) avg_mid = np.zeros(3) strl_len = [len(l) for l in streamlines] stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines]) centroids = [] if options.curvepoints_file: LOCAL_POINT_ASSIGN = False cpoints = [] ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',') # have a separate vtkCardinalSpline interpreter for x,y,z curve = [vtk.vtkCardinalSpline() for i in range(3)] for c in curve: c.ClosedOff() for pi, point in enumerate(ctrlpoints): for i, val in enumerate(point): curve[i].AddPoint(pi, point[i]) param_range = [0.0, 0.0] curve[0].GetParametricRange(param_range) t = param_range[0] step = (param_range[1] - param_range[0]) / (QB_NPOINTS - 1.0) while t < param_range[1]: cp = [c.Evaluate(t) for c in curve] cpoints.append(cp) t = t + step centroids.append(cpoints) centroids = np.array(centroids) else: """ Use quickbundles to find centroids """ # streamlines = newlines qb = QuickBundles(streamlines, dist_thr=QB_DIST, pts=QB_NPOINTS) # bundle_distance_mam centroids = qb.centroids clusters = qb.clusters() avg_d = np.zeros(3) avg_com = np.zeros(3) avg_mid = np.zeros(3) #unify centroid list orders to point in the same general direction for i, line in enumerate(centroids): ori = np.array(tm.mean_orientation(line)) #d = np.array(line[-1]) - np.array(line[0]) #print line[-1],line[0],d # get the unit vector of the mean orientation if i == 0: avg_d = ori #d = d / la.norm(d) dotprod = ori.dot(avg_d) print 'dotprod', dotprod if dotprod < 0: print 'reverse', dotprod centroids[i] = line[::-1] line = centroids[i] ori *= -1 avg_d += ori if options.is_reverse: for i, c in enumerate(centroids): centroids[i] = c[::-1] # prepare mayavi 3d viz if options.is_viz: bg_val = 0. fig = mlab.figure(bgcolor=(bg_val, bg_val, bg_val)) scene = mlab.gcf().scene fig.scene.render_window.aa_frames = 4 mlab.draw() if options.bg_file: mrsrc, bgdata = getNiftiAsScalarField(options.bg_file) orie = 'z_axes' opacity = 0.5 slice_index = 0 mlab.pipeline.image_plane_widget(mrsrc, opacity=opacity, plane_orientation=orie, slice_index=int( slice_index), colormap='black-white', line_width=0, reset_zoom=False) # prepare the plt plot len_cent = len(centroids) pal = sns.color_palette("bright", len_cent) DATADF = None """ CENTROIDS """ for ci, cent in enumerate(centroids): print '---- centroid:' if LOCAL_POINT_ASSIGN: """ apply centroid to only their point assignments through quickbundles """ ind = clusters[ci]['indices'] cent_streams = streamlines[ind] cent_scalars = stream_scalars[ind] cent_groups = stream_groups[ind] cent_pids = stream_pids[ind] cent_sids = stream_sids[ind] else: # apply each centriod to all the points # instead of only their centroid assignments cent_streams = streamlines cent_scalars = stream_scalars cent_groups = stream_groups cent_pids = stream_pids cent_sids = stream_sids cent_verts = np.vstack(cent_streams) cent_scalars = np.concatenate(cent_scalars) cent_groups = np.concatenate(cent_groups) cent_pids = np.concatenate(cent_pids) cent_sids = np.concatenate(cent_sids) cent_color = np.array(pal[ci]) c, labels = kmeans2(cent_verts, cent, iter=1) cid = np.ones(len(labels)) d = {'value': cent_scalars, 'position': labels, 'group': cent_groups, 'pid': cent_pids, 'sid': cent_sids} df = pd.DataFrame(data=d) if DATADF is None: DATADF = df else: pd.concat([DATADF, df]) UNIQ_GROUPS = df.group.unique() UNIQ_GROUPS.sort() # UNIQ_GROUPS = [0,1] grppal = sns.color_palette("Set2", len(UNIQ_GROUPS)) print '# UNIQ GROUPS', UNIQ_GROUPS # print df # df = df[df['sid'] != 15] # df = df[df['sid'] != 16] # df = df[df['sid'] != 17] # df = df[df['sid'] != 18] """ plot each group by their position """ fig = plt.figure(figsize=(14, 7)) ax1 = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3) ax2 = plt.subplot2grid((4, 3), (3, 0), colspan=3, sharex=ax1) axes = [ax1, ax2] plt.xlabel('Position Index') if len(centroids) > 1: cent_patch = mpatches.Patch( color=cent_color, label='Centroid {}'.format(ci + 1)) cent_legend = axes[0].legend(handles=[cent_patch], loc=9) axes[0].add_artist(cent_legend) """ Perform stats """ if len(UNIQ_GROUPS) > 1: # df = resample_data(df, num_sample_per_pos=120) # print df pvalsDf = position_stats(df, name_mapping=name_mapping) logpvals = np.log(pvalsDf) * -1 # print logpvals pvals = logpvals.mask(pvalsDf >= 0.05) import matplotlib.ticker as mticker print pvals cmap = mcmap.Reds cmap.set_bad('w', 1.) axes[1].pcolormesh(pvals.values.T, cmap=cmap, vmin=0, vmax=10, edgecolors='face', alpha=0.8) #axes[1].yaxis.set_major_locator(mticker.MultipleLocator(base=1.0)) axes[1].set_yticks( np.arange(pvals.values.shape[1]) + 0.5, minor=False) axes[1].set_yticklabels( pvalsDf.columns.values.tolist(), minor=False) legend_handles = [] for gi, GRP in enumerate(UNIQ_GROUPS): print '-------------------- GROUP ', gi, '----------------------' subgrp = df[df['group'] == GRP] print len(subgrp) if options.xrange: x0, x1 = options.xrange.split(',') x0 = int(x0) x1 = int(x1) subgrp = subgrp[(subgrp['position'] >= x0) & (subgrp['position'] < x1)] posGrp = subgrp.groupby('position', sort=True) cent_stats = posGrp.apply(lambda x: stats_per_group(x)) if len(cent_stats) == 0: continue cent_stats = cent_stats.unstack() cent_median_scalar = cent_stats['median'].tolist() x = np.array([i for i in posGrp.groups]) # print x # print cent_stats['median'].tolist() mcolor = np.array(grppal[gi]) # if gi>0: # mcolor*= 1./(1+gi) cent_color = tuple(cent_color) mcolor = tuple(mcolor) if type(axes) is list: cur_axe = axes[0] else: cur_axe = axes cur_axe.set_ylabel(SCALAR_NAME) # cur_axe.yaxis.label.set_color(cent_color) # cur_axe.tick_params(axis='y', colors=cent_color) #cur_axe.fill_between(x, [s[0] for s in cent_ci], [t[1] for t in cent_ci], alpha=0.3, color=mcolor) # cur_axe.fill_between(x, [s[0] for s in cent_stats['whisk'].tolist()], # [t[1] for t in cent_stats['whisk'].tolist()], alpha=0.1, color=mcolor) qtile_top = np.array([s[0] for s in cent_stats['ci'].tolist()]) qtile_bottom = np.array([t[1] for t in cent_stats['ci'].tolist()]) x_new, qtop_sm = smooth(x, qtile_top) x_new, qbottom_sm = smooth(x, qtile_bottom) cur_axe.fill_between(x_new, qtop_sm, qbottom_sm, alpha=0.25, color=mcolor) # cur_axe.errorbar(x, cent_stats['median'].tolist(), yerr=[[s[0] for s in cent_stats['err'].tolist()], # [t[1] for t in cent_stats['err'].tolist()]], color=mcolor, alpha=0.1) x_new, median_sm = smooth(x, cent_stats['median']) hnd, = cur_axe.plot(x_new, median_sm, c=mcolor) legend_handles.append(hnd) # cur_axe.scatter(x,cent_stats['median'].tolist(), c=mcolor) if options.yrange: plotrange = options.yrange.split(',') cur_axe.set_ylim([float(plotrange[0]), float(plotrange[1])]) legend_labels = UNIQ_GROUPS if name_mapping is not None: legend_labels = [name_mapping[str(i)] for i in UNIQ_GROUPS] cur_axe.legend(legend_handles, legend_labels) if annotations: for key, val in annotations.iteritems(): # print key cur_axe.axvspan(val[0], val[1], fill=False, linestyle='dashed') axis_to_data = cur_axe.transAxes + cur_axe.transData.inverted() data_to_axis = axis_to_data.inverted() axpoint = data_to_axis.transform((val[0], 0)) # print axpoint cur_axe.text(axpoint[0], 1.02, key, transform=cur_axe.transAxes) """ Plot 3D Viz """ if options.is_viz: scene.disable_render = True # scene.renderer.render_window.set(alpha_bit_planes=1,multi_samples=0) # scene.renderer.set(use_depth_peeling=True,maximum_number_of_peels=4,occlusion_ratio=0.1) # ran_colors = np.random.random_integers(255, size=(len(cent),4)) # ran_colors[:,-1] = 255 mypts = mlab.points3d(cent_verts[:, 0], cent_verts[:, 1], cent_verts[:, 2], labels, opacity=0.3, scale_mode='none', scale_factor=2, line_width=2, colormap='blue-red', mode='point') # print mypts.module_manager.scalar_lut_manager.lut.table.to_array() # mypts.module_manager.scalar_lut_manager.lut.table = ran_colors # mypts.module_manager.scalar_lut_manager.lut.number_of_colors = len(ran_colors) delta = len(cent) - len(cent_median_scalar) if delta > 0: cent_median_scalar = np.pad( cent_median_scalar, (0, delta), mode='constant', constant_values=0) # calculate the displacement vector for all pairs uvw = cent - np.roll(cent, 1, axis=0) uvw[0] *= 0 uvw = np.roll(uvw, -1, axis=0) arrow_plot = mlab.quiver3d( cent[:, 0], cent[:, 1], cent[:, 2], uvw[:, 0], uvw[:, 1], uvw[:, 2], scalars=cent_median_scalar, scale_factor=1, #color=mcolor, mode='arrow') gsource = arrow_plot.glyph.glyph_source.glyph_source # for name, thing in inspect.getmembers(gsource): # print name arrow_plot.glyph.color_mode = 'color_by_scalar' #arrow_plot.glyph.scale_mode = 'scale_by_scalar' #arrow_plot.glyph.glyph.clamping = True #arrow_plot.glyph.glyph.scale_factor = 5 #print arrow_plot.glyph.glyph.glyph_source gsource.tip_length = 0.4 gsource.shaft_radius = 0.2 gsource.tip_radius = 0.3 if options.show_centroid: tube_plot = mlab.plot3d(cent[:, 0], cent[:, 1], cent[ :, 2], cent_median_scalar, color=cent_color, tube_radius=0.2, opacity=0.25) tube_filter = tube_plot.parent.parent.filter tube_filter.vary_radius = 'vary_radius_by_scalar' tube_filter.radius_factor = 10 # plot first and last def plot_pos_index(p): pos = cent[p] mlab.text3d(pos[0], pos[1], pos[2], str(p), scale=0.8) for p in xrange(0, len(cent - 1), 10): plot_pos_index(p) plot_pos_index(len(cent) - 1) scene.disable_render = False DATADF.to_csv( '_'.join([filebase, SCALAR_NAME, 'rawdata.csv']), index=False) outfile = '_'.join([filebase, SCALAR_NAME]) print 'save to {}'.format(outfile) plt.savefig('{}.pdf'.format(outfile), dpi=300) if options.is_viz: plt.show(block=False) mlab.show()
def main(): parser = OptionParser(usage="Usage: %prog [options] <tract.vtp> <output.csv>") parser.add_option("-d", "--dist", dest="dist", default=20, type='float', help="Quickbundle distance threshold") parser.add_option("-n", "--num", dest="num", default=50, type='int', help="Number of subdivisions along centroids") parser.add_option("-s", "--scalar", dest="scalar", default="FA", help="Scalar to measure") parser.add_option("--curvepoints", dest="curvepoints_file", help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.") parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False, help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines") parser.add_option('--reverse', dest='is_reverse', action='store_true', default=False, help='Reverse the centroid measure stepping order') (options, args) = parser.parse_args() if len(args) == 0: parser.print_help() sys.exit(2) QB_DIST = options.dist QB_NPOINTS = options.num SCALAR_NAME = options.scalar LOCAL_POINT_ASSIGN = options.is_local filename= args[0] filebase = path.basename(filename).split('.')[0] reader = vtk.vtkXMLPolyDataReader() reader.SetFileName(filename) reader.Update() polydata = reader.GetOutput() tract_ids = [] for i in range(polydata.GetNumberOfCells()): # get point ids in [[ids][ids...]...] format pids = polydata.GetCell(i).GetPointIds() ids = [ pids.GetId(p) for p in range(pids.GetNumberOfIds())] tract_ids.append(ids) print 'tracks:',len(tract_ids) verts = vtk_to_numpy(polydata.GetPoints().GetData()) print 'verts:',len(verts) scalars = [] groups = [] subjects = [] pointdata = polydata.GetPointData() for si in range(pointdata.GetNumberOfArrays()): sname = pointdata.GetArrayName(si) print sname if sname==SCALAR_NAME: scalars = vtk_to_numpy(pointdata.GetArray(si)) if sname=='group': groups = vtk_to_numpy(pointdata.GetArray(si)) groups = groups.astype(int) if sname=='tid': subjects = vtk_to_numpy(pointdata.GetArray(si)) subjects = subjects.astype(int) streamlines = [] stream_scalars = [] stream_groups = [] stream_pids = [] stream_sids = [] for i in tract_ids: # index np.array by a list will get all the respective indices streamlines.append(verts[i]) stream_scalars.append(scalars[i]) stream_groups.append(groups[i]) stream_pids.append(i) stream_sids.append(subjects[i]) streamlines = np.array(streamlines) stream_scalars = np.array(stream_scalars) stream_groups = np.array(stream_groups) stream_pids = np.array(stream_pids) stream_sids = np.array(stream_sids) # get total average direction (where majority point towards) avg_d = np.zeros(3) # for line in streams: # d = np.array(line[-1]) - np.array(line[0]) # d = d / la.norm(d) # avg_d += d # avg_d /= la.norm(avg_d) avg_com = np.zeros(3) avg_mid = np.zeros(3) strl_len = [len(l) for l in streamlines] stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines]) centroids = [] if options.curvepoints_file: LOCAL_POINT_ASSIGN = False cpoints = [] ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',') # have a separate vtkCardinalSpline interpreter for x,y,z curve = [vtk.vtkCardinalSpline() for i in range(3)] for c in curve: c.ClosedOff() for pi, point in enumerate(ctrlpoints): for i,val in enumerate(point): curve[i].AddPoint(pi,point[i]) param_range = [0.0,0.0] curve[0].GetParametricRange(param_range) t = param_range[0] step = (param_range[1]-param_range[0])/(QB_NPOINTS) while t < param_range[1]: cp = [c.Evaluate(t) for c in curve] cpoints.append(cp) t = t + step centroids.append(cpoints) centroids = np.array(centroids) else: """ Use quickbundles to find centroids """ # streamlines = newlines qb = QuickBundles(streamlines, dist_thr=QB_DIST,pts=QB_NPOINTS) # bundle_distance_mam centroids = qb.centroids clusters = qb.clusters() avg_d = np.zeros(3) avg_com = np.zeros(3) avg_mid = np.zeros(3) #unify centroid list orders to point in the same general direction for i, line in enumerate(centroids): ori = np.array(tm.mean_orientation(line)) #d = np.array(line[-1]) - np.array(line[0]) #print line[-1],line[0],d # get the unit vector of the mean orientation if i==0: avg_d = ori #d = d / la.norm(d) dotprod = ori.dot(avg_d) print 'dotprod',dotprod if dotprod < 0: print 'reverse',dotprod centroids[i] = line[::-1] line = centroids[i] ori*=-1 avg_d += ori if options.is_reverse: for i,c in enumerate(centroids): centroids[i] = c[::-1] DATADF = None """ CENTROIDS """ for ci, cent in enumerate(centroids): print '---- centroid:' if LOCAL_POINT_ASSIGN: """ apply centroid to only their point assignments through quickbundles """ ind = clusters[ci]['indices'] cent_streams = streamlines[ind] cent_scalars = stream_scalars[ind] cent_groups = stream_groups[ind] cent_pids = stream_pids[ind] cent_sids = stream_sids[ind] else: # apply each centriod to all the points # instead of only their centroid assignments cent_streams = streamlines cent_scalars = stream_scalars cent_groups = stream_groups cent_pids = stream_pids cent_sids = stream_sids cent_verts = np.vstack(cent_streams) cent_scalars = np.concatenate(cent_scalars) cent_groups = np.concatenate(cent_groups) cent_pids = np.concatenate(cent_pids) cent_sids = np.concatenate(cent_sids) # cent_color = np.array(pal[ci]) c, labels = kmeans2(cent_verts, cent, iter=1) cid = np.ones(len(labels)) d = {'value':cent_scalars, 'position':labels, 'group':cent_groups, 'pid':cent_pids, 'sid':cent_sids, 'centroid':ci} df = pd.DataFrame(data=d) if DATADF is None: DATADF = df else: pd.concat([DATADF, df]) outfilename = '_'.join([filebase,SCALAR_NAME,'rawdata.csv']) if len(args) > 1: outfilename = args[1] if outfilename.endswith('.csv'): DATADF.to_csv(outfilename, index=False) if outfilename.endswith('.xls'): DATADF.to_excel(outfilename, index=False)
pts = polydata.GetCell(i).GetPoints() npts = np.array([pts.GetPoint(i) for i in range(pts.GetNumberOfPoints())]) streamlines.append(npts) import scipy scipy.io.savemat("1.mat",{'streamlines':streamlines}) #need to transpose each stream array for AFQ in malab # run with cmd # @MATLAB: ts = cellfun(@transpose,streamlines,'UniformOutput',false) #print streamlines qb = QuickBundles(streamlines, dist_thr=10.,pts=20) centroids = qb.centroids clusters = qb.clusters() colormap = np.random.rand(len(centroids),3) #print npts #ren = vtk.vtkRenderer() #renwin = vtk.vtkRenderWindow() #renwin.AddRenderer(ren) #c1 = clusters[0] #print c1['hidden'] ren = fvtk.ren() cam = fvtk.camera(ren, pos=(0,0,-1), viewup=(0,1,0)) fvtk.clear(ren)
def terminus2hemi_surface_density_map(streamlines, geo_path, hemi): """ Streamline endpoints areas mapping to hemisphere surface streamlines > 1000 Parameters ---------- streamlines: streamline data geo_path: surface data path Return ------ endpoints areas map on surface """ streamlines = _sort_streamlines(streamlines) bundles = QuickBundles(streamlines, 10, 12) # bundles.remove_small_clusters(10) clusters = bundles.clusters() data_clusters = [] for key in clusters.keys(): data_clusters.append(streamlines[clusters[key]['indices']]) data0 = data_clusters[0] stream_terminus_lh0 = np.array([s[0] for s in data0]) stream_terminus_rh0 = np.array([s[-1] for s in data0]) suffix = os.path.split(geo_path)[1].split('.')[-1] if suffix in ('white', 'inflated', 'pial'): coords, faces = nib.freesurfer.read_geometry(geo_path) elif suffix == 'gii': gii_data = nib.load(geo_path).darrays coords, faces = gii_data[0].data, gii_data[1].data else: raise ImageFileError( 'This file format-{} is not supported at present.'.format(suffix)) if hemi == 'lh': dist_lh0 = cdist(coords, stream_terminus_lh0) vert_value = np.array([ float(np.array(dist_lh0[m] < 5).sum(axis=0)) for m in range(len(dist_lh0[:])) ]) else: dist_rh0 = cdist(coords, stream_terminus_rh0) vert_value = np.array([ float(np.array(dist_rh0[n] < 5).sum(axis=0)) for n in range(len(dist_rh0[:])) ]) for i in range(len(data_clusters)): data = data_clusters[i] if hemi == 'lh': stream_terminus = np.array([s[0] for s in data]) else: stream_terminus = np.array([s[-1] for s in data]) dist = cdist(coords, stream_terminus) vert_value_i = np.array( [np.array(dist[j] < 5).sum(axis=0) for j in range(len(dist[:]))]) if i != 0: vert_value += vert_value_i return vert_value