Example #1
0
    def draw_dimension(self, dimNumber):
        '''
        Method to draw the points on the axes using the current dimension number
        '''
    
        #Clear the plot and any saved mouse click data for the old dimension
        self.cloudax.cla()
        self.cloudMouseClickData=[]

        #Clear the waveform plots and then plot the waveforms

        ntraces=40
        
        for indWave, waveax in enumerate(self.wavaxes):
            waveax.cla()

            if sum(self.inCluster)>0:
                inSamples = self.samples[self.inCluster]
                (nSpikesIn,nChannels,nSamplesPerSpike) = inSamples.shape
                spikesToPlotIn = np.random.randint(nSpikesIn,size=ntraces)

                alignedWaveformsIn = spikesorting.align_waveforms(inSamples[spikesToPlotIn,:,:])
                wavesToPlotIn = alignedWaveformsIn[:, indWave, :]

                for wave in wavesToPlotIn:
                    waveax.plot(wave, 'g', zorder=1)

            if sum(self.outsideCluster)>0:
                outSamples = self.samples[self.outsideCluster]
                (nSpikesOut,nChannels,nSamplesPerSpike) = outSamples.shape
                spikesToPlotOut = np.random.randint(nSpikesOut,size=ntraces)

                alignedWaveformsOut = spikesorting.align_waveforms(outSamples[spikesToPlotOut,:,:])

                wavesToPlotOut = alignedWaveformsOut[:, indWave, :]

                for wave in wavesToPlotOut:
                    waveax.plot(wave, color='0.8', zorder=0)

            # meanWaveforms = np.mean(alignedWaveforms,axis=0)
            
        
        #Find the point array indices for the dimensions to be plotted
        dim0 = self.combinations[self.dimNumber][0]
        dim1 = self.combinations[self.dimNumber][1]

        #Plot the points in the cluster in green, and points outside as light grey
        self.cloudax.plot(self.points[:,dim0][self.inCluster], self.points[:, dim1][self.inCluster], 'g.', zorder=1)
        self.cloudax.plot(self.points[:, dim0][self.outsideCluster], self.points[:,dim1][self.outsideCluster], marker='.', color='0.8', linestyle='None', zorder=0)

        #Label the axes and draw
        self.cloudax.set_xlabel('Dimension {}'.format(dim0))
        self.cloudax.set_ylabel('Dimension {}'.format(dim1))
        plt.suptitle('press c to cut, u to undo last cut, < or > to switch dimensions')
        self.fig.canvas.draw()
Example #2
0
def plot_colored_waveforms(waveforms, color='k', ntraces=40, ax=None):
    '''
    Plot mean waveform and variance as a colored area.
    Args:
        waveforms (array): waveform array of shape (nChannels,nSamplesPerSpike,nSpikes)
        color (str): matplotlib color
        ntraces (int): Number of randomly-selected traces to use
        ax (matplotlib Axes object): The axis to plot on
    '''
    if ax is None:
        ax = plt.gca()
    (nSpikes,nChannels,nSamplesPerSpike) = waveforms.shape
    if nSpikes>0:
        spikesToPlot = np.random.randint(nSpikes,size=ntraces)
        alignedWaveforms = spikesorting.align_waveforms(waveforms[spikesToPlot,:,:])
        meanWaveforms = np.mean(alignedWaveforms, axis=0)
        waveVariance = np.std(alignedWaveforms, axis=0)
        varUpper = meanWaveforms + waveVariance
        varLower = meanWaveforms - waveVariance
        scalebarSize = abs(meanWaveforms.min())
        xRange = np.arange(nSamplesPerSpike)
        for indc in range(nChannels):
            newXrange = xRange+indc*(nSamplesPerSpike+2)
            waveToPlot = meanWaveforms[indc,:].T
            ax.plot(newXrange,waveToPlot,color='w',lw=1,clip_on=False, zorder=1)
            ax.fill_between(newXrange, varLower[indc,:].T, varUpper[indc,:].T, color=color, zorder=0)
            plt.hold(True)
        fontsize=8
        ax.plot(2*[-7],[0,-scalebarSize],color='0.5',lw=2)
        ax.text(-10,-scalebarSize/2,'{0:0.0f}uV'.format(np.round(scalebarSize)),
                ha='right',va='center',ma='center',fontsize=fontsize)
    plt.axis('off')
