Ejemplo n.º 1
0
    def bundle_thre_seg(self,
                        streamlines=None,
                        cluster_thre=10,
                        dist_thre=10.0,
                        pts=12):
        """
        QuickBundles-based segmentation
        Parameters
        ----------
        streamlines: streamline data
        cluster_thre: remove small cluster
        dist_thre: clustering threshold (distance mm)
        pts: each streamlines are divided into sections

        Return
        ------
        sort_index: sort of clusters according to y mean of cluster's centroids
        data_cluster: cluster data corresponding to sort_index
        """
        if streamlines is None:
            streamlines = self._fasciculus.get_data()
        else:
            streamlines = streamlines
        bundles = QuickBundles(streamlines, dist_thre, pts)
        bundles.remove_small_clusters(cluster_thre)
        clusters = bundles.clusters()
        data_clusters = []
        for key in clusters.keys():
            data_clusters.append(streamlines[clusters[key]['indices']])
        centroids = bundles.centroids
        clusters_y_mean = [clu[:, 1].mean() for clu in centroids]
        sort_index = np.argsort(clusters_y_mean)

        return sort_index, data_clusters
Ejemplo n.º 2
0
    def bundle_seg(self, streamlines=None, dist_thre=10.0, pts=12):
        """
        QuickBundles-based segmentation
        Parameters
        ----------
        streamlines: streamline data
        dist_thre: clustering threshold (distance mm)
        pts: each streamlines are divided into sections

        Return
        ------
        labels: label of each streamline
        data_cluster: cluster data
        N_list: size of each cluster
        """
        if streamlines is None:
            streamlines = self._fasciculus.get_data()
        else:
            streamlines = streamlines
        bundles = QuickBundles(streamlines, dist_thre, pts)
        clusters = bundles.clusters()
        labels = np.array(len(streamlines) * [None])
        N_list = []
        for i in range(len(clusters)):
            N_list.append(clusters[i]['N'])
        # show(N_list, title='N histogram', xlabel='N')
        data_clusters = []
        for i in range(len(clusters)):
            labels[clusters[i]['indices']] = i + 1
            data_clusters.append(streamlines[clusters[i]['indices']])

        return labels, data_clusters, N_list
Ejemplo n.º 3
0
def QB_reps(limits=[0, np.Inf], reps=1):

    ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"]

    sizes = []

    for id in range(len(test_tractography_sizes)):
        filename = "subj_" + ids[id] + "_QB_reps.pkl"
        f = open(filename, "w")
        ur_tracks = get_tracks(id, limits)
        res = {}
        # res['filtered'] = len(ur_tracks)
        res["qb_threshold"] = qb_threshold
        res["limits"] = limits
        # res['ur_tracks'] = ur_tracks
        print "Dataset", id, res["filtered"], "filtered tracks"
        res["shuffle"] = {}
        res["clusters"] = {}
        res["nclusters"] = {}
        res["centroids"] = {}
        res["cluster_sizes"] = {}
        for i in range(reps):
            print "Subject", ids[id], "shuffle", i
            shuffle = np.random.permutation(np.arange(len(ur_tracks)))
            res["shuffle"][i] = shuffle
            tracks = [ur_tracks[i] for i in shuffle]
            qb = QuickBundles(tracks, qb_threshold, downsampling)
            res["clusters"][i] = {}
            for k in qb.clusters().keys():
                # this would be improved if
                # we 'enumerated' QB's keys and used the enumerator as
                # as the key in the result
                res["clusters"][i][k] = qb.clusters()[k]["indices"]
            res["centroids"][i] = qb.centroids
            res["nclusters"][i] = qb.total_clusters
            res["cluster_sizes"][i] = qb.clusters_sizes()
            print "QB for has", qb.total_clusters, "clusters"
            sizes.append(qb.total_clusters)
        pickle.dump(res, f)
        f.close()
        print "Results written to", filename
