예제 #1
0
파일: cluster.py 프로젝트: felfranke/BOTMpy
    def plot(self, data, views=2, show=False, filename=None):
        """plot clustering"""

        # get plotting tools
        try:
            from spikeplot import plt, cluster, save_figure
        except ImportError:
            return None

        # init
        views = min(views, int(data.shape[1] / 2))
        fig = plt.figure()
        fig.suptitle('clustering [%s]' % self.clus_type)
        ax = [fig.add_subplot(2, views, v + 1) for v in xrange(views)]
        axg = fig.add_subplot(212)
        ncmp = int(self.labels.max() + 1)
        cdata = dict(
            zip(xrange(ncmp), [data[self.labels == c] for c in xrange(ncmp)]))

        # plot clustering
        for v in xrange(views):
            cluster(cdata,
                    data_dim=(2 * v, 2 * v + 1),
                    plot_handle=ax[v],
                    plot_mean=sp.sqrt(self.sigma_factor),
                    xlabel='PC %d' % int(2 * v),
                    ylabel='PC %d' % int(2 * v + 1),
                    show=False)

        # plot gof
        axg.plot(self._gof, ls='steps')
        for i in xrange(1, len(self.crange)):
            axg.axvline(i * self.repeats - 0.5, c='y', ls='--')
        axg.axvspan(self._winner - 0.5,
                    self._winner + 0.5,
                    fc='gray',
                    alpha=0.2)
        labels = []
        for k in self.crange:
            labels += ['%d' % k]
            labels += ['.'] * (self.repeats - 1)
        axg.set_xticks(sp.arange(len(labels)))
        axg.set_xticklabels(labels)
        axg.set_xlabel('cluster count and repeats')
        axg.set_ylabel(str(self.gof_type).upper())
        axg.set_xlim(-1, len(labels))

        # handle the resulting plot
        if filename is not None:
            save_figure(fig, filename, '')
        if show is True:
            plt.show()
        return True
예제 #2
0
파일: cluster.py 프로젝트: pmeier82/BOTMpy
    def plot(self, data, views=2, show=False, filename=None):
        """plot clustering"""

        # get plotting tools
        try:
            from spikeplot import plt, cluster, save_figure
        except ImportError:
            return None

        # init
        views = min(views, int(data.shape[1] / 2))
        fig = plt.figure()
        fig.suptitle('clustering [%s]' % self.clus_type)
        ax = [fig.add_subplot(2, views, v + 1) for v in xrange(views)]
        axg = fig.add_subplot(212)
        ncmp = int(self.labels.max() + 1)
        cdata = dict(zip(xrange(ncmp),
                         [data[self.labels == c] for c in xrange(ncmp)]))

        # plot clustering
        for v in xrange(views):
            cluster(
                cdata,
                data_dim=(2 * v, 2 * v + 1),
                plot_handle=ax[v],
                plot_mean=sp.sqrt(self.sigma_factor),
                xlabel='PC %d' % int(2 * v),
                ylabel='PC %d' % int(2 * v + 1),
                show=False)

        # plot gof
        axg.plot(self._gof, ls='steps')
        for i in xrange(1, len(self.crange)):
            axg.axvline(i * self.repeats - 0.5, c='y', ls='--')
        axg.axvspan(self._winner - 0.5, self._winner + 0.5, fc='gray',
                    alpha=0.2)
        labels = []
        for k in self.crange:
            labels += ['%d' % k]
            labels += ['.'] * (self.repeats - 1)
        axg.set_xticks(sp.arange(len(labels)))
        axg.set_xticklabels(labels)
        axg.set_xlabel('cluster count and repeats')
        axg.set_ylabel(str(self.gof_type).upper())
        axg.set_xlim(-1, len(labels))

        # handle the resulting plot
        if filename is not None:
            save_figure(fig, filename, '')
        if show is True:
            plt.show()
        return True