Example #3
0
def calculate_avg_waveforms(subject, ephysSession, tetrode, clustersPerTetrode=12, wavesize=160):
    '''
    NOTE: This methods should look through sessions, not clusters.
          The idea is to compare clusters within a tetrode, and then across sessions
          but still within a tetrode.
    NOTE: This method is inefficient because it load the spikes file for each cluster.
    '''

    # DONE: Load data for one tetrodes and calculate average for each cluster.
    #ephysFilename = ???
    ephysDir = os.path.join(settings.EPHYS_PATH, subject, ephysSession)
    ephysFilename = os.path.join(ephysDir, 'Tetrode{}.spikes'.format(tetrode))
    spikes = loadopenephys.DataSpikes(ephysFilename)

    # DONE: Load cluster file
    #kkDataDir = os.path.dirname(self.filename)+'_kk'
    #fullPath = os.path.join(kkDataDir,clusterFilename)
    clustersDir = '{}_kk'.format(ephysDir)
    clusterFilename = os.path.join(clustersDir, 'Tetrode{}.clu.1'.format(tetrode))
    clusters = np.fromfile(clusterFilename, dtype='int32', sep=' ')[1:]

    # DONE: loop through clusters
    allWaveforms = np.empty((clustersPerTetrode,wavesize))
    for indc in range(clustersPerTetrode):
        print 'Estimating average waveform for {0} T{1}c{2}'.format(ephysSession,tetrode,indc+1)

        # DONE: get waveforms for one cluster
        #Add 1 to the cluster index because clusters start from 1
        waveforms = spikes.samples[clusters==indc+1, :, :]

        alignedWaveforms = spikesorting.align_waveforms(waveforms)
        meanWaveforms = np.mean(alignedWaveforms,axis=0)
        allWaveforms[indc,:] = meanWaveforms.flatten()
    return allWaveforms
Example #4
0
def get_all_waveforms_one_session(celldb, wavesize=160, sessionToUse='behavior'):
    '''
    Load waveforms for each row of a celldb for one type of session.
    Args:
        celldb: a pandas dataframe containing only good quality cells (containing n cells).
        wavesize: number of samples in a waveform, usually 160 for all 4 channels of a tetrode.
        sessionToUse: which type of session we want to load the waveform from.
    Returns:
        allWaveforms: a nd-array in the shape of (nCells, mSamples).
    '''
    numCells = len(celldb)
    allWaveforms = np.zeros((numCells, wavesize))
    for indc, (ind,cell) in enumerate(celldb.iterrows()):
        cellObj = ephyscore.Cell(cell)
        sessionInd = cellObj.get_session_inds(sessionToUse)[0]
        try:
            ephysData = cellObj.load_ephys_by_index(sessionInd)
            samples = ephysData['samples']
            alignedWaveforms = spikesorting.align_waveforms(samples)
            meanWaveforms = np.mean(alignedWaveforms,axis=0)
            allWaveforms[indc,:] = meanWaveforms.flatten()
        except ValueError:
            print 'Cell {} did not have any spikes in the {} session'.format(ind, sessionToUse)
            continue 

    return allWaveforms
Example #5
0
def average_waveform_in_timerange(spikeSamples, spikeTimes, eventOnsetTimes,
                                  timeRange):

    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial, spikeIndices = spikesanalysis.eventlocked_spiketimes(
        spikeTimes, eventOnsetTimes, timeRange, spikeindex=True)
    samplesToPlot = spikeSamples[spikeIndices]
    ax = plt.gca()
    alignedWaveforms = spikesorting.align_waveforms(samplesToPlot)
    meanWaveforms = np.mean(alignedWaveforms, axis=0)
    stdWaveforms = np.std(alignedWaveforms, axis=0)
    allWaveforms = meanWaveforms.flatten()
    allStd = stdWaveforms.flatten()
    return allWaveforms, allStd