Ejemplo n.º 4
0
def QB_reps_singly(limits=[0, np.Inf], reps=1):

    ids = ["02", "03", "04", "05", "06", "08", "09", "10", "11", "12"]
    replabs = [str(i) for i in range(reps)]

    for id in range(len(test_tractography_sizes)):
        ur_tracks = get_tracks(id, limits)
        for i in range(reps):
            res = {}
            # res['filtered'] = len(ur_tracks)
            res["qb_threshold"] = qb_threshold
            res["limits"] = limits
            res["shuffle"] = {}
            res["clusters"] = {}
            res["nclusters"] = {}
            res["centroids"] = {}
            res["cluster_sizes"] = {}
            print "Subject", ids[id], "shuffle", i
            shuffle = np.random.permutation(np.arange(len(ur_tracks)))
            res["shuffle"] = shuffle
            tracks = [ur_tracks[j] for j in shuffle]
            print "... starting QB"
            qb = QuickBundles(tracks, qb_threshold, downsampling)
            print "... finished QB"
            res["clusters"] = {}
            for k in qb.clusters().keys():
                # this would be improved if
                # we 'enumerated' QB's keys and used the enumerator as
                # as the key in the result
                res["clusters"][k] = qb.clusters()[k]["indices"]
            res["centroids"] = qb.centroids
            res["nclusters"] = qb.total_clusters
            res["cluster_sizes"] = qb.clusters_sizes()
            print "QB for has", qb.total_clusters, "clusters"
            filename = "subj_" + ids[id] + "_QB_rep_" + replabs[i] + ".pkl"
            f = open(filename, "w")
            pickle.dump(res, f)
            f.close()
        print "Results written to", filename
    return sizes
# Compute Eular Delta Crossing with FA
eu = EuDX(a=fa, ind=ind, seeds=100000, odf_vertices=sphere.vertices,
          a_low=0.2)  # FA uses a_low = 0.2
streamlines = [line for line in eu]
print('Number of streamlines %i' % len(streamlines))
'''
for line in streamlines:
  print(line.shape)
'''
# Do steamline clustering using QuickBundles (QB) using Eular's Method
# dist_thr (distance threshold) which affects number of clusters and their size
# pts (number of points in each streamline) which will be used for downsampling before clustering
# Default values : dist_thr = 4 & pts = 12
qb = QuickBundles(streamlines, dist_thr=20, pts=20)
clusters = qb.clusters()
print('Number of clusters %i' % qb.total_clusters)
print('Cluster size', qb.clusters_sizes())

# Display streamlines
ren = window.Renderer()
ren.add(actor.streamtube(streamlines, window.colors.white))
window.show(ren)
window.record(ren, out_path=filename + '_stream_lines_eu.png', size=(600, 600))