예제 #3
0
def main():
    TF, SNR, PCADIM = 65, 0.5, 8
    NTRL = 10
    LOAD = False
    if LOAD is True:
        spks, spks_info, ndet = load_data()
    else:
        # spks, spks_info, ndet = get_data(tf=TF, trials=NTRL, snr=SNR,
        #                                  mean_correct=False, save=True)
        pass

    # plot.waveforms(spks, tf=TF, show=False)

    input_obs = pre_processing(spks, ndet, TF, pca_dim=PCADIM)
    plot.cluster(input_obs, show=False)

    # kmeans
    labels_km = cluster_kmeans(input_obs)
    obs_km = {}
    wf_km = {}
    for i in xrange(labels_km.max() + 1):
        obs_km[i] = input_obs[labels_km == i]
        wf_km[i] = spks[labels_km == i]
    if WITH_PLOT:
        plot.cluster(obs_km, title='kmeans', show=False)
        plot.waveforms(obs_km, tf=TF, title='kmeans', show=False)

    # gmm
    labels_gmm = cluster_gmm(input_obs)
    obs_gmm = {}
    wf_gmm = {}
    for i in xrange(labels_km.max() + 1):
        obs_gmm[i] = input_obs[labels_gmm == i]
        wf_gmm[i] = spks[labels_gmm == i]
    if WITH_PLOT:
        plot.cluster(obs_gmm, title='gmm', show=False)
        plot.waveforms(wf_gmm, tf=TF, title='gmm', show=False)

    # ward
    labels_ward = cluster_ward(input_obs)
    obs_ward = {}
    wf_ward = {}
    for i in xrange(labels_km.max() + 1):
        obs_ward[i] = input_obs[labels_ward == i]
        wf_ward[i] = spks[labels_ward == i]
    if WITH_PLOT:
        plot.cluster(obs_ward, title='ward', show=False)
        plot.waveforms(wf_ward, tf=TF, title='ward', show=False)

    # spectral
    #cluster_spectral(spks)

    if WITH_PLOT:
        plot.plt.show()
예제 #4
0
def main():
    TF, SNR, PCADIM = 65, 0.5, 8
    NTRL = 10
    LOAD = False
    if LOAD is True:
        spks, spks_info, ndet = load_data()
    else:
        # spks, spks_info, ndet = get_data(tf=TF, trials=NTRL, snr=SNR,
        #                                  mean_correct=False, save=True)
        pass

    # plot.waveforms(spks, tf=TF, show=False)

    input_obs = pre_processing(spks, ndet, TF, pca_dim=PCADIM)
    plot.cluster(input_obs, show=False)

    # kmeans
    labels_km = cluster_kmeans(input_obs)
    obs_km = {}
    wf_km = {}
    for i in xrange(labels_km.max() + 1):
        obs_km[i] = input_obs[labels_km == i]
        wf_km[i] = spks[labels_km == i]
    if WITH_PLOT:
        plot.cluster(obs_km, title='kmeans', show=False)
        plot.waveforms(obs_km, tf=TF, title='kmeans', show=False)

    # gmm
    labels_gmm = cluster_gmm(input_obs)
    obs_gmm = {}
    wf_gmm = {}
    for i in xrange(labels_km.max() + 1):
        obs_gmm[i] = input_obs[labels_gmm == i]
        wf_gmm[i] = spks[labels_gmm == i]
    if WITH_PLOT:
        plot.cluster(obs_gmm, title='gmm', show=False)
        plot.waveforms(wf_gmm, tf=TF, title='gmm', show=False)

    # ward
    labels_ward = cluster_ward(input_obs)
    obs_ward = {}
    wf_ward = {}
    for i in xrange(labels_km.max() + 1):
        obs_ward[i] = input_obs[labels_ward == i]
        wf_ward[i] = spks[labels_ward == i]
    if WITH_PLOT:
        plot.cluster(obs_ward, title='ward', show=False)
        plot.waveforms(wf_ward, tf=TF, title='ward', show=False)

    # spectral
    #cluster_spectral(spks)

    if WITH_PLOT:
        plot.plt.show()
예제 #5
0
import scipy as sp
from spikeplot import cluster

# get some data
my_data = {0:sp.randn(500, 2), 1:sp.randn(300, 2) + 2}

# call the plot function on the axes
cluster(
    my_data,
    title='Test Plot',
    plot_mean=2)