def calculate_avg_waveforms(subject,
                            ephysSession,
                            tetrode,
                            clustersPerTetrode=12,
                            wavesize=160):
    '''
    NOTE: This methods should look through sessions, not clusters.
          The idea is to compare clusters within a tetrode, and then across sessions
          but still within a tetrode.
    NOTE: This method is inefficient because it load the spikes file for each cluster.
    '''

    # DONE: Load data for one tetrodes and calculate average for each cluster.
    #ephysFilename = ???
    ephysDir = os.path.join(settings.EPHYS_PATH, subject, ephysSession)
    ephysFilename = os.path.join(ephysDir, 'Tetrode{}.spikes'.format(tetrode))
    spikes = loadopenephys.DataSpikes(ephysFilename)

    # DONE: Load cluster file
    #kkDataDir = os.path.dirname(self.filename)+'_kk'
    #fullPath = os.path.join(kkDataDir,clusterFilename)
    clustersDir = '{}_kk'.format(ephysDir)
    clusterFilename = os.path.join(clustersDir,
                                   'Tetrode{}.clu.1'.format(tetrode))
    clusters = np.fromfile(clusterFilename, dtype='int32', sep=' ')[1:]

    # DONE: loop through clusters
    allWaveforms = np.empty((clustersPerTetrode, wavesize))
    for indc in range(clustersPerTetrode):
        print('Estimating average waveform for {0} T{1}c{2}'.format(
            ephysSession, tetrode, indc + 1))

        # DONE: get waveforms for one cluster
        #Add 1 to the cluster index because clusters start from 1
        waveforms = spikes.samples[clusters == indc + 1, :, :]

        alignedWaveforms = spikesorting.align_waveforms(waveforms)
        meanWaveforms = np.mean(alignedWaveforms, axis=0)
        allWaveforms[indc, :] = meanWaveforms.flatten()
    return allWaveforms
def calculate_avg_waveforms(subject, cellDB, ephysSession, tetrode,  wavesize=160):
    '''
    NOTE: This methods should look through sessions, not clusters.
          The idea is to compare clusters within a tetrode, and then across sessions
          but still within a tetrode.
    NOTE: This method is inefficient because it load the spikes file for each cluster.
    '''
    
    date = ephysSession.split('_')[0]
    #passingClusters = np.array(np.repeat(0,12), dtype=bool) #default to all false
    cells = cellDB.loc[(cellDB.date==date) & (cellDB.tetrode==tetrode)]
    if len(cells) == 0: #This tetrode doesn't exist in this session
        allWaveforms = None
    else:
        # DONE: Load data for one tetrodes and calculate average for each cluster.
        #ephysFilename = ???
        ephysDir = os.path.join(settings.EPHYS_PATH_REMOTE, subject, ephysSession)
        ephysFilename = os.path.join(ephysDir, 'Tetrode{}.spikes'.format(tetrode))
        spikes = loadopenephys.DataSpikes(ephysFilename)

        # DONE: Load cluster file
        #kkDataDir = os.path.dirname(self.filename)+'_kk'
        #fullPath = os.path.join(kkDataDir,clusterFilename)
        clustersDir = '{}_kk'.format(ephysDir)
        clusterFilename = os.path.join(clustersDir, 'Tetrode{}.clu.1'.format(tetrode))
        clusters = np.fromfile(clusterFilename, dtype='int32', sep=' ')[1:]
        clustersThisSession = np.unique(clusters)
        numClustersThisTetrode = len(cells) #Sometimes clustersThisSession don't include all the possible clusters this tetrode; on rare occasions cells don't include all possible clusters this tetrode?
        # DONE: loop through clusters
        allWaveforms = np.empty((numClustersThisTetrode,wavesize))
        for indc, cluster in enumerate(cells.cluster.values): #clustersThisSession):
            print 'Estimating average waveform for {0} T{1}c{2}'.format(ephysSession,tetrode,cluster)

            # DONE: get waveforms for one cluster
            waveforms = spikes.samples[clusters==cluster, :, :]
            alignedWaveforms = spikesorting.align_waveforms(waveforms)
            meanWaveforms = np.mean(alignedWaveforms,axis=0)
            allWaveforms[indc,:] = meanWaveforms.flatten()
    return allWaveforms