# Display centroids
window.clear(ren)
colormap = actor.create_colormap(np.arange(qb.total_clusters))
ren.add(actor.streamtube(streamlines, window.colors.white, opacity=0.1))
ren.add(actor.streamtube(qb.centroids, colormap, linewidth=0.5))
window.show(ren)
Ejemplo n.º 6
0
def main():
    parser = OptionParser(usage="Usage: %prog [options] <tract.vtp>")
    parser.add_option("-s", "--scalar", dest="scalar",
                      default="FA", help="Scalar to measure")
    parser.add_option("-n", "--num", dest="num", default=50,
                      type='int', help="Number of subdivisions along centroids")
    parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False,
                      help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines")
    parser.add_option("-d", "--dist", dest="dist", default=20,
                      type='float', help="Quickbundle distance threshold")
    parser.add_option("--curvepoints", dest="curvepoints_file",
                      help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.")
    parser.add_option('--yrange', dest='yrange')
    parser.add_option('--xrange', dest='xrange')
    parser.add_option('--reverse', dest='is_reverse', action='store_true',
                      default=False, help='Reverse the centroid measure stepping order')
    parser.add_option('--pairplot', dest='pairplot',)
    parser.add_option('--noviz', dest='is_viz',
                      action='store_false', default=True)
    parser.add_option('--hide-centroid', dest='show_centroid',
                      action='store_false', default=True)
    parser.add_option('--config', dest='config')
    parser.add_option('--background', dest='bg_file',
                      help='Background NIFTI image')
    parser.add_option('--annot', dest='annot')

    (options, args) = parser.parse_args()

    if len(args) == 0:
        parser.print_help()
        sys.exit(2)

    name_mapping = None
    if options.config:
        config = GtsConfig(options.config, configure=False)
        name_mapping = config.group_labels

    annotations = None
    if options.annot:
        with open(options.annot, 'r') as fp:
            annotations = yaml.load(fp)

    QB_DIST = options.dist
    QB_NPOINTS = options.num
    SCALAR_NAME = options.scalar
    LOCAL_POINT_ASSIGN = options.is_local

    filename = args[0]
    filebase = path.basename(filename).split('.')[0]

    reader = vtk.vtkXMLPolyDataReader()
    reader.SetFileName(filename)
    reader.Update()

    polydata = reader.GetOutput()

    tract_ids = []
    for i in range(polydata.GetNumberOfCells()):
        # get point ids in [[ids][ids...]...] format
        pids = polydata.GetCell(i).GetPointIds()
        ids = [pids.GetId(p) for p in range(pids.GetNumberOfIds())]
        tract_ids.append(ids)
    print 'tracks:', len(tract_ids)

    verts = vtk_to_numpy(polydata.GetPoints().GetData())
    print 'verts:', len(verts)

    scalars = []
    groups = []
    subjects = []
    pointdata = polydata.GetPointData()
    for si in range(pointdata.GetNumberOfArrays()):
        sname = pointdata.GetArrayName(si)
        print sname
        if sname == SCALAR_NAME:
            scalars = vtk_to_numpy(pointdata.GetArray(si))
        if sname == 'group':
            groups = vtk_to_numpy(pointdata.GetArray(si))
            groups = groups.astype(int)
        if sname == 'tid':
            subjects = vtk_to_numpy(pointdata.GetArray(si))
            subjects = subjects.astype(int)

    streamlines = []
    stream_scalars = []
    stream_groups = []
    stream_pids = []
    stream_sids = []

    for i in tract_ids:
        # index np.array by a list will get all the respective indices
        streamlines.append(verts[i])
        stream_scalars.append(scalars[i])
        stream_pids.append(i)
        stream_sids.append(subjects[i])
        try:
            stream_groups.append(groups[i])
        except Exception:
            # group might not exist
            pass

    streamlines = np.array(streamlines)
    stream_scalars = np.array(stream_scalars)
    stream_groups = np.array(stream_groups)
    stream_pids = np.array(stream_pids)
    stream_sids = np.array(stream_sids)

    # get total average direction (where majority point towards)
    avg_d = np.zeros(3)
    # for line in streams:
    #     d = np.array(line[-1]) - np.array(line[0])
    #     d = d / la.norm(d)
    #     avg_d += d
    #     avg_d /= la.norm(avg_d)

    avg_com = np.zeros(3)
    avg_mid = np.zeros(3)

    strl_len = [len(l) for l in streamlines]
    stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines])

    centroids = []
    if options.curvepoints_file:
        LOCAL_POINT_ASSIGN = False
        cpoints = []
        ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',')
        # have a separate vtkCardinalSpline interpreter for x,y,z
        curve = [vtk.vtkCardinalSpline() for i in range(3)]
        for c in curve:
            c.ClosedOff()

        for pi, point in enumerate(ctrlpoints):
            for i, val in enumerate(point):
                curve[i].AddPoint(pi, point[i])

        param_range = [0.0, 0.0]
        curve[0].GetParametricRange(param_range)

        t = param_range[0]
        step = (param_range[1] - param_range[0]) / (QB_NPOINTS - 1.0)

        while t < param_range[1]:
            cp = [c.Evaluate(t) for c in curve]
            cpoints.append(cp)
            t = t + step

        centroids.append(cpoints)
        centroids = np.array(centroids)

    else:
        """
            Use quickbundles to find centroids
        """
        # streamlines = newlines
        qb = QuickBundles(streamlines, dist_thr=QB_DIST, pts=QB_NPOINTS)
        # bundle_distance_mam

        centroids = qb.centroids
        clusters = qb.clusters()

        avg_d = np.zeros(3)
        avg_com = np.zeros(3)
        avg_mid = np.zeros(3)

        #unify centroid list orders to point in the same general direction
        for i, line in enumerate(centroids):
            ori = np.array(tm.mean_orientation(line))
            #d = np.array(line[-1]) - np.array(line[0])
            #print line[-1],line[0],d
            # get the unit vector of the mean orientation
            if i == 0:
                avg_d = ori

            #d = d / la.norm(d)
            dotprod = ori.dot(avg_d)
            print 'dotprod', dotprod
            if dotprod < 0:
                print 'reverse', dotprod
                centroids[i] = line[::-1]
                line = centroids[i]
                ori *= -1
            avg_d += ori

        if options.is_reverse:
            for i, c in enumerate(centroids):
                centroids[i] = c[::-1]

    # prepare mayavi 3d viz

    if options.is_viz:
        bg_val = 0.
        fig = mlab.figure(bgcolor=(bg_val, bg_val, bg_val))
        scene = mlab.gcf().scene
        fig.scene.render_window.aa_frames = 4
        mlab.draw()

        if options.bg_file:
            mrsrc, bgdata = getNiftiAsScalarField(options.bg_file)
            orie = 'z_axes'

            opacity = 0.5
            slice_index = 0

            mlab.pipeline.image_plane_widget(mrsrc, opacity=opacity, plane_orientation=orie, slice_index=int(
                slice_index), colormap='black-white', line_width=0, reset_zoom=False)

    # prepare the plt plot
    len_cent = len(centroids)
    pal = sns.color_palette("bright", len_cent)

    DATADF = None

    """
        CENTROIDS
    """
    for ci, cent in enumerate(centroids):
        print '---- centroid:'

        if LOCAL_POINT_ASSIGN:
            """
                apply centroid to only their point assignments
                through quickbundles
            """
            ind = clusters[ci]['indices']
            cent_streams = streamlines[ind]
            cent_scalars = stream_scalars[ind]
            cent_groups = stream_groups[ind]
            cent_pids = stream_pids[ind]
            cent_sids = stream_sids[ind]
        else:
            # apply each centriod to all the points
            # instead of only their centroid assignments
            cent_streams = streamlines
            cent_scalars = stream_scalars
            cent_groups = stream_groups
            cent_pids = stream_pids
            cent_sids = stream_sids

        cent_verts = np.vstack(cent_streams)
        cent_scalars = np.concatenate(cent_scalars)
        cent_groups = np.concatenate(cent_groups)
        cent_pids = np.concatenate(cent_pids)
        cent_sids = np.concatenate(cent_sids)
        cent_color = np.array(pal[ci])

        c, labels = kmeans2(cent_verts, cent, iter=1)

        cid = np.ones(len(labels))
        d = {'value': cent_scalars, 'position': labels,
             'group': cent_groups, 'pid': cent_pids, 'sid': cent_sids}

        df = pd.DataFrame(data=d)
        if DATADF is None:
            DATADF = df
        else:
            pd.concat([DATADF, df])

        UNIQ_GROUPS = df.group.unique()
        UNIQ_GROUPS.sort()

        # UNIQ_GROUPS = [0,1]

        grppal = sns.color_palette("Set2", len(UNIQ_GROUPS))

        print '# UNIQ GROUPS', UNIQ_GROUPS

        # print df
        # df = df[df['sid'] != 15]
        # df = df[df['sid'] != 16]
        # df = df[df['sid'] != 17]
        # df = df[df['sid'] != 18]
        """ 
            plot each group by their position 
        """

        fig = plt.figure(figsize=(14, 7))
        ax1 = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3)
        ax2 = plt.subplot2grid((4, 3), (3, 0), colspan=3, sharex=ax1)
        axes = [ax1, ax2]

        plt.xlabel('Position Index')

        if len(centroids) > 1:
            cent_patch = mpatches.Patch(
                color=cent_color, label='Centroid {}'.format(ci + 1))
            cent_legend = axes[0].legend(handles=[cent_patch], loc=9)
            axes[0].add_artist(cent_legend)

        """
            Perform stats
        """

        if len(UNIQ_GROUPS) > 1:
            # df = resample_data(df, num_sample_per_pos=120)
            # print df
            pvalsDf = position_stats(df, name_mapping=name_mapping)
            logpvals = np.log(pvalsDf) * -1
            # print logpvals

            pvals = logpvals.mask(pvalsDf >= 0.05)

            import matplotlib.ticker as mticker
            print pvals
            cmap = mcmap.Reds
            cmap.set_bad('w', 1.)
            axes[1].pcolormesh(pvals.values.T, cmap=cmap,
                               vmin=0, vmax=10, edgecolors='face', alpha=0.8)
            #axes[1].yaxis.set_major_locator(mticker.MultipleLocator(base=1.0))
            axes[1].set_yticks(
                np.arange(pvals.values.shape[1]) + 0.5, minor=False)
            axes[1].set_yticklabels(
                pvalsDf.columns.values.tolist(), minor=False)

        legend_handles = []
        for gi, GRP in enumerate(UNIQ_GROUPS):
            print '-------------------- GROUP ', gi, '----------------------'
            subgrp = df[df['group'] == GRP]
            print len(subgrp)

            if options.xrange:
                x0, x1 = options.xrange.split(',')
                x0 = int(x0)
                x1 = int(x1)
                subgrp = subgrp[(subgrp['position'] >= x0) &
                                (subgrp['position'] < x1)]

            posGrp = subgrp.groupby('position', sort=True)

            cent_stats = posGrp.apply(lambda x: stats_per_group(x))

            if len(cent_stats) == 0:
                continue

            cent_stats = cent_stats.unstack()
            cent_median_scalar = cent_stats['median'].tolist()

            x = np.array([i for i in posGrp.groups])
            # print x

            # print cent_stats['median'].tolist()
            mcolor = np.array(grppal[gi])
            # if gi>0:
            #     mcolor*= 1./(1+gi)

            cent_color = tuple(cent_color)
            mcolor = tuple(mcolor)

            if type(axes) is list:
                cur_axe = axes[0]
            else:
                cur_axe = axes

            cur_axe.set_ylabel(SCALAR_NAME)
            # cur_axe.yaxis.label.set_color(cent_color)
            # cur_axe.tick_params(axis='y', colors=cent_color)

            #cur_axe.fill_between(x, [s[0] for s in cent_ci], [t[1] for t in cent_ci], alpha=0.3, color=mcolor)

            # cur_axe.fill_between(x, [s[0] for s in cent_stats['whisk'].tolist()],
            #     [t[1] for t in cent_stats['whisk'].tolist()], alpha=0.1, color=mcolor)

            qtile_top = np.array([s[0] for s in cent_stats['ci'].tolist()])
            qtile_bottom = np.array([t[1] for t in cent_stats['ci'].tolist()])

            x_new, qtop_sm = smooth(x, qtile_top)
            x_new, qbottom_sm = smooth(x, qtile_bottom)
            cur_axe.fill_between(x_new, qtop_sm, qbottom_sm,
                                 alpha=0.25, color=mcolor)

            # cur_axe.errorbar(x, cent_stats['median'].tolist(), yerr=[[s[0] for s in cent_stats['err'].tolist()],
            #     [t[1] for t in cent_stats['err'].tolist()]], color=mcolor, alpha=0.1)

            x_new, median_sm = smooth(x, cent_stats['median'])
            hnd, = cur_axe.plot(x_new, median_sm, c=mcolor)
            legend_handles.append(hnd)

            # cur_axe.scatter(x,cent_stats['median'].tolist(), c=mcolor)

            if options.yrange:
                plotrange = options.yrange.split(',')
                cur_axe.set_ylim([float(plotrange[0]), float(plotrange[1])])

        legend_labels = UNIQ_GROUPS
        if name_mapping is not None:
            legend_labels = [name_mapping[str(i)] for i in UNIQ_GROUPS]
        cur_axe.legend(legend_handles, legend_labels)

        if annotations:
            for key, val in annotations.iteritems():
                # print key
                cur_axe.axvspan(val[0], val[1], fill=False, linestyle='dashed')
                axis_to_data = cur_axe.transAxes + cur_axe.transData.inverted()
                data_to_axis = axis_to_data.inverted()
                axpoint = data_to_axis.transform((val[0], 0))
                # print axpoint
                cur_axe.text(axpoint[0], 1.02, key,
                             transform=cur_axe.transAxes)

        """
            Plot 3D Viz 
        """

        if options.is_viz:
            scene.disable_render = True
            # scene.renderer.render_window.set(alpha_bit_planes=1,multi_samples=0)
            # scene.renderer.set(use_depth_peeling=True,maximum_number_of_peels=4,occlusion_ratio=0.1)
            # ran_colors = np.random.random_integers(255, size=(len(cent),4))
            # ran_colors[:,-1] = 255
            mypts = mlab.points3d(cent_verts[:, 0], cent_verts[:, 1], cent_verts[:, 2], labels,
                                  opacity=0.3,
                                  scale_mode='none',
                                  scale_factor=2,
                                  line_width=2,
                                  colormap='blue-red',
                                  mode='point')

            # print mypts.module_manager.scalar_lut_manager.lut.table.to_array()
            # mypts.module_manager.scalar_lut_manager.lut.table = ran_colors
            # mypts.module_manager.scalar_lut_manager.lut.number_of_colors = len(ran_colors)

            delta = len(cent) - len(cent_median_scalar)
            if delta > 0:
                cent_median_scalar = np.pad(
                    cent_median_scalar, (0, delta), mode='constant', constant_values=0)

            # calculate the displacement vector for all pairs
            uvw = cent - np.roll(cent, 1, axis=0)
            uvw[0] *= 0
            uvw = np.roll(uvw, -1, axis=0)
            arrow_plot = mlab.quiver3d(
                cent[:, 0], cent[:, 1], cent[:, 2],
                uvw[:, 0], uvw[:, 1], uvw[:, 2],
                scalars=cent_median_scalar,
                scale_factor=1,
                #color=mcolor,
                mode='arrow')

            gsource = arrow_plot.glyph.glyph_source.glyph_source

            # for name, thing in inspect.getmembers(gsource):
            #      print name

            arrow_plot.glyph.color_mode = 'color_by_scalar'
            #arrow_plot.glyph.scale_mode = 'scale_by_scalar'
            #arrow_plot.glyph.glyph.clamping = True
            #arrow_plot.glyph.glyph.scale_factor = 5
            #print arrow_plot.glyph.glyph.glyph_source
            gsource.tip_length = 0.4
            gsource.shaft_radius = 0.2
            gsource.tip_radius = 0.3

            if options.show_centroid:
                tube_plot = mlab.plot3d(cent[:, 0], cent[:, 1], cent[
                                        :, 2], cent_median_scalar, color=cent_color, tube_radius=0.2, opacity=0.25)
                tube_filter = tube_plot.parent.parent.filter
                tube_filter.vary_radius = 'vary_radius_by_scalar'
                tube_filter.radius_factor = 10

            # plot first and last
            def plot_pos_index(p):
                pos = cent[p]
                mlab.text3d(pos[0], pos[1], pos[2], str(p), scale=0.8)

            for p in xrange(0, len(cent - 1), 10):
                plot_pos_index(p)
            plot_pos_index(len(cent) - 1)

            scene.disable_render = False

    DATADF.to_csv(
        '_'.join([filebase, SCALAR_NAME, 'rawdata.csv']), index=False)
    outfile = '_'.join([filebase, SCALAR_NAME])
    print 'save to {}'.format(outfile)
    plt.savefig('{}.pdf'.format(outfile), dpi=300)

    if options.is_viz:
        plt.show(block=False)
        mlab.show()
Ejemplo n.º 7
0
def main():
    parser = OptionParser(usage="Usage: %prog [options] <tract.vtp> <output.csv>")
    parser.add_option("-d", "--dist", dest="dist", default=20, type='float', help="Quickbundle distance threshold")
    parser.add_option("-n", "--num", dest="num", default=50, type='int', help="Number of subdivisions along centroids")
    parser.add_option("-s", "--scalar", dest="scalar", default="FA", help="Scalar to measure")
    parser.add_option("--curvepoints", dest="curvepoints_file", help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.")
    parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False, help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines")
    parser.add_option('--reverse', dest='is_reverse', action='store_true', default=False, help='Reverse the centroid measure stepping order')
    (options, args) = parser.parse_args()

    if len(args) == 0:
        parser.print_help()
        sys.exit(2)
        
    QB_DIST = options.dist
    QB_NPOINTS = options.num
    SCALAR_NAME = options.scalar
    LOCAL_POINT_ASSIGN = options.is_local

    filename= args[0]
    filebase = path.basename(filename).split('.')[0]

    reader = vtk.vtkXMLPolyDataReader()
    reader.SetFileName(filename)
    reader.Update()

    polydata = reader.GetOutput()


    tract_ids = []
    for i in range(polydata.GetNumberOfCells()):
        # get point ids in [[ids][ids...]...] format
        pids =  polydata.GetCell(i).GetPointIds()
        ids = [ pids.GetId(p) for p in range(pids.GetNumberOfIds())]
        tract_ids.append(ids) 
    print 'tracks:',len(tract_ids)

    verts = vtk_to_numpy(polydata.GetPoints().GetData())
    print 'verts:',len(verts)

    scalars = []
    groups = []
    subjects = []
    pointdata = polydata.GetPointData()
    for si in range(pointdata.GetNumberOfArrays()):
        sname =  pointdata.GetArrayName(si)
        print sname
        if sname==SCALAR_NAME:
            scalars = vtk_to_numpy(pointdata.GetArray(si))
        if sname=='group':
            groups = vtk_to_numpy(pointdata.GetArray(si))
            groups = groups.astype(int)
        if sname=='tid':
            subjects = vtk_to_numpy(pointdata.GetArray(si))
            subjects = subjects.astype(int)


    streamlines = []
    stream_scalars = []
    stream_groups = []
    stream_pids = []
    stream_sids = []

    for i in tract_ids:
        # index np.array by a list will get all the respective indices
        streamlines.append(verts[i])
        stream_scalars.append(scalars[i])
        stream_groups.append(groups[i])
        stream_pids.append(i)
        stream_sids.append(subjects[i])

    streamlines = np.array(streamlines)
    stream_scalars = np.array(stream_scalars)
    stream_groups = np.array(stream_groups)
    stream_pids = np.array(stream_pids)
    stream_sids = np.array(stream_sids)

    # get total average direction (where majority point towards)
    avg_d = np.zeros(3)
    # for line in streams:
    #     d = np.array(line[-1]) - np.array(line[0])
    #     d = d / la.norm(d)
    #     avg_d += d
    #     avg_d /= la.norm(avg_d)

    avg_com = np.zeros(3)
    avg_mid = np.zeros(3)

    strl_len = [len(l) for l in streamlines]
    stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines])

    centroids = []
    if options.curvepoints_file: 
        LOCAL_POINT_ASSIGN = False
        cpoints = []
        ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',')
        # have a separate vtkCardinalSpline interpreter for x,y,z
        curve = [vtk.vtkCardinalSpline() for i in range(3)]
        for c in curve:
            c.ClosedOff()

        for pi, point in enumerate(ctrlpoints):
            for i,val in enumerate(point):
                curve[i].AddPoint(pi,point[i])

        param_range = [0.0,0.0]
        curve[0].GetParametricRange(param_range)

        t = param_range[0]
        step = (param_range[1]-param_range[0])/(QB_NPOINTS)

        while t < param_range[1]:
            cp = [c.Evaluate(t) for c in curve]
            cpoints.append(cp)
            t = t + step

        centroids.append(cpoints)
        centroids = np.array(centroids)


    else:
        """
            Use quickbundles to find centroids
        """
        # streamlines = newlines
        qb = QuickBundles(streamlines, dist_thr=QB_DIST,pts=QB_NPOINTS)
        # bundle_distance_mam

        centroids = qb.centroids
        clusters = qb.clusters()

        avg_d = np.zeros(3)
        avg_com = np.zeros(3)
        avg_mid = np.zeros(3)

        #unify centroid list orders to point in the same general direction
        for i, line in enumerate(centroids):
            ori = np.array(tm.mean_orientation(line))
            #d = np.array(line[-1]) - np.array(line[0])
            #print line[-1],line[0],d
            # get the unit vector of the mean orientation
            if i==0:
                avg_d = ori

            #d = d / la.norm(d) 
            dotprod = ori.dot(avg_d) 
            print 'dotprod',dotprod
            if dotprod < 0:
                print 'reverse',dotprod      
                centroids[i] = line[::-1]
                line = centroids[i]
                ori*=-1
            avg_d += ori

        if options.is_reverse:
            for i,c in enumerate(centroids):
                centroids[i] = c[::-1]



    DATADF = None

    """
        CENTROIDS
    """
    for ci, cent in enumerate(centroids):
        print '---- centroid:'

        if LOCAL_POINT_ASSIGN:
            """
                apply centroid to only their point assignments
                through quickbundles
            """
            ind = clusters[ci]['indices']
            cent_streams = streamlines[ind]
            cent_scalars = stream_scalars[ind]
            cent_groups = stream_groups[ind]
            cent_pids = stream_pids[ind]
            cent_sids = stream_sids[ind]
        else:
            # apply each centriod to all the points
            # instead of only their centroid assignments
            cent_streams = streamlines
            cent_scalars = stream_scalars
            cent_groups = stream_groups
            cent_pids = stream_pids
            cent_sids = stream_sids


        cent_verts = np.vstack(cent_streams)
        cent_scalars = np.concatenate(cent_scalars)
        cent_groups = np.concatenate(cent_groups)
        cent_pids = np.concatenate(cent_pids)
        cent_sids = np.concatenate(cent_sids)
        # cent_color = np.array(pal[ci])

        c, labels = kmeans2(cent_verts, cent, iter=1)

        cid = np.ones(len(labels))
        d = {'value':cent_scalars, 'position':labels, 'group':cent_groups, 'pid':cent_pids, 'sid':cent_sids, 'centroid':ci}

        df = pd.DataFrame(data=d)
        if DATADF is None:
            DATADF = df
        else:
            pd.concat([DATADF, df])


    outfilename = '_'.join([filebase,SCALAR_NAME,'rawdata.csv'])
    if len(args) > 1:
        outfilename = args[1]

    if outfilename.endswith('.csv'):
        DATADF.to_csv(outfilename, index=False)

    if outfilename.endswith('.xls'):
        DATADF.to_excel(outfilename, index=False)