def plot_waveforms_average_all(waveforms, ntraces=40, fontsize=8):
    '''
    Plot waveforms given array of shape (nChannels,nSamplesPerSpike,nSpikes)

    The average waveform is over the randomly-selected spikes, and not all of the spikes.
    '''
    (nSpikes, nChannels, nSamplesPerSpike) = waveforms.shape
    spikesToPlot = np.random.randint(nSpikes, size=ntraces)
    #NOTE: We are now aligning all waveforms
    alignedWaveforms = spikesorting.align_waveforms(waveforms)
    print 'Calculating mean of all waveforms'
    meanWaveforms = np.mean(alignedWaveforms, axis=0)
    scalebarSize = abs(meanWaveforms.min())

    xRange = np.arange(nSamplesPerSpike)
    for indc in range(nChannels):
        newXrange = xRange + indc * (nSamplesPerSpike + 2)
        #NOTE: Now spikesToPlot is used as an index here
        wavesToPlot = alignedWaveforms[spikesToPlot, indc, :].T
        plt.plot(newXrange, wavesToPlot, color='k', lw=0.4, clip_on=False)
        plt.hold(True)
        plt.plot(newXrange,
                 meanWaveforms[indc, :],
                 color='0.75',
                 lw=1.5,
                 clip_on=False)
    plt.plot(2 * [-7], [0, -scalebarSize], color='0.5', lw=2)
    plt.text(-10,
             -scalebarSize / 2,
             '{0:0.0f}uV'.format(np.round(scalebarSize)),
             ha='right',
             va='center',
             ma='center',
             fontsize=fontsize)
    plt.hold(False)
    plt.axis('off')