Ejemplo n.º 8
0
    pts = polydata.GetCell(i).GetPoints()
    npts = np.array([pts.GetPoint(i) for i in range(pts.GetNumberOfPoints())])
    streamlines.append(npts)

import scipy
scipy.io.savemat("1.mat",{'streamlines':streamlines})
#need to transpose each stream array for AFQ in malab
# run with cmd
# @MATLAB: ts = cellfun(@transpose,streamlines,'UniformOutput',false)

#print streamlines

qb = QuickBundles(streamlines, dist_thr=10.,pts=20)

centroids = qb.centroids
clusters = qb.clusters()
colormap = np.random.rand(len(centroids),3)


#print npts

#ren = vtk.vtkRenderer()
#renwin = vtk.vtkRenderWindow()
#renwin.AddRenderer(ren)

#c1 = clusters[0]
#print c1['hidden']
ren = fvtk.ren()
cam = fvtk.camera(ren, pos=(0,0,-1), viewup=(0,1,0))

fvtk.clear(ren)
Ejemplo n.º 9
0
def terminus2hemi_surface_density_map(streamlines, geo_path, hemi):
    """
    Streamline endpoints areas mapping to hemisphere surface
    streamlines > 1000
    Parameters
    ----------
    streamlines: streamline data
    geo_path: surface data path

    Return
    ------
    endpoints areas map on surface
    """
    streamlines = _sort_streamlines(streamlines)
    bundles = QuickBundles(streamlines, 10, 12)
    # bundles.remove_small_clusters(10)
    clusters = bundles.clusters()
    data_clusters = []
    for key in clusters.keys():
        data_clusters.append(streamlines[clusters[key]['indices']])

    data0 = data_clusters[0]
    stream_terminus_lh0 = np.array([s[0] for s in data0])
    stream_terminus_rh0 = np.array([s[-1] for s in data0])

    suffix = os.path.split(geo_path)[1].split('.')[-1]
    if suffix in ('white', 'inflated', 'pial'):
        coords, faces = nib.freesurfer.read_geometry(geo_path)
    elif suffix == 'gii':
        gii_data = nib.load(geo_path).darrays
        coords, faces = gii_data[0].data, gii_data[1].data
    else:
        raise ImageFileError(
            'This file format-{} is not supported at present.'.format(suffix))
    if hemi == 'lh':
        dist_lh0 = cdist(coords, stream_terminus_lh0)
        vert_value = np.array([
            float(np.array(dist_lh0[m] < 5).sum(axis=0))
            for m in range(len(dist_lh0[:]))
        ])
    else:
        dist_rh0 = cdist(coords, stream_terminus_rh0)
        vert_value = np.array([
            float(np.array(dist_rh0[n] < 5).sum(axis=0))
            for n in range(len(dist_rh0[:]))
        ])

    for i in range(len(data_clusters)):
        data = data_clusters[i]
        if hemi == 'lh':
            stream_terminus = np.array([s[0] for s in data])
        else:
            stream_terminus = np.array([s[-1] for s in data])

        dist = cdist(coords, stream_terminus)
        vert_value_i = np.array(
            [np.array(dist[j] < 5).sum(axis=0) for j in range(len(dist[:]))])

        if i != 0:
            vert_value += vert_value_i

    return vert_value