Example #9
0
ephysSession = '20150228a'
tetrode = 3
cluster = 11
'''

# -- Load some spike data --
import allcells_test055 as allcells
cellID = allcells.cellDB.findcell('test055','20150228a',3,9) # 11 #6
oneCell = allcells.cellDB[cellID]

spkData = ephyscore.CellData(oneCell)
waveforms = spkData.spikes.samples
samplingRate = spkData.spikes.samplingRate

# -- Align waveforms --
waveforms = spikesorting.align_waveforms(waveforms)

# -- Get spike shape --
N_INTERP_SAMPLES = 200
avWaveforms = np.mean(waveforms,0)
avWaveforms = avWaveforms - 2**15 # FIXME: this is specific to OpenEphys
energyEachChannel = np.sum(np.abs(avWaveforms),1)
maxChannel = np.argmax(energyEachChannel)
spikeShape = avWaveforms[maxChannel,:]
sampVals = np.arange(0,len(spikeShape)/samplingRate,1/samplingRate)

interpFun = interp1d(sampVals, spikeShape, kind='cubic')
interpSampVals = np.linspace(0,sampVals[-1],N_INTERP_SAMPLES)
interpSpikeShape = interpFun(interpSampVals)

# NOTE: the peaks of the action potential are: (1) capacitive, (2) Na+, (3) K+
def calculate_ave_waveform(waveforms):
    alignedWaveforms = spikesorting.align_waveforms(waveforms)
    meanWaveforms = np.mean(alignedWaveforms, axis=0)
    return meanWaveforms
Example #11
0
    def draw_dimension(self, dimNumber):
        '''
        Method to draw the points on the axes using the current dimension number
        '''

        #Clear the plot and any saved mouse click data for the old dimension
        self.cloudax.cla()
        self.cloudMouseClickData = []

        #Clear the waveform plots and then plot the waveforms

        ntraces = 40

        for indWave, waveax in enumerate(self.wavaxes):
            waveax.cla()

            if sum(self.inCluster) > 0:
                inSamples = self.samples[self.inCluster]
                (nSpikesIn, nChannels, nSamplesPerSpike) = inSamples.shape
                spikesToPlotIn = np.random.randint(nSpikesIn, size=ntraces)

                alignedWaveformsIn = spikesorting.align_waveforms(
                    inSamples[spikesToPlotIn, :, :])
                wavesToPlotIn = alignedWaveformsIn[:, indWave, :]

                for wave in wavesToPlotIn:
                    waveax.plot(wave, 'g', zorder=1)

            if sum(self.outsideCluster) > 0:
                outSamples = self.samples[self.outsideCluster]
                (nSpikesOut, nChannels, nSamplesPerSpike) = outSamples.shape
                spikesToPlotOut = np.random.randint(nSpikesOut, size=ntraces)

                alignedWaveformsOut = spikesorting.align_waveforms(
                    outSamples[spikesToPlotOut, :, :])

                wavesToPlotOut = alignedWaveformsOut[:, indWave, :]

                for wave in wavesToPlotOut:
                    waveax.plot(wave, color='0.8', zorder=0)

            # meanWaveforms = np.mean(alignedWaveforms,axis=0)

        #Find the point array indices for the dimensions to be plotted
        dim0 = self.combinations[self.dimNumber][0]
        dim1 = self.combinations[self.dimNumber][1]

        #Plot the points in the cluster in green, and points outside as light grey
        self.cloudax.plot(self.points[:, dim0][self.inCluster],
                          self.points[:, dim1][self.inCluster],
                          'g.',
                          zorder=1)
        self.cloudax.plot(self.points[:, dim0][self.outsideCluster],
                          self.points[:, dim1][self.outsideCluster],
                          marker='.',
                          color='0.8',
                          linestyle='None',
                          zorder=0)

        #Label the axes and draw
        self.cloudax.set_xlabel('Dimension {}'.format(dim0))
        self.cloudax.set_ylabel('Dimension {}'.format(dim1))
        plt.suptitle(
            'press c to cut, u to undo last cut, < or > to switch dimensions')
        self.fig.canvas.draw()
Example #12
0
ephysDir = os.path.join(settings.EPHYS_PATH, subject, session)
clusterDir = os.path.join(settings.EPHYS_PATH, subject,
                          '{}_kk'.format(session))

ephysFn = os.path.join(ephysDir, 'Tetrode{}.spikes'.format(tetrode))
clusterFn = os.path.join(clusterDir, 'Tetrode{}.clu.1'.format(tetrode))

#Load the samples
dataSpikes = loadopenephys.DataSpikes(spikesFn)
dataSpikes.samples = dataSpikes.samples.astype(
    float) - 2**15  # FIXME: this is specific to OpenEphys
dataSpikes.samples = (1000.0 / dataSpikes.gain[0, 0]) * dataSpikes.samples

#Set the clusters
dataSpikes.set_clusters(clusterFn)

#Select which cluster to use
spikesThisCluster = dataSpikes.clusters == cluster
dataSpikes.samples = dataSpikes.samples[spikesThisCluster, :, :]
dataSpikes.timestamps = dataSpikes.timestamps[spikesThisCluster]

#Align the waveforms
alignedWaves = spikesorting.align_waveforms(dataSpikes.samples)

#calculate and return mean
meanWaveform = np.mean(alignedWaves, axis=0)

figure()
clf()
plot(meanWaveform[0, :])
Example #13
0
    # nextWave = ravel(squeeze(cluster8Samples[violation+1, 1, :]))

    plot(vWave, 'g')
    hold(1)
    plot(nextWave, 'r')

    # subplot(2, 1, 2)
    # waveCorr = np.correlate(vWave, nextWave, 'full')
    # plot(waveCorr)

    waitforbuttonpress()


#Plot violation waves and next waves all at once
vWaves = cluster8Samples[violationInds, :, :]
vWaves = spikesorting.align_waveforms(vWaves)


vnWaves = cluster8Samples[violationInds+1, :, :]
vnWaves = spikesorting.align_waveforms(vnWaves)

figure()
subplot(2, 1, 1)
for wave in vWaves:
    plot(ravel(squeeze(wave)), 'r')
ylabel('microvolts')
subplot(2, 1, 2)
for wave in vnWaves:
    plot(ravel(squeeze(wave)), 'g')
xlabel('Samples (over 4 channels, 40 samples each)')
ylabel('microvolts')